我尝试使用theta
作为下面显示的代码,仅训练一个Tensorflow变量tf.select
。
import numpy as np
import tensorflow as tf
from numpy.random import RandomState
from tensorflow.contrib.losses import mean_squared_error
tf.logging.set_verbosity(tf.logging.INFO)
random_state = RandomState()
def generate_data(n_train, n_test, theta=0.2):
X_train = random_state.uniform(size=[n_train])
X_test = random_state.uniform(size=[n_test])
y_train = np.where((X_train > theta), np.sin(X_train), np.cos(X_train))
y_test = np.where((X_test > theta), np.sin(X_test), np.cos(X_test))
return (X_train, y_train), (X_test, y_test)
(X_train, y_train), (X_test, y_test) = generate_data(1000, 300)
intN = np.int64
floatN = np.float32
intT = tf.int64
floatT = tf.float32
x = tf.placeholder(dtype=floatT, name=r'x', shape=[None])
y_true = tf.placeholder(dtype=floatT, name=r'y_true', shape=[None])
theta = tf.get_variable(dtype=floatT, name=r'theta', initializer=tf.zeros_initializer(shape=[], dtype=floatT))
y_pred = tf.select(
tf.greater(x, theta),
tf.sin(x),
tf.cos(x), name=r'y_pred')
loss = mean_squared_error(y_pred, y_true)
optimizer = tf.train.AdamOptimizer(0.01).minimize(loss)
with tf.Session() as session:
all_variables = tf.global_variables() + tf.local_variables()
all_variables_initializer = tf.variables_initializer(all_variables)
session.run(all_variables_initializer)
_, batch_loss = session.run([optimizer, loss], feed_dict={
x: X_train,
y_true: y_train,
})
tf.logging.info(r'batch loss {:.4f}'.format(batch_loss))
但我得到了
ValueError:没有为任何变量提供渐变,检查图表中不支持渐变的ops,变量之间[' Tensor(" theta / read:0",shape =(), dtype = float32)']和丢失Tensor(" mean_squared_error / value:0",shape =(),dtype = float32)。
那怎么办呢?如果我真的想使用tf.cond
或tf.where
或tf.select
或tf.case
而不是像sigmoid
这样的话,我的意思是如何做条件分支(分段函数) Tensorflow一般吗?