我遇到了Tensorflow数据集API的问题。 我想传递一些每个样本的参数,但我无法优化它们。
sample_data = tf.placeholder(...)
design = tf.placeholder(...)
mixture_prob = tf.Variable(..., shape=[num_mixtures, num_samples])
# transpose to get 'num_samples' to axis 0:
mixture_log_prob_t = tf.transpose(tf.log(mixture_prob, name="mixture_log_prob"))
assert mixture_log_prob_t.shape == [num_samples, num_mixtures]
以下是我的问题的原因: 我有一些样本数据和设计矩阵。 此外,每个样本都有'num_mixtures'我想要优化的参数。
data = tf.data.Dataset.from_tensor_slices((
sample_data,
design,
mixture_log_prob_t
))
data = data.repeat()
data = data.shuffle(batch_size * 4)
data = data.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
iterator = data.make_initializable_iterator()
batch_sample_data, batch_design, batch_mixture_log_prob = iterator.get_next()
batch_mixture_log_prob = tf.transpose(batch_mixture_log_prob)
现在,当运行" optimizer.gradient()"我得到了#34;没有"对于此参数:
>>> model.gradient
[(None, <tf.Variable 'mixture_prob/logit_prob:0' shape=(2, 2000) dtype=float32_ref>), ...]