如何在TensorFlow模型中将标量张量转换为标量?

时间:2017-10-08 16:27:25

标签: python tensorflow

我正在尝试使用mnist_deep.py

向TensorFlow MNIS示例tf.contrib.image.rotate()添加数据增强功能
rotate_angle = 0.1


def deepnn(x):
    ...
    with tf.name_scope('rotate'):
        angle = tf.tf.placeholder(tf.float32)
        x_image = tf.contrib.image.rotate(x_image, angle)  # Wrong!
    ...
    return angle


...
angle = deepnn(x)
with tf.Session() as sess:
    angle.eval({angle: rotate_angle}

这不起作用,因为tf.contrib.image.rotate()仅接受普通标量作为角度。

我试过TensorFlow: cast a float64 tensor to float32但遗憾的是,提到的函数现在也返回张量。

如何在模型中将张量标量转换为标量?我想重复使用相同的模型,并为训练和测试提供不同的角度。

1 个答案:

答案 0 :(得分:0)

我认为你不需要奇怪的转换,而是需要重新组织代码。我找到了解决问题的可能方法,我希望它适合您:

import tensorflow as tf
import numpy as np

rotate_angle = 0.1

def deepnn(x,angle):
    x_image = tf.contrib.image.rotate(x, angle)     
    return x_image

angle = tf.placeholder(tf.float32,shape=())
input_image_placeholder = tf.placeholder(tf.float32,shape=(100,100,3))


rotated_x_image = deepnn(input_image_placeholder,angle)


sess = tf.Session()

input_image = np.ones(dtype=float,shape=(100,100,3))

curr_rotated_x_image = sess.run(rotated_x_image,{angle:rotate_angle,input_image_placeholder:input_image})

print(curr_rotated_x_image)

sess.close()

我不认为在函数内声明占位符是个好主意所以我把它移到了外面。如果这个解决方案没问题,请告诉我!