我有以下代码在tf1.3和tf1.4中有效。当我在t1.2中尝试时,代码运行但只是挂起。我只使用tf1.2,因为我想在Google云端ml引擎上进行测试,而引擎在此阶段仅支持tf1.2:
这是我的输入CSV文件:
A B Result
2 2 4
2 3 5
这是我的代码:
csv_defaults = OrderedDict([("A", [0]), ("B", [0]), ("Result", [0])]);
file_path = "InputFile.csv";
def csv_decoder(line):
parsed = tf.decode_csv(line, list(csv_defaults.values()), field_delim="\t");
return parsed[0];
def test():
dataset = (TextLineDataset(file_path)
.skip(1)
.map(csv_decoder)
.batch(512));
iterator = dataset.make_one_shot_iterator();
columns = iterator.get_next();
return columns;
input_fn = test();
with tf.Session() as sess:
columns = sess.run(input_fn);
print(columns);
这是tf 1.4
中的输出[2 2]
当我在tf 1.2中运行相同的代码时,代码只会挂起并且不会返回任何内容..
我从https://github.com/tensorflow/tensorflow/issues/13751知道,在tf 1.2中,parse_csv函数不能返回dict,tuple或namedtuple(我也尝试过它们)。所以我已经把它剥离下来,只返回一个张量。在bug中,@ mrry建议提取功能的值,然后手动创建元组。函数解析器(记录)返回一个张量,似乎工作。我的parse_csv也返回一个张量,但它仍然不起作用。有人可以帮助我吗?
对不起,如果我遗漏了一些明显的东西。我过去几周只使用过tf并且已经搜索过很多答案。