在几天前解决了同样的问题之后,我尝试了另一种算法来从aaron_sheep模型加载相同的预训练数据集,我尝试加载环境,这次返回的结果不同
我立即尝试了此操作,它开始下载.npz文件,但它结束并返回错误
from magenta.models.sketch_rnn.sketch_rnn_train import \
(load_env,
load_checkpoint,
reset_graph,
download_pretrained_models,
PRETRAINED_MODELS_URL)
from magenta.models.sketch_rnn.model import Model, sample
from magenta.models.sketch_rnn.utils import (get_bounds,
to_big_strokes,
to_normal_strokes)
MODEL_DIR = MODELS_ROOT_DIR + '/aaron_sheep/layer_norm'
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env(DATA_DIR, MODEL_DIR)
我希望它返回此
INFO:tensorflow:Downloading http://github.com/hardmaru/sketch-rnn-
datasets/raw/master/aaron_sheep/aaron_sheep.npz
INFO:tensorflow:Loaded 7400/300/300 from aaron_sheep.npz
INFO:tensorflow:Dataset combined: 8000 (7400/300/300), avg len 125
INFO:tensorflow:model_params.max_seq_len 250.
total images <= max_seq_len is 7400
total images <= max_seq_len is 300
total images <= max_seq_len is 300
INFO:tensorflow:normalizing_scale_factor 18.5198.
但是它返回了这个
INFO:tensorflow:Downloading http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron/sheep//aaron_sheep.npz
---------------------------------------------------------------------------
UnpicklingError Traceback (most recent call last)
~\Anaconda3\lib\site-packages\numpy\lib\npyio.py in load(file, mmap_mode, allow_pickle, fix_imports, encoding)
446 try:
--> 447 return pickle.load(fid, **pickle_kwargs)
448 except Exception:
UnpicklingError: invalid load key, '\x0a'.
During handling of the above exception, another exception occurred:
OSError Traceback (most recent call last)
<ipython-input-15-25a9b37135c4> in <module>()
1 #from io import BytesIO
2 #import pickle
----> 3 [train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env(DATA_DIR, MODEL_DIR)
~\Anaconda3\lib\site-packages\magenta\models\sketch_rnn\sketch_rnn_train.py in load_env(data_dir, model_dir)
71 with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
72 model_params.parse_json(f.read())
---> 73 return load_dataset(data_dir, model_params, inference_mode=True)
74
75
~\Anaconda3\lib\site-packages\magenta\models\sketch_rnn\sketch_rnn_train.py in load_dataset(data_dir, model_params, inference_mode)
131 tf.logging.info('Downloading %s', data_filepath)
132 response = requests.get(data_filepath)
--> 133 data = np.load(six.BytesIO(response.content), encoding='latin1')
134 else:
135 data_filepath = os.path.join(data_dir, dataset)
~\Anaconda3\lib\site-packages\numpy\lib\npyio.py in load(file, mmap_mode, allow_pickle, fix_imports, encoding)
448 except Exception:
449 raise IOError(
--> 450 "Failed to interpret file %s as a pickle" % repr(file))
451 finally:
452 if own_fid:
OSError: Failed to interpret file <_io.BytesIO object at 0x00000197ACA8D2B0> as a pickle