tensorflow-如何使用可变图像的大小进行conv2d_transpose?

时间:2018-05-29 05:43:30

标签: python tensorflow

我正试图找到一种方法来卷积转置可变大小的图像。 我使用了tf.nn.conv2d_transpose api,但我失败了。

<dependency>
    <groupId>org.apache.tomcat.embed</groupId>
    <artifactId>tomcat-embed-jasper</artifactId>
</dependency>
<dependency>
    <groupId>javax.servlet</groupId>
    <artifactId>jstl</artifactId>
</dependency>

import tensorflow as tf

def conv2d_transpose(inputs, filters_shape, strides, name, padding="SAME", activation=None):
  filters = get_conv_filters(filters_shape, name)

  inputs_shape = inputs.get_shape().as_list()
  output_shape = tf.stack(calc_output_shape(inputs_shape, filters_shape, strides, padding)) #tf.pack renamed tf.stack
  strides = [1,*strides,1]

  conv_transpose = tf.nn.conv2d_transpose(inputs, filters, output_shape=output_shape,
                                          strides=strides, padding=padding, name=name+"transpose")

  if activation != None:
    conv_transpose = activation(conv_transpose)

  return conv_transpose

def get_conv_filters(filters_size, name):
  conv_weights = tf.Variable(tf.truncated_normal(filters_size), name=name + "weights")
  return conv_weights

def calc_output_shape(inputs_shape, filters_shape, strides, padding): # For conv_transpose
  batch_size, inputs_height, inputs_width, n_channel = inputs_shape
  filters_height, filters_width, before_n_channel, after_n_channel = filters_shape
  strides_height, strides_width = strides

  if padding =="SAME":
    output_height = inputs_height*strides_height
    output_width = inputs_width*strides_width

  else: # padding="VALID"
    output_height = (inputs_height-1)*strides_height+filters_height
    output_width = (inputs_width-1)*strides_width+filters_width

  return [batch_size, output_height, output_width, after_n_channel]

然后,我收到以下错误。

input_images = tf.placeholder(tf.float32, [None, None, None, 3])
transpose_layer = conv2d_transpose(input_images, filters_shape=[3,3,3,3], strides=[2,2], name="conv_3_transpose", padding="SAME", activation=tf.nn.relu)

我认为这个错误的原因是input_shape没有修复。因此,在计算output_shape时会发生错误。 我该如何克服这个问题?

1 个答案:

答案 0 :(得分:2)

使用动态形状,您可以找到详细信息here。您的input_shape应为:

inputs_shape = tf.shape(inputs)
...
batch_size, inputs_height, inputs_width, n_channel = inputs_shape[0],inputs_shape[1],inputs_shape[2],inputs_shape[3]