Tensorflow,如何从张量中删除填充(特定值)

时间:2018-12-27 23:28:33

标签: python tensorflow

我正在使用数据集API从tfrecords文件中批量处理数据。数据具有不同长度的行。由于要使用batch()函数,所有行的大小都必须相等,因此我需要使用padded_batch()来代替。这样会填充批处理中的所有行,以匹配批处理中最大行的大小。

批处理之后,是否可以删除这些填充的值?

这是一个最小的示例,其中我使用“ -1”作为填充值

import math
import numpy as np
import tensorflow as tf

import math
import numpy as np
import tensorflow as tf

#Set up data
cells = np.array([[0,1,2,3], [2,3,4], [3,6,5,4,3], [3,9]])
mells = np.array([[0], [2], [3], [9]])
print(cells)

#Write data to tfrecords
writer = tf.python_io.TFRecordWriter('test.tfrecords')
for index in range(mells.shape[0]):
    example = tf.train.Example(features=tf.train.Features(feature={
        'num_value':tf.train.Feature(int64_list=tf.train.Int64List(value=mells[index])),
        'list_value':tf.train.Feature(int64_list=tf.train.Int64List(value=cells[index]))
    }))
    writer.write(example.SerializeToString())
writer.close()

#Open tfrecords using dataset api and batch data
filenames = ["test.tfrecords"]
dataset = tf.data.TFRecordDataset(filenames)
def _parse_function(example_proto):
    keys_to_features = {'num_value':tf.VarLenFeature(tf.int64),
                        'list_value':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return tf.sparse.to_dense(parsed_features['num_value']), \
           tf.sparse.to_dense(parsed_features['list_value'])
# Parse the record into tensors.
dataset = dataset.map(_parse_function)
# Shuffle the dataset
dataset = dataset.shuffle(buffer_size=1)
# Repeat the input indefinitly
dataset = dataset.repeat()  
# Generate batches
dataset = dataset.padded_batch(3, padded_shapes=([None],[None]), padding_values=(tf.constant(-1, dtype=tf.int64)
                                                 ,tf.constant(-1, dtype=tf.int64)))
iterator = dataset.make_one_shot_iterator()
i, data = iterator.get_next()

with tf.Session() as sess:
    print(sess.run([i, data]))
    print(sess.run([i, data]))

到目前为止,我尝试过使用布尔型掩码,Filter out non-zero values in a tensor

但是,我的尝试只是使批处理中的所有张量展平。这是我使用的代码

filenames = ["test.tfrecords"]
dataset = tf.data.TFRecordDataset(filenames)
def _parse_function(example_proto):
    keys_to_features = {'num_value':tf.VarLenFeature(tf.int64),
                        'list_value':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return tf.sparse.to_dense(parsed_features['num_value']), \
           tf.sparse.to_dense(parsed_features['list_value'])
# Parse the record into tensors.
dataset = dataset.map(_parse_function)
# Shuffle the dataset
dataset = dataset.shuffle(buffer_size=1)
# Repeat the input indefinitly
dataset = dataset.repeat()  
# Generate batches
dataset = dataset.padded_batch(3, padded_shapes=([None],[None]), padding_values=(tf.constant(-1, dtype=tf.int64)
                                                 ,tf.constant(-1, dtype=tf.int64)))
# Create a one-shot iterator
iterator = dataset.make_one_shot_iterator()
i, data = iterator.get_next()

neg1 = tf.constant(-1, dtype=tf.int64)
where1 = tf.not_equal(data, neg1)

result=tf.boolean_mask( data , where1)

with tf.Session() as sess:
    print(sess.run([data, result ]))

这就是结果

[array([[ 0,  1,  2,  3, -1],
       [ 2,  3,  4, -1, -1],
       [ 3,  6,  5,  4,  3]]), array([0, 1, 2, 3, 2, 3, 4, 3, 6, 5, 4, 3])]

我需要一些时间来保留张量的形状。因此结果将是类似

array([[ 0,  1,  2,  3],
           [ 2,  3,  4],
           [ 3,  6,  5,  4,  3]])

1 个答案:

答案 0 :(得分:0)

使用参差不齐的张量。从上面更改代码。

public static void UpdateCell(DataGridViewDisableButtonCell cell)
    {
        switch ((int)cell.Value)
        {
            case -1:
                MessageBox.Show("Game Over");
                break;
            case 0:
                cell.Style.ForeColor = Color.Gray;
                break;
            case 1:
                cell.Style.ForeColor = Color.Blue;
                break;