是否可以从GCS存储桶URL加载预训练的Pytorch模型而无需先在本地持久保存?

时间:2019-09-12 02:36:30

标签: python google-cloud-storage pytorch google-cloud-dataflow

我是在Google Dataflow的背景下问这个问题,但通常也是这样。

使用PyTorch,我可以引用包含多个文件的本地目录,这些文件包含一个预先训练的模型。我碰巧正在使用Roberta模型,但是其他人的界面是相同的。

ls some-directory/
      added_tokens.json
      config.json             
      merges.txt              
      pytorch_model.bin       
      special_tokens_map.json vocab.json
from pytorch_transformers import RobertaModel

# this works
model = RobertaModel.from_pretrained('/path/to/some-directory/')

但是,我的预训练模型存储在GCS存储桶中。我们称之为gs://my-bucket/roberta/

在将模型加载到Google Dataflow的情况下,我试图保持无状态并避免持久化到磁盘,因此我的首选方法是直接从GCS获得该模型。据我了解,PyTorch通用接口方法from_pretrained()可以采用本地目录或URL的字符串表示形式。但是,我似乎无法从GCS URL加载模型。

# this fails
model = RobertaModel.from_pretrained('gs://my-bucket/roberta/')
# ValueError: unable to parse gs://mahmed_bucket/roberta-base as a URL or as a local path

如果我尝试使用目录blob的公共https URL,它也会失败,尽管这很可能是由于lack of authentication所致,因为python环境中可以创建客户端的参考凭证无法转换为公开要求https://storage.googleapis

# this fails, probably due to auth
bucket = gcs_client.get_bucket('my-bucket')
directory_blob = bucket.blob(prefix='roberta')
model = RobertaModel.from_pretrained(directory_blob.public_url)
# ValueError: No JSON object could be decoded

# and for good measure, it also fails if I append a trailing /
model = RobertaModel.from_pretrained(directory_blob.public_url + '/')
# ValueError: No JSON object could be decoded

我知道GCS doesn't actually have subdirectories,它实际上只是存储区名称下的平面命名空间。但是,似乎我对认证的必要性和PyTorch不会说gs://感到沮丧。

我可以通过首先在本地保留文件来解决此问题。

from pytorch_transformers import RobertaModel
from google.cloud import storage
import tempfile

local_dir = tempfile.mkdtemp()
gcs = storage.Client()
bucket = gcs.get_bucket(bucket_name)
blobs = bucket.list_blobs(prefix=blob_prefix)
for blob in blobs:
    blob.download_to_filename(local_dir + '/' + os.path.basename(blob.name))
model = RobertaModel.from_pretrained(local_dir)

但是这似乎是一种hack,我一直认为我一定会丢失一些东西。当然,有一种方法可以保持无状态,而不必依赖磁盘持久性!

  • 那么有没有办法加载存储在GCS中的预训练模型?
  • 在这种情况下进行公共URL请求时,是否有一种方法可以进行身份​​验证?
  • 即使有身份验证的方法,子目录的不存在仍然会成为问题吗?

感谢您的帮助!我也很高兴能指出任何重复的问题,因为我确定找不到任何问题。


编辑和说明

  • 我的Python会话已通过GCS认证,这就是为什么我能够在本地下载blob文件,然后使用load_frompretrained()指向该本地目录的原因

  • load_frompretrained()需要目录引用,因为它需要问题顶部列出的所有文件,而不仅仅是pytorch-model.bin

  • 为澄清问题#2,我想知道是否有某种方法可以为PyTorch方法提供一个嵌入了加密凭据或类似内容的请求URL。有点远,但是我想确保自己没有错过任何东西。

  • 为阐明问题3(除了以下一个答案的注释),即使有一种方法可以将凭据嵌入我不知道的URL中,我仍然需要引用目录,而不是单个Blob,而且我不知道GCS子目录是否会被识别,因为(如Google docs所述)GCS中的子目录是一种错觉,它们并不代表真实目录结构体。所以我认为这个问题无关紧要,或者至少被问题2所阻止,但这是我追逐的话题,所以我还是很好奇。

4 个答案:

答案 0 :(得分:1)

