我正在尝试在PyTorch中建立一个通用的强化学习框架,以利用所有利用PyTorch DataSet和DataLoader的高级实用程序,例如Ignite或FastAI,但是我遇到了一个障碍强化学习数据的动态性质:
到目前为止,我的Google搜索和StackOverflow搜索均已成功。这里有人知道将DataLoader或DataSet与Reinforcement Learning结合使用的现有解决方案或解决方法吗?我讨厌放宽对所有现有库的依赖。
答案 0 :(得分:1)
Here是一个基于PyTorch的框架,而here是Facebook的东西。
关于您的问题(毫无疑问是崇高的追求):
您可以轻松地创建一个依赖于任何东西的torch.utils.data.Dataset
,包括模型(像这样)(请弱抽象,这只是为了证明这一点):
import typing
import torch
from torch.utils.data import Dataset
class Environment(Dataset):
def __init__(self, initial_state, actor: torch.nn.Module, max_interactions: int):
self.current_state = initial_state
self.actor: torch.nn.Module = actor
self.max_interactions: int = max_interactions
# Just ignore the index
def __getitem__(self, _):
self.current_state = self.actor.update(self.current_state)
return self.current_state.get_data()
def __len__(self):
return self.max_interactions
假设类似torch.nn.Module
的网络具有某种update
不断变化的环境状态。总而言之,它只是一个Python结构,因此您可以用它来建模很多东西。
您可以指定max_interactions
几乎为infinite
,也可以根据需要随时进行更改,并在训练过程中进行一些回调(因为__len__
在整个代码中可能会多次调用)。环境还可以提供batches
而不是示例。
torch.utils.data.DataLoader
具有batch_sampler
参数,在那里您可以生成不同长度的批次。由于网络不取决于第一维,因此您也可以从那里返回任何所需的批次大小。
顺便说一句。如果每个样本的长度都不同,则应使用填充,而批次大小的变化与此无关。