我正在阅读this code,我想了解它的实施情况。
我想知道的第一件事是,一些张量对象(占位符)的形状是什么,例如xs
,h_init
,y_init
,{{ 1}},y_sample
等
我写了一行代码,例如print(xs.shape)
,但它不起作用。
我如何理解这些参数(张量)的形状?我可以在NumPy中写下类似的内容吗?
定义这些张量的代码部分如下所示:
x_init = tf.placeholder(tf.float32, shape=(args.init_batch_size,) + obs_shape)
xs = [tf.placeholder(tf.float32, shape=(args.batch_size, ) + obs_shape)
for i in range(args.nr_gpu)]
# if the model is class-conditional we'll set up label placeholders +
# one-hot encodings 'h' to condition on if args.class_conditional:
num_labels = train_data.get_num_labels()
y_init = tf.placeholder(tf.int32, shape=(args.init_batch_size,))
h_init = tf.one_hot(y_init, num_labels)
y_sample = np.split(
np.mod(np.arange(args.batch_size * args.nr_gpu), num_labels), args.nr_gpu)
h_sample = [tf.one_hot(tf.Variable(
y_sample[i], trainable=False), num_labels) for i in range(args.nr_gpu)]
答案 0 :(得分:1)
形状由不同的命令行参数组合而成:
obs_shape
是输入图像的形状,例如(32, 32, 3)
args.init_batch_size
和args.batch_size
是命令行中的值。例如,可以是30
和40
。然后,x_init
的形状是init_batch_size
和obs_shape
:(30, 32, 32, 3)
的串联。相应地,xs
中每个项目的形状为(40, 32, 32, 3)
。
您无法评估xs.shape
,因为xs
是占位符的列表。您可以改为评估xs[0].shape
。
y_sample
和h_sample
也是张量列表。第一个包含(batch_size, num_labels)
个张量,第二个包含(num_labels, )
。