无法将Pytorch模型保存到Google Colab中的Google云端硬盘?

时间:2019-04-09 15:30:30

标签: python-3.x google-drive-api pytorch google-colaboratory

我正在尝试将模型保存到Google colab上的驱动器中。我已使用以下代码安装我的Google云端硬盘-

from google.colab import drive
drive.mount('/content/gdrive')

在进行所有预处理,模型定义和训练之后,我想将模型保存到驱动器中,因为训练需要很长时间。因此,我将其保存为定期驱动并从该点重新加载以继续。 保存我的模型的代码是:

def save_model(model, model_name, iter):
  path = f'content/gdrive/My Drive/Machine Learning Models/kaggle_jigsaw_{model_name}_iter_{iter}.pth'
  print(f'Saving {model_name} model...')
  torch.save(model.state_dict(), path)
  print(f'{model_name} saved successfully.')

EMBEDDING_DIMS = 128
HIDDEN_SIZE = 256

gru = GRU(vocab.n_words, EMBEDDING_DIMS, HIDDEN_SIZE, 2).to(device)
save_model(gru, 'gru', 0)

我遇到以下错误:

Saving gru model...
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
<ipython-input-27-d2510611a9d4> in <module>()
      9 
     10 gru = GRU(vocab.n_words, EMBEDDING_DIMS, HIDDEN_SIZE, 2).to(device)
---> 11 save_model(gru, 'gru', 0)

<ipython-input-27-d2510611a9d4> in save_model(model, model_name, iter)
      2   path = f'content/gdrive/My Drive/Machine Learning Models/kaggle_jigsaw_{model_name}_iter_{iter}.pth'
      3   print(f'Saving {model_name} model...')
----> 4   torch.save(model.state_dict(), path)
      5   print(f'{model_name} saved successfully.')
      6 

/usr/local/lib/python3.6/dist-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol)
    217         >>> torch.save(x, buffer)
    218     """
--> 219     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
    220 
    221 

/usr/local/lib/python3.6/dist-packages/torch/serialization.py in _with_file_like(f, mode, body)
    140             (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
    141         new_fd = True
--> 142         f = open(f, mode)
    143     try:
    144         return body(f)

FileNotFoundError: [Errno 2] No such file or directory: 'content/gdrive/My Drive/Machine Learning Models/kaggle_jigsaw_gru_iter_0.pth'

我已经在驱动器中手动创建了该文件夹,只需要创建该文件。但是,错误仍然存​​在。但是,我确定不需要手动创建文件夹。问题是别的。 我要去哪里错了?

3 个答案:

答案 0 :(得分:2)

您无法将文件直接保存到已安装的云端硬盘。它不像常规文件系统那样工作。尝试使用基于PyDrive的{​​{1}}或CoUtils工具,该工具专为Google Colab设计:Working with Google Drive

答案 1 :(得分:2)

您的路径中可能需要前导/

尝试更改此行:

  path = f'content/gdrive/My Drive/Machine Learning Models/kaggle_jigsaw_{model_name}_iter_{iter}.pth'

收件人:

  path = f'/content/gdrive/My Drive/Machine Learning Models/kaggle_jigsaw_{model_name}_iter_{iter}.pth'

答案 2 :(得分:0)

我不知道为什么,但是现在它可以正常工作了。我仍然想知道为什么这个问题首先发生。