Keras自定义损失函数,带有来自完整输入数据集的样本

时间:2019-09-21 21:59:39

标签: python tensorflow keras

我正在尝试为Keras中的变分自动编码器设计一个自定义丢失函数,该函数包括两个部分:reconstruction lossdivergence loss。但是,我不想使用高斯分布来处理散度损失,而是想从输入中随机采样,然后根据采样的输入执行散度损失。但是,我不知道如何对来自完整数据集的输入进行采样,然后对其进行损失。编码器型号为:

x_input = Input((input_size,))
enc1 = Dense(encoder_size[0], activation='relu')(x_input)
drop = Dropout(keep_prob)(enc1)
enc2 = Dense(encoder_size[1], activation='relu')(drop)
drop = Dropout(keep_prob)(enc2)
mu = Dense(latent_dim, activation='linear', name='encoder_mean')(drop)
encoder = Model(x_input,mu)

损失的结构应为:

# the input is the placeholder for the complete input
def loss(x, y, input):
    reconstruction_loss = mean_squared_error(x, y)
    sample_num = 100
    sample_input = sample_from_input(input, sample_num)
    sample_encoded = encoder.predict(sample_input) <-- this would not work with placeholder
    sample_prior = gaussian(mean=0, std=1)
    # perform KL divergence between sample_encoded and sample_prior

我还没有找到类似的东西。如果有人能指出我正确的方向,那就太好了。

1 个答案:

答案 0 :(得分:1)

您的代码中有几个问题。首先,当您创建自定义损失函数时,它只期望y_truey_pred两个参数(等效)。因此,在您的情况下,您将无法显式传递input的参数。如果希望传递其他参数,则必须使用嵌套函数的概念。

下一步是predict函数内部,您将无法传递TensorFlow占位符。您将必须在其中传递Numpy数组等效项。因此,我建议您重写sample_from_input,该input_data从一组文件路径输入样本中读取,读取并发送Numpy数组文件数据。另外,在def custom_loss(input_data): def loss(y_true, y_pred): reconstruction_loss = mean_squared_error(x, y) sample_num = 100 sample_input = sample_from_input(input_data) # sample_input is a Numpy array sample_encoded = encoder.predict(sample_input) sample_prior = gaussian(mean=0, std=1) # perform KL divergence between sample_encoded and sample_prior divergence_loss = # Your logic returning a numeric value return reconstruction_loss + divergence_loss return loss encoder.compile(optimizer='adam', loss=custom_loss('<<input_data_path>>')) 的参数中,将数据所在的文件路径传递给它。

我仅附上了代码的相关部分。

list_1=["a", "b", "c", "d", "e", "f", "g"]
i = 0
index = 6
while True:
    a = input("Enter:")
    if a == "apple":
        a = 0
        if i < 31:
            index = (index + 1) % 7
            d = list_1[index]
            print( "day" ,i, d )
            start = input("Start: ")
            current = input("Current: ")
            i = i + 1
            a = a + 1