Tensorflow tf_strided_slice阐述

时间:2017-07-24 21:46:35

标签: python tensorflow

我试图了解tf.strided_slice的工作原理。为此,我写了以下代码:

import numpy as np
import tensorflow as tf

# parameters
record_size = 841

# create a random vector of 1682 integers in range [0.255]
content = np.random.randint(255,size=[1682])

depth_major = tf.reshape(
  tf.strided_slice(content, [0],
                   [record_size]),
                   [1, 29, 29])

depth_major1 = tf.reshape(
  tf.strided_slice(content, [record_size+1],
                   [2*record_size]),
                   [1, 29, 29])

# Initializing the variables
init = tf.global_variables_initializer()

with tf.Session as sess:
  sess.run(depth_major)
  print("depth_major", depth_major.shape)

当我执行上面的例子时,我收到以下错误:

ValueError: Cannot reshape a tensor with 840 elements to shape [1,29,29] (841 elements) for 'Reshape_1' (op: 'Reshape') with input shapes: [840], [3] and with input tensors computed as partial shapes: input[1] = [1,29,29].

我简直无法理解为什么元素数量是840,因为我从[0]开始到[record_size]结束?

1 个答案:

答案 0 :(得分:0)

错误是抱怨第二个 Reshape。您可以看到此信息,因为错误消息中的Reshape名称为Reshape_1。第二次重塑中的strided_slicerecord_size+1变为2*record_size,其大小为1682 - 842 = 840。因此关于错误大小的错误消息。我想您要指定record_size而不是record_size+1

希望有所帮助!