“发生异常:迁移到 stable_baselines3 时,AttributeError 'list' 对象没有属性 'get'”

时间:2020-12-26 06:26:53

标签: python-3.x stable-baselines

我已经在 stable_baselines 中进行了试验,效果不错,并且一直想在 stable_baselines3 上进行尝试。

我正在使用 A2C 模型来训练库存环境。使用的自定义环境目前与 stable_baselines 一起工作。我在训练过程中看到了不稳定性,并希望迁移到 stable_baselines3 以防万一。

0
100

自定义环境如下。调试时,错误会在环境中的第一个“步骤”之后立即弹出。可能是因为退货问题。我确实将状态的返回类型更改为 np.array 并且问题没有解决。

def train_A2C(env_train, model_name,timesteps = 50000, i=0 ):
    start = time.time()
    # policy_kwargs = dict(net_arch=[128, 128])
    policy_kwargs=dict(optimizer_class=RMSpropTFLike)
    model = A2C(MlpPolicy,env_train,verbose = 1,tensorboard_log='./tensorboard/tensorboard_A2C/', 
            learning_rate =0.0001, vf_coef = 0.05, ent_coef = 0.005, policy_kwargs=policy_kwargs)
        model.learn(total_timesteps = timesteps,tb_log_name = f"A2C_{i}")
    end = time.time()

    model.save(f'{config.TRAINED_MODEL_DIR}/{model_name}')
    print(f'Training Time A2C : ', (end - start) /60, ' minutes')
    return model

我得到的错误是在模型过程中。学习过程:

<块引用>

发生异常:AttributeError 'list' 对象没有属性 'get'

0 个答案:

没有答案