在自定义张量流模型中加载时发生ML.NET错误:“ System.InvalidOperationException:“从...加载模型时触发了TensorFlow异常”。

时间:2019-07-10 09:37:17

标签: c# tensorflow ml.net

我正在尝试使用Keras构建CNN,然后将其在C#中使用ML.Net进行推理。我改写了ML.Net's Image Classification Sample,并且在使用预训练模型(通过示例中的InceptionV3以及Tensorflow Models回购中的MobileNet进行了测试)时,它可以很好地工作,但是当我尝试导入自己的模型时,我得到了以下内容错误:

System.InvalidOperationException: "TensorFlow exception triggered while loading model from <path>

我用于创建模型的代码如下:

from keras.applications.mobilenet import MobileNet
from keras.preprocessing import image
from keras.models import Model, load_model
from keras.layers import Dense, GlobalAveragePooling2D
from keras import backend as K
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint
import tensorflow as tf
import numpy as np
import pandas as pd
from PIL import Image
import argparse


def data_gen(df, num_classes, batch_size=32, input_shape=(224, 224, 3)):
    """ Load in image data"""
    while True:
        idx = np.random.choice(a=np.arange(len(df['ImgPath'])), size=batch_size)
        batch_paths = df['ImgPath'][idx]
        images = []
        for img_path in batch_paths:
            image = Image.open(str(img_path))
            image = image.resize(input_shape[0:2], Image.ANTIALIAS)
            if input_shape[2] == 1:
                image = image.convert('LA')
            image = np.asarray(image)
            images.append(image)
        images = np.array(images)
        images = images.reshape(len(images), input_shape[0], input_shape[1], input_shape[2])
        labels = np.array(df['VG'][idx])

        labels = to_categorical(labels, num_classes=num_classes)
        yield (images, labels)


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
        Freezes the state of a session into a pruned computation graph.

        Creates a new computation graph where variable nodes are replaced by
        constants taking their current value in the session. The new graph will be
        pruned so subgraphs that are not necessary to compute the requested
        outputs are removed.
        @param session The TensorFlow session to be frozen.
        @param keep_var_names A list of variable names that should not be frozen,
                            or None to freeze all the variables in the graph.
        @param output_names Names of the relevant graph outputs.
        @param clear_devices Remove the device directives from the graph for better portability.
        @return The frozen graph definition.
    """
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ''
        frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph


def create_model(num_classes, compile=True):
    base_model = MobileNet(weights='imagenet', include_top=False)

    x = base_model.output
    x = GlobalAveragePooling2D()(x)

    x = Dense(1024, activation='relu')(x)

    predictions = Dense(num_classes, activation='softmax')(x)

    model = Model(base_model.input, predictions)

    if compile:
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    print(model.summary())

    return model


def get_model(filepath, num_classes):
    try:
        model = load_model(filepath)
        if len(model.predict(np.zeros((1, 224, 224, 3)))[0]) != num_classes:
            print('Replacing output layer')
            output = Dense(num_classes, activation='softmax', name='dense_2')(model.layers[-2].output)
            model = Model(model.input, output)
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        print(model.summary())
        return model
    except Exception as e:
        print(e)
        print('Wrong model path. Creating new model.')
        model = create_model(num_classes)
        return model


def train_model(model, filepath, epochs, batch_size, num_classes, saving_directory, data_quality):
    #Prepare data
    df = pd.read_csv(filepath)
    df.dropna(inplace=True)
    df = df[(df['Q']>data_quality)]
    df.reset_index(drop=True, inplace=True)
    df['VG'] = df['VG'] - 1
    df = df[:50]

    # Training
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        checkpoint = ModelCheckpoint(saving_directory + 'model-{epoch:04d}.h5', monitor='loss', verbose=1, save_best_only=True)

        #model.fit_generator(data_gen(df, num_classes, batch_size=batch_size, input_shape=(224, 224, 3)), epochs=epochs, steps_per_epoch=(len(df)/batch_size), callbacks=[checkpoint])

        print([out.op.name for out in model.outputs])

        frozen_graph = freeze_session(tf.keras.backend.get_session(), output_names=[out.op.name for out in model.outputs])

        tf.io.write_graph(frozen_graph, "./", "model.pb", as_text=False)



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Vehicle Classification Network')
    parser.add_argument('-f', '--filename', type=str, required=True, help='Path to data csv file')
    parser.add_argument('-m', '--model_path', type=str, default='', help='Path to model file (h5)')
    parser.add_argument('-e', '--epochs', type=int, default=10, help='Number of epochs')
    parser.add_argument('-b', '--batch_size', type=int, default=32, help='Batch Size')
    parser.add_argument('-sd', '--saving_directory', type=str, default='models/', help='Model saving directory')
    parser.add_argument('-nc', '--num_classes', type=int, default=7, help='Number of classes')
    parser.add_argument('-q', '--data_quality', type=int, default=10, help='Min Q value')
    args = parser.parse_args()
    if args.model_path:
        model = get_model(args.model_path, args.num_classes)
    else:
        model = create_model(args.num_classes)
    train_model(model, args.filename, args.epochs, args.batch_size, args.num_classes, args.saving_directory, args.data_quality)

0 个答案:

没有答案