我有方法reformat
,其中使用numpy我将label(256,)
转换为label(256,2)
形状。
现在我想在具有形状(256,)的Tensor上进行相同的操作
我的代码看起来像这样(num_labels = 2): -
def reformat(dataset, labels):
dataset = dataset.reshape((-1, image_size, image_size,num_channels)).astype(np.float32)
labels = (np.arange(num_labels)==labels[:,None]).astype(np.float32)
return dataset, labels
答案 0 :(得分:2)
您可以使用tf.expand_dims()添加新尺寸。
In [1]: import tensorflow as tf
x = tf.constant([3., 2.])
tf.expand_dims(x, 1).shape
Out[1]: TensorShape([Dimension(2), Dimension(1)])
您也可以使用tf.reshape(),但建议您使用expand_dims,因为如果可以满足新的形状,这也会为新维度带来一些值。
In [1]: tf.reshape(x, [2, 1])
Out[1]: TensorShape([Dimension(2), Dimension(1)])