tf.where导致优化器在tensorflow中失败

时间:2017-04-21 13:28:21

标签: python python-3.x tensorflow

我想检查一下是否可以用tensorflow而不是pymc3来解决this问题。实验的想法是我要定义一个包含switchpoint的probibalistic系统。我可以使用抽样作为推理方法,但我开始想知道为什么我不能用梯度下降来做这件事。

我决定在tensorflow中进行渐变搜索,但似乎当涉及tf.where时,tensorflow很难执行渐变搜索。

您可以在下面找到代码。

import tensorflow as tf
import numpy as np

x1 = np.random.randn(50)+1
x2 = np.random.randn(50)*2 + 5
x_all = np.hstack([x1, x2])
len_x = len(x_all)
time_all = np.arange(1, len_x + 1)

mu1 = tf.Variable(0, name="mu1", dtype=tf.float32)
mu2 = tf.Variable(5, name = "mu2", dtype=tf.float32)
sigma1 = tf.Variable(2, name = "sigma1", dtype=tf.float32)
sigma2 = tf.Variable(2, name = "sigma2", dtype=tf.float32)
tau = tf.Variable(10, name = "tau", dtype=tf.float32)

mu = tf.where(time_all < tau,
              tf.ones(shape=(len_x,), dtype=tf.float32) * mu1,
              tf.ones(shape=(len_x,), dtype=tf.float32) * mu2)
sigma = tf.where(time_all < tau,
              tf.ones(shape=(len_x,), dtype=tf.float32) * sigma1,
              tf.ones(shape=(len_x,), dtype=tf.float32) * sigma2)

likelihood_arr = tf.log(tf.sqrt(1/(2*np.pi*tf.pow(sigma, 2)))) -tf.pow(x_all - mu, 2)/(2*tf.pow(sigma, 2))
total_likelihood = tf.reduce_sum(likelihood_arr, name="total_likelihood")

optimizer = tf.train.RMSPropOptimizer(0.01)
opt_task = optimizer.minimize(-total_likelihood)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print("these variables should be trainable: {}".format([_.name for _ in tf.trainable_variables()]))
    for step in range(10000):
        _lik, _ = sess.run([total_likelihood, opt_task])
        if step % 1000 == 0:
            variables = {_.name:_.eval() for _ in [mu1, mu2, sigma1, sigma2, tau]}
            print("step: {}, values: {}".format(str(step).zfill(4), variables))

您会注意到即使tensorflow似乎知道变量及其渐变,tau参数也不会改变。什么是错的任何线索?这是可以在tensorflow中计算的东西还是我需要不同的模式?

2 个答案:

答案 0 :(得分:3)

tau仅用于condition的{​​{1}}参数:where),这是一个布尔张量。由于计算梯度仅对连续值有意义,因此输出相对于tf.where(time_all < tau, ...的梯度将为零。

即使忽略tau,您在表达式tf.where中使用了tau,它几​​乎无处不在,因此渐变为零。

由于梯度为零,无法通过梯度下降法学习time_all < tau

根据您的问题,可能不是在两个值之间进行硬切换,而是使用加权和而不是tau,其中p*val1 + (1-p)*val2以连续的方式依赖p。< / p>

答案 1 :(得分:1)

指定的解决方案是正确的答案,但不包含我的问题的代码解决方案。以下片段做了;

import tensorflow as tf
import numpy as np
import os
import uuid

TENSORBOARD_PATH = "/tmp/tensorboard-switchpoint"
# tensorboard --logdir=/tmp/tensorboard-switchpoint

x1 = np.random.randn(35)-1
x2 = np.random.randn(35)*2 + 5
x_all = np.hstack([x1, x2])
len_x = len(x_all)
time_all = np.arange(1, len_x + 1)

mu1 = tf.Variable(0, name="mu1", dtype=tf.float32)
mu2 = tf.Variable(0, name = "mu2", dtype=tf.float32)
sigma1 = tf.Variable(2, name = "sigma1", dtype=tf.float32)
sigma2 = tf.Variable(2, name = "sigma2", dtype=tf.float32)
tau = tf.Variable(15, name = "tau", dtype=tf.float32)
switch = 1./(1+tf.exp(tf.pow(time_all - tau, 1)))

mu = switch*mu1 + (1-switch)*mu2
sigma = switch*sigma1 + (1-switch)*sigma2

likelihood_arr = tf.log(tf.sqrt(1/(2*np.pi*tf.pow(sigma, 2)))) - tf.pow(x_all - mu, 2)/(2*tf.pow(sigma, 2))
total_likelihood = tf.reduce_sum(likelihood_arr, name="total_likelihood")

optimizer = tf.train.AdamOptimizer()
opt_task = optimizer.minimize(-total_likelihood)
init = tf.global_variables_initializer()

tf.summary.scalar("mu1", mu1)
tf.summary.scalar("mu2", mu2)
tf.summary.scalar("sigma1", sigma1)
tf.summary.scalar("sigma2", sigma2)
tf.summary.scalar("tau", tau)
tf.summary.scalar("likelihood", total_likelihood)
merged_summary_op = tf.summary.merge_all()

with tf.Session() as sess:
    sess.run(init)
    print("these variables should be trainable: {}".format([_.name for _ in tf.trainable_variables()]))
    uniq_id = os.path.join(TENSORBOARD_PATH, "switchpoint-" + uuid.uuid1().__str__()[:4])
    summary_writer = tf.summary.FileWriter(uniq_id, graph=tf.get_default_graph())
    for step in range(40000):
        lik, opt, summary = sess.run([total_likelihood, opt_task, merged_summary_op])
        if step % 100 == 0:
            variables = {_.name:_.eval() for _ in [total_likelihood]}
            summary_writer.add_summary(summary, step)
            print("i{}: {}".format(str(step).zfill(5), variables))