排队模型重新启动时出现大错误,尽管已经过培训

时间:2017-11-13 22:04:03

标签: python numpy tensorflow

我有一个模型,它以简单的feed_dict模型开始,并已使用队列转换为一个模型。 feed_dict模型按预期工作,但由于某种原因排队的模型不会像feed_dict版本那样以相同的错误恢复。我无法解释为什么在前几次运行中错误如此之大? (它正在实现一个4维输出分类器,而diff是与训练数据的差异,所以谁将反向关联)

我唯一能想到的是它与Queue容量有关吗?

def multithreaded_train_rows_layered(initial_weights, weights_filename, rows):
    keys = sorted(rows[0].keys())
    num_keys_range = range(len(keys))
    (features, outputs) = features_outputs(keys)

    x_true = np.array([ [ [float(row[feature]) for feature in features] for row in rows] ])
    y_true = np.array([ [float(row[output]) for output in outputs] for row in rows ] )

    #queue
    q = tf.FIFOQueue(capacity=40, dtypes=[tf.float32 for x in x_true])
    enq_op = q.enqueue_many(x_true)
    qr = tf.train.QueueRunner(q, [enq_op] * 1)
    tf.train.add_queue_runner(qr)

    input = q.dequeue()

    if initial_weights is None:
        print 'Creating weights', len(features), len(features), len(outputs)
        w1 = tf.Variable(tf.random_normal((len(features), len(features))), name="w1")
        b1 = tf.Variable(tf.constant(0.1, shape=[len(features)]), name="b1")
        w2 = tf.Variable(tf.random_normal((len(features), len(outputs))), name="w2")
        b2 = tf.Variable(tf.constant(0.1, shape=[len(outputs)]), name="b2")
    else:
        print 'Using supplied weights', len(weights['w1']), len(weights['w2']),  len(weights['w2'][0])
        w1 = tf.Variable(weights['w1'], name="w1")
        b1 = tf.Variable(weights['b1'], name="b1")
        w2 = tf.Variable(weights['w2'], name="w2")
        b2 = tf.Variable(weights['b2'], name="b2")

    y = tf.matmul(tf.nn.relu(tf.matmul(input, w1) + b1), w2) + b2

    loss_op = tf.reduce_mean(tf.square(y - y_true), name="loss")
    #train_op = tf.train.GradientDescentOptimizer(0.3).minimize(loss_op)
    train_op = tf.train.AdamOptimizer(0.01).minimize(loss_op)

    print 'Starting session'
    with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=2, inter_op_parallelism_threads=2)) as sess:
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        last_error = 1.7976931348623157e+308
        this_error = 1.7976931348623157e+307
        diff = 1
        iteration = initial_weights['iteration'] if initial_weights is not None and 'iteration' in initial_weights else 0

        while diff > 0:
            iteration += 1
            last_error = this_error
            for step in range(10):
                _, results, error, w1_computed, b1_computed, w2_computed, b2_computed = sess.run([train_op, y, loss_op, w1, b1, w2, b2])

            this_error = float(error)

            (diff, locs) = compare(y_true, results)
            if locs < 50:
                print "iteration:", iteration, "error:",this_error, "diff:", diff, "locs:", locs
            else:
                print "iteration:", iteration,"error:",this_error, "diff:", diff
            if this_error < last_error:
                with open(weights_filename, "w") as weights_file:
                    try:
                        json.dump({"iteration": iteration,
                        "w1": w1_computed.tolist(), "b1": b1_computed.tolist(),
                        "w2": w2_computed.tolist(), "b2": b2_computed.tolist(),
                        "error": this_error, "inputs": features}, weights_file, indent=True)
                    except Exception as e:
                        print {"iteration": iteration,
                            "w1": w1_computed.tolist(), "b1": b1_computed.tolist(),
                            "w2": w2_computed.tolist(), "b2": b2_computed.tolist(),
                        "error": this_error, "inputs": features}
                        print e
        coord.request_stop()
        coord.join(threads)
    return rows

控制台日志:

