Tensorflow:tf.data.Dataset中的字符串分割行为异常

时间:2018-10-26 04:24:19

标签: python tensorflow dataset

我在Tensorflow中使用tf.data.Dataset API。我有2个numpy数组,其中data是2-d,而labels是1-d。我已经创建了一个Dataset

dataset = tf.data.Dataset.from_tensor_slices((data, labels))
val_dataset = dataset.map(lambda x, y: ({'reviews': x}, y))

我有一个要使用的预处理功能,如下所示:

def preprocess(x, y):
    # split on whitespace
    x['reviews'] = tf.string_split(x['reviews'])
    return x, y

我尝试像这样使用map

dataset = dataset.map(preprocess)

但我回来了:

ValueError: Shape must be rank 1 but is rank 0 for 'StringSplit' (op: 'StringSplit') with input shapes: [], [].

我四处搜寻,发现有人在预处理功能中建议了这种方法:

x['reviews'] = tf.string_split([x['reviews']])

但是我不清楚为什么要这么做。它不会像以前一样出错,但是我的数据形状不正确。例如,这是我在dataset中第一个元素看到的内容:

({'sequence': array([[ 6391,  3352, 10236,   244,  1362,   244,  9350,  7649,  6391,
         6324,  6063,  3620,   244,  8153,  6542, 10056,  7303,  1955,
         1362,  6194, 10250,  6391,   550,   244,  7577,   850,  3620,
         5807, 10325,  1362,  6542,   595,  9060,  9052,  9459,   351,
         4676,  9354,  7648,  3082,  7694,  8497, 10703,  1610,  9454,
        10236,   244,  7965,  8018,  9392,  6391,  6063,  2878,  1318,
         3169,  8198,  9354,  4131,  3620,  3082,  3352,  9052,  8018,
         7527,  3419,  1907,  8835,   796,   244,  8957,  4325,  8171,
         9454,  7602,  4435,  7648,  3169,  2083,  9454,  4789,  9620,
         9261,   556,  3524,  8497,  9174,  8299,  5871,  9052,  2888,
         9846,  1610,  1362,  4930,  2150,  1362,  8018,  3867,   341,
         7694,  8497,  6063,  3620,   244,  5807,  6089,  3169,  6350,
         1174,  7694,   949,  1292,   244,  9052,  9440,  3690,  1362,
         1907,  9011,  4156,  6081,   145,  1174,  7694,  9986,   949,
         1292,  3169,  1455,  6372,  9760,  5013,  3169,  1455,  5942,
         4365,  1362,  1907,   244,  5813,   244,  7994,  3525,  3550,
         7509,  6372,  9760,  7860,  9052,  2888,  7694,  8497,  1610,
         1316,   326,  1174,  3039,  3524,  9703,  3620,  6612,  1455,
          556,  9011,  3169,  1927,  9052,   409,  4059,  9354,   700,
         5503,  3550,  9052,  2083,  1963,   595,  3169,  7715, 10236,
         9442,  1174, 10087,  3169,  5312,  7474,  9052,  3525,  3169,
         5826,  7885,  6944,  7130,  5821,  2878,  7184,   153,  3169,
         8633,  8574,  1283,   606,  7902,  6110,  3082,  6406,  3169,
         8316,  6126,   688, 10236,  9440,  3082, 10584,  2143,  5460,
         5809,  1362,  2878, 10439,  3419,  1907,  4598,  4156, 10239,
         1450,  5514,  5010,  9350,   244,   651]])}, 0)

因此,当字典值仅应为1-d时,它是一个2-d数组。我要去哪里错了?

谢谢!

1 个答案:

答案 0 :(得分:0)

不带标量似乎是tf.string_split的局限性。请在https://github.com/tensorflow/tensorflow/issues

提交问题

就变通办法而言,将列表换行的建议是一个很好的建议,但是在拆分之后您还需要对其进行挤压,以便获得分量矢量而不是二维张量。

import tensorflow as tf
tf.enable_eager_execution()
scalar = tf.constant('ab c de')
print(scalar.shape)  # () scalar
vector = scalar[None]
print(vector.shape)  # (1,) vector
output = tf.sparse.to_dense(tf.string_split(vector), default_value='')
print(output)  # tf.Tensor([[b'ab' b'c' b'de']], shape=(1, 3), dtype=string)
squeezed = tf.squeeze(output, axis=0)
print(squeezed.shape)  # (3,) vector
print(squeezed)  # tf.Tensor([b'ab' b'c' b'de'], shape=(3,), dtype=string)