如何将此RNN文本分类代码更改为文本生成?

时间:2019-09-23 02:53:15

标签: tensorflow machine-learning nlp recurrent-neural-network text-classification

我有这段代码可使用TensorFlow RNN进行文本分类,但是如何更改它以进行文本生成?

以下文本分类具有3D输入,但具有2D输出。是否应将其更改为用于文本生成的3D输入和3D输出?以及如何?

示例数据为:

t0      t1      t2
british gray    is => cat (y=0)
0       1       2
white   samoyed is => dog (y=1)
3       4       2 

对于分类进料,“英国灰色是”会导致“猫”。我希望得到的是喂食“英国肉”,应产生下一个单词“灰色”。

import tensorflow as tf;
tf.reset_default_graph();

#data
'''
t0      t1      t2
british gray    is => cat (y=0)
0       1       2
white   samoyed is => dog (y=1)
3       4       2 
'''
Bsize = 2;
Times = 3;
Max_X = 4;
Max_Y = 1;

X = [[[0],[1],[2]], [[3],[4],[2]]];
Y = [[0],           [1]          ];

#normalise
for I in range(len(X)):
  for J in range(len(X[I])):
    X[I][J][0] /= Max_X;

for I in range(len(Y)):
  Y[I][0] /= Max_Y;

#model
Inputs   = tf.placeholder(tf.float32, [Bsize,Times,1]);
Expected = tf.placeholder(tf.float32, [Bsize,      1]);

#single LSTM layer
#'''
Layer1   = tf.keras.layers.LSTM(20);
Hidden1  = Layer1(Inputs);
#'''

#multi LSTM layers
'''
Layers = tf.keras.layers.RNN([
  tf.keras.layers.LSTMCell(30), #hidden 1
  tf.keras.layers.LSTMCell(20)  #hidden 2
]);
Hidden2 = Layers(Inputs);
'''

Weight3  = tf.Variable(tf.random_uniform([20,1], -1,1));
Bias3    = tf.Variable(tf.random_uniform([   1], -1,1));
Output   = tf.sigmoid(tf.matmul(Hidden1,Weight3) + Bias3);

Loss     = tf.reduce_sum(tf.square(Expected-Output));
Optim    = tf.train.GradientDescentOptimizer(1e-1);
Training = Optim.minimize(Loss);

#train
Sess = tf.Session();
Init = tf.global_variables_initializer();
Sess.run(Init);

Feed = {Inputs:X, Expected:Y};
for I in range(1000): #number of feeds, 1 feed = 1 batch
  if I%100==0: 
    Lossvalue = Sess.run(Loss,Feed);
    print("Loss:",Lossvalue);
  #end if

  Sess.run(Training,Feed);
#end for

Lastloss = Sess.run(Loss,Feed);
print("Loss:",Lastloss,"(Last)");

#eval
Results = Sess.run(Output,Feed);
print("\nEval:");
print(Results);

print("\nDone.");
#eof

1 个答案:

答案 0 :(得分:0)

我发现了如何切换它(代码)以执行文本生成任务,使用3D输入(X)和3D标签(Y)的方法,如下面的源代码所示:

源代码:

import tensorflow as tf;
tf.reset_default_graph();

#data
'''
t0       t1       t2
british  gray     is  cat
0        1        2   (3)  <=x
1        2        3        <=y
white    samoyed  is  dog
4        5        2   (6)  <=x
5        2        6        <=y 
'''
Bsize = 2;
Times = 3;
Max_X = 5;
Max_Y = 6;

X = [[[0],[1],[2]], [[4],[5],[2]]];
Y = [[[1],[2],[3]], [[5],[2],[6]]];

#normalise
for I in range(len(X)):
  for J in range(len(X[I])):
    X[I][J][0] /= Max_X;

for I in range(len(Y)):
  for J in range(len(Y[I])):
    Y[I][J][0] /= Max_Y;

#model
Input    = tf.placeholder(tf.float32, [Bsize,Times,1]);
Expected = tf.placeholder(tf.float32, [Bsize,Times,1]);

#single LSTM layer
'''
Layer1   = tf.keras.layers.LSTM(20);
Hidden1  = Layer1(Input);
'''

#multi LSTM layers
#'''
Layers = tf.keras.layers.RNN([
  tf.keras.layers.LSTMCell(30), #hidden 1
  tf.keras.layers.LSTMCell(20)  #hidden 2
],
return_sequences=True);
Hidden2 = Layers(Input);
#'''

Weight3  = tf.Variable(tf.random_uniform([20,1], -1,1));
Bias3    = tf.Variable(tf.random_uniform([   1], -1,1));
Output   = tf.sigmoid(tf.matmul(Hidden2,Weight3) + Bias3); #sequence of 2d * 2d

Loss     = tf.reduce_sum(tf.square(Expected-Output));
Optim    = tf.train.GradientDescentOptimizer(1e-1);
Training = Optim.minimize(Loss);

#train
Sess = tf.Session();
Init = tf.global_variables_initializer();
Sess.run(Init);

Feed   = {Input:X, Expected:Y};
Epochs = 10000;

for I in range(Epochs): #number of feeds, 1 feed = 1 batch
  if I%(Epochs/10)==0: 
    Lossvalue = Sess.run(Loss,Feed);
    print("Loss:",Lossvalue);
  #end if

  Sess.run(Training,Feed);
#end for

Lastloss = Sess.run(Loss,Feed);
print("Loss:",Lastloss,"(Last)");

#eval
Results = Sess.run(Output,Feed).tolist();
print("\nEval:");
for I in range(len(Results)):
  for J in range(len(Results[I])):
    for K in range(len(Results[I][J])):
      Results[I][J][K] = round(Results[I][J][K]*Max_Y);
#end for i      
print(Results);

print("\nDone.");
#eof