tf.split的输出是什么?

时间:2017-06-19 21:58:08

标签: python arrays tensorflow tensor

假设我有这个:

  

TensorShape([Dimension(None),Dimension(32)])

我在张量_X上使用tf.split,其尺寸如上:

_X = tf.split(_X, 128, 0) 

这种新张量的形状是什么?输出是一个列表,因此很难知道这个新张量的形状。

3 个答案:

答案 0 :(得分:9)

tf.split()返回张量对象列表。你可以知道每个张量对象的形状如下

import tensorflow as tf

X = tf.random_uniform([256, 32]);
Y = tf.split(X,128,0)
Y_shape = tf.shape(Y[1])

sess = tf.Session()
X_v,Y_v,Y_shape_v = sess.run([X,Y,Y_shape]) 
# numpy style
print X_v.shape
print len(Y_v)
print Y_v[100].shape
# TF style
print len(Y)
print Y_shape_v

输出:

(256, 32)
128
(2, 32)
128
[ 2 32]

我希望这有帮助!

答案 1 :(得分:5)

tf.split(X, row = n, column = m)用于将变量的数据集分成n个行数和m个列数。

例如,我们的数据集x的大小为(10,10), 然后tf.split(x, 2, 0)会破坏x(5, 10)的数据集tf.split(x, 2, 2)

但如果我们选择(5, 5), 然后我们将得到4组大小为npm install --prefix="/your/path" . 的数据。

答案 2 :(得分:0)

新版的tensorflow定义了split函数,如下所示:

tf.split(     值,     num_or_size_splits,     轴= 0,     num = None,     name ='split' )

但是,当我尝试在R中运行它时

X = tf$random_uniform(minval=0,
                      maxval=10,shape(256, 32),name = "X");

Y = tf$split(X,num_or_size_splits = 2,axis = 0)

它报告错误消息:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: Rank-0 tensors are not supported as the num_or_size_splits argument to split. Argument provided: 2.0