使用Tensorflow Estimator API和SemSeg的图像

时间:2018-04-18 16:55:07

标签: python tensorflow tensorflow-estimator

我尝试使用Tensorflow Estimator API实现我的模型。 作为数据输入功能,我使用

def input_fn():
    train_in_np = sorted(io_utils.loadDataset(join(basepath, r"leftImg8bit/train/*/*")))
    train_out_np = sorted(io_utils.loadDataset(join(basepath, r"gtFine/train/*/*_ignoreLabel.png")))

    train_in = tf.constant(train_in_np)
    train_out = tf.constant(train_out_np)

    tr_data = tf.data.Dataset.from_tensor_slices((train_in, train_out))
    tr_data = tr_data.shuffle(len(train_in_np))
    tr_data = tr_data.repeat(epoch_cnt+1)

    tr_data = tr_data.apply(tf.contrib.data.map_and_batch(parse_files, batch_size=batchSize))
    tr_data = tr_data.prefetch(buffer_size=batchSize)

    iterator = tr_data.make_initializable_iterator()
    return iterator.get_next()

io_utils.loadDataset只返回文件路径列表。 数据本身由

解析
def parse_files(in_file, gt_file):
    image_in = tf.read_file(in_file)
    image_in = tf.image.decode_image(image_in, channels=3)
    image_in.set_shape([None, None, 3])
    image_in = tf.cast(image_in, tf.float32)

    mean, std = tf.nn.moments(image_in, [0, 1])
    image_in = image_in - mean
    image_in = image_in / std

    gt = tf.read_file(gt_file)
    gt = tf.image.decode_image(gt, channels=1)    
    gt.set_shape([None, None, 1])
    gt = tf.cast(gt, tf.int32)

    return {'img':image_in}, gt

我的估算工具以

开头
def estimator_fcn_model_fn(features, labels, mode, params):
    x = tf.feature_column.input_layer(features, params['feature_columns'])

并且要素列定义为

my_feature_columns = []
my_feature_columns.append(tf.feature_column.numeric_column(key='img'))

我跳过的其余代码的可清除性。 我的问题在于功能的形状:

  

Blockquote ValueError :('对于具有等级的输入不支持卷积',2)

x,features和feature_columns的打印输出:

  

Tensor(“input_layer / concat:0”,shape =(?,1),dtype = float32)

     

{'img':tf.Tensor'IteratorGetNext:0'shape =(?,?,?,3)dtype = float32}

     

[_ NumericColumn(key ='img',shape =(1,),default_value = None,dtype = tf.float32,normalizer_fn = None)]

有谁知道如何解决这个问题,我猜它会与特征列有关,但我不知道将这个应用于图像的人。

1 个答案:

答案 0 :(得分:0)

由于Y.Luo给出了提示tf.feature_column.input_layer不适合图像。一种更简单的方法是通过密钥直接使用特征字典,这可以通过参数传递以获得更大的灵活性。

x = tf.feature_column.input_layer(features, params['feature_columns'])

而不是

{{1}}