%This script uses the HMM Matlab toolbox to train an HMM on multiple-trial %data function [P,Q,pi_i,LL,observ,gamma] = HMM_train_Mult_2(spk_vec,time_vec,cell_nums,num_states,plot_prog), spk_vec; %Multiple-trial matrix of spike vectors from MEA time_vec; %Time vector associated with spk_vec cell_nums; %Vector of [First_Cell Last_Cell] to train with num_states; %Guess of relevant # of hidden states plot_prog; %Plot the progress of the HMM training in real time if == 1 num_iter = 12; %Number of HMM training iterations %Truncate spk_vec to contain only desired cells spk_vec = spk_vec(:,:,cell_nums(1):cell_nums(2)); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %Determine observation made at each timestep %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %Create observation vector (only 1 observation allowed per step!) observ = zeros(size(spk_vec,1),size(spk_vec,2)); %For each of the trials for i = 1:size(spk_vec,1), %For each of the relevant channels for j = 1:size(spk_vec,3), %Find spikes in channel i tmp = find(spk_vec(i,:,j)); %Assign channel # to proper observation slot and %overwrite any previous channel in that timeslot %(only 1 allowed!!) observ(i,tmp) = j; end; end; %A zero observation ruins indexing below...shift all up by 1 observ = observ + 1; %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %Extract HMM parameters %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %Find total # of observations possible (add the zero observation) num_observs = size(spk_vec,3)+1; %Initialize guesses for all parameters pi_i = normalise(rand(num_states,1)); %P = mk_stochastic(0.99*diag(ones(num_states,1))+0.01*rand(num_states,num_states)); P = mk_stochastic(0.95*diag(ones(num_states,1))+0.01*rand(num_states,num_states)); Q = mk_stochastic(rand(num_states,num_observs)); %Run EM algorithm %If no progress updates needed if(plot_prog == 0), disp('Training HMM...'); for i = 1:num_iter, disp(sprintf('Running iteration #%d...',i)); %Compute 1 iteration [LL, pi_i, P, Q] = dhmm_em(observ,pi_i,P,Q,'max_iter',1); %[LL, pi_i, P, Q] = dhmm_em(data,pi_i,P,Q); %Calculate state transitions (for 1st trial data set) obslik = multinomial_prob(observ(1,:),Q); %Calculate Fwd-Back variables for the extracted HMM parameters [alpha, beta, gamma, loglik, xi, gamma2] = fwdback(pi_i, P, obslik); end; % [LL, pi_i, P, Q] = dhmm_em(observ,pi_i,P,Q,'max_iter',num_iter); % %[LL, pi_i, P, Q] = dhmm_em(data,pi_i,P,Q); % % %Calculate state transitions (for 1st trial data set) % obslik = multinomial_prob(observ(1,:),Q); % % %Calculate Fwd-Back variables for the extracted HMM parameters % [alpha, beta, gamma, loglik, xi, gamma2] = fwdback(pi_i, P, obslik); %If progress updates desired elseif(plot_prog == 1), figure(1); for i = 1:num_iter, disp(sprintf('Running iteration #%d...',i)); %Compute 1 iteration [LL, pi_i, P, Q] = dhmm_em(observ,pi_i,P,Q,'max_iter',1); %[LL, pi_i, P, Q] = dhmm_em(data,pi_i,P,Q); %Calculate state transitions (for 1st trial data set) obslik = multinomial_prob(observ(1,:),Q); %Calculate Fwd-Back variables for the extracted HMM parameters [alpha, beta, gamma, loglik, xi, gamma2] = fwdback(pi_i, P, obslik); %Plot state transistions %Alternate colors for easy viewing col = 'brbrbrbrbrbrbrbrbrbrbr'; clf; for j = 1:num_states, subplot(num_states,1,j); tmp = sprintf('%s-',col(i)); plot(time_vec,gamma(j,:),tmp); axis([0 time_vec(length(time_vec)) -0.1 1.1]); title(sprintf('State %d',j)); drawnow; hold on; end; end; end; return;