如何使用PyTorch DataLoader进行强化学习?

时间:2019-07-29 16:58:12

标签: pytorch reinforcement-learning dataloader

我正在尝试在PyTorch中建立一个通用的强化学习框架,以利用所有利用PyTorch DataSet和DataLoader的高级实用程序,例如Ignite或FastAI,但是我遇到了一个障碍强化学习数据的动态性质:

  • 数据项是从代码生成的,而不是从文件中读取的,它们取决于先前的操作和模型结果,因此每个nextItem调用都需要访问模型状态。
  • 培训情节的长度不是固定的,因此我需要动态的批处理大小以及动态的总数据集大小。我更喜欢使用终止条件函数而不是数字。我可以像在NLP句子处理中那样“可能”用填充来做到这一点,但这是一个真正的技巧。

到目前为止,我的Google搜索和StackOverflow搜索均已成功。这里有人知道将DataLoader或DataSet与Reinforcement Learning结合使用的现有解决方案或解决方法吗?我讨厌放宽对所有现有库的依赖。

1 个答案:

答案 0 :(得分:1)

Here是一个基于PyTorch的框架,而here是Facebook的东西。

关于您的问题(毫无疑问是崇高的追求):

您可以轻松地创建一个依赖于任何东西的torch.utils.data.Dataset,包括模型(像这样)(请弱抽象,这只是为了证明这一点):

import typing

import torch
from torch.utils.data import Dataset


class Environment(Dataset):
    def __init__(self, initial_state, actor: torch.nn.Module, max_interactions: int):
        self.current_state = initial_state
        self.actor: torch.nn.Module = actor
        self.max_interactions: int = max_interactions

    # Just ignore the index
    def __getitem__(self, _):
        self.current_state = self.actor.update(self.current_state)
        return self.current_state.get_data()

    def __len__(self):
        return self.max_interactions

假设类似torch.nn.Module的网络具有某种update不断变化的环境状态。总而言之,它只是一个Python结构,因此您可以用它来建模很多东西。

您可以指定max_interactions几乎为infinite,也可以根据需要随时进行更改,并在训练过程中进行一些回调(因为__len__在整个代码中可能会多次调用)。环境还可以提供batches而不是示例。

torch.utils.data.DataLoader具有batch_sampler参数,在那里您可以生成不同长度的批次。由于网络不取决于第一维,因此您也可以从那里返回任何所需的批次大小。

顺便说一句。如果每个样本的长度都不同,则应使用填充,而批次大小的变化与此无关。