在不同时期训练不同的输出

时间:2019-10-08 21:23:44

标签: python tensorflow keras epoch multipleoutputs

在Keras中,多输出训练中每个输出或某些输出的训练是否有可能在不同的时期开始?例如,输出之一将其他一些输出作为其输入。但是,一开始的输出结果还为时过早,给模型带来了巨大的计算负担。我希望将其训练推迟到一段时间后的输出是一个自定义层,该层必须对其输入应用某些图像处理操作,该输入是另一输出生成的图像,但是开始时生成的图像是毫无意义的,我认为在第一个纪元应用此自定义图层只是浪费时间。有没有办法做到这一点?就像我们对每个输出的损失进行加权一样,我们是否有不同的起点来计算每个输出的损失?

1 个答案:

答案 0 :(得分:1)

  1. 构建不包含更高版本输出的模型。
  2. 将模型训练到所需的程度。
  3. 构建一个将旧模型纳入其中的新模型。
  4. 使用所需的新损失函数编译新模型。
  5. 训练该模型。

要详细说明第3步,可以像Keras功能API中的图层一样使用Keras模型。

您可以像这样建立普通模型:

  java.lang.NoClassDefFoundError: Failed resolution of: Lcom/google/common/base/CharMatcher;
        at com.google.common.base.Splitter.on(Splitter.java:125)
        at io.grpc.internal.GrpcUtil.<clinit>(GrpcUtil.java:203)
        at io.grpc.internal.AbstractManagedChannelImplBuilder.<clinit>(AbstractManagedChannelImplBuilder.java:84)
        at io.grpc.okhttp.OkHttpChannelBuilder.forTarget(OkHttpChannelBuilder.java:119)
        at io.grpc.okhttp.OkHttpChannelProvider.builderForTarget(OkHttpChannelProvider.java:48)
        at io.grpc.okhttp.OkHttpChannelProvider.builderForTarget(OkHttpChannelProvider.java:27)
        at io.grpc.ManagedChannelBuilder.forTarget(ManagedChannelBuilder.java:73)
        at com.google.firebase.firestore.remote.GrpcCallProvider.initChannel(com.google.firebase:firebase-firestore@@21.1.1:92)
        at com.google.firebase.firestore.remote.GrpcCallProvider.lambda$new$0(com.google.firebase:firebase-firestore@@21.1.1:62)
        at com.google.firebase.firestore.remote.GrpcCallProvider$$Lambda$1.call(com.google.firebase:firebase-firestore@@21.1.1)
        at com.google.android.gms.tasks.zzv.run(Unknown Source)
        at com.google.firebase.firestore.util.ThrottledForwardingExecutor.lambda$execute$0(com.google.firebase:firebase-firestore@@21.1.1:54)
        at com.google.firebase.firestore.util.ThrottledForwardingExecutor$$Lambda$1.run(com.google.firebase:firebase-firestore@@21.1.1)
        at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1112)
        at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:587)
        at java.lang.Thread.run(Thread.java:818)
     Caused by: java.lang.ClassNotFoundException: Didn't find class "com.google.common.base.CharMatcher" on path: DexPathList[[zip file "/data/app/com.byte_artisan.mchat2-1/base.apk"],nativeLibraryDirectories=[/data/app/com.byte_artisan.mchat2-1/lib/arm, /vendor/lib, /system/lib]]
        at dalvik.system.BaseDexClassLoader.findClass(BaseDexClassLoader.java:56)
        at java.lang.ClassLoader.loadClass(ClassLoader.java:511)
        at java.lang.ClassLoader.loadClass(ClassLoader.java:469)
        at com.google.common.base.Splitter.on(Splitter.java:125) 
        at io.grpc.internal.GrpcUtil.<clinit>(GrpcUtil.java:203) 
        at io.grpc.internal.AbstractManagedChannelImplBuilder.<clinit>(AbstractManagedChannelImplBuilder.java:84) 
        at io.grpc.okhttp.OkHttpChannelBuilder.forTarget(OkHttpChannelBuilder.java:119) 
        at io.grpc.okhttp.OkHttpChannelProvider.builderForTarget(OkHttpChannelProvider.java:48) 
        at io.grpc.okhttp.OkHttpChannelProvider.builderForTarget(OkHttpChannelProvider.java:27) 
        at io.grpc.ManagedChannelBuilder.forTarget(ManagedChannelBuilder.java:73) 
        at com.google.firebase.firestore.remote.GrpcCallProvider.initChannel(com.google.firebase:firebase-firestore@@21.1.1:92) 
        at com.google.firebase.firestore.remote.GrpcCallProvider.lambda$new$0(com.google.firebase:firebase-firestore@@21.1.1:62) 
        at com.google.firebase.firestore.remote.GrpcCallProvider$$Lambda$1.call(com.google.firebase:firebase-firestore@@21.1.1) 
        at com.google.android.gms.tasks.zzv.run(Unknown Source) 
        at com.google.firebase.firestore.util.ThrottledForwardingExecutor.lambda$execute$0(com.google.firebase:firebase-firestore@@21.1.1:54) 
        at com.google.firebase.firestore.util.ThrottledForwardingExecutor$$Lambda$1.run(com.google.firebase:firebase-firestore@@21.1.1) 
        at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1112) 
        at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:587) 
        at java.lang.Thread.run(Thread.java:818) 
        Suppressed: java.lang.NoClassDefFoundError: com.google.common.base.CharMatcher
        at dalvik.system.DexFile.defineClassNative(Native Method)
        at dalvik.system.DexFile.defineClass(DexFile.java:226)
        at dalvik.system.DexFile.loadClassBinaryName(DexFile.java:219)
        at dalvik.system.DexPathList.findClass(DexPathList.java:321)
        at dalvik.system.BaseDexClassLoader.findClass(BaseDexClassLoader.java:54)
                ... 18 more
        Suppressed: java.lang.ClassNotFoundException: com.google.common.base.CharMatcher
        at java.lang.Class.classForName(Native Method)
        at java.lang.BootClassLoader.findClass(ClassLoader.java:781)
        at java.lang.BootClassLoader.loadClass(ClassLoader.java:841)
        at java.lang.ClassLoader.loadClass(ClassLoader.java:504)
                ... 17 more
     Caused by: java.lang.NoClassDefFoundError: Class not found using the boot class loader; no stack available

但是,如果您有另一个标准Keras模型,则可以像其他任何层一样使用它。例如,如果我们有一个名为input = Input((100,)) x = Dense(50)(input) x = Dense(1, activation='sigmoid')(x) model = Model(input, x) 的模型(使用Sequential()Model()keras.models.load_model()创建),则可以这样输入:

model1

这等效于分别放置input = Input((100,)) x = model1(input) x = Dense(1, activation='sigmoid')(x) model = Model(input, x) 中的每一层。