在训练期间将非张量参数传递给 Keras 模型/使用张量进行索引

时间:2021-01-21 14:34:11

标签: python tensorflow keras deep-learning

我正在尝试训练一个 Keras 模型,该模型在模型本身中加入了数据增强功能。模型的输入是不同类的图像,模型应该为每个类生成一个增强模型,用于增强过程。我的代码大致如下:

from keras.models import Model
from keras.layers import Input
...further imports...

def get_main_model(input_shape, n_classes):
    encoder_model = get_encoder_model()
    input = Input(input_shape, name="input")
    label_input = Input((1,), name="label_input")
    aug_models = [get_augmentation_model() for i in range(n_classes)]
    augmentation = aug_models[label_input](input)
    x = encoder_model(input)
    y = encoder_model(augmentation)
    model = Model(inputs=[input, label_input], outputs=[x, y])
    model.add_loss(custom_loss_function(x, y))
    return model 

然后我想通过模型传递一批数据,该模型由一组图像(传递给 input)和一个相应的标签数组(传递给 label_input)组成。但是,这不起作用,因为输入到 label_input 的任何内容都被 Tensorflow 转换为张量,并且不能用于以下索引。我试过的是以下内容:

  • augmentation = aug_models[int(label_input)](input) --> 不起作用 因为label_input is a tensor
  • augmentation = aug_models[tf.make_ndarray(label_input)](input) --> 转换不起作用(我猜是因为 label_input 是一个符号张量)
  • tf.gather(aug_models, label_input) --> 不起作用,因为操作的结果是一个 Keras 模型实例,Tensorflow 试图将其转换为张量(显然失败)

Tensorflow 中是否有任何技巧可以让我在训练期间将参数传递给模型,该参数不会转换为张量,或者我可以告诉模型选择哪个增强模型的不同方式?提前致谢!

1 个答案:

答案 0 :(得分:2)

要对 input 张量的每个元素应用不同的增强(例如以 label_input 为条件),您需要:

  1. 首先,为批次的每个元素计算每个可能的增强。
  2. 其次,根据标签选择所需的增强。

不幸的是,索引是不可能的,因为 inputlabel_input 张量都是多维的(例如,如果您要对批次的每个元素应用 相同 增强,然后就可以使用任何有条件的 tensorflow 语句,例如 tf.case)。


这是一个最小的工作示例,展示了如何实现这一目标:

input = tf.ones((3, 1))  # Shape=(bs, 1)
label_input = tf.constant([3, 2, 1])  # Shape=(bs, 1)
aug_models = [lambda x: x, lambda x: x * 2, lambda x: x * 3, lambda x: x * 4]
nb_classes = len(aug_models)

augmented_data = tf.stack([aug_model(input) for aug_model in aug_models])  # Shape=(nb_classes, bs, 1)
selector = tf.transpose(tf.one_hot(label_input, depth=nb_classes))  # Shape=(nb_classes, bs)
augmentation = tf.reduce_sum(selector[..., None] * augmented_data, axis=0)  # Shape=(bs, 1) 
print(augmentation)

# prints:
# tf.Tensor(
# [[4.]
#  [3.]
#  [2.]], shape=(3, 1), dtype=float32)

注意:您可能需要将这些操作包装到 Keras Lambda layer 中。