我想用两个输入和两个输出使用相同的架构/权重来构建Keras模型。然后将两个输出都用于计算单个损耗。
这是我想要的架构的照片。
这是我的伪代码:
model = LeNet(inputs=[input1, input2, input3],outputs=[output1, output2, output3])
model.compile(optimizer='adam',
loss=my_custom_loss_function([output1,outpu2,output3],target)
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
这种方法行得通吗?
我需要使用其他Keras API吗?
答案 0 :(得分:1)
架构很好。这是一个玩具示例,其中包含有关如何使用keras的功能性API进行定义的训练数据:
<Window x:Class="PortFolio_application.MainWindow"
xmlns:System="clr-namespace:System;assembly=mscorlib"
xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:local="clr-namespace:PortFolio_application"
mc:Ignorable="d"
Title="MainWindow" Height="1080" Width="1900">
<Page Name="portfolio">
<Page.Resources>
<System:Double x:Key="theMargin">0.35</System:Double>
</Page.Resources>
<Grid>
<Grid.RowDefinitions>
<RowDefinition Height="10*"/>
<RowDefinition Height="1*"/>
<RowDefinition Height="*"/>
</Grid.RowDefinitions>
<Grid Grid.Row="1">
<Grid.ColumnDefinitions>
<ColumnDefinition Width="1*"/>
<ColumnDefinition Width="2*"/>
<ColumnDefinition Width="2*"/>
<ColumnDefinition Width="2*"/>
<ColumnDefinition Width="2*"/>
<ColumnDefinition Width="2*"/>
<ColumnDefinition Width="1*"/>
<!--Margin="424,944,1259,55"-->
</Grid.ColumnDefinitions>
<Rectangle Grid.Row="1" Grid.Column="1"
Height="50"
Fill="#00aaff" Stroke="Black"
StrokeThickness="0"
RenderTransformOrigin="0.517,2.253"
TextBlock.FontSize="24"
TextBlock.TextAlignment="center">
</Rectangle>
<TextBlock Grid.Row="1" Grid.Column="1"
FontSize="24" HorizontalAlignment="Center"
VerticalAlignment="Center">
<Run Foreground="Black">1</Run>
</TextBlock>
<Button Grid.Row="1" Grid.Column="1">
<Button.Background>
<SolidColorBrush Color="Gray" Opacity="0" />
</Button.Background>
<TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">Hi there</Run></TextBlock>
</Button>
<Rectangle Grid.Row="1" Grid.Column="1"
Fill="#00aaff" Stroke="Black"
StrokeThickness="2"
RenderTransformOrigin="0.517,2.253"
TextBlock.FontSize="24"
TextBlock.TextAlignment="center">
</Rectangle>
<TextBlock Grid.Row="1" Grid.Column="1"
FontSize="24" HorizontalAlignment="Center"
VerticalAlignment="Center">
<Run Foreground="Black"></Run>
</TextBlock>
<Button Grid.Row="1" Grid.Column="1">
<Button.Background>
<SolidColorBrush Color="Gray" Opacity="0" />
</Button.Background>
<TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black"></Run>1</TextBlock>
</Button>
<Rectangle Grid.Row="1" Grid.Column="2"
Fill="#00aaff" Stroke="Black"
StrokeThickness="2"
RenderTransformOrigin="0.517,2.253"
TextBlock.FontSize="24"
TextBlock.TextAlignment="center">
</Rectangle>
<TextBlock Grid.Row="1" Grid.Column="2"
FontSize="24" HorizontalAlignment="Center"
VerticalAlignment="Center">
<Run Foreground="Black"></Run>
</TextBlock>
<Button Grid.Row="1" Grid.Column="2">
<Button.Background>
<SolidColorBrush Color="Gray" Opacity="0" />
</Button.Background>
<TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">2</Run></TextBlock>
</Button>
<Rectangle Grid.Row="1" Grid.Column="3"
Fill="#00aaff" Stroke="Black"
StrokeThickness="2"
RenderTransformOrigin="0.517,2.253"
TextBlock.FontSize="24"
TextBlock.TextAlignment="center">
</Rectangle>
<TextBlock Grid.Row="1" Grid.Column="3"
FontSize="24" HorizontalAlignment="Center"
VerticalAlignment="Center">
<Run Foreground="Black"></Run>
</TextBlock>
<Button Grid.Row="1" Grid.Column="3">
<Button.Background>
<SolidColorBrush Color="Gray" Opacity="0" />
</Button.Background>
<TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">3</Run></TextBlock>
</Button>
<Rectangle Grid.Row="1" Grid.Column="4"
Fill="#00aaff" Stroke="Black"
StrokeThickness="2"
RenderTransformOrigin="0.517,2.253"
TextBlock.FontSize="24"
TextBlock.TextAlignment="center">
</Rectangle>
<TextBlock Grid.Row="1" Grid.Column="4"
FontSize="24" HorizontalAlignment="Center"
VerticalAlignment="Center">
<Run Foreground="Black"></Run>
</TextBlock>
<Button Grid.Row="1" Grid.Column="4">
<Button.Background>
<SolidColorBrush Color="Gray" Opacity="0" />
</Button.Background>
<TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">4</Run></TextBlock>
</Button>
<Rectangle Grid.Row="1" Grid.Column="5"
Fill="#00aaff" Stroke="Black"
StrokeThickness="2"
RenderTransformOrigin="0.517,2.253"
TextBlock.FontSize="24"
TextBlock.TextAlignment="center">
</Rectangle>
<TextBlock Grid.Row="1" Grid.Column="5"
FontSize="24" HorizontalAlignment="Center"
VerticalAlignment="Center">
<Run Foreground="Black"></Run>
</TextBlock>
<Button Grid.Row="0" Grid.Column="5">
<Button.Background>
<SolidColorBrush Color="Gray" Opacity="0" />
</Button.Background>
<TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">5</Run></TextBlock>
</Button>
</Grid>
<Grid Grid.Row="0">
<Grid.RowDefinitions>
<RowDefinition Height="*"/>
<RowDefinition Height="*"/>
</Grid.RowDefinitions>
<Grid.ColumnDefinitions>
<ColumnDefinition Width="1*"/>
<ColumnDefinition Width="10*"/>
<ColumnDefinition Width="1*"/>
<!--Margin="424,944,1259,55"-->
</Grid.ColumnDefinitions>
<Grid Grid.Row="1" Grid.Column="1">
<Path Data="M40,0 L66,0 106.4,30 0,30 z" Fill="#98FB98 " Stretch="Fill" Stroke="Black" Width="Auto" />
<Grid>
<Grid.RowDefinitions>
<RowDefinition Height="1*"/>
</Grid.RowDefinitions>
<Grid.ColumnDefinitions>
<ColumnDefinition Width="1.00*"/>
<ColumnDefinition Width="1.15*"/>
<ColumnDefinition Width="0.7*"/>
<ColumnDefinition Width="1.15*"/>
<ColumnDefinition Width="1.00*"/>
</Grid.ColumnDefinitions>
<Grid Grid.Row="1"></Grid>
<Grid Grid.Row="1" Grid.Column="1">
<Line X1="299" Y1="483" X2="700" Y2="0" Stroke="black" StrokeThickness="1" Stretch="UniformToFill"></Line>
</Grid>
<Grid Grid.Row="1" Grid.Column="2"></Grid>
<Grid Grid.Row="1" Grid.Column="3">
<Line X1="115.5" Y1="115" X2="186" Y2="200" Stroke="black" StrokeThickness="1" Stretch="UniformToFill"></Line>
</Grid>
</Grid>
<Grid>
<Grid.RowDefinitions>
<RowDefinition Height="1*"/>
</Grid.RowDefinitions>
<Grid.ColumnDefinitions>
<ColumnDefinition Width="4.07*"/>
<ColumnDefinition Width="0.77*"/>
<ColumnDefinition Width="0.5*"/>
<ColumnDefinition Width="0.77*"/>
<ColumnDefinition Width="4.07*"/>
</Grid.ColumnDefinitions>
<Grid Grid.Row="1"></Grid>
<Grid Grid.Row="1" Grid.Column="1">
<Line X1="330" Y1="1360" X2="700" Y2="0" Stroke="black" StrokeThickness="1" Stretch="UniformToFill"></Line>
</Grid>
<Grid Grid.Row="1" Grid.Column="2"></Grid>
<Grid Grid.Row="1" Grid.Column="3">
<Line X1="115" Y1="132" X2="202" Y2="455" Stroke="black" StrokeThickness="1" Stretch="UniformToFill"></Line>
</Grid>
</Grid>
</Grid>
</Grid>
</Grid>
</Page>
编辑,如果您想一起计算损失,则可以使用from keras.models import Model
from keras.layers import Dense, Input
# two separate inputs
in_1 = Input((10,10))
in_2 = Input((10,10))
# both inputs share these layers
dense_1 = Dense(10)
dense_2 = Dense(10)
# both inputs are passed through the layers
out_1 = dense_1(dense_2(in_1))
out_2 = dense_1(dense_2(in_2))
# create and compile the model
model = Model(inputs=[in_1, in_2], outputs=[out_1, out_2])
model.compile(optimizer='adam', loss='mse')
model.summary()
# train the model on some dummy data
import numpy as np
i_1 = np.random.rand(10, 10, 10)
i_2 = np.random.rand(10, 10, 10)
model.fit(x=[i_1, i_2], y=[i_1, i_2])
Concatenate()
您传递给output = Concatenate()([out_1, out_2])
的任何损失函数都将以其组合状态应用于model.compile
。从预测中获得输出后,您可以将其拆分回原始状态:
output