我对Pytorch或Roberta模型了解不多,但我会尽力回答您有关GCS的询问:

1.-“那么,有没有办法加载存储在GCS中的预训练模型?”

如果您的模型可以直接从二进制文件加载Blob:

from google.cloud import storage

client = storage.Client()
bucket = client.get_bucket("bucket name")
blob = bucket.blob("path_to_blob/blob_name.ext")
data = blob.download_as_string() # you will have your binary data transformed into string here.

2.-“在这种情况下执行公共URL请求时,是否有一种方法可以进行身份​​验证?”

这是棘手的部分,因为根据运行脚本的上下文,它将使用默认服务帐户进行身份验证。因此,当您使用官方GCP库时,您可以:

A.-授予该默认服务帐户访问存储桶/对象的权限。

B.-创建一个新的服务帐户并在脚本中对其进行身份验证(您还将需要为该服务帐户生成身份验证令牌):

from google.cloud import storage
from google.oauth2 import service_account

VISION_SCOPES = ['https://www.googleapis.com/auth/devstorage']
SERVICE_ACCOUNT_FILE = 'key.json'

cred = service_account.Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE, scopes=VISION_SCOPES)

client = storage.Client(credentials=cred)
bucket = client.get_bucket("bucket_name")
blob = bucket.blob("path/object.ext")
data = blob.download_as_string()

但这是可行的,因为官方库会在后台处理对API调用的身份验证,因此在from_pretrained()函数不起作用的情况下。

因此,另一种选择是将对象公开,这样您可以在使用公共URL时访问它。

3.-“即使有身份验证的方法,子目录的不存在仍然会成为问题吗?”

不确定在这里的意思是,您的存储桶中可以有文件夹。

答案 1 :(得分:1)

主要编辑:

您可以在Dataflow worker上安装wheel文件,还可以使用worker临时存储在本地持久存储二进制文件!

的确是(当前,截至2019年11月)您不能通过提供--requirements参数来做到这一点。相反,您必须像这样使用setup.py。假定IN CAPS中的所有常量都在其他位置定义。

REQUIRED_PACKAGES = [
    'torch==1.3.0',
    'pytorch-transformers==1.2.0',
]

setup(
    name='project_dir',
    version=VERSION,
    packages=find_packages(),
    install_requires=REQUIRED_PACKAGES)

运行脚本

python setup.py sdist

python project_dir/my_dataflow_job.py \
--runner DataflowRunner \
--project ${GCP_PROJECT} \
--extra_package dist/project_dir-0.1.0.tar.gz \
# SNIP custom args for your job and required Dataflow Temp and Staging buckets #

在工作中,这里是在自定义Dataflow运算符的上下文中从GCS下载和使用模型的步骤。为了方便起见,我们将一些实用程序方法包装在一个SEPARATE MODULE中(对于绕过Dataflow依赖项上传很重要),然后将它们导入到自定义运算符的LOCAL SCOPE中,而不是全局的。

class AddColumn(beam.DoFn):
    PRETRAINED_MODEL = 'gs://my-bucket/blah/roberta-model-files'

    def get_model_tokenizer_wrapper(self):
        import shutil
        import tempfile
        import dataflow_util as util
        try:
            return self.model_tokenizer_wrapper
        except AttributeError:
            tmp_dir = tempfile.mkdtemp() + '/'
            util.download_tree(self.PRETRAINED_MODEL, tmp_dir)
            model, tokenizer = util.create_model_and_tokenizer(tmp_dir)
            model_tokenizer_wrapper = util.PretrainedPyTorchModelWrapper(
                model, tokenizer)
            shutil.rmtree(tmp_dir)
            self.model_tokenizer_wrapper = model_tokenizer_wrapper
            logging.info(
                'Successfully created PretrainedPyTorchModelWrapper')
            return self.model_tokenizer_wrapper

    def process(self, elem):
        model_tokenizer_wrapper = self.get_model_tokenizer_wrapper()

        # And now use that wrapper to process your elem however you need.
        # Note that when you read from BQ your elements are dictionaries
        # of the column names and values for each BQ row.

实用程序在代码库内的SEPARATE MODULE中起作用。在我们的项目根目录中,它位于dataflow_util / init.py中,但您不必那样做。

