处理大小(约)30 000的稀疏向量的最佳方法是什么,其中所有索引都为零,除了一个值为1的索引(1-HOT向量)?
在我的数据集中,我有一系列值,我为每个值转换为一个1-HOT向量。以下是我目前的工作:
# Create some queues to read data from .csv files
...
# Parse example(/line) from the data file
example = tf.decode_csv(value, record_defaults=record_defaults)
# example now looks like (e.g) [[5], [1], [4], [38], [571], [9]]
# [5] indicates the length of the sequence
# 1, 4, 38, 571 is the input sequence
# 4, 38, 571, 9 is the target sequence
# Create 1-HOT vectors for each value in the sequence
sequence_length = example[0]
one_hots = example[1:]
one_hots = tf.reshape(one_hots, [-1])
one_hots = tf.one_hot(one_hots, depth=n_classes)
# Grab the first values as the input features and the last values as target
features = one_hots[:-1]
targets = one_hots[1:]
...
# The sequence_length, features and targets are added to a list
# and the list is sent into a batch with tf.train_batch_join(...).
# So now I can get batches and feed into my RNN
...
这有效,但我相信它可以以更有效的方式完成。我查看了SparseTensor,但我无法弄清楚如何从example
得到tf.decode_csv
张量来创建SparseTensors。我读到了somwhere,最好在批量检索数据后对其进行解析,这仍然是正确的吗?
Here是完整代码的pastebin。第32行是我目前创建1-HOT向量的方法。
答案 0 :(得分:0)
不是处理将输入转换为稀疏的1个热矢量,而是优先使用tf.nn.embedding_lookup
,它只选择要乘以的矩阵的相关行。这相当于矩阵乘以1热矢量。
这是一个用法示例
import React from 'react';
export default class CommentForm extends React.Component {
constructor(props){
super(props);
this.clickSubmitComment = this.clickSubmitComment.bind(this);
this.comments = [];
}
clickSubmitComment() {
textarea -> comments -> send props to comment.jsx and view?
}
render() {
return (
<div><textarea className="form-control" rows="3"></textarea><br></br>
<button type="submit" className="btn btn-primary" onClick={this.clickSubmitComment}>Submit</button></div>
);
}
}
另请参阅im2txt项目中的example,了解如何为RNN提供此类数据(embed_dim = 3;
vocab_size = 10;
E = np.random.rand(vocab_size, embed_dim)
print E
embeddings = tf.Variable(E)
examples = tf.Variable(np.array([4,5, 2,9]).astype('int32'))
examples_embedded = tf.nn.embedding_lookup(embeddings, examples)
s = tf.InteractiveSession()
s.run(tf.initialize_all_variables())
print ''
print examples_embedded.eval()
行