通过.pb文件训练TensorFlow模型

时间:2020-08-12 19:40:06

标签: python tensorflow machine-learning

我正在尝试从.pb文件中重新训练TensorFlow模型。我正在使用以下函数来检索它并在Python中加载图形:

# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

在这里,我尝试列出其可训练的变量和操作:

with tf.Session(graph=tf_graph) as sess:
    print("Trainable variables: {}".format(tf.trainable_variables())) 
    variables = [op for op in tf_graph.get_operations()]
    for var in variables:
        print("{}".format(var.name), end = ' ,')

这是上面代码的输出: Code output

如上所示,它说没有可以训练的变量,当我尝试以下代码训练图形时:

with tf.Session() as sess:
    random_input  = tf.convert_to_tensor(np.random.rand(1, 3, 2848, 4256)) # Input dimensions
    random_output = sess.run(random_input) 

    random_y0 = tf.convert_to_tensor(np.random.rand(1, 3, 2848, 4256))

    loss = tf.reduce_sum(tf.square(random_y0 - random_output))
    train = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)

    sess.run(tf.global_variables_initializer())

    print("Training")

    for step in range(1):
        sess.run(train)

它给了我错误:

train = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/tensorflow_core/python/training/optimizer.py", line 410, in minimize
    ([str(v) for _, v in grads_and_vars], loss))
ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients, between variables

如果有人能弄清为什么它找不到任何可训练的变量,以及我的代码有误,我将不胜感激。非常感谢!

0 个答案:

没有答案