我尝试使用tf hub modules导出用于文本分类的模型,然后使用predictor.from_saved_model()从模型中为单个字符串示例进行预测。我看到some examples有类似的想法,但在使用tf集线器模块构建功能时仍然无法解决问题。这是我的工作:
train_input_fn = tf.estimator.inputs.pandas_input_fn(
train_df, train_df['label_ids'], num_epochs= None, shuffle=True)
# Prediction on the whole training set.
predict_train_input_fn = tf.estimator.inputs.pandas_input_fn(
train_df, train_df['label_ids'], shuffle=False)
embedded_text_feature_column = hub.text_embedding_column(
key='sentence',
module_spec='https://tfhub.dev/google/nnlm-de-dim128/1')
#Estimator
estimator = tf.estimator.DNNClassifier(
hidden_units=[500, 100],
feature_columns=[embedded_text_feature_column],
n_classes=num_of_class,
optimizer=tf.train.AdagradOptimizer(learning_rate=0.003) )
# Training
estimator.train(input_fn=train_input_fn, steps=1000)
#prediction on training set
train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)
print('Training set accuracy: {accuracy}'.format(**train_eval_result))
feature_spec = tf.feature_column.make_parse_example_spec([embedded_text_feature_column])
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
export_dir_base = self.cfg['model_path']
servable_model_path = estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn)
# Example message for inference
message = "Was ist denn los"
saved_model_predictor = predictor.from_saved_model(export_dir=servable_model_path)
content_tf_list = tf.train.BytesList(value=[str.encode(message)])
example = tf.train.Example(
features=tf.train.Features(
feature={
'sentence': tf.train.Feature(
bytes_list=content_tf_list
)
}
)
)
with tf.python_io.TFRecordWriter('the_message.tfrecords') as writer:
writer.write(example.SerializeToString())
reader = tf.TFRecordReader()
data_path = 'the_message.tfrecords'
filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
_, serialized_example = reader.read(filename_queue)
output_dict = saved_model_predictor({'inputs': [serialized_example]})
输出:
Traceback (most recent call last):
File "/Users/dimitrs/component-pythia/src/pythia.py", line 321, in _train
model = algo.generate_model(samples, generation_id)
File "/Users/dimitrs/component-pythia/src/algorithm_layer/algorithm.py", line 56, in generate_model
model = self._process_training(samples, generation)
File "/Users/dimitrs/component-pythia/src/algorithm_layer/tf_hub_classifier.py", line 91, in _process_training
output_dict = saved_model_predictor({'inputs': [serialized_example]})
File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/contrib/predictor/predictor.py", line 77, in __call__
return self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict)
File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1135, in _run
feed_dict_tensor, options, run_metadata)
File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
run_metadata)
File "/Users/dimitrs/anaconda3/envs/pythia/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: Unable to get element as bytes.
serialized_example
所建议的输入不是serving_input_receiver_fn
吗?
答案 0 :(得分:1)
所以,我只需要 # Example message for inference
message = "Was ist denn los"
saved_model_predictor = predictor.from_saved_model(export_dir=servable_model_path)
content_tf_list = tf.train.BytesList(value=[message.encode('utf-8')])
sentence = tf.train.Feature(bytes_list=content_tf_list)
sentence_dict = {'sentence': sentence}
features = tf.train.Features(feature=sentence_dict)
example = tf.train.Example(features=features)
serialized_example = example.SerializeToString()
output_dict = saved_model_predictor({'inputs': [serialized_example]})
在文件上编写示例需要在读回文件之前启动会话。只需序列化就足够了:
void ft_perm(char *star)
{
struct stat fileStat;
if(stat(star,&fileStat) < 0)
return;
printf((S_ISDIR(fileStat.st_mode)) ? "d" : "-");
printf( (fileStat.st_mode & S_IRUSR) ? "r" : "-");
printf( (fileStat.st_mode & S_IWUSR) ? "w" : "-");
printf( (fileStat.st_mode & S_IXUSR) ? "x" : "-");
printf( (fileStat.st_mode & S_IRGRP) ? "r" : "-");
printf( (fileStat.st_mode & S_IWGRP) ? "w" : "-");
printf( (fileStat.st_mode & S_IXGRP) ? "x" : "-");
printf( (fileStat.st_mode & S_IROTH) ? "r" : "-");
printf( (fileStat.st_mode & S_IWOTH) ? "w" : "-");
printf( (fileStat.st_mode & S_IXOTH) ? "x" : "-");
return;
}
char *ft_group(char *tim)
{
int i;
int j;
char *s;
s = (char *)malloc (sizeof(char) * 16);
i = 3;
j = 0;
while( i < 16 )
{
s[j] = tim[i];
i++;
j++;
}
return(s);
}
int count ()
{
DIR *dir;
struct dirent *sd;
int i;
dir = opendir(".");
if (!dir)
{
printf("error");
exit(1);
}
i = 0;
while((sd = readdir(dir) ))
{
i++;
}
return(i);
}
int main(int argc, char **argv)
{
struct stat statbuf;
struct group *grp;
struct passwd *pwd;
pwd = getpwuid((geteuid()));
struct dirent **sd;
;
int i = 1;
int r = 1;
while( r < 10)
{
stat(sd[i]->d_name,&statbuf);
ft_perm(sd[i]);
printf("1 %s",pwd->pw_name);
if ((grp = getgrgid(statbuf.st_gid)) != NULL)
printf(" %-8.8s", grp->gr_name);
else
printf(" %-8d", statbuf.st_gid);
printf(" %d", (int)statbuf.st_size);
printf(" %s\n", ft_group(ctime(&statbuf.st_atime)));
r++;
i++;
}
}