我正在尝试在不使用Python的情况下在Java中创建TensorFlow模型。 我设法为Java做了很多Python代码,但我缺少一些要完成的元素。 我在优化器上阻塞了。 Python中的原始代码是一个非常简单的模型。
import tensorflow as tf
# Batch of input and target output (1x1 matrices)
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
y = tf.placeholder(tf.float32, shape=[None, 1, 1], name='target')
# Trivial linear model
y_ = tf.identity(tf.layers.dense(x, 1), name='output')
# Optimize loss
loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()
我开始转换为Java,我离结束不远,但我仍然停留在优化器上。
try (Graph g = new Graph()) {
//# Batch of input and target output (1x1 matrices)
//x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
Output<OperationBuilder> x = g.opBuilder("Placeholder", "input")
.setAttr("dtype", DataType.FLOAT)
.build().output(0);
//y = tf.placeholder(tf.float32, shape=[None, 1, 1], name=target')
Output<OperationBuilder> y = g.opBuilder("Placeholder", "target")
.setAttr("dtype", DataType.FLOAT)
.build().output(0);
//# Trivial linear model
//y_ = tf.identity(tf.layers.dense(x, 1), name='output')
Tensor t = Tensor.create(new int[] {0});
Output reductionIndices = g.opBuilder("Const", "layer")
.setAttr("dtype", t.dataType()).setAttr("value", t)
.build().output(0);
Output dense = g.opBuilder("layersdense", "dense")
.setAttr("T", DataType.FLOAT)
.setAttr("Tidx", DataType.INT32)
.addInput(input).addInput(reductionIndices)
.build().output(0);
Tensor<?> t2 = Tensor.create(dense);
Output<OperationBuilder> y_ = g.opBuilder("Identity", "output")
.setAttr("value", t2)
.build().output(0);
//# Optimize loss
//loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
Output<OperationBuilder> sub=g.opBuilder("Sub","sub")
.addInput(y_).addInput(y)
.build().output(0);
Output<OperationBuilder> sq = g.opBuilder("Square", "Square")
.addInput(sub)
.build().output(0);
//optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
Code java ???
//train_op = optimizer.minimize(loss, name='train')
Code java ???
}
答案 0 :(得分:0)
在推断输出的某些后期处理过程中,我也遇到了很多麻烦。我可以建议的解决方案是导入整个Graph文件。时代。请注意,我还没有尝试过这种方法,但是根据我的经验,避免使用Java特定的图形生成器将对您有很大帮助。 祝你好运