我正在尝试创建一个由多个不同的数据集更新的网络。
代码如下:
# Network
def UNet(args):
with tf.variable_scope('UNet',reuse=args.reuse):
input_imgs_ph=tf.placeholder(
tf.float32,
[None,
int(args.input_size),
int(args.input_size),
3])
# ============================================================
# conv2d
o_conv1=tf.layers.conv2d(
inputs=input_imgs_ph,
filters=100,
kernel_size=[3,3],
strides=1,
padding="same",
kernel_regularizer=tf.contrib.layers.l2_regularizer(reg))
o_conv1=tf.nn.leaky_relu(o_conv1,alpha=0.2,name='leaky_relu_after_conv1')
o_conv1=tf.keras.layers.AveragePooling2D((2,2),name='avgpool_after_conv1')(o_conv1)
# ...
return input_imgs_ph,end_conv_r,end_conv_s
# ============================================================
# train_dataset1.py
# Train algorithm over dataset1
def train(mpisintel_d_bs_pa,args):
# ============================================================
with tf.variable_scope("UNet"):
# Perform augmentation
# ...
# ============================================================
# Create placeholders
# ...
# ============================================================
# Construct network graph structure
input_imgs_ph,end_conv_r,end_conv_s=networks.UNet(args)
# ...
# ============================================================
# Loss function
# ...
# ============================================================
# c optimizer: AdamOptimizer node
optimizer=tf.train.AdamOptimizer(0.001).minimize(r_data_loss)
# ============================================================
# c init: initializer node which initialzes trainable Variables
init=tf.global_variables_initializer()
# ============================================================
feed_dict={
input_imgs_ph:cgmit_tr_3c_imgs,
cgmit_gt_R_3c_imgs_ph:cgmit_gt_R_3c_imgs,
cgmit_gt_S_1c_imgs_ph:cgmit_gt_S_1c_imgs,
cgmit_mask_3c_imgs_ph:cgmit_mask_3c_imgs}
return init,feed_dict,r_data_loss,optimizer
# ============================================================
# train_dataset2.py
# Train algorithm over dataset2
def train(mpisintel_d_bs_pa,args):
# ============================================================
with tf.variable_scope("UNet"):
# Perform augmentation
# ...
# ============================================================
# Create placeholders
# ...
# ============================================================
# Construct network graph structure
input_imgs_ph,end_conv_r,end_conv_s=networks.UNet(args)
# ...
# ============================================================
# Loss function
# ...
# ============================================================
# c optimizer: AdamOptimizer node
optimizer=tf.train.AdamOptimizer(0.001).minimize(r_data_loss)
# ============================================================
# c init: initializer node which initialzes trainable Variables
init=tf.global_variables_initializer()
# ============================================================
feed_dict={
input_imgs_ph:cgmit_tr_3c_imgs,
cgmit_gt_R_3c_imgs_ph:cgmit_gt_R_3c_imgs,
cgmit_gt_S_1c_imgs_ph:cgmit_gt_S_1c_imgs,
cgmit_mask_3c_imgs_ph:cgmit_mask_3c_imgs}
return init,feed_dict,r_data_loss,optimizer
# ============================================================
# train_jointly.py
# --------------------------------------------------------------------------------
# Create iterator for dataset1 (mpisintel dataset)
path_data_mpisintel=tf.data.Dataset.from_tensor_slices(mpisintel_d_list)
path_data_mpisintel=path_data_mpisintel.prefetch(
buffer_size=int(args.batch_size)).batch(int(args.batch_size)).repeat()
path_data_mpisintel_iter=path_data_mpisintel.make_one_shot_iterator()
# Construct train algorithm graph for dataset1
init_dense,feed_dict_dense,r_data_loss_dense,optimizer_dense=\
train_over_dense_dataset.train(mpisintel_d_bs_pa,args)
# --------------------------------------------------------------------------------
# Create iterator for dataset2 (iiw dataset)
path_data_iiw=tf.data.Dataset.from_tensor_slices(iiw_d_list)
path_data_iiw=path_data_iiw.prefetch(
buffer_size=int(args.batch_size)).batch(int(args.batch_size)).repeat()
path_data_iiw_iter=path_data_iiw.make_one_shot_iterator()
# Construct train algorithm graph for dataset2
init_iiw,feed_dict_iiw,r_data_loss_iiw,optimizer_iiw=\
train_over_iiw_dataset.train(iiw_d_bs_pa,args)
# ============================================================
# Train network over epochs
for one_ep in range(epoch):
# Train network over batch of dataset1
with tf.Session() as sess_dense:
# Initialize Variables
sess_dense.run(init_dense)
mpisintel_d_bs_pa=path_data_mpisintel_iter.get_next()
loaded_input_img=load_image_from_paths(mpisintel_d_bs_pa)
loss_val,optim=sess_dense.run(
[r_data_loss_dense,optimizer_dense],
feed_dict=loaded_input_img)
# print("loss_val",loss_val)
# loss_val 9.534523
# ============================================================
# Train network over batch of dataset2
with tf.Session() as sess_iiw:
# Initialize Variables
sess_iiw.run(init_iiw)
iiw_d_bs_pa=path_data_iiw_iter.get_next()
loaded_input_img=load_image_from_paths(iiw_d_bs_pa)
loss_val,optim=sess_iiw.run(
[r_data_loss_iiw,optimizer_iiw],
feed_dict=loaded_input_img)
问题:
1 ..我尝试多次使用以下网络
def UNet(args):
with tf.variable_scope('UNet',reuse=args.reuse):
在每个训练算法图中
train_over_dense_dataset.train(mpisintel_d_bs_pa,args)
train_over_iiw_dataset.train(iiw_d_bs_pa,args)
这会导致错误。
2 ..所以我尝试使用tf.reset_default_graph()
但是要实现我想使用的培训架构很困难。
例如,我不禁在for loop
中创建整个图形
否则,我找不到满足我意图的方法。