我正在研究简单的GridWorld(3x4,如Russell& Norvig Ch.21.2所述)问题;我使用Q-Learning和QTable解决了它,现在我想使用函数逼近器而不是矩阵。
我使用MATLAB并尝试了神经网络和决策树,但没有得到预期的结果,即找到了错误的策略。我已经阅读了一些关于这个主题的论文,但是大多数都是理论性的,并且不太关注实际的实现。
我一直在使用离线学习,因为它更简单。我的方法是这样的:
这似乎太简单了,实际上我没有得到预期的结果。这是一些MATLAB代码:
retrain = 1;
if(retrain)
x = zeros(1, 16); %This is my training set
y = 0;
t = 0; %Iterations
end
tree = fitrtree(x, y);
x = zeros(1, 16);
y = 0;
for i=1:100
%Get the initial game state as a 3x4 matrix
gamestate = initialstate();
end = 0;
while (end == 0)
t = t + 1; %Increase the iteration
%Get the index of the best action to take
index = chooseaction(gamestate, tree);
%Make the action and get the new game state and reward
[newgamestate, reward] = makeaction(gamestate, index);
%Get the state-action vector for the current gamestate and chosen action
sa_pair = statetopair(gamestate, index);
%Check for end of game
if(isfinalstate(gamestate))
end = 1;
%Get the final reward
reward = finalreward(gamestate);
%Add a sample to the training set
x(size(x, 1)+1, :) = sa_pair;
y(size(y, 1)+1, 1) = updateq(reward, gamestate, index, newgamestate, tree, t, end);
else
%Add a sample to the training set
x(size(x, 1)+1, :) = sa_pair;
y(size(y, 1)+1, 1) = updateq(reward, gamestate, index, newgamestate, tree, t, end);
end
%Update gamestate
gamestate = newgamestate;
end
end
它选择随机动作的一半时间。 updateq 功能是:
function [ q ] = updateq( reward, gamestate, index, newgamestate, tree, iteration, finalstate )
alfa = 1/iteration;
gamma = 0.99;
%Get the action with maximum qvalue in the new state s'
amax = chooseaction(newgamestate, tree);
%Get the corresponding state-action vectors
newsa_pair = statetopair(newgamestate, amax);
sa_pair = statetopair(gamestate, index);
if(finalstate == 0)
X = reward + gamma * predict(tree, newsa_pair);
else
X = reward;
end
q = (1 - alfa) * predict(tree, sa_pair) + alfa * X;
end
任何建议都将不胜感激!
答案 0 :(得分:2)
问题在于,在离线Q-Learning中,您需要重复至少 n 次收集数据的过程,其中 n 取决于您遇到的问题。重新尝试建模。如果你分析在每次迭代中计算的q值并考虑它,就会立即明白为什么需要这样做。
在第一次迭代中,你只学习最终状态,在第二次迭代中你也学习倒数第二个状态,在第三次迭代中你也学习倒数第二个状态,依此类推。你从最终状态学习到初始状态,传播回qvalues。在GridWorld示例中,结束游戏所需的最小访问状态数为6。
最后,正确的算法变为: