sklearn单例数组问题

时间:2018-10-22 11:45:04

标签: python-3.x unicode scikit-learn

Python 3.6 Anaconda

我正在这样做:

def load_data(file, episode):
    data = load_file_data(file)
    return map(list, zip(*[iter(data)] * episode))


def load_file_data(file):
    with open(file, 'rb') as myFile:
        # data = six.moves.cPickle.load(myFile)
        data = pickle.load(myFile, encoding='latin')
    return data

def make_supervised_data(data, data_dict):
    supervised_data = []
    if not os.path.exists('supervised_data.pkl'):
        i = 0
        for episode in data:
            if i % 1000 == 13:
                print(i)
            i += 1
            supervised_data.append(episode_supervised_data(episode, 
                                                       data_dict))
        with open("supervised_data.pkl", "wb") as myFile:
            pickle.dump(supervised_data, myFile, -1)
    with open('supervised_data.pkl', 'rb') as myFile:
        supervised_data = pickle.load(myFile, encoding='latin')
    return supervised_data

data = episodic_data.load_data("data.pkl",episode=10)

data_dict = episodic_data.load_file_data("data_dict.pkl")

supervised_y_data  = episodic_data.make_supervised_data(data, data_dict)

x_train, x_test, y_train, y_test = train_test_split(data, supervised_y_data, 
  test_size=0.10, random_state=123)

我收到此错误:

x_train, x_test, y_train, y_test = train_test_split(data, supervised_y_data, test_size=0.10, random_state=123)
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\model_selection_split.py", line 1689, in train_test_split
arrays = indexable(*arrays)
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\utils\validation.py", line 206, in indexable
check_consistent_length(*result)
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\utils\validation.py", line 177, in check_consistent_length
lengths = [_num_samples(X) for X in arrays if X is not None]
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\utils\validation.py", line 177, in 
lengths = [_num_samples(X) for X in arrays if X is not None]
File "C:\AnacondaPython3_6\DeepLearning_Udemy\lib\site-packages\sklearn\utils\validation.py", line 126, in _num_samples
" a valid collection." % x)
TypeError: Singleton array array(<map object at 0x00000190A1D75EF0>, dtype=object) cannot be considered a valid collection.

0 个答案:

没有答案