clear all;
% This function is a precursor to the multivariate classification described
% in (e.g.) Haynes & Rees Nature Neuroscience 2005. It works on one or
% multiple subjects, taking as input the realigned time series, a mask
% image (usually the thresholded F-map of the effects of interest, to
% identify the stimulus representation) and an ROI (usually V1, or some
% other retinotopic area). It outputs a data structure (vectors_all.mat)
% which contains (for a two-category
% classification) two sets of patterns (across voxels contained within the
% ROI that are also present in the mask image e.g. activated voxels in V1)
% expressed as a time series. This can then be input to the next
% (classifier) stage e.g. lin_discrim_.m
%
%@autor: Code by Christian Kaul (2007)   c.kaul@ucl.ac.uk
%(with parts from Su Watkins, John Dylan Haynes & Geraint Rees)

%__________________________________________________________________________
%%%%%%%%%%%%%%%%%%%%%%% edit this section
%--------------------------------------------------------------------------
% WHAT TO DO WITH THE DATA
scramble = 0; % scramble all vectors to get unbiased results
n_voxels = 50;     % 100 or 50 voxels SET UP FOR NEW GET_INDEXLIST METHODS IF > 0
average_how_many_vectors = 1;   % !!!  make sure this number is a dividor of howmanyvolumes, otherwise warning!
mormalise = 1;      % NORMALZE JUST SINGLE RUNS or norm just single runs all voxels
%--------------------------------------------------------------------------
% DATA PARAMETERS
ndummy=3; %only used to calculate file name of first image in realigned time-series
nvol=101; %number of volumes per scanning run
howmanyvolumes = 11;  % how many volumes are considered after the start points
adjust_HDR = 3;       % this is a simple HDR function, shifting the volumes to be considered by this number
%--------------------------------------------------------------------------
conditions_per_run = 4;    % total condition per run (scanner on-off)
n_conditions = 4;
NRvec_test = howmanyvolumes/average_how_many_vectors*(conditions_per_run/n_conditions);
%--------------------------------------------------------------------------
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%         THIS IS THE BIT WHERE YOU SHOULD PUT YOUR OWN FOLDER
%         STUCTURE
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
subjects = {'Williamson_Kirk', 'JALKANEN_Lauri', 'Bays_Paul', 'Schulvinik_Marieke'};

ROIs = {'V1', 'V2', 'V3', 'V123'};

localizers = {...
    ['\MOTION_localizer_inner'];...
    ['\MOTION_localizer_outer']};

