如何学习字段​​描述和可能的类别之间的关系

时间:2017-07-03 09:43:57

标签: tensorflow

我看起来像这样的产品之间的关系

enter image description here

seq2seq是一种正确的方法,因此我的模型可以学习产品(左栏)和类别(右栏)中的文本之间的关系,然后能够在给定产品描述的情况下预测未来的类别吗? / p>

1 个答案:

答案 0 :(得分:1)

Seq2Seq基本上有两个不同的递归神经网络连接在一起:一个接收输入text tokens的编码器RNN和一个根据编码器RNN的输出开始产生text tokens的解码器RNN。它是序列网络的序列。但是我看到你的情况,输入是一个序列,输出是一个基于输入的类别。您最好尝试使用LSTM网络,将您的输入序列通过embedding layer,然后将hidden state的最后LSTM传递给dense layer进行分类

适用于您的用例的LSTM模型:

输入和输出的占位符

# input batch of text sequences of length `seq_length`
X = tf.placeholder(tf.int32, [None, seq_length], name='input')

# Output class labels 
y = tf.placeholder(tf.float32, [None], name='labels')

嵌入图层

# For every word in your vocab you need a embedding vector. 
# The below weights are not trainable as we will `init` with `pre-trained` embeddings. If you dont want to do that set it to True.
W = tf.get_variable(initializer=tf.random_uniform([vocab_size, embedding_size]), name='embed', trainable=False)

# Get the embedding representation of the inputs
embed = tf.nn.embedding_lookup(W, X)

LSTM图层

# Create basic LSTMCell, with number of hidden units as a input param
def lstm_cell():
  return tf.contrib.rnn.BasicLSTMCell(n_hidden) 

# Create a stack of LSTM cells (if you need)
stacked_lstm = tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(n_layers)])

# Create a dynamic RNN to handle the sequence of inputs
output, _ = tf.nn.dynamic_rnn(stacked_lstm, x, dtype=tf.float32)

# get the output of the last hidden state
last_hidden = output[:, -1, :]

最终致密层

# output dimension should be `n_classes`.
logits = dense_layer(last_hidden ...)

这应该是你的模特。