获取Keras中子模型的图层输出

时间:2020-05-02 17:41:18

标签: python keras plaidml

我目前正在尝试在Keras中实现自定义损失功能。但是,由于我的模型是VGG和我的自定义网络的结合,所以我无法获得子模型的输出。

这是网络定义:

model = VGG19(weights='imagenet', include_top=False)
#model.summary()

def change_model(model, new_input_shape=(None, 40, 40, 3)):
    ''' Change the input size of the provided network'''
    # replace input shape of first layer
    model._layers[0].batch_input_shape = new_input_shape

    # rebuild model architecture by exporting and importing via json
    new_model = keras.models.model_from_json(model.to_json())

    # copy weights from old model to new one
    for layer in new_model.layers:
        try:
            layer.set_weights(model.get_layer(name=layer.name).get_weights())
            print("Loaded layer {}".format(layer.name))
        except:
            print("Could not transfer weights for layer {}".format(layer.name))

    return new_model

new_model = change_model(model,new_input_shape=(None, 1024, 1024, 3))
new_model.summary()
for layer in new_model.layers:
    layer.trainable = False

vector_1 = new_model.get_layer("block4_conv4").output

def create_detector_network(kernel_reg = 0.):
    input = Input(shape=(128, 128, 512))
    x = Conv2D(128, kernel_size=3, strides=1, name='detect_1', padding='same', kernel_regularizer=regularizers.l2(kernel_reg))(input)
    x = BatchNormalization()(x)
    x = Conv2D(1+pow(8,2), kernel_size=1, strides=1, name='detect_2', kernel_regularizer=regularizers.l2(kernel_reg))(x)
    x = BatchNormalization()(x)
    prob = Activation('softmax')(x)
    prob = Lambda(lambda x: x[:,:, :, :-1], output_shape= (128, 128, 64))(prob)  #x[:, :, :-1]
    prob = keras.layers.UpSampling2D(size=(8, 8), data_format=None, interpolation='nearest')(prob)
    prob = Conv2D(1, kernel_size=1, strides=1, name='reduce_dim')(prob)

    return Model(input, [prob, x])



detector_model = create_detector_network()

detector_model.summary()

output = detector_model(vector_1)

full_model  = Model(inputs=new_model.input, outputs=output)

和摘要可以在这里看到: 我的网络:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         (None, 128, 128, 512)     0         
_________________________________________________________________
detect_1 (Conv2D)            (None, 128, 128, 128)     589952    
_________________________________________________________________
batch_normalization_3 (Batch (None, 128, 128, 128)     512       
_________________________________________________________________
detect_2 (Conv2D)            (None, 128, 128, 65)      8385      
_________________________________________________________________
batch_normalization_4 (Batch (None, 128, 128, 65)      260       
_________________________________________________________________
activation_2 (Activation)    (None, 128, 128, 65)      0         
_________________________________________________________________
lambda_2 (Lambda)            (None, 128, 128, 64)      0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 1024, 1024, 64)    0         
_________________________________________________________________
reduce_dim (Conv2D)          (None, 1024, 1024, 1)     65        
=================================================================
Total params: 599,174
Trainable params: 598,788
Non-trainable params: 386
_________________________________________________________________

VGG +我的网络

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_7 (InputLayer)         (None, 1024, 1024, 3)     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 1024, 1024, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 1024, 1024, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 512, 512, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 512, 512, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 512, 512, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 256, 256, 128)     0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 256, 256, 256)     295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 256, 256, 256)     590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 256, 256, 256)     590080    
_________________________________________________________________
block3_conv4 (Conv2D)        (None, 256, 256, 256)     590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 128, 128, 256)     0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 128, 128, 512)     1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 128, 128, 512)     2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 128, 128, 512)     2359808   
_________________________________________________________________
block4_conv4 (Conv2D)        (None, 128, 128, 512)     2359808   
_________________________________________________________________
model_7 (Model)              (None, 128, 128, 65)      599109    
=================================================================
Total params: 11,184,261
Trainable params: 598,723
Non-trainable params: 10,585,538
_________________________________________________________________

这是我训练网络的方式:

losses = {
    "detect_2": "mse",
    "reduce_dim": "mse",
}

full_model.compile(optimizer='adam',
              loss=losses,
              loss_weights={'prob': 1.0, 'main': 0})

full_model.summary()

history = full_model.fit_generator(
    train_generator,
    steps_per_epoch=100,
    epochs=7,
    validation_data=validation_generator,
    validation_steps=80)

现在,我想使用'detect_2'和'reduce_dim'层的输出,并从中计算出损耗/精度。但是,当我运行代码时,出现以下错误:

ValueError: Unknown entry in loss dictionary: "detect_2". Only expected the following keys: ['model_7', 'model_7']

很显然,某处一定有一个错误,因为字典不能两次具有相同的键。 那么有人可以告诉我需要进行哪些更改才能获得图层的输出吗?

2 个答案:

答案 0 :(得分:0)

如果'detect_2'和'reduce_dim'是模型的图层,则需要将其输出指定为网络的输出。然后,您可以通过“ y_true”和“ y_pred”以自定义丢失的方式访问它们,它们将分别保存所有网络输出,用于地面真实性和预测。像这样:

import pygame
import time
import random
pygame.init()