for subjtodo= 3:length(subjects)
    for g = 1:length(ROIs)
        region_todo = ROIs{g};
        subjectname = subjects{subjtodo}
        basedir= ['D:\V1load_motion\' subjectname];
        [filsubjectnumber filehdr dir_data localizer_100plus] = filsubjectnumbersMOTION(subjectname,basedir,region_todo);

        % ROI map. Make by summing left and right (i.e.) V1 (or higher
        % area) with ImCalc beforehand.
        roimaps(1) = {[basedir '\structural\ROI\' region_todo '.img']};

        % directory to put result (extracted timecourse for selected voxels) into at the end
        dir_analysis(1) = {['D:\V1load_motion\' subjectname '\prediction\' region_todo]};

        %starting points of the experimental conditions of each run
        load([basedir '\experiment\sots.mat']);

        % This variable makes sure classification of all successive runs
        % in the experiment. Currently the program is set up for different
        % classification of different runs.
        nruns = size(filehdr{1},1);         %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

        for loc = 1:length(localizers)

            images{subjtodo} = {...
                [basedir localizers{loc} '\' 'spmF_0001.img'];...  % Put in F-test here...
                [basedir '\structural\ROI\' region_todo '.img'];... % ROI map
                };
            cd ([basedir localizers{loc}]);




            thresh = 45;
            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            %    IMCALC TO THRESHOLD FOR GIVEN NUMBER OF VOXELS(n_voxels)
            % Maskeffect.img is usually the F-map for the effects of interest. Make
            % this using ImCalc to threshold the SPM_F.img at your chosen value,
            % writing it out as maskeffect.img
            % here now imcalc iteration

            spm_defaults
            warning off MATLAB:divideByZero
            c = [];
            while true
              thresh = thresh-0.5;  %this threshold is applied directly during imcalc as a greater than function
                f= sprintf('(abs(i1)>%d).*i2',thresh);  %string of the function as input to imcalc
                clear P Q
                realcond=0;
                added_spaces='';
                for i=[1 2]
                    realcond=realcond+1;
                    imfile=[images{subjtodo}(i)];  % names for  f-contrast & ROI
                    pathlength=length(imfile);
                    if i ~=1 %extra loops added to cope with different length paths
                        if pathlength-oldpathlength>0
                            for jj=1:(pathlength-oldpathlength)
                                added_spaces=[added_spaces ' '];
                                P(1:(i-1),(oldpathlength+1):pathlength)=added_spaces;
                            end
                        end
                    end
                    oldpathlength=pathlength;
                    P(realcond,:)=imfile;   %... pits them into P

                end
                %            Q=sprintf('%s%d.img',maskmapthresh{reg},j)   %creates a name:
                Q = [region_todo '_thresh_' subjectname num2str(thresh) '.img'];
                % ROI(reg) _tresh 1-100

warning off all
                spm_imcalc_ui(P,Q,f);  %P = input files, Q = outputfile, f = function
warning on all

                fname=[Q];  % load thresholded imcalc file
                maskmap_hdr=spm_vol(fname);   % load SPM
                maskvol_vol=spm_read_vols(maskmap_hdr);  % read vol
                sz=size(maskvol_vol);   %how big?
                catchdup=zeros(sz);     %

                % Get region of interest
                % (and perform coordinate transformation)
                roimap_hdr=spm_vol([roimaps{1}]);
                tmpvol=spm_read_vols(roimap_hdr);
                roimap_vol=zeros(size(tmpvol));
%     ---------------------------------------------
                % find intersection of ROI and localizer.img
                roimap_vol(find(~isnan(tmpvol)))= tmpvol(find(~isnan(tmpvol)));	% Damn NaNs
                roimap_vol=~isnan(roimap_vol).*(roimap_vol>0);
                pts=[];;

                % transform roi space to tmap space
                [x,y,z]=ind2sub(size(roimap_vol),find(roimap_vol)); %ind2sub transforms from single numeric index into 3d matrix into 3d coords
                rXYZ=[x y z ones(size(x,1),1)]; % dummy extra col of ones to make matrix multiplication work

                % Transform to tmap space (e.g. see http://www.mrc-cbu.cam.ac.uk/Imaging/Common/spm_format.shtml at bottom for example of matrix transformation)
                tXYZ= rXYZ * (inv(maskmap_hdr.mat)*roimap_hdr.mat)'; % transforming from ROI-space to t-map space. Subsequently will need to round, look for duplicates, eliminate coords outside vol etc.
                tXYZ=tXYZ(:,1:3); % remove extra col of ones

                for pt=1:size(tXYZ,1)
                    % the next big if makes sure that the coords are
                    % all greater than zero, not bigger than the
                    % overall size of the image volume, and not already
                    % flagged in the catch duplicate volume
                    if ( round(tXYZ(pt,1)) >1) & ( round(tXYZ(pt,2)) > 1) & ( round(tXYZ(pt,3)) > 1) & ... % Make sure in range
                            ( round(tXYZ(pt,1)) <=sz(1)) & ( round(tXYZ(pt,2)) <=sz(2)) & ( round(tXYZ(pt,3)) <=sz(3)) & ...
                            (maskvol_vol(round(tXYZ(pt,1)), round(tXYZ(pt,2)), round(tXYZ(pt,3)))>0) & ...
                            (catchdup(round(tXYZ(pt,1)), round(tXYZ(pt,2)), round(tXYZ(pt,3)))==0);
                        pts=[pts;[round(tXYZ(pt,1)) round(tXYZ(pt,2)) round(tXYZ(pt,3))]];
                        catchdup(round(tXYZ(pt,1)), round(tXYZ(pt,2)), round(tXYZ(pt,3)))=1;
                    end

                end
                lin_index=sub2ind(sz,pts(:,1),pts(:,2),pts(:,3)); %go back from 3d coords to a single numeric index. So lin_index is indexing all the values extracted from the t-map
%     ---------------------------------------------

                dims=length(lin_index);

                if dims>n_voxels+10 %&& dims<n_voxels+25 % this is a small cheat to correct for voxels that are deleted after normalization (at a later stage)
                                     % this might further help when sorting
                                     % voxels in subsequent scripts to
                                     % avoid the famous non pos. def. matrix
                    n_voxels_above_thresh{subjtodo,g} = dims;
                    thresh_used{subjtodo,g,loc} = thresh;
                    fn3 = ['Fin_' region_todo '_' num2str(n_voxels) '_thresh_' subjectname '.img'];
                    spm_imcalc_ui(P,fn3,f);
                    delete V1_thresh*
                    delete V2_thresh*
                    delete V3_thresh*
                    delete V123_thresh*
                    fprintf('Subject %d number of voxels: %d\n',subjtodo, dims);
                    break  % break out of for loop, threshold found!
                else

                end % break end

            end %while loop that potentially never ends (only with break)

            %%%%%%%%%%%%%%%%%%now read out the complete timeseries of these voxels of all runs %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            % ----------------------------------------------------------
            % Now we know which coords in t-map space correspond to the intersection of the
            % localizer and the ROI.
            % These points represent the pattern vector for multivariate classification
            % It's now a simple step to extract the raw time course for the pattern
            % vectors from the rFM*.imgs
            % ----------------------------------------------------------
            clear vectors_inorder;
            for r=1:nruns
                fprintf('Subject %d loading run %d\n',subjtodo, r);
                for vol=1+ndummy:(nvol+ndummy)
                    fname=[dir_data{1}{r} '/' filehdr{1}{r} sprintf('%6.6d',vol) '.img'];
                    vol_hdr=spm_vol(fname);
                    vol_vol=spm_read_vols(vol_hdr);
                    activation_vector=vol_vol(lin_index);
                    vectors_inorder{r}(vol-ndummy,:)=activation_vector;
                    clear activation_vector vol_hdr vol_hdr;
                end % vol
            end % runs

            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            %%%%%%%%%%%%%%%%%% extract & sort conditions according to sots file %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

            vectors_all{n_conditions}=[];

            for r=1:nruns
                for cond = 1:conditions_per_run   %this is all blocks in one run (scanner on-off)
                    activation_vector_complete_run{n_conditions}=[];
                    startvolumes= sots(r,cond);
                    for t= 1:size(startvolumes,2)
                        activation_vector = vectors_inorder{r}(startvolumes(t)+adjust_HDR+1 : startvolumes(t)+adjust_HDR+howmanyvolumes,:);
                        activation_vector_complete_run{cond} = [activation_vector_complete_run{cond} activation_vector'];
                        clear activation_vector;
                    end
                    vectors_all{cond} = [vectors_all{cond} activation_vector_complete_run{cond}];
                    clear activation_vector_complete_run d
                end  % cond
                % at this point a single run is in every condition
            end %nruns

            %%%%%%%%%%%%%% at this point vectors_all contains the same voxels splitt up for in conditions
            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


            for cond = 1:conditions_per_run      % to follow the above format lets turn it around...
                vectors_all{cond} = vectors_all{cond}';
            end
            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            % avaerage over several TRs to avoid the autocorrelation problem
            if average_how_many_vectors > 1 % make sure this is a number that devides by how many vectors from each block
                fprintf(['averaging ' int2str(average_how_many_vectors) ' voxels...']);
                for cond = 1:conditions_per_run
                    for a = 0:(size(vectors_all{1},1)/average_how_many_vectors-1)
                        vectors_all_new{cond}(a+1,:) = mean(vectors_all{cond}(1+a*average_how_many_vectors:(a+1)*average_how_many_vectors,:));
                    end
                end

                vectors_all = vectors_all_new;
            end
            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            if mormalise ==1  % normalize/ rescale data from every condition idividually!
                % , only on voxels that are actually selected
                % , to a range from 0 to 1
                fprintf('normalizing...');
                for cond = 1:conditions_per_run
                    for a = 1:nruns
                        d  = vectors_all{cond}(howmanyvolumes/average_how_many_vectors*(a-1)+1:howmanyvolumes/average_how_many_vectors*a,:);
                        d1 = d - min(min(d));
                        d2 = d1./max(max(d1));
                        vectors_all{cond}(howmanyvolumes/average_how_many_vectors*(a-1)+1:howmanyvolumes/average_how_many_vectors*a,:) = d2;
                        clear d d1 d2
                    end
                end
                % This line should get us rid of the spurious 0 coloums that are
                % extremly annoying (values extremly low == 0)
                % they show after normalizing that's why the line appears here
                for i = 1:conditions_per_run
                    vectors_all{1,i}(:,sum(vectors_all{1,i})==0) = [];
                end
                if size(vectors_all{1},2) < n_voxels+5
                    warning('there are not enough voxels after normalisation!');
                end
            end


            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            if loc ==1
                save([dir_analysis{1} '\allvectors_in'],'vectors_all');
            else
                save([dir_analysis{1} '\allvectors_out'],'vectors_all');
            end

            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            % bit of code set up for dividing the data into equal parts that
            % might be used for crossvalidation at a later stage
            for i = 1:nruns
                thestartpoint   = ((i-1)*NRvec_test)+1;
                theendpoint     = thestartpoint+NRvec_test-1;
                vectors_test    = { vectors_all{1}(thestartpoint:theendpoint,:)...
                    vectors_all{2}(thestartpoint:theendpoint,:)...
                    vectors_all{3}(thestartpoint:theendpoint,:)...
                    vectors_all{4}(thestartpoint:theendpoint,:)};

                if (i > 1)
                    vectors_train = {[vectors_all{1}(1:thestartpoint-1,:);vectors_all{1}(theendpoint+1:NRvec_test*nruns,:)] ...
                        [vectors_all{2}(1:thestartpoint-1,:);vectors_all{2}(theendpoint+1:NRvec_test*nruns,:)] ...
                        [vectors_all{3}(1:thestartpoint-1,:);vectors_all{3}(theendpoint+1:NRvec_test*nruns,:)] ...
                        [vectors_all{4}(1:thestartpoint-1,:);vectors_all{4}(theendpoint+1:NRvec_test*nruns,:)] };
                else
                    vectors_train = {vectors_all{1}(theendpoint+1:NRvec_test*nruns,:)...
                        vectors_all{2}(theendpoint+1:NRvec_test*nruns,:)...
                        vectors_all{3}(theendpoint+1:NRvec_test*nruns,:)...
                        vectors_all{4}(theendpoint+1:NRvec_test*nruns,:)};
                end
                if loc ==1
                    filestring = sprintf('vectors_in_%d',i);
                else
                    filestring = sprintf('vectors_out_%d',i);
                end
                save ([dir_analysis{1} '/' filestring], 'vectors_train','vectors_test');
                clear vectors_train vectors_test;
            end

            clear vectors_all vectors_all_new lin_index maskmaps
        end % loc of diff localizers
    end %ROIS
end % subjtodo