删除张量流中张量的每个第3个元素

时间:2017-11-03 16:58:28

标签: python numpy tensorflow

我正在寻找张量流中np.delete的类似 - 所以我有批量张量 - 每个批次都有(batch_size, variable_length)形状,我希望得到一个形状张量(batch_size, 2 * variable_length / 3)。如所看到的,每个批次具有不同的长度,其从tfrecord存储和读取。我what API在这里我有点不知所措。相关(对于numpy):

解决方案只是np.delete(x, slice(2, None, 3))(在执行重塑以满足batch_size之后)

根据评论中的要求,我发布了用于解析单个示例原型的代码 - 尽管我有兴趣将张量的第n个(第3个)元素作为独立的问题删除。

@classmethod
def parse_single_example(cls, example_proto):
    instance = cls()
    features_dict = cls._get_features_dict(example_proto)
    instance.path_length = features_dict['path_length']
    ...
    instance.coords = tf.decode_raw(features_dict['coords'], DATA_TYPE) # the tensor
    ...
    return instance.coords, ...

@classmethod
def _get_features_dict(cls, value):
    features_dict = tf.parse_single_example(value,
        features={'coords': tf.FixedLenFeature([], tf.string),
                  ...
                  'path_length': tf.FixedLenFeature([], tf.int64)})
    return features_dict

2 个答案:

答案 0 :(得分:0)

Disclamer :由于您未提供minimum, complete and verifiable example,因此我的代码无法完全测试。您需要尝试根据自己的需要进行调整。

这是使用tf.data API执行此操作的方法。 请注意,由于您没有显示班级的整体布局,因此我必须对您的数据的访问方式和位置做出一些假设。

首先,我假设你的班级'构造函数知道.tfrecord文件的存储位置。具体来说,我假设TFRECORD_FILENAMESlist,其中包含要从中提取记录的文件的所有文件路径。

在类构造函数中,您需要在其上实例化TFRecordDatasetmap()函数,以修改数据集包含的数据

class MyClass():
    def __init__(self):
        # more init stuff
        def parse_example(serialized_example):
            features_dict = tf.parse_single_example(value,
              features={'coords': tf.FixedLenFeature([], tf.string),
              ...
              'path_length': tf.FixedLenFeature([], tf.int64)})
            return features_dict

        def skip_every_third_pyfunc(coords):
            # you mention something about a reshape, I guess that goes here as well
            return np.delete(coords, slice(None, None, 3)) 

        self.dataset = (tf.data.TFRecordDataset(TFRECORD_FILENAMES)
                        .map(parse_example)
                        .map( lambda features_dict : { **features_dict, **{'coords': tf.py_func(skip_every_third_pyfunc, features_dict['coords'], features_dict['coords'].dtype)} } )
        self.iterator = self.dataset.make_one_shot_iterator() # adapt this to your needs
        self.features_dict = self.iterator.get_next() # I'm putting this here because I don't know where you'll need it

请注意,在skip_every_third_pyfunc中,您可以使用numpy函数,因为我们正在使用tf.py_func将python函数包装为张量操作(链接中的所有警告都适用)。

第二次.map()调用中的丑陋lambda是必要的,因为你使用了一个特征字典而不是返回一个张量元组。 py_func的参数将numpy数组作为输入并返回numpy数组。为了保持dict格式,我们使用 python 3.5+ ** operator 。如果你使用旧版本的python,你可以定义自己的merge_two_dicts函数,并按照this answer在lambda调用中替换它。

答案 1 :(得分:0)

这是一种避免tf.py_func

的方法
import numpy as np
import tensorflow as tf

slices = ([[1, 2, 3, 4, 5, 6]], [2])
d = tf.contrib.data.Dataset.from_tensor_slices(slices)
d = d.map(lambda coords, _pl: tf.boolean_mask(coords, tf.tile(
  np.array([True, True, False]), tf.reshape(tf.cast(_pl, tf.int32), [1]))))

it = d.make_one_shot_iterator()

with tf.Session() as sess:
  print(sess.run(it.get_next()))
  # [1 2 4 5]

就像所有事情一样,张量流有点难以正确 - 请注意转换(平铺失败为int64'倍数'参数(这是我从tf记录中读取的长度类型)),而且相当不直观重塑需要。将这个例子概括为接受可变长度数组是一个练习。

我会对此代码的gather_nd版本感兴趣。