我正在尝试使用inception_resnet_v2
来训练我自己的人体姿势估计模型。这是一个回归问题,我需要一个完全连接的层而不需要激活功能,因此标准API对我的情况不利。
以下代码运行正常,但我担心session.run(init)
会将权重重置为随机而不是imagenet
。
有出路吗?如果我删除会话init,则会抛出错误。
def inception_resnet_v2():
model = InceptionResNetV2(include_top=False,
weights='imagenet',
input_shape=(FLAGS.resize_input_image,
FLAGS.resize_input_image, 3))
x = model.output
x = Flatten()(x)
x = Dense(28, activation=None, name='predictions')(x)
model = Model(input=model.input, output=x)
print(model.summary())
return model
model = inception_resnet_v2()
network = model(images)
_, mean_loss = regression_loss.loss_func(joints_gt, is_valid_joint, network)
train_op = optimizer.rms_prop(mean_loss=mean_loss, global_step=global_step)
init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
with tf.Session(config=config) as sess:
# Initialize variables
sess.run(init)
答案 0 :(得分:1)
答案只是保存模型并在初始化变量后重新加载。
def inception_resnet_v2():
model = InceptionResNetV2(include_top=False,
weights='imagenet',
input_shape=(FLAGS.resize_input_image,
FLAGS.resize_input_image, 3))
x = model.output
x = Flatten()(x)
x = Dense(28, activation=None, name='predictions')(x)
model = Model(input=model.input, output=x)
print(model.summary())
return model
model = inception_resnet_v2()
**model.save('mymodel.h5')**
network = model(images)
_, mean_loss = regression_loss.loss_func(joints_gt, is_valid_joint, network)
train_op = optimizer.rms_prop(mean_loss=mean_loss, global_step=global_step)
init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
with tf.Session(config=config) as sess:
# Initialize variables
sess.run(init)
**model.load('mymodel.h5')**