我要在一个循环中创建多个张量,我想将它们全部聚合到一个张量张量对象中。我想我缺少了scala tensorflow API的某些部分,因为在使用tf.concat
的python中,同样的事情非常容易。
scala中tf.concat / tf.split的等效命令是什么?
这基本上是我想在Scala中复制的python程序(找不到用于Scala的tf.concat和tf.split):
import tensorflow as tf
# Initialize two constants
x1 = tf.constant([1,2,3,4])
x2 = tf.constant([5,6,7,8])
x11 = tf.constant([4,3,2,1])
x22 = tf.constant([8,7,6,5])
# concatenate x1, x11 and x2, x22 tensors
x111 = tf.concat([x1, x11], 0)
x222 = tf.concat([x2, x22], 0)
# Multiply
result_bat = tf.multiply(x111, x222)
# Intialize the Session
sess = tf.Session()
# Print the result
res = sess.run(result_bat)
res10, res11 = tf.split(res, 2, 0)
print(sess.run(res10))
print(sess.run(res11))
# Close the session
sess.close()