在Tensorflow中初始化一个依赖于批处理的变量

时间:2018-04-13 06:33:20

标签: python tensorflow

我有一个运行良好且准确的张量流代码,但占用了大量内存。具体来说,在我的代码中,我有一个看起来像这样的for循环:

K = 10
myarray1 = tf.placeholder(tf.float32, shape=[None,5,5]) # shape = [None, 5, 5]
myarray2 = tf.Variable( np.zeros([K,5,5]), dtype=tf.float32 )
min_value = tf.Variable(myarray1, validate_shape=False, trainable=False)
for k in range(0,K):
    tmp = myarray1*myarray2[k]
    idx = tf.where(tmp<min_value)
    tf.scatter_nd_assign(min_value, idx, tmp[idx], use_locking=True)

result = min_value

不幸的是,由于K在我的应用程序中变得很大,因此需要大量内存。所以,我希望有一个更好的方法来做到这一点。例如,在numpy / python中,您只需在遍历循环时跟踪最小值,并在每次迭代时更新它。好像我可以使用tf.assign,因为:

sess.run(tf.global_variables_initializer())

虽然此代码构建了图形(当validate_shape = False时),但它无法运行,因为它抱怨min_value尚未初始化。问题是,当我运行初始化程序时:

sess.run(tf.variables_initializer(tf.trainable_variables()))

{{1}}

它抱怨说我没有吃占位符。这实际上是有道理的,因为min_value的定义取决于图中的myarray1。

我真正想要做的是定义一个虚拟变量,该变量不依赖于myarray1的值,但确实与其形状相匹配。我希望将这些值初始化为某个数字(在这种情况下,大的东西很好),因为我将手动确保这些值在网络中被覆盖。

注意:据我所知,目前您无法定义具有未知形状的变量,除非您输入所需形状的另一个变量并设置validate_shape = False)。也许有另一种方式?

任何帮助/建议表示赞赏。

1 个答案:

答案 0 :(得分:0)

试试这个,如果不知道如何提供占位符,请阅读教程。

K = 10
myarray1 = tf.placeholder(tf.float32, shape=[None,5,5]) # shape = [None, 5, 5]

###################ADD THIS ####################
sess=tf.Session()
FOO = tf.run(myarray1,feed_dict={myarray1: YOURDATA}) #get myarray1 value
#replace all myarray1 below with FOO
################################################

myarray2 = tf.Variable( np.zeros([K,5,5]), dtype=tf.float32 )
min_value = tf.Variable(FOO, validate_shape=False, trainable=False)
for k in range(0,K):
    tmp = FOO*myarray2[k]
    idx = tf.where(tmp<min_value)
    tf.scatter_nd_assign(min_value, idx, tmp[idx], use_locking=True)

result = min_value

-------高于新的15.April.2018 ------

由于我不知道您的输入数据,我想尝试一些步骤。

Step_1:为输入数据创建一个位置

x = tf.placeholder(tf.float32, shape=[None,2])

Step_2:获取批量数据

batch_x=[[1,2],[3,4]]   #example
#since x=[None,2], the batch size would be batch_x_size/x_size=2 

Step_3:进行会话

sess=tf.Session()

如果您有变量,请在计算前添加以下代码进行初始化

init=tf.gobal_variables_initializer()
sess.run(init)

Step_4:

yourplaceholderdictiornay={x: batch_x}
sess.run(x, feed_dict=yourplaceholderdictiornay)

始终输入占位符,以便获取要计算的值。

有一个Tensorflow and Deep Learning without a PHD非常有用的PDF文件,你也可以在youtube上找到这个标题。