如何在Tensorflow Estimator中为火车输入功能进行批处理中的预处理

时间:2019-04-27 04:09:04

标签: python tensorflow deep-learning tensorflow-estimator

我已经使用Estimator编写了一个简单的TF模型。简单的版本有效,但是在训练过程中输入量很大,我的代码遇到了内存错误。为此,我想批量发送数据到火车功能,并在批处理而不是整个数据中执行预处理。我下面的代码会出错。有人可以请我回顾一下并帮助我找出问题所在吗?

def train_fn_custom(features, labels, batch_size):
    print(type(features)) // gives <class 'numpy.ndarray'>
    print(type(labels))
    def _preprocess_function(features, labels):        
        print(type(features)) // gives <class 'tensorflow.python.framework.ops.Tensor'>
        print(type(labels))
        features=tokenizer.texts_to_sequences(features, maxlen)
        features=sequence.pad_sequences(features, maxlen=maxlen)
        features=features.astype(np.float32)

        labels=utils.to_categorical(labels, nb_classes)
        labels=labels.astype(np.float32)
        return features, labels

    dataset = tf.data.Dataset.from_tensor_slices((features, labels))
    dataset = dataset.map(_preprocess_function)
    dataset = dataset.batch(batch_size)
    return dataset.make_one_shot_iterator().get_next()

batch_size = 128
customTrain = lambda: train_fn_custom(X1, y_train, batch_size)
estimator_model.train(input_fn=customTrain, steps=10000)

错误:

in texts_to_sequences_generator
    for text in texts:
  File "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 459, in __iter__
    "Tensor objects are only iterable when eager execution is "
TypeError: Tensor objects are only iterable when eager execution is enabled. To iterate over this tensor use tf.map_fn.

0 个答案:

没有答案