如何在自定义python运算符(tf.py_func)中更新模型(变量)?

时间:2018-02-22 15:37:07

标签: python tensorflow

我需要在python中编写一个自定义w,它会根据模型生成输出,另一个操作会更新模型。在下面的示例代码中,我有一个非常简单的缩放器模型custom_model_read_op(但实际上它将是一个nxm矩阵)。我想出了如何在w函数中演示“读取”模型(实际上要复杂得多)。但是,如何创建类似的内容,以某种自定义复杂的方式更新custom_model_update_op(使用Optimizer)?我认为这是可能的,因为import tensorflow as tf import numpy # Create a model w = tf.Variable(numpy.random.randn(), name="weight") X = tf.placeholder(tf.int32, shape=(), name="X") def custom_model_read_op(i, w): y = i*float(w) return y y = tf.py_func(custom_model_read_op, [X, w], [tf.float64], name="read_func") def custom_model_update_op(i, w): ==> # How to update w (the model stored in a Variable above) based on the value of i and some crazy logic? return 0 crazy_update = tf.py_func(custom_model_update_op, [X, w], [tf.int64], name="update_func") with tf.Session() as sess: tf.global_variables_initializer().run() for i in range(10): y_out, __ = sess.run([y, crazy_update], feed_dict={X: i}) print("y=", "{:.4f}".format(y_out[0])) 像SGD这样的操作能够做到这一点。提前谢谢!

Toolbar toolbar = (Toolbar) findViewById(R.id.toolbar);
DrawerLayout drawer = (DrawerLayout) findViewById(R.id.drawer_layout);
NavigationView navigationView = (NavigationView) findViewById(R.id.nav_view);

1 个答案:

答案 0 :(得分:1)

嗯,我不确定这是最好的方法,但它会在我需要的时候完成。我没有py_func w更新read_op,但我会在assign更新它,将其作为返回值传回,最后使用{ {1}}函数在自定义操作之外修改它。如果任何Tensorflow专家确认这是一个很好的合法方式,我会很感激。

import tensorflow as tf
import numpy

# Create a model
w = tf.Variable(numpy.random.randn(), name="weight")
X = tf.placeholder(tf.int32, shape=(), name="X")

def custom_model_read_op(i, w):
    y = i*float(w)
    w = custom_model_update(w)
    return y, w
y = tf.py_func(custom_model_read_op, [X, w], [tf.float64, tf.float64], name="read_func")

def custom_model_update(w):
    # update w (the model stored in a Variable above) based on the vaue of i and some crazy logic
    return w + 1

with tf.Session() as sess:

    tf.global_variables_initializer().run()

    for i in range(10):
        y_out, w_modified = sess.run(y, feed_dict={X: i})
        print("y=", "{:.4f}".format(y_out))
        assign_op = w.assign(w_modified)
        sess.run(assign_op)
        print("w=", "{:.4f}".format(sess.run(w)))