如何在Tensorflow中使用自定义/非默认tf.Graph?

时间:2017-04-23 18:30:08

标签: python session graph parallel-processing tensorflow

我是Tensorflow的新手,我正在阅读https://www.amazon.com/TensorFlow-Machine-Learning-Cookbook-McClure/dp/1786462168。我在tf.Session中注意到的一个论点是graph。我喜欢完全控制流程,我想知道如何正确使用tf.Graph tf.Session以及如何将计算添加到特定图表?此外,什么是规范语法(如果有),人们在Tensorflow中向特定图表添加操作?

类似于以下内容:

t = np.linspace(0,2*np.pi)
fig, ax = plt.subplots()
ax.scatter(x=t, y=np.sin(t))

与之相比:

plt.scatter(x=t, y=np.sin(t))

我如何才能与tf.Graph()具有相同的灵活性?

G = tf.Graph()

t_query = np.linspace(0,2*np.pi,50)
pH_t = tf.placeholder(tf.float32, shape=t_query.shape)

def simple_sinewave(t, name=None):
    return tf.sin(t, name=name)

with tf.Session() as sess:
    r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
# array([  0.00000000e+00,   1.27877161e-01,   2.53654599e-01,
# ...
#         -1.27877384e-01,   1.74845553e-07], dtype=float32)

现在尝试指定graph参数:

with tf.Session(graph=G) as sess:
    r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-51-d73a1f0963e3> in <module>()
     26 #         -1.27877384e-01,   1.74845553e-07], dtype=float32)
     27 with tf.Session(graph=G) as sess:
---> 28     r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})

...     RuntimeError:会话图为空。在调用run()之前向图表添加操作。

使用David Parks回答更新此问题:

# Custom Function
def simple_sinewave(t, name=None):
    return tf.sin(t, name=name)

# Synth graph
G = tf.Graph()

# Build Graph
with G.as_default():
    t_query = np.linspace(0,2*np.pi,50)
    pH_t = tf.placeholder(tf.float32, shape=t_query.shape)

# Run session using Graph
with tf.Session(graph=G) as sess:
    r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
r
# array([  0.00000000e+00,   1.27877161e-01,   2.53654599e-01,
#          3.75266999e-01,   4.90717560e-01,   5.98110557e-01,
# ...
#         -4.90717530e-01,  -3.75267059e-01,  -2.53654718e-01,
#         -1.27877384e-01,   1.74845553e-07], dtype=float32)

Bonus:在Tensorflow中是否有一个特定的术语来命名占位符变量?与pd.DataFrame一样df_data

1 个答案:

答案 0 :(得分:4)

你通常这样做:

with tf.Graph().as_default():
  # build your model
    with tf.Session() as sess:
      sess.run(...)

我有时使用多个图表来运行与训练集分开的测试集,这与上面的示例类似,您可以这样做:

g = tf.Graph()
with g.as_default():
  # build your model
  with tf.Session() as sess:
    sess.run(...)

正如您也指出的那样,您可以避免使用with而只执行sess = tf.Session(graph=g),并且您必须自己关闭会话。大多数用例将通过使用python的with

进行简化

当您有两张图表时,只要您使用该图表,就会将每个as_default()设置为默认图表。

示例:

g1 = tf.Graph()
g2 = tf.Graph()

with g1.as_default():
  # do stuff like normal, g1 is the graph that will be used
  with tf.Session() as session_on_g1:
    session_on_g1.run(...)

with g2.as_default():
  # do stuff like normal, g2 is the graph that will be used
  with tf.Session() as session_on_g2:
    session_on_g2.run(...)