这是我的代码,使用DynamicPartition操作构建图形,使用掩码将矢量[1,2,3,4,5,6]分割为两个矢量[1,2,3]和[4,5,6] [1,1,1,0,0,0]:
@Test
public void dynamicPartition2() {
Graph graph = new Graph();
Output a = graph.opBuilder("Const", "a")
.setAttr("dtype", DataType.INT64)
.setAttr("value", Tensor.create(new long[]{6}, LongBuffer.wrap(new long[] {1, 2, 3, 4, 5, 6})))
.build().output(0);
Output partitions = graph.opBuilder("Const", "partitions")
.setAttr("dtype", DataType.INT32)
.setAttr("value", Tensor.create(new long[]{6}, IntBuffer.wrap(new int[] {1, 1, 1, 0, 0, 0})))
.build().output(0);
graph.opBuilder("DynamicPartition", "result")
.addInput(a)
.addInput(partitions)
.setAttr("num_partitions", 2)
.build().output(0);
try (Session s = new Session(graph)) {
List<Tensor> outputs = s.runner().fetch("result").run();
try (Tensor output = outputs.get(0)) {
LongBuffer result = LongBuffer.allocate(3);
output.writeTo(result);
assertArrayEquals("Shape", new long[]{3}, output.shape());
assertArrayEquals("Values", new long[]{4, 5, 6}, result.array());
}
//Test will fail here
try (Tensor output = outputs.get(1)) {
LongBuffer result = LongBuffer.allocate(3);
output.writeTo(result);
assertArrayEquals("Shape", new long[]{3}, output.shape());
assertArrayEquals("Values", new long[]{1, 2, 3}, result.array());
}
}
}
调用s.runner().fetch("result").run()
后,返回长度为1的列表,其值为[4,5,6]。我的图表似乎只产生一个输出。
如何获得分裂向量的其余部分?
答案 0 :(得分:1)
DynamicPartition
操作返回多个输出(每个分区一个),但Session.Runner.fetch
调用仅请求第0个输出。
Java API缺少Python API所具有的一堆便利糖,但您可以通过显式请求所有输出来执行您想要的操作。换句话说,改变自:
List<Tensor> outputs = s.runner().fetch("result").run();
到
List<Tensor> outputs = s.runner().fetch("result", 0).fetch("result", 1).run();
希望有所帮助。
答案 1 :(得分:0)
不确定java(我不知道它,也没有调查的环境),但在python中一切正常。例如这个
import tensorflow as tf
a = tf.constant([1, 2, 3, 4, 5, 6])
b = tf.constant([1, 1, 1, 0, 0, 0])
c = tf.dynamic_partition(a, b, 2)
with tf.Session() as sess:
v1, v2 = sess.run(c)
print v1
print v2
返回正确的分区。