在tf.cond()中使用Dimension值

时间:2019-08-02 15:01:08

标签: python tensorflow

我正在从文件中读取原始浮点数,其中每个文件都是一个单独的4通道时间序列序列(单个观察值)。由于数据集中存在伪影,某些文件可能包含一列我想摆脱的列。

使用tf.read_file(filename)读取文件后,如果张量的第一列过多,我尝试删除该张量的第一列。我正在使用tf.cond(),其中pred参数是检查第一行的形状是否等于所需的列数。

我正在读取140 GB的数据,因此不能使用numpy(数据不适合内存),我想将该解决方案用作tf.data.Dataset映射函数。

import tensorflow as tf

desired_number_of_columns = 4

# Mock reading from file - tf.read_file(filename). The file contains 5 columns but 4 are required.
raw_single_observation = tf.constant(' '.join(['0.0'] * 5) + '\n' + ' '.join(['0.0'] * 5))

# Split file content on newline and whitespace - produces one long list of strings (representing numbers).
single_observation = tf.string_split([raw_single_observation], sep='\r\n ').values
# Casts strings to floats.
single_observation = tf.strings.to_number(single_observation, tf.float32)

# Extract rows from the file. Will be used to get the length of the first row.
rows = tf.string_split([raw_single_observation], sep='\r\n').values
# Split the first row on whitespaces and extract the number of columns.
actual_number_of_columns = tf.string_split([rows[0]], sep=' ').values.shape[0]

# If the number of columns doesn't match - drop the first column.
single_observation = tf.cond(
    actual_number_of_columns == desired_number_of_columns,
    lambda: tf.reshape(single_observation, (-1, desired_number_of_columns), name="desired_number_of_columns"),
    lambda: tf.reshape(single_observation, (-1, actual_number_of_columns), name="one_too_many_columns")[:, 1:]
)

tf.print(single_observation)

我收到以下错误:

Traceback (most recent call last):
  File "C:/xxx/src/debug/debug_dimension_in_cond.py", line 24, in <module>
    lambda: tf.reshape(single_observation, (-1, actual_number_of_columns), name="one_too_many_columns")[:, 1:]
  File "C:\Users\xxx\.conda\envs\test-env\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "C:\Users\xxx\.conda\envs\test-env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 1965, in cond
    p_2, p_1 = switch(pred, pred)
  File "C:\Users\xxx\.conda\envs\test-env\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 305, in switch
    data, dtype=dtype, name="data", as_ref=True)
  File "C:\Users\xxx\.conda\envs\test-env\lib\site-packages\tensorflow\python\framework\ops.py", line 1518, in internal_convert_to_tensor_or_composite
    accept_composite_tensors=True)
  File "C:\Users\xxx\.conda\envs\test-env\lib\site-packages\tensorflow\python\framework\ops.py", line 1224, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "C:\Users\xxx\.conda\envs\test-env\lib\site-packages\tensorflow\python\framework\constant_op.py", line 305, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "C:\Users\xxx\.conda\envs\test-env\lib\site-packages\tensorflow\python\framework\constant_op.py", line 246, in constant
    allow_broadcast=True)
  File "C:\Users\xxx\.conda\envs\test-env\lib\site-packages\tensorflow\python\framework\constant_op.py", line 284, in _constant_impl
    allow_broadcast=allow_broadcast))
  File "C:\Users\xxx\.conda\envs\test-env\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 454, in make_tensor_proto
    raise ValueError("None values not supported.")
ValueError: None values not supported.

我不知道如何根据列数来限制删除第一列。

0 个答案:

没有答案