tf.train.Saver()和GPU的问题-TensorFlow

时间:2018-10-31 11:41:45

标签: python tensorflow gpu

我的代码结构如下:

with tf.device('/gpu:1'):
...
model = get_model(input_pl)
...
    with tf.Session() as sess:
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
           ...
           for n in range(num_batches):
              ...
              sess.run(...)
           # eval epoch
        saver.save(sess, ...)

我想在训练阶段之后保存模型。当我运行它会给我这个错误:

InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'save/SaveV2': Could not satisfy explicit device specification '/device:GPU:1' because no supported kernel for GPU devices is available.

阅读this question时,我以这种方式更改了代码:

saver = tf.train.Saver()
with tf.device('/gpu:1'):
...
model = get_model(pointcloud_pl)
...
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
           ...
           for n in range(num_batches):
              ...
              sess.run(...)
           # eval epoch
        saver.save(sess, ...)

但是现在我收到此错误:

ValueError: No variables to save

我也尝试过这种方式:

with tf.Session() as sess:
    saver = tf.train.Saver()
    ...
    with tf.device('/gpu:1'):
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
        ...
            for n in range(num_batches):
               ...
               sess.run()
            # eval epoch
        saver.save(sess, ...)

我仍然遇到相同的错误。该错误始终在saver = tf.train.Saver()行中。

我该如何解决这个问题?

1 个答案:

答案 0 :(得分:1)

解决了这个问题:

  1. tf.Session()
  2. 模型
  3. saver = tf.train.Saver()
  4. with tf.device():

这里是示例代码

with tf.Session() as sess:
    ...
    model = get_model(input_pl)
    saver = tf.train.Saver()
    ...
    with tf.device('/gpu:1'):
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
        ...
            for n in range(num_batches):
               ...
               sess.run()
            # eval epoch
        saver.save(sess, ...)