在保存的模型中缺少tensorflow2稀疏输入名称

时间:2020-11-12 01:53:28

标签: tensorflow tensorflow-serving

tensorflow 2.3 ubuntu16.04 python = 3.7.7

在将tf.keras.Input与'sparse = True'一起使用时,输入张量信息名称在服务签名(例如args_0,args_0_1,args_0_2)中不可读。结果,很难区分何时在一个模型中使用多个稀疏输入。

源代码/日志

def make_parse_example_test(serialized_example):
    # 解析的字典格式
    data_dict_test = {  # 解析example
        'label': tf.io.FixedLenFeature([1], tf.float32),
        'features': tf.io.FixedLenFeature([38], tf.float32),
        # 'emb_arr': tf.io.FixedLenFeature([100],tf.float32),
        'scid_index': tf.io.VarLenFeature(tf.int64),
    }
    features = tf.io.parse_single_example(serialized_example, features=data_dict_test)
    label = features.pop('label')
    return features,label

def batch_input(file_dir,batchsize):
    # 判断是否是文件目录,创建文件流
    if os.path.isdir(file_dir):
        files = os.listdir(file_dir)
        filenamequeues = list(map(lambda x: file_dir+x, files))
    else:
        filenamequeues = [file_dir]
    print(filenamequeues)
    dataset = tf.data.TFRecordDataset(filenamequeues)
    # dataset = dataset.batch(batchsize)
    dataset = dataset.map(make_parse_example_test,num_parallel_calls=4)
    dataset = dataset.batch(batchsize)
    # 读入数据,对数据进行混洗(shuffle)、分批batch
    dataset = dataset.prefetch(-1)
    return dataset

class EmbeddingLayer(tf.keras.layers.Layer):
    def __init__(self, input_dim, output_dim):
        super(EmbeddingLayer, self).__init__(trainable=True)
        self.params = tf.Variable(tf.random.truncated_normal([input_dim, output_dim]), trainable=True)
    def call(self,inputs):
        param = tf.nn.safe_embedding_lookup_sparse(self.params,inputs)
        return param
    def get_config(self):
        return super(EmbeddingLayer, self).get_config()

def models():
    feature1 = tf.keras.layers.Input(shape=[38], name='features', dtype=tf.float32)
    scid_index = tf.keras.layers.Input(shape=[None], name='scid_index', dtype=tf.int64,sparse=True)
    scid_em = EmbeddingLayer(780000,50)(scid_index)
    feature_all = tf.keras.layers.concatenate([
        feature1,scid_em
    ])
    h1 = tf.keras.layers.Dense(256, activation=leaky_relu,name='h1')(feature_all)
    h2 = tf.keras.layers.Dense(256, activation=leaky_relu, name="h2")(h1)
    h3 = tf.keras.layers.Dense(1, activation=leaky_relu, name="h3")(h2)
    output = tf.keras.layers.Activation(activation="sigmoid")(h3)
    out = tf.keras.models.Model(
        inputs=[feature1,scid_index],
        outputs=[output]
    )
    return out
train = batch_input('./part-r-00001',512)
model = models()
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=[tf.keras.metrics.AUC(),
                                                             tf.keras.metrics.Precision(),
                                                                     tf.keras.metrics.Recall(),
                                                                     tf.keras.metrics.BinaryAccuracy()])

model.fit(train, epochs=1,verbose=1,class_weight={0:1,1:2},steps_per_epoch=5)
print(model.input_names)
model.save("./model3test")

[检查export_model] $ saved_model_cli show --dir ./model3test --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['args_0'] tensor_info:
        dtype: DT_INT64
        shape: (-1, 2)
        name: serving_default_args_0:0
    inputs['args_0_1'] tensor_info:
        dtype: DT_INT64
        shape: (-1)
        name: serving_default_args_0_1:0
    inputs['args_0_2'] tensor_info:
        dtype: DT_INT64
        shape: (2)
        name: serving_default_args_0_2:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['label'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: StatefulPartitionedCall_18:0
  Method name is: tensorflow/serving/predict

0 个答案:

没有答案