tensorflow如何更改数据集

时间:2018-01-09 20:41:18

标签: python tensorflow tensorflow-datasets

我有一个Dataset API doohickey,它是我的张量流图的一部分。当我想使用不同的数据时,如何将其交换出来?

dataset = tf.data.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

variable = tf.Variable(3, dtype=tf.int64)
model = variable*next_element

#pretend like this is me training my model, or something
with tf.Session() as sess:
    sess.run(variable.initializer)
    try:
        while True:
            print(sess.run(model)) # (0,3,6)
    except:
        pass

dataset = tf.data.Dataset.range(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()  

### HOW TO DO THIS THING?
with tf.Session() as sess:
    sess.run(variable.initializer) #This would be a saver restore operation, normally...
    try:
        while True:
            print(sess.run(model)) # (0,3)... hopefully
    except:
        pass

2 个答案:

答案 0 :(得分:2)

我不相信这是可能的。您要求更改计算图本身,这在tensorflow中是不允许的。我没有自己解释,而是发现这篇文章中接受的答案在解释这一点时特别清楚Is it possible to modify an existing TensorFlow computation graph?

现在,我说,我认为有一种相当简单/干净的方式可以实现您的目标。基本上,您希望重置图形并重建Dataset部分。当然,您希望重用代码的model部分。因此,只需将该模型放在类或函数中以允许重用。一个基于代码的简单示例:

# the part of the graph you want to reuse
def get_model(next_element):
    variable = tf.Variable(3,dtype=tf.int64)
    return variable*next_element

# the first graph you want to build
tf.reset_default_graph()

# the part of the graph you don't want to reuse
dataset = tf.data.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

# reusable part
model = get_model(next_element)

#pretend like this is me training my model, or something
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    try:
        while True:
            print(sess.run(model)) # (0,3,6)
    except:
        pass

# now the second graph
tf.reset_default_graph()

# the part of the graph you don't want to reuse
dataset = tf.data.Dataset.range(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()  

# reusable part
model = get_model(next_element)

### HOW TO DO THIS THING?
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    try:
        while True:
            print(sess.run(model)) # (0,3)... hopefully
    except:
        pass

最终注释:您还会在tf.contrib.graph_editor docs here处看到一些参考资料。他们明确表示你无法用graph_editor完成你想要的东西(参见那个链接:"这是你不能做的一个例子&#34 ;;但你可以非常接近)。尽管如此,它仍然不是很好的做法;他们有充分的理由让图表只附加,我认为我建议的上述方法是更清洁的方式来实现你所寻求的。

答案 1 :(得分:0)

我建议的一种方法是用place_holders,然后使用tf.data.dataset,但这会使事情变慢。因此,您将拥有以下内容:

train_data = tf.placeholder(dtype=tf.float32, shape=[None, None, 1]) # just an example
# Then add the tf.data.dataset here
train_data = tf.data.Dataset.from_tensor_slices(train_data).shuffle(10000).batch(batch_size)

现在,在会话中运行图形时,必须使用占位符来输入数据。所以您可以随便喂...

希望这会有所帮助!