无法在Google Colab上使用torch.load()加载.pth文件(预训练的神经网络)

时间:2019-11-27 19:30:33

标签: deep-learning google-drive-api pytorch google-colaboratory ioerror

我的Google驱动器已链接到我的Google Colab笔记本。使用pytorch库torch.load($ PATH)无法加载我的Google驱动器中的219 Mo文件(预训练的神经网络)(https://drive.google.com/drive/folders/1-9m4aVg8Hze0IsZRyxvm5gLybuRLJHv-)。但是,当我在计算机上本地执行此操作时,它工作正常。我在Google Collab上遇到的错误是:(设置:Python 3.6,Pytorch 1.3.1):

state_dict = torch.load(model_path)['state_dict']
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 303, in load
return _load(f, map_location, pickle_module)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 454, in _load
return legacy_load(f)
File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 380, in legacy_load
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar,
File "/usr/lib/python3.6/tarfile.py", line 1589, in open
return func(name, filemode, fileobj, **kwargs)
File "/usr/lib/python3.6/tarfile.py", line 1619, in taropen
return cls(name, mode, fileobj, **kwargs)
File "/usr/lib/python3.6/tarfile.py", line 1482, in init
self.firstmember = self.next()
File "/usr/lib/python3.6/tarfile.py", line 2297, in next
tarinfo = self.tarinfo.fromtarfile(self)
File "/usr/lib/python3.6/tarfile.py", line 1092, in fromtarfile
buf = tarfile.fileobj.read(BLOCKSIZE)
OSError: [Errno 5] Input/output error```   

Any help would be much appreciated!

2 个答案:

答案 0 :(得分:0)


您可以直接使用Drive API下载文件,然后将其传递给手电筒,在Python上实现它应该不难,我已经提供了有关如何下载文件并将其传递给Torch的示例。

import torch
import pickle
import os.path
import io
from googleapiclient.discovery import build
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport.requests import Request
from googleapiclient.http import MediaIoBaseDownload
from __future__ import print_function

url = "https://drive.google.com/file/d/1RwpuwNPt_r0M5mQGEw18w-bCfKVwnZrs/view?usp=sharing"
# If modifying these scopes, delete the file token.pickle.

def main():
    """Shows basic usage of the Sheets API.
    Prints values from a sample spreadsheet.
    creds = None
    # The file token.pickle stores the user's access and refresh tokens, and is
    # created automatically when the authorization flow completes for the first
    # time.
    if os.path.exists('token.pickle'):
        with open('token.pickle', 'rb') as token:
            creds = pickle.load(token)
    # If there are no (valid) credentials available, let the user log in.
    if not creds or not creds.valid:
        if creds and creds.expired and creds.refresh_token:
            flow = InstalledAppFlow.from_client_secrets_file(
                'credentials.json', SCOPES)
            creds = flow.run_local_server(port=0)
        # Save the credentials for the next run
        with open('token.pickle', 'wb') as token:
            pickle.dump(creds, token)

    drive_service = build('drive', 'v2', credentials=creds)

    file_id = '1RwpuwNPt_r0M5mQGEw18w-bCfKVwnZrs'
    request = drive_service.files().get_media(fileId=file_id)
    # fh = io.BytesIO()
    fh = open('file', 'wb')
    downloader = MediaIoBaseDownload(fh, request)
    done = False
    while done is False:
      status, done = downloader.next_chunk()
      print("Download %d%%." % int(status.progress() * 100))

if __name__ == '__main__':


  • 为您的帐户启用Drive API
  • 安装Google Drive API库

这不超过3分钟,并且在Quickstart Guide for Google Drive API上有正确的解释,只需按照步骤1和2进行操作,然后从上方运行提供的示例代码。

答案 1 :(得分:0)

通过将文件直接上传到google colab而不是使用以下方法从google驱动器加载文件来工作:

from google.colab import files
uploaded= files.upload()
