每次会话从tensorflow数据集采样到相同张量多次。run()调用

时间:2019-06-05 16:41:38

标签: python tensorflow tensorflow-datasets

考虑以下示例:

import tensorflow as tf
import numpy as np

X = np.arange(4).reshape(4, 1) + (np.arange(3) / 10).reshape(1, 3)

batch = tf.data.Dataset.from_tensor_slices(X) \
        .batch(2).make_one_shot_iterator().get_next()

def foo(x):
    return x + 1

tensor = foo(batch)

现在,我正在寻找一种方法,可以使每个tensor调用多次session.run()进行采样,即:

def bar(x):
    return x - 1

result1 = bar(tensor)
with tf.control_dependencies([result1]):
    op = <create operation to sample from dataset into `tensor` again>
    with tf.control_dependencies([op]):
        result2 = bar(tensor)

sess = tf.Session()
print(*sess.run([result1, result2]), sep='\n\n')

应输出:

[[0.  0.1 0.2]
 [1.  1.1 1.2]]

[[2.  2.1 2.2]
 [3.  3.1 3.2]]

那有可能吗?我知道一个人可以多次调用get_next()以在不同张量对象中获取多个数据集样本,但是一个样本可以进入相同张量对象吗?

对我来说,用例是这样的:此代码的foobar部分是分开的,而foo部分不知道需要多少次样本每次运行。

P.S。 我正在使用tf 1.12。也可以选择1.13,但不是tf 2。

1 个答案:

答案 0 :(得分:0)

是的,有可能。

到目前为止,您已经尝试过的一些见解:

  1. 每次需要从数据集中的新值时,都可以使用从make_one_shot_iterator()返回的数据集迭代器
  2. 您可以使自己的函数成为tf图的一部分,以将结果通过foo()
  3. 传递

像这样的东西给你想要的输出(据我了解)

import tensorflow as tf
import numpy as np

X = np.arange(4).reshape(4, 1) + (np.arange(3) / 10).reshape(1, 3)

iterator = tf.data.Dataset.from_tensor_slices(X) \
        .batch(2).make_one_shot_iterator()

def foo(x):
    return x + 1

def get_tensor():
  return foo(iterator.get_next())

tensor = get_tensor()

def bar(x):
    return x - 1

result1 = bar(tensor)
with tf.control_dependencies([result1]):
    op = get_tensor()
    with tf.control_dependencies([op]):
        result2 = bar(op)

sess = tf.Session()
print(*sess.run([result1, result2]), sep='\n\n')