(Rene) Segment 28: GMM's in N-dimensions

From Computational Statistics Course Wiki
Jump to navigation Jump to search

To Compute

The file twoexondata.txt contains data similar to that shown in slide 6, as 3000 (x,y) pairs.

1. In your favorite computer language, write a code for K-means clustering, and cluster the given data using (a) 3 components and (b) 8 components. Don't use anybody's K-means clustering package for this part: Code it yourself. Hint: Don't try to do it as limiting case of GMMs, just code it from the definition of K-means clustering, using an E-M iteration. Plot your results by coloring the data points according to which cluster they are in. How sensitive is your answer to the starting guesses?

function varargout = kmeansclustering(data,K,varargin)

% size data set
N = size(data,1);

if nargin==2
    % initial mean and covariance
    mu = mvnrnd(mean(data),cov(data),K); 
elseif nargin==3
    mu = varargin{1};
    if size(mu)==[K,N]
        % continue;
    else
        error('Size mean should correspond to the size of the data and number of clusters K')
    end
else
    error('Wrong number of input arguments');
end 

for jj=1:100
    % compute distances w.r.t. means
    for k=K:-1:1
        temp = data - repmat(mu(k,:),N,1);
        R(:,k) = sqrt(temp(:,1).^2 +temp(:,2).^2);
    end

    % find minimum distance and corresponding Gaussian
    [~,ind] = min(R,[],2);

    % update means
    for k=1:K
        mu(k,:) = mean(data(ind==k));
    end
    
end

% return output
varargout{1} = mu;
if nargout==2
    varargout{2} = ind; 
end

end



2. In your favorite computer language, and either writing your own GMM program or using any code you can find elsewhere (e.g., Numerical Recipes for C++, or scikit-learn, which is installed on the class server, for Python), construct mixture models like those shown in slide 8 (for 3 components) and slide 9 (for 8 components). You should plot 2-sigma error ellipses for the individual components, as shown in those slides.

function varargout = gaussmixturemodel(data,K,max_it,varargin)

% size data set
N = size(data,1); dim = size(data,2);

% initial guess for the mean, variance and population fraction
pop_frac = ones(1,K)./K;
sigma = cell(max_it,1); mu = cell(max_it,1);

if nargin==3
    mu{1} = data(randsample(N,K),:);
    sigma{1} = zeros(dim,dim,K);
    sigma{1}(:,:,1) = cov(data) ./ K;
    for k=K:-1:2
        sigma{1}(:,:,k) = sigma{1}(:,:,1);
    end
elseif nargin==5
    for k=K:-1:1
        mu{1}(k,:) = varargin{1}(k,:);
        sigma{1}(:,:,k) = varargin{2}(:,:,k);
    end
end
    

for jj=1:max_it-1 

    for k=K:-1:1
        % probability matrix
        Pnk(:,k) = mvnpdf(data,mu{jj}(k,:),sigma{jj}(:,:,k)) * pop_frac(k);
    end
    
    % probability nth data point
    Pxn = sum(Pnk,2); 
    
    % probability of nth data point being in the kth model
    Pnk = Pnk./ repmat(Pxn,1,K);
    
    for k=K:-1:1
        
        % compute sum over the data points of pnk
        pop_frac(k) = sum(Pnk(:,k),1) / N;
        
        % estimate mean
        mu{jj+1}(k,:) = sum(repmat(Pnk(:,k),1,dim).*data,1) / (N*pop_frac(k));    
        
        % estimate covariance matrix
        y = data - repmat(mu{jj+1}(k,:),N,1); % centered data
        sigma{jj+1}(:,:,k) = (repmat(Pnk(:,k),1,dim).* y).' * y / (N*pop_frac(k));    
    end 

end

  
% return output
varargout{1} = mu;
varargout{2} = sigma;
if nargout==3
    varargout{3} = pop_frac;
end

end

We computed the expectation maximization of the mixture model using three and eight Gaussian components and plotted confidence regions according to 3 standard deviations. Below the results for 3 components.

Mmx31.jpg Mmx32.jpg Mmx33.jpg Mmx34.jpg Mmx35.jpg Mmx36.jpg

Below the results for 8 components

Mmx81.jpg Mmx82.jpg Mmx83.jpg Mmx84.jpg Mmx85.jpg Mmx86.jpg

To Think About

1. The segment (or the previous one) mentioned that the log-likelihood can sometimes get stuck on plateaus, barely increasing, for long periods of time, and then can suddenly increase by a lot. What do you think is happening from iteration to iteration during these times on a plateau?