我知道past中已经以不同的方式回答了这个问题。但是我无法弄清楚并适合我的代码,需要帮助。我正在使用cornell movie corpus作为数据集。最终期望为聊天机器人训练LSTM模型。但是我坚持使用最初的一种热编码,并且内存不足。请注意,我正在训练的VM是86GB内存,但是仍然有问题。在nmt_special_utils_mod.py中,一种热编码超出了分配的内存范围,我无法通过阶段。做这些行的任何其他方法都将有所帮助,而不会失去功能
Xoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), X)))
Yoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(machine_vocab)), Y)))
以下所有代码可以使问题明确
import_corpus_mod.py - 变更1:更新了不那么频繁的单词删除功能
def data_load():
TrainDataSetPath = 'D:\\Script\\Python\\NLP\\chatbotSeq2SeqWithAtt\\ChatBot\\'
####initializing libraries####
#import numpy as np
#import tensorflow as tf
import re
#import time
########### Data Pre-processing Part 1##########
def clean_text(text):
'''The function will clean known texts and make it more meaningful'''
text = text.lower()
text = re.sub(r"i'm", "i am", text)
text = re.sub(r"he's", "he is", text)
text = re.sub(r"she's", "she is", text)
text = re.sub(r"it's", "it is", text)
text = re.sub(r"let's", "let us", text)
text = re.sub(r"that's", "that is", text)
text = re.sub(r"what's", "what is", text)
text = re.sub(r"where's", "where is", text)
text = re.sub(r"how's", "how is", text)
text = re.sub(r"howz", "how is", text)
text = re.sub(r"\'ll", " will", text)
text = re.sub(r"\'ve", " have", text)
text = re.sub(r"\'re", " are", text)
text = re.sub(r"\'d", " would", text)
text = re.sub(r"don't", "do not", text)
text = re.sub(r"won't", "will not", text)
text = re.sub(r"can't", "cannot", text)
text = re.sub(r"wouldn't", "would not", text)
text = re.sub(r"wasn't", "was not", text)
text = re.sub(r"haven't", "have not", text)
text = re.sub(r"\s+"," ",text)
text = re.sub(r"[-()\"#/@;:<>+=~|{}.?,]", "", text)
#####Add more below this line######
#####Add more above this line######
return text
lines = open(TrainDataSetPath+'movie_lines.txt', encoding='utf-8', errors='ignore').read().split('\n')
conversations = open(TrainDataSetPath+'movie_conversations_short.txt', encoding='utf-8', errors='ignore').read().split('\n')
#Create dictionary which maps each line with its corresponding ID
id2line = {}
for line in lines:
_line = line.split(' +++$+++ ')
if len(_line) == 5:
id2line[_line[0]] = _line[4]
#Create list of all conversation
conversations_ids = []
for conversation in conversations[:-1]: #the last line in conversation is blank hence -1
#Split then pick last part[-1] which is conversation. Then Removing square bracket by [1:-1] and then replacing quotes and space
_conversation = conversation.split(' +++$+++ ')[-1][1:-1].replace("'","").replace(" ","")
# Append to form a list of list separating by comma
conversations_ids.append(_conversation.split(","))
#Separating the question and answer - assuming the first is the question second is the answer in a conversation
questions = []
answers = []
threshold = 5 #If more than 15 counts of words
for conversation in conversations_ids:
for i in range(len(conversation)-1):
questions.append(id2line[conversation[i]])
answers.append(id2line[conversation[i+1]])
# Cleaning all questions
clean_questions = []
for question in questions:
clean_questions.append(clean_text(question))
# Cleaning all answers
clean_answers = []
for answer in answers:
clean_answers.append(clean_text(answer))
# Creating a dictionary that maps each word to its number of occurrence
word2count = {}
for question in clean_questions:
for word in question.split():
if word not in word2count:
word2count[word] = 1
else:
word2count[word] += 1
for answer in clean_answers:
for word in answer.split():
if word not in word2count:
word2count[word] = 1
else:
word2count[word] += 1
#Create dictionary of words which has more occurrence than threshold
for k in list(word2count):
if word2count[k] < threshold:
del word2count[k]
cleanest_questions, cleanest_answers, keys_list = [], [], list(word2count.keys())
for answers in clean_answers:
ans = []
for word in answers.split():
if word in keys_list:
ans.append(word)
else:
ans.append('<unk>')
cleanest_answers.append(' '.join(ans))
for question in clean_questions:
ques = []
for word in question.split():
if word in keys_list:
ques.append(word)
else:
ques.append('<unk>')
cleanest_questions.append(' '.join(ques))
return cleanest_questions, cleanest_answers
nmt_data_load_asmain_words.py 变更1:更新不频繁的单词删除
from tqdm import tqdm
from import_corpus_mod import data_load
def load_dataset(clean_questions, clean_answers):
"""
Loads a dataset with m examples and vocabularies
:m: the number of examples to generate
"""
human_vocab = set()
machine_vocab = set()
dataset = []
lines = len(clean_questions)
for i in tqdm(range(lines)):
hu, mc = clean_questions[i], clean_answers[i]
if hu is not None:
dataset.append((hu, mc))
human_vocab.update(set(hu.split()))
machine_vocab.update(set(mc.split()))
human = dict(zip(sorted(human_vocab) + ['<pad>'],
list(range(len(human_vocab) + 1))))
#human = dict(zip(sorted(human_vocab) + ['<pad>'],
#list(range(len(human_vocab) + 1))))
#human = dict(zip(sorted(human_vocab),
#list(range(len(human_vocab)))))
machine = dict(zip(sorted(machine_vocab) + ['<pad>'],
list(range(len(machine_vocab) + 1))))
#machine = dict(zip(sorted(machine_vocab) + ['<pad>'],
#list(range(len(machine_vocab) + 1))))
inv_machine = {v:k for k,v in machine.items()}
inv_human = {p:q for q,p in human.items()}
return dataset, human, machine, inv_machine, inv_human
clean_questions, clean_answers = data_load()
dataset, human_vocab, machine_vocab, inv_machine_vocab, inv_human_vocab = load_dataset(clean_questions, clean_answers)
nmt_special_utils_mod.py
import numpy as np
from keras.utils import to_categorical
import keras.backend as K
import matplotlib.pyplot as plt
import sys
# Initiate a list to store integer version of sentences
X_into_int = []
Y_into_int = []
def preprocess_data(dataset, human_vocab, machine_vocab, Tx, Ty):
X, Y = zip(*dataset)
X = np.asarray([string_to_int(i, Tx, human_vocab) for i in X])
Y = [string_to_int(t, Ty, machine_vocab) for t in Y]
Xoh, Yoh = [], []
Xoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), X)))
Yoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(machine_vocab)), Y)))
return X, np.array(Y), Xoh, Yoh
def string_to_int(line, length, vocab):
#print("hello- inside function")
"""
Converts all strings in the vocabulary into a list of integers representing the positions of the
input string's characters in the "vocab"
Arguments:
string -- input string, e.g. 'Hello how are you'
length -- the number of time steps you'd like, determines if the output will be padded or cut
vocab -- vocabulary, dictionary used to index every character of your "string"
Returns:
rep -- list of integers (or '<unk>') (size = length) representing the position of the string's character in the vocabulary
"""
'''
#make lower to standardize
for string in listofstring:
string = string.lower()
string = string.replace(',','')
if len(string) > length:
string = string[:length]
rep = list(map(lambda x: vocab.get(x, '<unk>'), string))
if len(string) < length:
rep += [vocab['<pad>']] * (length - len(string))
#print (rep)
return rep
'''
newlist = []
if len(line.split()) > length:
line = line.split()
for i in range(length):
newlist.append(line[i])
line = ' '.join(newlist)
else:
line = line + ' <pad>' * (length - len(line.split()))
#print(line)
#print("hello- inside padded")
#words_into_int = []
ints = []
for word in line.split():
if word not in vocab:
ints.append(vocab['<unk>'])
else:
ints.append(vocab[word])
#print("hello- inside append if loop")
#words_into_int.append(ints)
#words_into_int = ",".join(x for x in words_into_int)
return ints
def int_to_string(ints, inv_vocab):
"""
Output a machine readable list of characters based on a list of indexes in the machine's vocabulary
Arguments:
ints -- list of integers representing indexes in the machine's vocabulary
inv_vocab -- dictionary mapping machine readable indexes to machine readable characters
Returns:
l -- list of characters corresponding to the indexes of ints thanks to the inv_vocab mapping
"""
l = [inv_vocab[i] for i in ints]
return l
EXAMPLES = ['3 May 1979', '5 Apr 09', '20th February 2016', 'Wed 10 Jul 2007']
def softmax(x, axis=1):
"""Softmax activation function.
# Arguments
x : Tensor.
axis: Integer, axis along which the softmax normalization is applied.
# Returns
Tensor, output of softmax transformation.
# Raises
ValueError: In case `dim(x) == 1`.
"""
ndim = K.ndim(x)
if ndim == 2:
return K.softmax(x)
elif ndim > 2:
e = K.exp(x - K.max(x, axis=axis, keepdims=True))
s = K.sum(e, axis=axis, keepdims=True)
return e / s
else:
raise ValueError('Cannot apply softmax to a tensor that is 1D')
def plot_attention_map(model, input_vocabulary, inv_output_vocabulary, text, n_s = 128, num = 6, Tx = 30, Ty = 10):
"""
Plot the attention map.
"""
attention_map = np.zeros((10, 30))
Ty, Tx = attention_map.shape
s0 = np.zeros((1, n_s))
c0 = np.zeros((1, n_s))
layer = model.layers[num]
encoded = np.array(string_to_int(text, Tx, input_vocabulary)).reshape((1, 30))
encoded = np.array(list(map(lambda x: to_categorical(x, num_classes=len(input_vocabulary)), encoded)))
f = K.function(model.inputs, [layer.get_output_at(t) for t in range(Ty)])
r = f([encoded, s0, c0])
for t in range(Ty):
for t_prime in range(Tx):
attention_map[t][t_prime] = r[t][0,t_prime,0]
# Normalize attention map
# row_max = attention_map.max(axis=1)
# attention_map = attention_map / row_max[:, None]
prediction = model.predict([encoded, s0, c0])
predicted_text = []
for i in range(len(prediction)):
predicted_text.append(int(np.argmax(prediction[i], axis=1)))
predicted_text = list(predicted_text)
predicted_text = int_to_string(predicted_text, inv_output_vocabulary)
text_ = list(text)
# get the lengths of the string
input_length = len(text)
output_length = Ty
# Plot the attention_map
plt.clf()
f = plt.figure(figsize=(8, 8.5))
ax = f.add_subplot(1, 1, 1)
# add image
i = ax.imshow(attention_map, interpolation='nearest', cmap='Blues')
# add colorbar
cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2)
# add labels
ax.set_yticks(range(output_length))
ax.set_yticklabels(predicted_text[:output_length])
ax.set_xticks(range(input_length))
ax.set_xticklabels(text_[:input_length], rotation=45)
ax.set_xlabel('Input Sequence')
ax.set_ylabel('Output Sequence')
# add grid and legend
ax.grid()
#f.show()
return attention_map
nmt_code_mod.py 主要代码
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 10 16:31:44 2018
@author: Anirban
"""
from keras.layers import Bidirectional, Concatenate, Dot, Input, LSTM
from keras.layers import RepeatVector, Dense, Activation
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.models import Model
import keras.backend as K
import numpy as np
from nmt_data_load_asmain_words import load_dataset
from import_corpus_mod import data_load
from nmt_special_utils_mod import *
epochs = 50
clean_questions, clean_answers = data_load()
dataset, human_vocab, machine_vocab, inv_machine_vocab, inv_human_vocab = load_dataset(clean_questions, clean_answers)
m = len(clean_questions)
Tx = 8
Ty = 8
X, Y, Xoh, Yoh = preprocess_data(dataset, human_vocab, machine_vocab, Tx, Ty)
print("X.shape:", X.shape)
print("Y.shape:", Y.shape)
print("Xoh.shape:", Xoh.shape)
print("Yoh.shape:", Yoh.shape)
# Defined shared layers as global variables
repeator = RepeatVector(Tx)
concatenator = Concatenate(axis=-1)
densor1 = Dense(20, activation = "tanh")
densor2 = Dense(1, activation = "relu")
activator = Activation(softmax, name='attention_weights') # We are using a custom softmax(axis = 1) loaded from nmt_special_utils
dotor = Dot(axes = 1)
def one_step_attention(a, s_prev):
"""
Performs one step of attention: Outputs a context vector computed as a dot product of the attention weights
"alphas" and the hidden states "a" of the Bi-LSTM.
Arguments:
a -- hidden state output of the Bi-LSTM, numpy-array of shape (m, Tx, 2*n_a)
s_prev -- previous hidden state of the (post-attention) LSTM, numpy-array of shape (m, n_s)
Returns:
context -- context vector, input of the next (post-attetion) LSTM cell
"""
### START CODE HERE ###
# Use repeator to repeat s_prev to be of shape (m, Tx, n_s) so that you can concatenate it with all hidden states "a" (≈ 1 line)
s_prev = repeator(s_prev)
# Use concatenator to concatenate a and s_prev on the last axis (≈ 1 line)
concat = concatenator([a,s_prev])
# Use densor1 to propagate concat through a small fully-connected neural network to compute the "intermediate energies" variable e. (≈1 lines)
e = densor1(concat)
# Use densor2 to propagate e through a small fully-connected neural network to compute the "energies" variable energies. (≈1 lines)
energies = densor2(e)
# Use "activator" on "energies" to compute the attention weights "alphas" (≈ 1 line)
alphas = activator(energies)
# Use dotor together with "alphas" and "a" to compute the context vector to be given to the next (post-attention) LSTM-cell (≈ 1 line)
context = dotor([alphas,a])
### END CODE HERE ###
return context
n_a = 32
n_s = 64
post_activation_LSTM_cell = LSTM(n_s, return_state = True)
output_layer = Dense(len(machine_vocab), activation=softmax)
def model(Tx, Ty, n_a, n_s, human_vocab_size, machine_vocab_size):
"""
Arguments:
Tx -- length of the input sequence
Ty -- length of the output sequence
n_a -- hidden state size of the Bi-LSTM
n_s -- hidden state size of the post-attention LSTM
human_vocab_size -- size of the python dictionary "human_vocab"
machine_vocab_size -- size of the python dictionary "machine_vocab"
Returns:
model -- Keras model instance
"""
# Define the inputs of your model with a shape (Tx,)
# Define s0 and c0, initial hidden state for the decoder LSTM of shape (n_s,)
X = Input(shape=(Tx, human_vocab_size))
s0 = Input(shape=(n_s,), name='s0')
c0 = Input(shape=(n_s,), name='c0')
s = s0
c = c0
# Initialize empty list of outputs
outputs = []
### START CODE HERE ###
# Step 1: Define your pre-attention Bi-LSTM. Remember to use return_sequences=True. (≈ 1 line)
a = Bidirectional(LSTM(n_a, return_sequences=True),input_shape=(m, Tx, n_a*2))(X)
# Step 2: Iterate for Ty steps
for t in range(Ty):
# Step 2.A: Perform one step of the attention mechanism to get back the context vector at step t (≈ 1 line)
context = one_step_attention(a, s)
# Step 2.B: Apply the post-attention LSTM cell to the "context" vector.
# Don't forget to pass: initial_state = [hidden state, cell state] (≈ 1 line)
s, _, c = post_activation_LSTM_cell(context,initial_state = [s, c])
# Step 2.C: Apply Dense layer to the hidden state output of the post-attention LSTM (≈ 1 line)
out = output_layer(s)
# Step 2.D: Append "out" to the "outputs" list (≈ 1 line)
outputs.append(out)
# Step 3: Create model instance taking three inputs and returning the list of outputs. (≈ 1 line)
model = Model(inputs=[X,s0,c0],outputs=outputs)
### END CODE HERE ###
return model
model = model(Tx, Ty, n_a, n_s, len(human_vocab), len(machine_vocab))
opt = Adam(lr=0.05, beta_1=0.9, beta_2=0.999,decay=0.01)
model.compile(loss='categorical_crossentropy', optimizer=opt,metrics=['accuracy'])
s0 = np.zeros((m, n_s))
c0 = np.zeros((m, n_s))
outputs = list(Yoh.swapaxes(0,1))
model.fit([Xoh, s0, c0], outputs, epochs=epochs, batch_size=5)
EXAMPLES = ['can we make this quick roxanne korrine and andrew barrett are having an incredibly horrendous public break up on the quad again'
,'the thing is cameron i am at the mercy of a particularly hideous breed of loser my sister i cannot date until she does'
,'Hello how are you']
#EXAMPLES = ['13 May 1979', 'Tue 11 Jul 2007','Saturday May 9 2018', 'March 3 2001','March 3rd 2001', '1 March 2001','23 May 2017']
for example in EXAMPLES:
source = np.asarray([string_to_int(example, Tx, human_vocab)])
#need a try block here to prevent errors if vocab is small and example has characters not in the vocab
source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))) #.swapaxes(0,1)
prediction = model.predict([source, s0, c0])
prediction = np.argmax(prediction, axis = -1)
output = [inv_machine_vocab[int(i)] for i in prediction]
pads = output.count('<pad>')
output = output[0:(len(output)-pads)]
print("source:", example)
print("output:", ' '.join(output))
注:该代码与2016年非常著名的研究论文的代码相同,该代码将任何日期时间转换为计算机可理解的日期时间。我正在尝试将其用于我们的聊天机器人-具有注意力模型的Seq2Seq(双向)。该代码有效-只是电影语料库如果在1000个对话中加载就可以正常工作。当您加载完整的语料库时,它会由于内存过载而失败
编辑
感谢您在此问题上的协作努力-非常感谢您在遍历代码并尝试找出最佳解决方案方面遇到的麻烦。正如你指示我已经更新的 import_corpus_mod.py ,以纳入阈值= 5。在最开始转换频率最低的话不到5
现在基于另一点和您共享的代码-我在 nmt_special_utils_mod.py
中对以下行进行了哈希处理#Xoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), X)))
#Yoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(machine_vocab)), Y)))
然后根据您的指导立即更改输入?
Xi = Input(shape=(Tx,))
X = Embedding( human_vocab_size, 100, embeddings_initializer='uniform', input_length=Tx , trainable=True )(Xi)
s0 = Input(shape=(n_s,), name='s0')
c0 = Input(shape=(n_s,), name='c0')
s = s0
c = c0
有很多错误
runfile('D:/Script/Python/NLP/chatbotSeq2SeqWithAtt/ChatBot/nmt_code_mod.py', wdir='D:/Script/Python/NLP/chatbotSeq2SeqWithAtt/ChatBot')
Reloaded modules: nmt_data_load_asmain_words, import_corpus_mod, nmt_special_utils_mod
100%|██████████| 384/384 [00:00<00:00, 24615.06it/s]
100%|██████████| 384/384 [00:00<?, ?it/s]
X.shape: (384, 8)
Y.shape: (384, 8)
D:\Python\Anaconda3\lib\site-packages\keras\engine\topology.py:1592: UserWarning: Model inputs must come from a Keras Input layer, they cannot be the output of a previous non-Input layer. Here, a tensor specified as input to "model_2" was not an Input tensor, it was generated by layer embedding_1.
Note that input tensors are instantiated via `tensor = Input(shape)`.
The tensor that caused the issue was: embedding_1/Gather:0
str(x.name))
Traceback (most recent call last):
File "<ipython-input-44-addb6f9e6bc1>", line 1, in <module>
runfile('D:/Script/Python/NLP/chatbotSeq2SeqWithAtt/ChatBot/nmt_code_mod.py', wdir='D:/Script/Python/NLP/chatbotSeq2SeqWithAtt/ChatBot')
File "D:\Python\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 705, in runfile
execfile(filename, namespace)
File "D:\Python\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 102, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "D:/Script/Python/NLP/chatbotSeq2SeqWithAtt/ChatBot/nmt_code_mod.py", line 138, in <module>
model = model(Tx, Ty, n_a, n_s, len(human_vocab), len(machine_vocab))
File "D:/Script/Python/NLP/chatbotSeq2SeqWithAtt/ChatBot/nmt_code_mod.py", line 132, in model
model = Model(inputs=[X,s0,c0],outputs=outputs)
File "D:\Python\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "D:\Python\Anaconda3\lib\site-packages\keras\engine\topology.py", line 1652, in __init__
layer.__class__.__name__))
TypeError: Input layers to a `Model` must be `InputLayer` objects. Received inputs: [<tf.Tensor 'embedding_1/Gather:0' shape=(?, 8, 100) dtype=float32>, <tf.Tensor 's0_1:0' shape=(?, 64) dtype=float32>, <tf.Tensor 'c0_1:0' shape=(?, 64) dtype=float32>]. Input 0 (0-based) originates from layer type `Embedding`
因此在此处还原nmt_code_mod.py和nmt_special_utils_mod.py的代码
答案 0 :(得分:1)
我不建议使用一次性编码和密集矩阵。 如果您的词汇量为100.000个单词,那么100.000 x 100.000会消耗70Gb以上的RAM。
您可以尝试使用稀疏矩阵。但是我想这会改变您其余的代码。您可以看看这个answer。
您可以使用单词嵌入表示法,它很紧凑,对内存友好,并且被所有最新的NLP系统所使用。
无论如何,有人认为您必须对模型进行处理是使用适当的embedding layer处理嵌入输入。 该层将存储一次嵌入矩阵,然后您可以构建训练样本,只给出一个代表词汇表中单词索引的整数。
如果要使用一种热编码,则可以使用Keras initializer使用NxN单位矩阵来构建嵌入层。其中N是词汇量。然后,您可以将整数形式的单词索引作为输入传递。这样会增加模型的大小,但会减少批次的大小。
如果需要word2vec嵌入,则可以加载NxV尺寸的嵌入矩阵。其中N是词汇量,V是嵌入量。您会注意到V通常设置为100或200维,这比N小得多。为您节省了很多内存。
编辑:说明您的案例中嵌入的用法:
您这样做:
X = Input(shape=(Tx, human_vocab_size))
s0 = Input(shape=(n_s,), name='s0')
c0 = Input(shape=(n_s,), name='c0')
s = s0
c = c0
相反,您可以通过以下方式进行一次热编码:
Xi = Input(shape=(Tx,))
X = Embedding( human_vocab_size, human_vocab_size, embeddings_initializer=keras.initializers.Identity, input_length=Tx )(Xi)
s0 = Input(shape=(n_s,), name='s0')
c0 = Input(shape=(n_s,), name='c0')
s = s0
c = c0
这样做,您可以仅使用单词索引而不使用一个热向量来构建训练样本。这将为您节省训练样本中的一些空间,但是您的模型将更大。 如果仍然太大,则只能选择使用密集的嵌入。为此,您可以执行以下操作:
Xi = Input(shape=(Tx,))
X = Embedding( human_vocab_size, 100, embeddings_initializer='uniform', input_length=Tx , trainable=True )(Xi)
s0 = Input(shape=(n_s,), name='s0')
c0 = Input(shape=(n_s,), name='c0')
s = s0
c = c0
这会以紧凑的表示形式(尺寸100代替human_vocab_size)随机初始化嵌入。这样可以节省大量内存。
最后,您可以通过将所有内容都转换为小写字母或用特殊标记“ RARE”代替稀有词(在语料库中仅出现一次或两次)来减少词汇量
答案 1 :(得分:1)
问题不是单热编码,而是整个数据集存储在内存中。明智的选择是使用生成器或Sequence,它可以让您动态加载和编码数据。例如,通常对大型图像数据集执行此操作。
我建议您执行所有预处理并保存输入,输出对,而无需将其编码为csv文件,然后您可以创建一个延迟加载和编码的生成器:
class MySequence(Sequence):
def __init__(self, data, batch_size):
self.data_file = data
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, batch_id):
# Get corresponding batch data...
# one-hot encode
return X, Y
请注意,生成器(或Sequence [i])返回单个批处理。