我有下面的代码用于在包含特征序列的数据集上运行keras LSTM分类器网络,并且运行良好,没有问题(省略了详细信息以保持文本可读性)。我的Tensorflow版本是1.13:
int index = _widgets.indexWhere((widget) {
CardSmallSquare card = (((widget as GestureDetector).child as StreamBuilder).child as CardSmallSquare);
if (card.title == "Title") {
return true;
}
return false;
});
我想在分布式模式下运行它,所以我用一个估计器对象封装了keras分类器。网上尚不清楚分布式学习文档,但是据我了解(并希望如此),我应该像这样修改main()函数:
def parse_tfrecord(example):
features = tf.parse_single_example(example, featuresDict)
label = features['label']
data = tf.decode_raw(features['data'], tf.int64)
return data, label
def read_datasets(pattern, numFiles, numEpochs=None, batchSize=None):
files = tf.data.Dataset.list_files(pattern)
def _parse(x):
x = tf.data.TFRecordDataset(x, compression_type='GZIP')
return x
dataset = files.interleave(_parse, cycle_length=numFiles, block_length=1).map(parse_tfrecord)
dataset = dataset.batch(batchSize)
dataset = dataset.repeat(numEpochs)
return dataset
def keras_model(...)
def main(args):
...
train_data = read_datasets(...)
val_data = read_datasets(...)
test_data = read_datasets(...)
model = keras_model(...)
model.compile(...)
model.fit(train_data, epochs=epochs, steps_per_epoch = train_steps, validation_data=val_data, validation_steps = val_steps)
if __name__ == '__main__':
...
main(args)
但是我得到以下错误回溯:
def main(args):
...
train_data = read_datasets(...)
val_data = read_datasets(...)
test_data = read_datasets(...)
model = keras_model(...)
model.compile(...)
#model.fit(train_data, epochs=epochs, steps_per_epoch = train_steps, validation_data=val_data, validation_steps = val_steps)
runConfig = tf.estimator.RunConfig(
session_config=config,
model_dir=log_dir,
save_summary_steps=1,
save_checkpoints_steps=train_steps
)
estimator = tf.keras.estimator.model_to_estimator(model, model_dir=log_dir, config=runConfig)
train_input_fn = train_data.make_one_shot_iterator().get_next
eval_input_fn = val_data.make_one_shot_iterator().get_next
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=epochs*train_steps)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, start_delay_secs=1, throttle_secs=1, steps=None)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
if __name__ == '__main__':
...
TF_CONFIG = {
'task': ...
'cluster': ...
}
os.environ['TF_CONFIG'] = json.dumps(TF_CONFIG)
main(args)
有人可以帮我解决这个问题吗?