如何将单个输入传递给张量流贴图函数

时间:2018-10-11 04:32:57

标签: python tensorflow

我正在玩tensorflow只是为了“拉屎”。我知道有更好的方法来做我正在做的事情。

我有一个值列表,我想计算其方差。我正在使用地图功能来计算( xi -平均值)。 xi 是列表的元素,平均值是我之前计算的平均值。

tf_y_mean = tf.reduce_mean(tf_y_train)
tf_y_mean = tf.Print(tf_y_mean, [tf_y_mean], message="tf_y_mean")

tf_y_mean = tf.tile([tf_y_mean], [tf.shape(tf_y_train)[0]])

# used in the map function to calculate the distance
def tf_value_minus_mean_square(elem):
    sub = elem[0] - elem[1]
    return tf.pow(sub, 2)

tf_variance = tf.map_fn(tf_value_minus_mean_square, [tf_y_train, tf_y_mean], dtype=tf.float32)

为了使用map函数,我必须使用tile和reshape将均值转换为具有与我的值列表相同的第一维大小的张量。

有没有办法传递“单一”的意思而不是这样做?

非常感谢!

在下面填写代码

import tensorflow as tf

y_train = [2.48647765732222, -0.303258845569525, -4.05313903208675, -4.33589931314926, -6.17420514970129,
           -5.60395159162841, -3.50690652820209, -2.32567755531334, -4.63772499017313, -0.232708358092122,
           -1.98576143074415, 1.02839153133005, -2.26396081470418, -0.450826849993658, 1.16715548976294,
           6.65243873752780, 4.14520332259566, 5.26766016927822, 6.34028299499329, 9.62643862532723,
           14.7841620484121]
x_train = [i for i in range(len(y_train))]


tf_y_train = tf.placeholder(dtype=tf.float32, shape=21, name='y_train')
tf_x_train = tf.placeholder(dtype=tf.float32, shape=21, name='x_train')

tf_y_sum = tf.reduce_sum(tf_y_train)
tf_y_sum = tf.Print(tf_y_sum, [tf_y_sum], message="tf_y_sum")

tf_y_mean = tf.reduce_mean(tf_y_train)
tf_y_mean = tf.Print(tf_y_mean, [tf_y_mean], message="tf_y_mean")

tf_y_mean = tf.tile([tf_y_mean], [tf.shape(tf_y_train)[0]])


# used in the map function to calculate the distance
def tf_value_minus_mean_square(elem):
    sub = elem[0] - elem[1]
    return tf.pow(sub, 2)


tf_variance = tf.map_fn(tf_value_minus_mean_square, [tf_y_train, tf_y_mean], dtype=tf.float32)

tf_variance = tf.reduce_sum(tf_variance)
tf_variance = tf.divide(tf_variance, tf.cast(tf.shape(tf_y_train)[0], tf.float32))


with tf.Session() as sess:
    print(sess.run(tf_variance, feed_dict={tf_y_train: y_train, tf_x_train: x_train}))

0 个答案:

没有答案