连接numpy数组时出现值错误

时间:2018-08-31 10:51:07

标签: python numpy keras concatenation generative-adversarial-network

Im加载mnist数据集如下,

(X_train, y_train), (X_test, y_test) = mnist.load_data()

但是,由于我需要加载和训练自己的数据集,因此我编写了以下小脚本,该脚本将提供确切的训练和测试值

def load_train(path):
X_train = []
y_train = []
print('Read train images')
for j in range(10):
    files = glob(path + "*.jpeg")
    for fl in files:
        img = get_im(fl)
        print(fl)
        X_train.append(img)
        y_train.append(j)

return np.asarray(X_train), np.asarray(y_train)

相关模型在训练时会生成一个大小为(64,28,28,1)的numpy数组。我将生成的图像中的image_batch连接起来,如下所示,

    X = np.concatenate((image_batch, generated_images))

但是我遇到以下错误,

  

ValueError:所有输入数组的维数必须相同

img_batch的大小为(64,28,28) generate_images的大小为(64,28,28,1)

如何在X_train中扩展img_batch的尺寸,以便与generate_images连接?还是有其他方法可以代替loadmnist来加载自定义图像?

2 个答案:

答案 0 :(得分:3)

python中有一个名为np.expand_dims()的函数,它可以沿参数提供的轴扩展任何数组的维数。根据您的情况,使用img_batch = np.expand_dims(img_batch, axis=3)

另一种方法是使用@Ioannis Nasios建议的reshape函数。 img_batch = img_batch.reshape(64,28,28,1)

答案 1 :(得分:1)

with models.DAG(
        'composer_sample_ml',
        # Continue to run DAG once per day
        schedule_interval=datetime.timedelta(days=1),
        default_args=default_dag_args) as dag:

    train_model = mlengine_operator.MLEngineTrainingOperator(
        task_id='train_model',
        project_id='PROJECT_ID',
        job_id='{}_{}'.format('iris_train_job', str(uuid.uuid4())),
        package_uris='gs://BUCKET_ID/scikit_learn_job_dir/packages/PACKAGE_ID/iris_sklearn_trainer-0.1.tar.gz',
        training_python_module='iris_sklearn_trainer.iris',
        training_args=["--jobDir='gs://BUCKET_ID/scikit_learn_job_dir'"],
        region='us-central1',
        scale_tier='BASIC',
        runtimeVersion = '1.8',
        pythonVersion = '2.7'
    )

    train_model