设置用于动态根据数据集进行图层的单位

时间:2019-04-07 12:50:21

标签: python tensorflow keras

我尝试根据数据集更改最后一层的单位大小。 那是我的代码的抽象,但是没用。

class cnn_model:
 num_classes = 1

 model.add(layers.Conv2D(128, (3, 3), activation='relu'))
 model.add(layers.MaxPooling2D((2, 2)))
 model.add(layers.Conv2D(256, (3, 3), activation='relu'))
 model.add(layers.MaxPooling2D((2, 2)))
 model.add(layers.Dense(num_classes, activation='softmax'))


@staticmethod
def train_two():
 cnn_mod = cnn_model
 cnn_mod.num_classes = 2
 model = cnn_mod.model

@staticmethod
def train_three():
 cnn_mod = cnn_model
 cnn_mod.num_classes = 3
 model = cnn_mod.model

1 个答案:

答案 0 :(得分:1)

实例化新的CNNModel类时只需传递类数作为参数:

import tensorflow as tf
from tensorflow.keras import layers

class CNNModel:

    def __init__(self, num_classes=2):
        self.num_classes = num_classes
        self.model = tf.keras.models.Sequential()
        self.model.add(layers.Conv2D(128, (3, 3), activation='relu'))
        self.model.add(layers.MaxPooling2D((2, 2)))
        self.model.add(layers.Conv2D(256, (3, 3), activation='relu'))
        self.model.add(layers.MaxPooling2D((2, 2)))
        self.model.add(layers.Dense(self.num_classes, activation='softmax'))

cnnmodel = CNNModel(num_classes=3)
kerasmodel = cnnmodel.model
print(cnnmodel.num_classes) # 3   

我还建议您阅读Naming Conventions中的IndentationPEP8