如何使用数据集API将Iterator的输出映射到Tensorflow中的丢失函数占位符

时间:2017-09-11 13:21:45

标签: python tensorflow dataset

下面是tensorflow网站关于使用数据集api来使用来自tfrecords的数据的代码

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)

iterator = dataset.make_one_shot_iterator()
next_example, next_label = iterator.get_next()

loss = model_function(next_example, next_label)
training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
  while not sess.should_stop

通常我将网络定义为

x = tf.placeholder(tf.float32, [None, INPUT_SIZE], name='INPUT')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_SIZE], name='OUTPUT')

w1 = tf.Variable(tf.truncated_normal([INPUT_SIZE, L1_SIZE], stddev=0.1))
b1 = tf.Variable(tf.constant(0.1, shape=[L1_SIZE]))
w2 = tf.Variable(tf.truncated_normal([L1_SIZE, L2_SIZE], stddev=0.1))
b2 = tf.Variable(tf.constant(0.1, shape=[L2_SIZE]))

w3 = tf.Variable(tf.truncated_normal([L2_SIZE, OUTPUT_SIZE], stddev=0.1))
b3 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_SIZE]))

input_layer = tf.nn.relu(tf.matmul(x, w1) + b1)
hidden_layer1_dropout = tf.nn.dropout(input_layer, DROPOUT1)

hidden_layer2 = tf.nn.relu(tf.matmul(hidden_layer1_dropout, w2) + b2)
hidden_layer2_dropout = tf.nn.dropout(hidden_layer2, DROPOUT2)

y = tf.nn.softmax(tf.matmul(hidden_layer2_dropout, w3) + b3)

和我的损失函数

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

但是现在看起来没有必要再使用feed_dict,但我对如何以这种新方式定义损失函数感到困惑,示例代码只显示一行

loss = model_function(next_example, next_label)

任何人都可以帮忙详细说明如何定义损失函数,如何将要素和标签映射到占位符?非常感谢

1 个答案:

答案 0 :(得分:6)

在使用DataSet apis时不再需要占位符,因为读取的数据已经是else的一部分。

我们不需要在python代码中读取文件,并且在训练时提供它们,但是在tf.Graph中读取数据作为tensorflow操作,对于主要在cpp中运行的tensorflow操作,它将更有效率。

就像你的情况一样,这是行:

var chartGraphContent =
        <div className={"chartContent"}>
            {this.state.modalityGraph['nca'] > 0 ?
                <div className={"chart-container"}>
                    <Chart
                        chartType="ColumnChart"
                        data = { this.state.modalityGraph?this.state.modalityGraph.chartData['units']:emptyDataRows }
                        options={chartOptions}
                        graph_id="modalitiesChart"
                        width="100%"
                        height="250px"
                    /> 
                </div>
                : "<span>Else Block</span>"
            }
        </div>;

变成:

var ifBlockCode = function ifBlockCode(){
    return (
        <div className={"chart-container"}>
            <Chart
                chartType="ColumnChart"
                data = { this.state.modalityGraph?this.state.modalityGraph.chartData['units']:emptyDataRows }
                options={chartOptions}
                graph_id="modalitiesChart"
                width="100%"
                height="250px"
            /> 
        </div>
    )
}

var elseBlockCode = function elseBlockCode(){
    return (
        <span>Else Block</span>
    )
}
var chartGraphContent =
<div className={"chartContent"}>
    {this.state.modalityGraph['nca'] > 0 ?
        {this.ifBlockCode} : {this.elseBlockCode}
    }
</div>;

在调用tf.Graph

删除 x = tf.placeholder(tf.float32, [None, INPUT_SIZE], name='INPUT') y_ = tf.placeholder(tf.float32, [None, OUTPUT_SIZE], name='OUTPUT')