from contextlib import closing
import logging

import apache_beam as beam
import numpy as np
from pytorch_transformers import RobertaModel, RobertaTokenizer
import torch

class PretrainedPyTorchModelWrapper():
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

def download_tree(gcs_dir, local_dir):
    gcs = beam.io.gcp.gcsio.GcsIO()
    assert gcs_dir.endswith('/')
    assert local_dir.endswith('/')
    for entry in gcs.list_prefix(gcs_dir):
        download_file(gcs, gcs_dir, local_dir, entry)


def download_file(gcs, gcs_dir, local_dir, entry):
    rel_path = entry[len(gcs_dir):]
    dest_path = local_dir + rel_path
    logging.info('Downloading %s', dest_path)
    with closing(gcs.open(entry)) as f_read:
        with open(dest_path, 'wb') as f_write:
            # Download the file in chunks to avoid requiring large amounts of
            # RAM when downloading large files.
            while True:
                file_data_chunk = f_read.read(
                    beam.io.gcp.gcsio.DEFAULT_READ_BUFFER_SIZE)
                if len(file_data_chunk):
                    f_write.write(file_data_chunk)
                else:
                    break


def create_model_and_tokenizer(local_model_path_str):
    """
    Instantiate transformer model and tokenizer

      :param local_model_path_str: string representation of the local path 
             to the directory containing the pretrained model
      :return: model, tokenizer
    """
    model_class, tokenizer_class = (RobertaModel, RobertaTokenizer)

    # Load the pretrained tokenizer and model
    tokenizer = tokenizer_class.from_pretrained(local_model_path_str)
    model = model_class.from_pretrained(local_model_path_str)

    return model, tokenizer

那里有乡亲!可在此处找到更多详细信息:https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/


我发现的是,整个查询链都是无关紧要的,因为Dataflow仅允许您在worker上安装源分发包,这意味着您实际上无法安装PyTorch。

提供requirements.txt文件时,Dataflow将带有--no-binary标志进行安装,该标志阻止安装Wheel(.whl)软件包,并且仅允许源分发(.tar.gz)。我决定尝试在Google Dataflow上投放自己的PyTorch源代码发行版,其中一半是C ++,一部分是Cuda,另一部分是谁知道傻瓜的事。

感谢大家的投入。

答案 2 :(得分:0)

正如您正确地指出的那样,开箱即用的pytorch-transformers似乎不支持此功能,主要是因为它无法将文件链接识别为URL。

经过一些搜索,我在this source file的第144-155行附近找到了相应的错误消息。

当然,您可以尝试将'gs'标签添加到第144行,然后将您与GCS的连接解释为一个http请求(第269-272行)。如果GCS接受此要求,那它应该是更改后才需要工作的唯一内容。
如果这不起作用,唯一的立即解决方法是实施类似于Amazon S3存储桶功能的功能,但是我对S3和GCS存储桶的了解不足,无法在此处主张任何有意义的判断。

答案 3 :(得分:0)

当前我不是在玩Roberta,而是在使用Bert进行NER的令牌分类,但是我认为它具有相同的机制。

下面是我的代码:

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'your_gcs_auth.json'

# initiate storage
client = storage.Client()
en_bucket = client.get_bucket('your-gcs-bucketname')

# get blob
en_model_blob = en_bucket.get_blob('your-modelname-in-gcsbucket.bin')
en_model = en_model_blob.download_as_string()

# because model downloaded into string, need to convert it back
buffer = io.BytesIO(en_model)

# prepare loading model
state_dict = torch.load(buffer, map_location=torch.device('cpu'))
model = BertForTokenClassification.from_pretrained(pretrained_model_name_or_path=None, state_dict=state_dict, config=main_config)
model.load_state_dict(state_dict)

不能确定download_as_string()方法是否将数据保存到本地磁盘,但是根据执行download_to_filename()的经验,该函数会将模型下载到本地。

如果您还修改了变压器网络的配置(并将其放在GCS中并且还需要加载),则还需要修改类PretrainedConfig,因为它可以处理{{1 }}功能。

欢呼,希望对您有帮助