TensorFlow tf.map_fn删除尺寸

时间:2018-07-18 21:26:03

标签: python tensorflow

我正在尝试使用defaultdict重新映射输入张量中的值。

class MyDataSet(object):
    def __init__(self):
        self.class_map = MyDataSet.remap_class()

    @staticmethod
    def remap_class():
        class_remap = defaultdict(lambda: 11)
        class_remap[128] = 0  
        class_remap[130] = 1  
        class_remap[132] = 2
        # ...

    def parser(self, serialized_example):
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.string),
            })
        label = tf.decode_raw(features['label'], tf.uint8)
        label.set_shape([256 * 512])
        label = tf.cast(tf.reshape(label, [256, 512]), tf.int32)

        output_label = tf.map_fn(lambda x: self.class_map(x), label)

    #...
    dataset = tf.data.TFRecordDataset(filenames).repeat()
    dataset = dataset.map(self.parser, num_parallel_calls=batch_size)

标签形状为(256,512),但output_label形状为(256,)。如果我尝试使用

更改output_label
output_label = tf.reshape(output_label, [256, 512])

我得到了例外

ValueError: Cannot reshape a tensor with 256 elements to shape [256,512] (131072 elements) for 'Reshape_2' (op: 'Reshape') with input shapes: [256], [2] and with input tensors computed as partial shapes: input[1] = [256,512].

如果我尝试使用

更改output_label
output_label.set_shape([256, 512])

我得到了例外

ValueError: Shapes (256,) and (256, 512) must have the same rank

如何在output_label中映射值并保持与label中相同的形状?

1 个答案:

答案 0 :(得分:0)

该问题的解决方法是对平坦的张量进行操作。所以改变:

ImageButton btnPlay = findViewById(R.id.btnPlay_Song);

btnPlay.setOnTouchListener(new View.OnTouchListener() {
    @Override
    public boolean onTouch(View v, MotionEvent event) {

        int eventPadTouch = event.getAction();
        float iX=event.getX();
        float iY=event.getY();

        switch (eventPadTouch) {

            case MotionEvent.ACTION_DOWN:
                bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.button_test);

                if (iX>=0 & iY>=0 & iX<bitmap.getWidth() & iY<bitmap.getHeight()) { //Makes sure that X and Y are not less than 0, and no more than the height and width of the image.
                    if (bitmap.getPixel((int) iX, (int) iY)!=0) {
                        // actual image area is clicked(alpha not equal to 0), do something
                        Toast.makeText(ActivityPlayerSong.this, "Play", Toast.LENGTH_SHORT).show();
                    }
                }
                return true;
        }
        return false;
    }
});

收件人:

    def parser(self, serialized_example):
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.string),
            })
        label = tf.decode_raw(features['label'], tf.uint8)
        label.set_shape([256 * 512])
        label = tf.cast(tf.reshape(label, [256, 512]), tf.int32)

        output_label = tf.map_fn(lambda x: self.class_map(x), label)