%%%% GPL source code %%%% We do not claim that this source code works: use it at your own risks %%%% and tell us if you find bugs or bad features :-) %%%% %%%% you want to use this multivariate regression tree ? %%%% no problem ! %%%% just remove the fives lines below, replace them %%%% by your own x and y, and put in "variableNames" the names of the variables. %%%% where x is the set independent variables %%%% and y is the set of dependent variables. %%%% %%%% The source code contains the number 10 (in mvrt) %%%% which is the maximum depth and can be modified. %%%% A node if splitted if (i) all sons contain at least 30 individuals %%%% (ii) 20% of variance is discarded by this split. %%%% Splits are chosen in a greedy manner. %%%% All possible splits are tested. %%%% %%%% Please feel free of requiring some help by email. %%%% Please also tell me if you find this code useful. %%%% Please also tell me if it is too slow, I can easily %%%% strongly fasten it. %%%% %%%% teytaud@lri.fr x=rand(1000,5)*2-1;variableNamesX=[]; for i=1:size(x,2), variableNamesX=[variableNamesX;sprintf('inputVariable%d',i)]; end y=ceil(x(:,1:3)); y=y+0.01*randn(size(y));variableNamesY=[]; for i=1:size(y,2), variableNamesY=[variableNamesY;sprintf('outputVariable%d',i)]; end %%%% more off function assert(x,y) if (x==0) disp(y);zorglub;end; end function s=split(d) d=vec(sort(d)); d=[d(1);d(find(diff(d)>0)+1)]; if (length(d)>1) s=zeros(length(d)-1,1); %s(1)=d(1)-0.5*(d(2)-d(1));; for i=2:length(d), s(i-1)=0.5*(d(i-1)+d(i)); end %s(length(d)+1)=d(length(d))+0.5*(d(length(d))-d(length(d)-1)); else s=[]; disp('no possible splits!'); end end function e=mvar(z) e=sum(mean(z.^2,1)-mean(z,1).^2); end function dt=zmrt(x,y,n,bornesup,borneinf,intitulesx,intitulesy) assert(size(x,2)==length(bornesup),'bornesup!'); assert(size(x,2)==length(borneinf),'borneinf!'); if (n>0) bestIndex=0; bestRatio=Inf; bestSeuil=nan; currentVar=mvar(y); for index=1:size(x,2), possibleSplits=split(x(:,index)); for seuil=vec(possibleSplits)', y1=y(find(x(:,index)=seuil),:); sumVars=(size(y1,1)*mvar(y1)+size(y2,1)*mvar(y2))/size(y,1); ratio=sumVars/currentVar; assert(ratio>=-1e-3,'var>=0'); if (ratio>=1) disp(sprintf('%gx%g+%gx%g<=%gx%g',size(y1,1),mvar(y1),size(y2,1),mvar(y2),size(y,1),mvar(y))); end assert(ratio<=1+1e-1,'var(intra) < var(totale)'); if (ratio0.8) %%%% meme pas 20% de reduction de variance dt=[]; else f1=find(x(:,bestIndex)=bestSeuil); if ((length(f1)<30)||(length(f2)<30)) %%%% meme pas 30 personnes par classe dt=[]; else bornesup1=bornesup; borneinf2=borneinf; bornesup1(bestIndex)=bestSeuil; borneinf2(bestIndex)=bestSeuil; dt1=zmrt(x(f1,:),y(f1,:),n-1,bornesup1,borneinf,intitulesx,intitulesy); dt2=zmrt(x(f2,:),y(f2,:),n-1,bornesup,borneinf2,intitulesx,intitulesy); dt=[bestIndex,bestSeuil,size(dt1,1),size(dt2,1),mean(y),size(y),bornesup,borneinf;dt1;dt2]; end end else dt=[]; end if (size(dt,1)==0) disp(sprintf('Final class: %d individuals ===========================',size(y,1))); for i=1:size(x,2), if ((borneinf(i)>-Inf)||(bornesup(i)