Matlab的Hebbian学习实现

时间:2020-06-18 14:50:31

标签: matlab machine-learning neural-network neuroscience

我正在尝试完成一项任务,要求我对具有两个输入的单个神经元(线性激发速率模型)实施基本的Hebbian学习;我得到了训练集,一个2x100的输入模式,该模式在每个时期都会随机播放。主要要求是针对输入相关矩阵Q的主要特征向量绘制最终权重向量。这是我的问题:Q定义为

    Q = <uu>

其中,角度符号表示输入模式的平均值。我被困在正确计算该矩阵的过程中,因为我无法正确理解如何实现“输入模式的平均” ...

我将代码拆分如下​​:实际执行学习的功能

function [w_out,w_t,w_norm] = hebbian(xtrain,eta)

        train_size = size(xtrain,2); % Checking training set dimension
        w = -1 + 2.*rand(2,1); % Weights random initialization
        epochs = 1000;

        w_t = zeros(2,epochs);
        w_norm = zeros(1,epochs);


        for i = 1:epochs

                w_old = w;
                randp = randperm(train_size); % Generating random index
                for k=1:train_size

                        xtrain_k = xtrain(:,randp(k));% Shuffling training set
                        % Computing output via linear firing rate model
                        v = w'*xtrain_k;
                        w = w_old + eta*v*xtrain_k; % Updating weights

                end

                w_norm(i) = norm(w);
                w_out = w/norm(w);
                w_t(:,i) = w_out;
        end 
end

虽然这是主要的.m文件:我认为学习算法确实有效,但是我不能说我在比较正确的对象。

T = readtable('lab2_1_data.csv'); % Importing data as table
u = table2array(T);% Converting table into input array
eta = 10e-3; % Learning rate

[w, w_t, w_norm] = hebbian(u,eta);

Q = u*u'; % Input correlation matrix
[vec, D] = eig(Q); % Computing eigenvalues and eigenvectors of Q

% Plotting data points and comparison between final weight vector and main
% eigenvector of Q
figure('Name','P1): Dataset, final weight vector and main eigenvector of Q','NumberTitle','off')
scatter(u(1,:),u(2,:))
hold on
plotv(vec(:,end));
set(findall(gca,'Type', 'Line'),'LineWidth',1.75);
plotv(w)
hold off
legend('Dataset','Dominant eigenvector of Q','Final weight vector','Location','best')

% Weight evolution, first component
figure('Name','P2.1): Weight vector time evolution (1st component)','NumberTitle','off')
plot(w_t(1,:))

% Weight evolution, second component
figure('Name','P2.2): Weight vector time evolution (2nd component)','NumberTitle','off')
plot(w_t(2,:))

% Weight norm evolution
figure('Name','P2.3): Weight vector norm time evolution','NumberTitle','off')
plot(w_norm)

0 个答案:

没有答案
相关问题