是否可以更改张量流预训练模型的输入形状?

时间:2019-10-07 21:18:00

标签: tensorflow transfer-learning

我有一个用于图像分割的Tensorflow预训练模型,该模型接收6个波段作为输入,我想将模型的输入大小更改为接收4个波段,因此我可以使用自己的数据集进行重新训练,但仍然无法这样做,不确定是否可行?

我尝试按名称获取输入节点,并使用import_graph_def对其进行更改,但均未成功,似乎是在尝试替换时要求尊重尺寸。

graph = tf.get_default_graph()
tf_new_input = tf.placeholder(shape=(4, 256, 256), dtype='float32', name='new_input')
tf.import_graph_def(graph_def, input_map={"ImageInputLayer": tf_new_input})

但是我遇到以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimensions must be equal, but are 4 and 6 for 'import/ImageInputLayer_Sub' (op: 'Sub') with input shapes: [4,256,256], [6,256,256]

2 个答案:

答案 0 :(得分:0)

您必须将4通道占位符输入转换为6通道输入,并且输入图像形状应与6通道模型期望的相同。您可以使用任何操作,但是在将conv2d输入现有模型之前,它是一种易于执行的操作。这就是你的方法。

with tf.Graph().as_default() as old_graph:
  # You have to load your 6 channel input graph here
  saver.restore(tf.get_default_session(), <<save_path>>)
  # Assuming that input node is named as 'input_node' and 
  # final node is named as 'softmax_node'

with tf.Graph().as_default() as new_graph:
  tf_new_input = tf.placeholder(shape=(None, 256, 256, 4), dtype='float32')

  # Map 4 channeled input to 6 channel and 
  # image input shape should be same as expected by old model.
  new_node = tf.nn.conv2d(tf_new_input, (3, 3, 4, 6), strides=1, padding='SAME')

  # If you want to obtain output node so that you can further perform operations.
  softmax_node = tf.import_graph_def(old_graph, input_map={'input_node:0': new_node}, 
                                     return_elements=['softmax_node:0'])

答案 1 :(得分:0)

user1190882 很好地回答了这个问题。刚使用本节来发布代码以供将来参考时,由于出现错误,我不得不通过在单独的变量中创建过滤器进行了一些小的更改:Shape必须为4级,但对于'Conv2D'必须为1级。另外,由于模型的输入格式为“ Channels First”,因此做了一些小的更改,并添加了data_format标志。

with tf.Graph().as_default() as new_graph:

tf_new_input = tf.placeholder(shape=(None, 4, 256, 256), dtype='float32')
# Creating separate variable for filter  
filterc = tf.Variable(tf.random_normal([3, 3, 4, 6]))
new_node = tf.nn.conv2d(tf_new_input, filterc, strides=1, padding='SAME',  data_format='NCHW')
tf.import_graph_def(old_graph, input_map={'ImageInputLayer': new_node})