如何在Tensorflow String Tensor上执行字符串查找和替换?

时间:2018-05-02 16:16:25

标签: tensorflow tensorflow-datasets

我目前正在使用Tensorflow数据集api对指定路径上的图像执行某些扩充。文件名本身包含说明是否要扩充文件的信息。所以我想要做的是读取数据集中的文件和每个文件,在文件名中执行查找,如果找到特定的子字符串,则设置bool标志并用"&#34替换子字符串;

我得到的错误是:

  

属性错误:' Tensor'对象没有属性'找到'

我无法执行"发现"在带有dtype字符串条目的张量上,因为find不是Tensor的一部分,所以我试图找出如何执行上述操作。我已经在下面分享了一些代码,我认为这些代码展示了我想要做的事情。性能很重要,所以如果有人发现我正在通过数据集API错误地执行此操作,我宁愿以正确的方式执行此操作。

def preproc_img(filenames):
  def parse_fn(filename):
    augment_inst = False
    if cfg.SPLIT_INTO_INST:
      #*****************************************************
      #*** THIS IS WHERE THE LOGIC IS CURRENTLY BREAKING ***
      #*****************************************************
      if filename.find('_data_augmentation') != -1:
        augment_inst = True
        filename = filename.replace('_data_augmentation', '')

    image_string = tf.read_file(filename)
    img = tf.image.decode_image(image_string, channels=3)
    return dict(zip([filename], [img]))   

  dataset = tf.data.Dataset.from_tensor_slices(filenames)
  dataset = dataset.map(parse_fn)
  iterator = dataset.make_one_shot_iterator()
  return iterator.get_next()


def perform_train():
  if __name__ == '__main__':
    filenames = helper.get_image_paths()
    next_batch = preproc_img(filenames)

  with tf.Session() as sess:
    with sess .graph.as_default():
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())

      dat = sess.run(next_batch)
      # I would now go about calling any of my tf op code below

1 个答案:

答案 0 :(得分:3)

您可以使用tf.regex_replace替换tf.string张量中的文字。

filename = tf.regex_replace(filename, "_data_augmentation", "")