Tensorflow数据集API:渐变是"无"?

时间:2018-05-03 12:23:30

标签: python-3.x tensorflow gradient

我遇到了Tensorflow数据集API的问题。 我想传递一些每个样本的参数,但我无法优化它们。

sample_data = tf.placeholder(...)
design = tf.placeholder(...)

mixture_prob = tf.Variable(..., shape=[num_mixtures, num_samples])

# transpose to get 'num_samples' to axis 0:
mixture_log_prob_t = tf.transpose(tf.log(mixture_prob, name="mixture_log_prob"))
assert mixture_log_prob_t.shape == [num_samples, num_mixtures]

以下是我的问题的原因: 我有一些样本数据和设计矩阵。 此外,每个样本都有'num_mixtures'我想要优化的参数。

data = tf.data.Dataset.from_tensor_slices((
    sample_data,
    design,
    mixture_log_prob_t
))
data = data.repeat()
data = data.shuffle(batch_size * 4)
data = data.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))

iterator = data.make_initializable_iterator()

batch_sample_data, batch_design, batch_mixture_log_prob = iterator.get_next()
batch_mixture_log_prob = tf.transpose(batch_mixture_log_prob)

现在,当运行" optimizer.gradient()"我得到了#34;没有"对于此参数:

>>> model.gradient
[(None, <tf.Variable 'mixture_prob/logit_prob:0' shape=(2, 2000) dtype=float32_ref>), ...]

0 个答案:

没有答案