将示例提供给tf预报器.from_saved_model(),以使用tf集线器模块训练的估算器

时间:2018-08-06 19:06:08

标签: tensorflow tensorflow-hub

我尝试使用tf hub modules导出用于文本分类的模型,然后使用predictor.from_saved_model()从模型中为单个字符串示例进行预测。我看到some examples有类似的想法,但在使用tf集线器模块构建功能时仍然无法解决问题。这是我的工作:

        train_input_fn = tf.estimator.inputs.pandas_input_fn(
        train_df, train_df['label_ids'], num_epochs= None, shuffle=True)

    # Prediction on the whole training set.
    predict_train_input_fn = tf.estimator.inputs.pandas_input_fn(
        train_df, train_df['label_ids'], shuffle=False)

    embedded_text_feature_column = hub.text_embedding_column(
        key='sentence',
        module_spec='https://tfhub.dev/google/nnlm-de-dim128/1')

    #Estimator
    estimator = tf.estimator.DNNClassifier(
        hidden_units=[500, 100],
        feature_columns=[embedded_text_feature_column],
        n_classes=num_of_class,
        optimizer=tf.train.AdagradOptimizer(learning_rate=0.003) )

    # Training
    estimator.train(input_fn=train_input_fn, steps=1000)

    #prediction on training set
    train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)

    print('Training set accuracy: {accuracy}'.format(**train_eval_result))

    feature_spec = tf.feature_column.make_parse_example_spec([embedded_text_feature_column])
    serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

    export_dir_base = self.cfg['model_path']
    servable_model_path = estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn)

    # Example message for inference
    message = "Was ist denn los"
    saved_model_predictor = predictor.from_saved_model(export_dir=servable_model_path)
    content_tf_list = tf.train.BytesList(value=[str.encode(message)])
    example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    'sentence': tf.train.Feature(
                        bytes_list=content_tf_list
                    )
                }
            )
        )

    with tf.python_io.TFRecordWriter('the_message.tfrecords') as writer:
        writer.write(example.SerializeToString())

    reader = tf.TFRecordReader()
    data_path = 'the_message.tfrecords'
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
    _, serialized_example = reader.read(filename_queue)
    output_dict = saved_model_predictor({'inputs': [serialized_example]})

输出:

Traceback (most recent call last):
  File "/Users/dimitrs/component-pythia/src/pythia.py", line 321, in _train
    model = algo.generate_model(samples, generation_id)
  File "/Users/dimitrs/component-pythia/src/algorithm_layer/algorithm.py", line 56, in generate_model
    model = self._process_training(samples, generation)
  File "/Users/dimitrs/component-pythia/src/algorithm_layer/tf_hub_classifier.py", line 91, in _process_training
    output_dict = saved_model_predictor({'inputs': [serialized_example]})
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/contrib/predictor/predictor.py", line 77, in __call__
    return self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict)
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1135, in _run
    feed_dict_tensor, options, run_metadata)
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
    run_metadata)
  File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Unable to get element as bytes.

serialized_example所建议的输入不是serving_input_receiver_fn吗?

1 个答案:

答案 0 :(得分:1)

所以,我只需要 # Example message for inference message = "Was ist denn los" saved_model_predictor = predictor.from_saved_model(export_dir=servable_model_path) content_tf_list = tf.train.BytesList(value=[message.encode('utf-8')]) sentence = tf.train.Feature(bytes_list=content_tf_list) sentence_dict = {'sentence': sentence} features = tf.train.Features(feature=sentence_dict) example = tf.train.Example(features=features) serialized_example = example.SerializeToString() output_dict = saved_model_predictor({'inputs': [serialized_example]}) 在文件上编写示例需要在读回文件之前启动会话。只需序列化就足够了:

void  ft_perm(char *star)
    {
        struct stat fileStat;
        if(stat(star,&fileStat) < 0)
            return;
        printf((S_ISDIR(fileStat.st_mode))  ? "d" : "-");
        printf( (fileStat.st_mode & S_IRUSR) ? "r" : "-");
        printf( (fileStat.st_mode & S_IWUSR) ? "w" : "-");
        printf( (fileStat.st_mode & S_IXUSR) ? "x" : "-");
        printf( (fileStat.st_mode & S_IRGRP) ? "r" : "-");
        printf( (fileStat.st_mode & S_IWGRP) ? "w" : "-");
        printf( (fileStat.st_mode & S_IXGRP) ? "x" : "-");
        printf( (fileStat.st_mode & S_IROTH) ? "r" : "-");
        printf( (fileStat.st_mode & S_IWOTH) ? "w" : "-");
        printf( (fileStat.st_mode & S_IXOTH) ? "x" : "-");
        return;
    }

    char *ft_group(char *tim)
    {
        int i;
        int j;
        char *s;

        s = (char *)malloc (sizeof(char) * 16);
        i = 3;
        j = 0;
        while( i < 16 )
       {
           s[j] = tim[i];
           i++;
           j++;
       }
       return(s);
    }

    int count ()
    {
        DIR *dir;
        struct dirent *sd;
        int i;

        dir = opendir(".");
        if (!dir)
        {
            printf("error");
            exit(1);
        }
        i = 0;
        while((sd = readdir(dir) ))
        {
            i++;
        }
        return(i);
    }

    int main(int argc, char **argv)
    {
        struct stat statbuf;
        struct group *grp;
        struct passwd *pwd;
        pwd = getpwuid((geteuid()));
        struct dirent **sd; 
        ;
        int i = 1;
        int r = 1;
        while( r < 10)
        {
        stat(sd[i]->d_name,&statbuf);

        ft_perm(sd[i]);
        printf("1  %s",pwd->pw_name);

        if ((grp = getgrgid(statbuf.st_gid)) != NULL)
            printf(" %-8.8s", grp->gr_name);
        else
            printf(" %-8d", statbuf.st_gid);
        printf(" %d", (int)statbuf.st_size);
        printf(" %s\n", ft_group(ctime(&statbuf.st_atime)));
        r++;
        i++;
        }
    }