Python 3.6 Anaconda
我正在这样做:
def load_data(file, episode):
data = load_file_data(file)
return map(list, zip(*[iter(data)] * episode))
def load_file_data(file):
with open(file, 'rb') as myFile:
# data = six.moves.cPickle.load(myFile)
data = pickle.load(myFile, encoding='latin')
return data
def make_supervised_data(data, data_dict):
supervised_data = []
if not os.path.exists('supervised_data.pkl'):
i = 0
for episode in data:
if i % 1000 == 13:
print(i)
i += 1
supervised_data.append(episode_supervised_data(episode,
data_dict))
with open("supervised_data.pkl", "wb") as myFile:
pickle.dump(supervised_data, myFile, -1)
with open('supervised_data.pkl', 'rb') as myFile:
supervised_data = pickle.load(myFile, encoding='latin')
return supervised_data
data = episodic_data.load_data("data.pkl",episode=10)
data_dict = episodic_data.load_file_data("data_dict.pkl")
supervised_y_data = episodic_data.make_supervised_data(data, data_dict)
x_train, x_test, y_train, y_test = train_test_split(data, supervised_y_data,
test_size=0.10, random_state=123)
我收到此错误:
x_train, x_test, y_train, y_test = train_test_split(data, supervised_y_data, test_size=0.10, random_state=123)
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\model_selection_split.py", line 1689, in train_test_split
arrays = indexable(*arrays)
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\utils\validation.py", line 206, in indexable
check_consistent_length(*result)
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\utils\validation.py", line 177, in check_consistent_length
lengths = [_num_samples(X) for X in arrays if X is not None]
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\utils\validation.py", line 177, in
lengths = [_num_samples(X) for X in arrays if X is not None]
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\utils\validation.py", line 126, in _num_samples
" a valid collection." % x)
TypeError: Singleton array array(<map object at 0x00000190A1D75EF0>, dtype=object) cannot be considered a valid collection.