window = pygame.display.set_mode((500,500))
pygame.display.set_caption("I am a hacker")


# player class
class players(object):
    def __init__(self,x,y,height,width):
        self.x = x
        self.y = y
        self.height = height
        self.width = width
        self.isJump = False
        self.JumpCount = 10
        self.fall = 0
        self.speed = 5

# enemy class
class enemys(object):
    def __init__(self,cordx,cordy,heights,widths):
        self.cordx = cordx
        self.cordy = cordy
        self.heights = heights
        self.widths = widths


# color blue for player
blue = (32,207,173)

red = (255,0,0)

orange = (207,135,32)


# FPS
FPS = 60
clock = pygame.time.Clock()

display_width = 50
display_height = 50
font_style = pygame.font.SysFont("bahnschrift", 25)
score_font = pygame.font.SysFont("comicsansms", 35)


# -----------------------------------------------------
# scoring and apple varabiles etc
snake_block = 10
snake_speed = 15 
def Your_score(score):
    value = score_font.render("Your Score: " + str(score), True, red)
    window.blit(value, [0, 0])



def our_snake(snake_block, snake_list):
    for x in snake_list:
        pygame.draw.rect(window, red, [x[0], x[1], snake_block, snake_block])


def message(msg, color):
    mesg = font_style.render(msg, True, color)
    window.blit(mesg, [500 / 6, 500 / 3])


game_over = False
game_close = False

x1 = 500 / 2
y1 = 500 / 2

x1_change = 0
y1_change = 0

snake_List = []
Length_of_snake = 1

foodx = round(random.randrange(0, 500 - snake_block) / 10.0) * 10.0
foody = round(random.randrange(0, 500 - snake_block) / 10.0) * 10.0
# ------------------------------------------------------------------------------------------


# Main Loop
playerman = players(50,390,50,50)
enemyman = enemys(190,390,150,10)
runninggame = True
while runninggame:

    clock.tick(FPS)
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            runninggame = False
# ------------------------------------------------------------------------
# Scoring and Apple System
    if x1 >= 500 or x1 < 0 or y1 >= 500 or y1 < 0:
            game_close = True
            x1 += x1_change
            y1 += y1_change
            window.fill(red)
    pygame.draw.rect(window, red, [foodx, foody, snake_block, snake_block])
    snake_Head = []
    snake_Head.append(x1)
    snake_Head.append(y1)
    snake_List.append(snake_Head)
    if len(snake_List) > Length_of_snake:
        del snake_List[0]
        for x in snake_List[:-1]:
            if x == snake_Head:
                game_close = True

        our_snake(snake_block, snake_List)
        Your_score(Length_of_snake - 1)

        pygame.display.update()

        if x1 == foodx and y1 == foody:
            foodx = round(random.randrange(0, 500 - snake_block) / 10.0) * 10.0
            foody = round(random.randrange(0, 500 - snake_block) / 10.0) * 10.0
            Length_of_snake += 1
    # ------------------------------------------------------------------------------




    window.fill((0,0,0))
    player = pygame.draw.rect(window,(blue),(playerman.x,playerman.y,playerman.height,playerman.width))
    enemy = pygame.draw.rect(window,(orange),(enemyman.cordx,enemyman.cordy,enemyman.heights,enemyman.widths))

    keys = pygame.key.get_pressed()
    if keys[pygame.K_LEFT] and playerman.x > playerman.speed:
        playerman.x -= playerman.speed
    if keys[pygame.K_RIGHT] and playerman.x < 500 - playerman.width - playerman.speed:
        playerman.x += playerman.speed

    if not playerman.isJump:

        playerman.y += playerman.fall
        playerman.fall += 1
# ----------------------------------------------------- # enem1 collisio
# both of my 2 enemy squares collisions push me back when ever I Jump on the top of them on there sides but when I jump on the middle of of both of them it seems to work if I just want it so when I jump on both of my squares I just don't get pushed back 
        player.topleft = (playerman.x, playerman.y)
        collide = False
        playerman.isJump = False
        if player.colliderect(enemy):
            collide = True
            playerman.isJump = False
            playerman.y = enemy.top - player.height
            if player.right > enemy.left and  player.left < enemy.left - player.width:
                playerman.x = enemy.left - player.width
            if player.left < enemy.right and  player.right > enemy.right + player.width:
                playerman.x = enemy.right



        if player.bottom >= 500:
            collide = True
            playerman.isJump = False
            playerman.JumpCount = 10
            playerman.y = 500 - player.height

        if collide:
            if keys[pygame.K_SPACE]:
                playerman.isJump = True
            playerman.fall = 0

    else:
        if playerman.JumpCount > 0:
            playerman.y -= (playerman.JumpCount*abs(playerman.JumpCount)) * 0.5
            playerman.JumpCount -= 1
        else:
            playerman.JumpCount = 10
            playerman.isJump = False







    pygame.display.update()


pygame.quit()
``

然后根据需要随意使用它们

 return Model([input], [detect_2_output, reduce_dim_output])

答案 1 :(得分:0)

我认为模型和所有串联都可以。...

尝试以这种方式进行编译,避免使用图层名称,因为它们被隐藏在另一个模型中

losses = ['mse','mse']

full_model.compile(optimizer='adam',
          loss=losses ,
          loss_weights=[1,0])