我已经在 stable_baselines 中进行了试验,效果不错,并且一直想在 stable_baselines3 上进行尝试。
我正在使用 A2C 模型来训练库存环境。使用的自定义环境目前与 stable_baselines 一起工作。我在训练过程中看到了不稳定性,并希望迁移到 stable_baselines3 以防万一。
0
100
自定义环境如下。调试时,错误会在环境中的第一个“步骤”之后立即弹出。可能是因为退货问题。我确实将状态的返回类型更改为 np.array 并且问题没有解决。
def train_A2C(env_train, model_name,timesteps = 50000, i=0 ):
start = time.time()
# policy_kwargs = dict(net_arch=[128, 128])
policy_kwargs=dict(optimizer_class=RMSpropTFLike)
model = A2C(MlpPolicy,env_train,verbose = 1,tensorboard_log='./tensorboard/tensorboard_A2C/',
learning_rate =0.0001, vf_coef = 0.05, ent_coef = 0.005, policy_kwargs=policy_kwargs)
model.learn(total_timesteps = timesteps,tb_log_name = f"A2C_{i}")
end = time.time()
model.save(f'{config.TRAINED_MODEL_DIR}/{model_name}')
print(f'Training Time A2C : ', (end - start) /60, ' minutes')
return model
我得到的错误是在模型过程中。学习过程:
<块引用>发生异常:AttributeError 'list' 对象没有属性 'get'