如何在tensorflow C api中创建字符串类型张量?

时间:2016-12-14 09:18:29

标签: c tensorflow

参数列表中的data究竟应该是什么?

TF_Tensor* tensorStr = TF_NewTensor(TF_STRING, nullptr, 0, &data[0], 8, no_op, nullptr);

我试过了:

char * data = "blah";
char* data[] = {"blah"};
char data[1][4] = {{'b','l','a','h'}};
一切都运气不好。当输入到输入。我总是得到:

Malformed TF_STRING tensor; element 0 out of range

3 个答案:

答案 0 :(得分:2)

创建字符串张量的有效(但有点难看)代码示例:

std::string input_str = "abracdabra";  // any input string
size_t encoded_size = TF_StringEncodedSize(input_str.size());
size_t total_size = 8 + encoded_size;  // 8 extra bytes - for start_offset 
char *input_encoded = (char*)malloc(total_size);
for (int i =0; i < 8; ++i) {  // fills start_offset
    input_encoded[i] = 0;
}
TF_StringEncode(input_str.c_str(), input_str.size(), input_encoded+8, encoded_size, status); // fills the rest of tensor data
if (TF_GetCode(status) != TF_OK){
    fprintf(stderr, "ERROR: something wrong with encoding: %s", TF_Message(status));
    return 1;
}
TF_Tensor* input = TF_NewTensor(TF_STRING, NULL, 0, input_encoded, total_size, &Deallocator, 0);

为什么会起作用:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h#L213根据此链接,字符串张量的数据由两部分组成。最后一个是通过TF_StringEncode函数编码的输入字符串。第一个是数组'start_offset',我不完全理解它的作用,但看起来像八个零做的伎俩)

张量创建的另一个例子可以在C API测试中找到:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api_test.cc#L1934

答案 1 :(得分:1)

除非Tensor是标量(只保留一个数字),否则您需要传递维度信息。你传递的是nullptr,而当前代码中的dims是零,这就是它引发错误的原因。您可以在此处查看如何为字符串调用TF_NewTensor的示例: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.cc#L441

答案 2 :(得分:-2)

我对tensorflow一无所知,但C中的字符串是一个数组(通常是未指定的legnth)字符,所以很可能你应该准备一个字符串数组作为指针的数组字符数组。像这样:

char *str1 = "abcde";
char *str2 = "GHIJ";
char *str3 = "123";
char *str4 = "ONE two THREE!!!";
....

char *data[] = {str1, str2, str3, str4};

或:

char str1[] = "abcde";
char str2[] = "GHIJ";
char str3[] = "123";
char str4[] = "ONE two THREE!!!";
....

char *data[] = {str1, str2, str3, str4};