Keras Tensorflow 自定义层仅调用一次

时间:2021-05-05 10:52:27

标签: python tensorflow keras tensorflow2.0 keras-layer

经过两天的激烈搜索,我没有在 keras docs/stack of 中找到答案。

我正在尝试创建一个自定义层,以某种概率对每个批次执行增强(现在为 1 以进行演示)。 数据为单通道图像。

我需要一种方法来知道增强是否在训练时实际发生,因为增强函数内的打印仅在定义网络时发生一次,而在激活 fit 函数时发生另一次。

非常感谢任何帮助,谢谢!

基本的可运行工作示例,注意来自 call() 的打印数量:

from tensorflow.keras.layers import Input, Dense, Conv2D, Flatten
from tensorflow.keras.models import Sequential
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras import optimizers


# Custom Augmentation layer
class AugmentationLayer(tf.keras.layers.Layer):
    def __init__(self, p):
        super(AugmentationLayer, self).__init__()
        self.p = p

    def call(self, inputs, training=None):
        if not training:
            return inputs
        print('Enter Custom layer call, training = ', training)
        
        if random.random() < self.p:  
            inputs *= 2
        return inputs

# Training Script

# parameters
height = 38
width = 22
numClasses = 2
ydata = np.array([[1, 0], [1, 0]])   # hot-encoded labels

# create 2 images for network training data
Xdata = []
for i in range(2):
    xx = np.random.uniform(-1, 1, (height, width))
    Xdata.append(xx)

# create array from list
Xdata = np.array(Xdata)

# make dims n_samples X height X width X channels = (2, 38, 22, 1)
Xdata = np.reshape(Xdata, (Xdata.shape[0], Xdata.shape[1], Xdata.shape[2], 1))

model = Sequential()
model.add(Input(shape=(height, width, 1)))
model.add(AugmentationLayer(1))                   # Added custom layer
model.add(Conv2D(32, (2, 2), activation='relu'))
model.add(Flatten())
model.add(Dense(numClasses, activation='softmax'))

model.compile(
    loss='categorical_crossentropy',
    metrics=['accuracy'],
    optimizer=optimizers.Adam(learning_rate=0.001))

model.summary()

model.fit(
    Xdata,
    ydata,
    epochs=10,
    batch_size=2,
    verbose=1)

2 个答案:

答案 0 :(得分:0)

如果您使用 tf.print(),它也会在图形执行中运行时打印。您可以在此处阅读更多相关信息:https://www.tensorflow.org/api_docs/python/tf/print

答案 1 :(得分:0)

对于将来的某个人,我会补充一点,在编译模型时,出于性能目的,有一个默认值会阻止某些函数从自定义层内部运行(这就是常规 print() 不起作用的原因)。其中包括:print()、np.save() 等等。

model.compile(..., run_eagerly=False)   # default

如果要启用上述所有功能,请使用以下命令编译模型:

model.compile(..., run_eagerly=True)