在sketch_rnn算法上加载load_env时出错

时间:2019-05-01 14:03:52

标签: python tensorflow machine-learning

在几天前解决了同样的问题之后,我尝试了另一种算法来从aaron_sheep模型加载相同的预训练数据集,我尝试加载环境,这次返回的结果不同

我立即尝试了此操作,它开始下载.npz文件,但它结束并返回错误

 from magenta.models.sketch_rnn.sketch_rnn_train import \
 (load_env,
 load_checkpoint,
 reset_graph,
 download_pretrained_models,
 PRETRAINED_MODELS_URL)
 from magenta.models.sketch_rnn.model import Model, sample
 from magenta.models.sketch_rnn.utils import (get_bounds,
                                         to_big_strokes,
                                         to_normal_strokes)

 MODEL_DIR = MODELS_ROOT_DIR + '/aaron_sheep/layer_norm'


 [train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env(DATA_DIR, MODEL_DIR)

我希望它返回此

  INFO:tensorflow:Downloading http://github.com/hardmaru/sketch-rnn- 
  datasets/raw/master/aaron_sheep/aaron_sheep.npz
  INFO:tensorflow:Loaded 7400/300/300 from aaron_sheep.npz
  INFO:tensorflow:Dataset combined: 8000 (7400/300/300), avg len 125
  INFO:tensorflow:model_params.max_seq_len 250.
  total images <= max_seq_len is 7400
  total images <= max_seq_len is 300
  total images <= max_seq_len is 300
  INFO:tensorflow:normalizing_scale_factor 18.5198.

但是它返回了这个

  INFO:tensorflow:Downloading http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron/sheep//aaron_sheep.npz
  ---------------------------------------------------------------------------
  UnpicklingError                           Traceback (most recent call last)
  ~\Anaconda3\lib\site-packages\numpy\lib\npyio.py in load(file, mmap_mode, allow_pickle, fix_imports, encoding)
  446             try:
  --> 447                 return pickle.load(fid, **pickle_kwargs)
  448             except Exception:

  UnpicklingError: invalid load key, '\x0a'.

  During handling of the above exception, another exception occurred:

  OSError                                   Traceback (most recent call last)
  <ipython-input-15-25a9b37135c4> in <module>()
  1 #from io import BytesIO
  2 #import pickle
  ----> 3 [train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env(DATA_DIR, MODEL_DIR)

  ~\Anaconda3\lib\site-packages\magenta\models\sketch_rnn\sketch_rnn_train.py in load_env(data_dir, model_dir)
 71   with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
 72     model_params.parse_json(f.read())
 ---> 73   return load_dataset(data_dir, model_params, inference_mode=True)
 74 
 75 

 ~\Anaconda3\lib\site-packages\magenta\models\sketch_rnn\sketch_rnn_train.py in load_dataset(data_dir, model_params, inference_mode)
131       tf.logging.info('Downloading %s', data_filepath)
132       response = requests.get(data_filepath)
--> 133       data = np.load(six.BytesIO(response.content), encoding='latin1')
134     else:
135       data_filepath = os.path.join(data_dir, dataset)

~\Anaconda3\lib\site-packages\numpy\lib\npyio.py in load(file, mmap_mode, allow_pickle, fix_imports, encoding)
448             except Exception:
449                 raise IOError(
--> 450                     "Failed to interpret file %s as a pickle" % repr(file))
451     finally:
452         if own_fid:

OSError: Failed to interpret file <_io.BytesIO object at 0x00000197ACA8D2B0> as a pickle

0 个答案:

没有答案