我正在使用tensorfow 1.8和python 3.6。我拥有最基本的脚本,这些脚本试图训练和评估示例虹膜数据集。当我调用train()方法时,我得到一个错误:“ ValueError:Tensor(” batch_size:0“,shape =(),dtype = int64,device = / device:CPU:0)必须与Tensor来自同一图(“ TakeDataset:0”,shape =(),dtype = variant)。”
我的代码与示例代码premade_estimator.py和iris_data.py非常相似,以至于我看不到在哪里创建新图。
import tensorflow as tf
if __name__ == '__main__':
# get data from file first
data_file = "c:/..."
cols = ["SepalLength", "SepalWidth", "PetalLength", "PetalWidth", "Label"]
col_def = [[0.0], [0.0], [0.0], [0.0], [""]]
dataset = tf.data.TextLineDataset(data_file)
length = 150
def parse_line(line):
fields = tf.decode_csv(line, col_def)
features = dict(zip(cols, fields))
label = features.pop("Label")
return features, label
dataset = dataset.map(parse_line).shuffle(200).repeat()
print(dataset)
train_size = int(.8 * length)
print(f"train size {train_size}")
train_ds = dataset.take(train_size)
print(train_ds)
features = []
for i in range(5):
features.append(tf.feature_column.numeric_column(cols[i]))
def prep_train_dataset(ds):
return ds.batch(10)
classifier = tf.estimator.DNNClassifier(feature_columns=features, hidden_units=[10, 10], n_classes=3)
classifier.train(lambda: prep_train_dataset(train_ds), steps=12)
...
我没有看到关于此错误的任何帖子,都没有提及“ with graph”子句。我想念什么?