tf.reshape在急切模式和不急切模式下的行为似乎有所不同

时间:2020-07-23 05:58:18

标签: python tensorflow

我正在尝试将我急切的代码转换为估算器。我遇到的问题是以下几行代码:

parts = tf.strings.split(file_path, os.path.sep)
label_t = tf.where(label_list == parts[-2])
label = tf.reshape(label_t, [])

其中file_path是包含文件路径的tf.Tensor。

上面的代码在预想模式下运行良好,但是当传递给估计器的输入函数时。我收到ValueError(如下所示)。

这是一个渴望执行的例子:

import os
import pathlib

import tensorflow as tf

def decode_img(img, img_size):
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.cast(img, tf.float32)
    img = (img / 127.5) - 1  # Rescale input channels to a range of [-1, 1]
    img = tf.image.resize(img, img_size)
    return img

def get_class_names(data_dir):
    print(type(data_dir))
    return tf.constant(sorted([p.name for p in data_dir.glob("*")]))

def make_processor(label_list):
    print(type(label_list))
    def _processor(file_path):
        parts = tf.strings.split(file_path, os.path.sep)
        label_t = tf.where(label_list == parts[-2])
        label = tf.reshape(label_t, []) # This is where the error happens in non-eager mode
        img = tf.io.read_file(file_path)
        img = decode_img(img, (224, 224))
        return img, label

    return _processor

DATA_DIR = pathlib.Path("./data/train")
LABEL_LIST = get_class_names(DATA_DIR)

IMG_HEIGHT, IMG_WIDTH = 224, 224
IMG_SHAPE = (IMG_HEIGHT, IMG_WIDTH, 3)
IMG_SIZE = (IMG_HEIGHT, IMG_WIDTH)

processor = make_processor(LABEL_LIST)
ds = tf.data.Dataset.list_files(str(DATA_DIR / "*/*"), shuffle=True)

processed = ds.map(processor)
for _, label in processed.take(1):
    print(label)

它会按预期打印出以下内容:

tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64)

这是我在估算器的输入函数中尝试执行的操作:

def input_fn():
    ds = tf.data.Dataset.list_files(str(DATA_DIR / "*/*"), shuffle=False)
    ds = ds.map(processor) # The processor is the same as in the previous example
    return ds


base_model = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SHAPE, include_top=False, weights="imagenet"
)
base_model.trainable = False

model = tf.keras.Sequential(
    [base_model, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(1)]
)

model.compile(
    optimizer=tf.keras.optimizers.RMSprop(lr=0.001),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"],
    )

estimator = tf.keras.estimator.model_to_estimator(keras_model=model)
estimator.train(input_fn=input_fn)

上面的代码无法运行,并显示一条较长的跟踪,结尾为:

ValueError: Cannot reshape a tensor with 0 elements to shape [] (1 elements) for 'Reshape' (op: 'Reshape') with input shapes: [?,0], [0].

我如何使tf.reshape在估算器中工作,因为它在渴望的版本中起作用?

编辑: 这是模拟数据集ds的方法:

ds = tf.data.Dataset.from_tensor_slices(
    [b"data/train/dog/dog.8437.jpg", "data/train/cat/cat.9423.jpg"]
)

0 个答案:

没有答案