Keras tensorflow训练从头开始与imagenet权重

时间:2018-02-03 14:24:38

标签: python tensorflow keras

我正在尝试使用inception_resnet_v2来训练我自己的人体姿势估计模型。这是一个回归问题,我需要一个完全连接的层而不需要激活功能,因此标准API对我的情况不利。

以下代码运行正常,但我担心session.run(init)会将权重重置为随机而不是imagenet

有出路吗?如果我删除会话init,则会抛出错误。

def inception_resnet_v2():

    model = InceptionResNetV2(include_top=False,
                      weights='imagenet',
                      input_shape=(FLAGS.resize_input_image,
                                   FLAGS.resize_input_image, 3))
    x = model.output
    x = Flatten()(x)
    x = Dense(28, activation=None, name='predictions')(x)
    model = Model(input=model.input, output=x)

    print(model.summary())

    return model 

model = inception_resnet_v2()
network = model(images)

_, mean_loss = regression_loss.loss_func(joints_gt, is_valid_joint, network)
train_op = optimizer.rms_prop(mean_loss=mean_loss,   global_step=global_step)

init = tf.group(tf.global_variables_initializer(),
            tf.local_variables_initializer())

with tf.Session(config=config) as sess:

    # Initialize variables
    sess.run(init)

1 个答案:

答案 0 :(得分:1)

答案只是保存模型并在初始化变量后重新加载。

    def inception_resnet_v2():

model = InceptionResNetV2(include_top=False,
                  weights='imagenet',
                  input_shape=(FLAGS.resize_input_image,
                               FLAGS.resize_input_image, 3))
x = model.output
x = Flatten()(x)
x = Dense(28, activation=None, name='predictions')(x)
model = Model(input=model.input, output=x)

print(model.summary())

return model 

model = inception_resnet_v2()
**model.save('mymodel.h5')**
network = model(images)

_, mean_loss = regression_loss.loss_func(joints_gt, is_valid_joint, network)
train_op = optimizer.rms_prop(mean_loss=mean_loss,      global_step=global_step)

init = tf.group(tf.global_variables_initializer(),
            tf.local_variables_initializer())

with tf.Session(config=config) as sess:

    # Initialize variables
    sess.run(init)
    **model.load('mymodel.h5')**