iteration: 155 error: 429.936584473 diff: 10258
iteration: 156 error: 1.47592782974 diff: 7480
iteration: 157 error: 46.404006958 diff: 20614
iteration: 158 error: 21.8979129791 diff: 20465
iteration: 159 error: 3.92602086067 diff: 18277
iteration: 160 error: 0.696662843227 diff: 12208
iteration: 161 error: 0.289250463247 diff: 11400
iteration: 162 error: 0.227930724621 diff: 10794
iteration: 163 error: 0.181772902608 diff: 11081
iteration: 164 error: 0.105023987591 diff: 10537
iteration: 165 error: 0.0506426692009 diff: 7624
iteration: 166 error: 0.0422255061567 diff: 4108
iteration: 167 error: 0.0387616828084 diff: 3845
iteration: 168 error: 0.0329982601106 diff: 3310
iteration: 169 error: 0.0305742416531 diff: 3913
iteration: 170 error: 0.0279670264572 diff: 2877
iteration: 171 error: 0.0260633360595 diff: 2617
iteration: 172 error: 0.0243702083826 diff: 2550
iteration: 173 error: 0.0229212064296 diff: 2355
iteration: 174 error: 0.021658487618 diff: 2191
iteration: 175 error: 0.0205383263528 diff: 2093
iteration: 176 error: 0.0195389948785 diff: 1950
iteration: 177 error: 0.0186410453171 diff: 1853
iteration: 178 error: 0.0178280193359 diff: 1740
iteration: 179 error: 0.0170874521136 diff: 1650
iteration: 180 error: 0.0164082739502 diff: 1550
iteration: 181 error: 0.015782084316 diff: 1471
iteration: 182 error: 0.0152017641813 diff: 1413
iteration: 183 error: 0.0146614015102 diff: 1359

[重启]

Using supplied weights 1182 1182 4
Starting session
iteration: 184 error: 277.624237061 diff: 25028
iteration: 185 error: 93.7772903442 diff: 24692
iteration: 186 error: 32.6932220459 diff: 24810
iteration: 187 error: 13.0545969009 diff: 22863
iteration: 188 error: 4.19472408295 diff: 22383
iteration: 189 error: 0.274255514145 diff: 10213
iteration: 190 error: 0.441053062677 diff: 5573
iteration: 191 error: 0.188079148531 diff: 5710
iteration: 192 error: 0.103881075978 diff: 10231
iteration: 193 error: 0.0453641563654 diff: 4084
iteration: 194 error: 0.0464548133314 diff: 4007
iteration: 195 error: 0.0357639603317 diff: 4082
iteration: 196 error: 0.0287289172411 diff: 2713
iteration: 197 error: 0.0246742982417 diff: 2695
iteration: 198 error: 0.0216691829264 diff: 2425
iteration: 199 error: 0.0191784724593 diff: 2305
iteration: 200 error: 0.0170567110181 diff: 2099
iteration: 201 error: 0.0152298733592 diff: 1997
iteration: 202 error: 0.0136485118419 diff: 1869
iteration: 203 error: 0.012274238281 diff: 1761
iteration: 204 error: 0.0110755767673 diff: 1659
iteration: 205 error: 0.0100243529305 diff: 1537
iteration: 206 error: 0.00909862946719 diff: 1452
iteration: 207 error: 0.00828060507774 diff: 1371
iteration: 208 error: 0.0075555741787 diff: 1286
iteration: 209 error: 0.00691094947979 diff: 1230
iteration: 210 error: 0.00633637560531 diff: 1164
iteration: 211 error: 0.00582302454859 diff: 1103
iteration: 212 error: 0.00536326272413 diff: 1031
iteration: 213 error: 0.00495055690408 diff: 983
iteration: 214 error: 0.00457921577618 diff: 936
iteration: 215 error: 0.00424438575283 diff: 886
iteration: 216 error: 0.00394174130633 diff: 832
iteration: 217 error: 0.00366767356172 diff: 791
iteration: 218 error: 0.00341899320483 diff: 751
iteration: 219 error: 0.00319294095971 diff: 704
iteration: 220 error: 0.00298711610958 diff: 660
iteration: 221 error: 0.00279938825406 diff: 632
iteration: 222 error: 0.00262793083675 diff: 596
iteration: 223 error: 0.00247104465961 diff: 566
iteration: 224 error: 0.00232721073553 diff: 540
iteration: 225 error: 0.00219512404874 diff: 514
iteration: 226 error: 0.00207365490496 diff: 487

[重启]

Using supplied weights 1182 1182 4
Starting session
iteration: 227 error: 163.869613647 diff: 18455
iteration: 228 error: 21.2328510284 diff: 18491
iteration: 229 error: 0.492245942354 diff: 4858
iteration: 230 error: 6.67014503479 diff: 8085
iteration: 231 error: 0.820626139641 diff: 3161
iteration: 232 error: 1.1164072752 diff: 17916
iteration: 233 error: 0.173011898994 diff: 1788
iteration: 234 error: 0.0890910178423 diff: 2969
iteration: 235 error: 0.0811333060265 diff: 10824
iteration: 236 error: 0.0661801099777 diff: 2528
iteration: 237 error: 0.0558460131288 diff: 3008

1 个答案:

答案 0 :(得分:0)

所以这是AdamOptimizer的结果。切换到GDO会使行为消失。

sess.run()的赋值也必须是一个元组才能按预期工作。

(_, results, error, w1_computed, b1_computed, w2_computed, b2_computed) = sess.run([train_op, y, loss_op, w1, b1, w2, b2])