TF:如何从用户输入数据创建数据集

时间:2017-10-18 07:16:33

标签: python tensorflow tensorflow-datasets

我最近开始使用tensorflow,更具体地说,使用新的数据集API。 通过将数据集的迭代器插入到代表输入和标签的图形节点,我成功地使用数据集将训练数据提供给我的简单模型。类似的东西:

input = input_dataset.make_one_shot_iterator().get_next() 
label = label_dataset.make_one_shot_iterator().get_next()

现在我想知道当我必须对用户输入进行推断时要做什么,也就是说,用户给了我一个输入值,我必须做出预测。如果我有占位符,我会将用户输入放在feed_dict中,但是对于数据集api,我几乎不知道如何做类似的事情。我是否只有一个单独的图表用于我的input变量是占位符的推理?

我已经尝试按照here所述制作一个可输入迭代器,但这只适用于字符串的占位符,而我的输入是int32。

感谢您的任何建议。

1 个答案:

答案 0 :(得分:0)

出于特定目的,tensorflow提供tf.placeholder_with_default API

# Create a Dataset
dataset = tf.data.Dataset.zip((input_dataset, label_dataset)).batch(32).repeat(...)

# Create Iterator
input, label = dataset.make_one_shot_iterator()

# Create Placholders
x = tf.placeholder_with_default(input, shape=[...], name='input')
y = tf.placeholder_with_default(label, shape-[...], name='label')

def nn_model(features, labels):
    logits = ...    
    loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits))
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)
    return optimizer, loss

# Create Model
train_op, loss_op = nn_model(x, y)

# Training
sess.run(train_op)

# Inference
sess.run(logits, feed_dict={x:..., y:...})