我必须处理两个文本文件,其中有几个来自酒店的评论。每个评论旁边都有一个值,表明它是真实的评论还是欺骗性的评论。 为了处理测试和训练集,我有这部分代码:
import csv
x_train = list()
y_train = list()
with open('TRAINING_ALL.txt', encoding='utf-8') as infile:
reader = csv.reader(infile, delimiter='\t')
for row in reader:
x_train.append(row[0])
y_train.append(int(row[1]))
x_test = list()
y_test = list()
with open('TEST_ALL.txt', encoding='utf-8') as infile:
reader = csv.reader(infile, delimiter='\t')
for row in reader:
x_test.append(row[0])
y_test.append(int(row[1]))
然后我必须使用神经网络进行分类。但是,在加载数据部分我遇到了问题:
print('Loading data...')
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')
print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)
我得到了:
Loading data...
480 train sequences
320 test sequences
Pad sequences (samples x time)
到目前为止还不错。它读取正确的序列数。那么错误:
ValueError: invalid literal for int() with base 10: "ould take a quick dip in the pool. I toured the hotel as my niece is planning her wedding and just so happens to live close to the hotel. The ' Chagall Ballroom ', was elegant enough for such an occa
给这段代码的正确输入是什么?
请注意,代码最初工作正常如下(从imdb获取数据集):
print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')
print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)
可能x_train和x_test的格式不正确吗?
答案 0 :(得分:1)
当您从csv文件加载数据时,您也加载了包含列名称的第一行,您可以轻松地检查查看x_train和x_test中的第一个元素。如果是这种情况,你可以像这样跳过第一行
import csv
x_train = list()
y_train = list()
with open('TRAINING_ALL.txt', encoding='utf-8') as infile:
reader = csv.reader(infile, delimiter='\t')
next(reader)
for row in reader:
x_train.append(row[0])
y_train.append(int(row[1]))
x_test = list()
y_test = list()
with open('TEST_ALL.txt', encoding='utf-8') as infile:
reader = csv.reader(infile, delimiter='\t')
next(reader)
for row in reader:
x_test.append(row[0])
y_test.append(int(row[1]))