matlab:指定要索引的维度

时间:2019-02-06 22:41:53

标签: matlab

想象一个函数,您要在用户指定的n维矩阵维中输出代码片段

function result=a(x,dim)
  window=1:10;
  dim=3;
  result=x(:,:,window);
end

如何将window设置为所需尺寸?例如。如果为dim=2;,则为result=x(:,window,:)

我现在想到的方式是评估一个将window放在正确位置的字符串命令-或使用许多if then else块。有什么更好的方法?

1 个答案:

答案 0 :(得分:1)

您可以按照示例here使用单元格数组来定义索引。

具体来说,如果您有矩阵

length

您可以定义所需的索引:

def _pad_tensors_to_same_length(logits, labels):
    """Pad x and y so that the results have the same length (second dimension)."""
    with tf.name_scope("pad_to_same_length"):
        logits_length = tf.shape(logits)[1]
        labels_length = tf.shape(labels)[1]

        max_length = tf.maximum(logits_length, labels_length)

        logits = tf.pad(logits, [[0, 0], [0, max_length - logits_length], [0, 0]])
        labels = tf.pad(labels, [[0, 0], [0, max_length - labels_length]])
        return logits, labels

然后从矩阵def padded_cross_entropy_loss(logits, labels, vocab_size): """Calculate cross entropy loss while ignoring padding. Args: logits: Tensor of size [batch_size, length_logits, vocab_size] labels: Tensor of size [batch_size, length_labels] vocab_size: int size of the vocabulary Returns: Returns the cross entropy loss """ with tf.name_scope("loss", values=[logits, labels]): logits, labels = _pad_tensors_to_same_length(logits, labels) # Calculate cross entropy with tf.name_scope("cross_entropy", values=[logits, labels]): xentropy = tf.nn.softmax_cross_entropy_with_logits_v2( logits=logits, labels=targets) weights = tf.to_float(tf.not_equal(labels, 0)) return xentropy * weights 中获取这些索引,例如

x = ones(7,5,9);

因此,您可以编写类似的功能

% get all the indexes in all dimensions
all_indexes = {':', ':', ':'};

% get all indexes in dimensions 1 and 3, and just indices 1:4 in dimension 2
indexes_2 = {':', 1:4, ':'}; 

这将返回除指定尺寸之外的所有尺寸的单元格,并且在该方向上将返回x给定范围内的单元格。因此,在以下代码中,a = x(all_indexes{:}); b = x(indexes_2{:}); function result=extract_cells(x, dim, window) % Create blank cell array {':', ':', ...} with entries ':' for each dimension % Edit (c/o Cris Luengo): need to use ndims(x) to get the number of dimensions num_dims = ndims(x) dims = cell(1, num_dims); dims(:) = {':'}; % Set the specified window of cells in the specified dimension dims{dim} = window; % Pick out the required cells result=x(dims{:}); end 是等效的。

window
相关问题