Tensorflow-使用数据集API填充或截断序列

时间:2018-09-04 17:59:25

标签: tensorflow sequence tensorflow-datasets

我正在尝试使用Dataset API来准备import time for n in (1000, 10000, 100000): long = test_cases[0] * n started = time.time() fix(long) needed = time.time() - started print("len_input = {:7d}, time={:7.1f} ms ".format(len(long), needed * 1000)) 的文本序列。处理后,我为每个记录都有一个张量字典。每个记录包含两个序列。

我正在使用len_input = 37000, time= 12.2 ms len_input = 370000, time= 110.6 ms len_input = 3700000, time= 1124.4 ms 进行填充

TFRecordDataset

这会将每个序列填充到批次中的最大序列长度。但是,我想选择一个任意的序列长度,并在实际序列长度较小的情况下将其填充为该长度,否则会截断该序列。

例如,当我尝试将padded_batch替换为dataset = dataset.padded_batch(batch_size, padded_shapes= { 'seq1': tf.TensorShape([None]), 'seq2': tf.TensorShape([None]) }) 时,我遇到了None

  

DataLossError:尝试填充到小于输入元素的大小。

是否有一种方法可以实现与序列上的100类似的功能?

2 个答案:

答案 0 :(得分:0)

没有简单的方法来填充或截断,但是可以使用map函数来获取包含具有所需长度的元素的数据集。这是一个简单的示例:

k = 4
def pad_or_trunc(t):
    dim = tf.size(t)
    return tf.cond(tf.equal(dim, k), lambda: t, lambda: tf.cond(tf.greater(dim, k), lambda: tf.slice(t, [0], [k]), lambda: tf.concat([t, tf.zeros(k-dim, dtype=tf.int32)], 0)))

vals = tf.constant([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
dset1 = tf.data.Dataset.from_tensor_slices(vals)
dset2 = dset1.map(pad_or_trunc)
iter = dset2.make_one_shot_iterator()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(iter.get_next()))
        except tf.errors.OutOfRangeError:
            break

答案 1 :(得分:0)

您可以先使用tf.slicetf.math.greater截断所有更长的序列,然后使用padded_batch填充序列。

一个例子可能像这样:

import tensorflow as tf
import numpy as np

# data generator
def gen():
  for i in [np.array([1, 1, 1]), np.array([2, 2, 2, 2]), np.array([3, 3, 3, 3, 3])]:
    yield i

cut_or_pad = 4 # 100 in your example

def cut_if_longer(el):
  if tf.greater(tf.shape(el), cut_or_pad): # only slice if longer
    return tf.slice(el, begin=[0], size=[cut_or_pad])
  return el

# data pipeline
dataset = tf.data.Dataset.from_generator( gen, (tf.int32), (tf.TensorShape([None])))
dataset = dataset.map( lambda el: cut_if_longer(el))
dataset = dataset.padded_batch(batch_size=2, padded_shapes=[cut_or_pad], padding_values=-1)

list(dataset.take(2).as_numpy_iterator())