项目作者: gravesee

项目描述 :
Bernoulli Mixture Models
高级语言: C
项目地址: git://github.com/gravesee/BMM.git
创建时间: 2018-07-10T02:28:26Z
项目社区:https://github.com/gravesee/BMM

开源协议:MIT License

下载


Bernoulli Mixture Models

This package provides a fast implementation modelling a mixture of
multivariate bernoulli samples. Expectation Maximization is used to find
the multivariate bernoulli prototypes and their mixture weights that
maximize the likelihood of the data.

The main function, BMM, works with both dense matrices and sparse,
pattern matrices from the Matrix package that ships with most
installations of R

Example Usage

We will first train a model on synthetic data generated from two very
specific prototypes and mixing weights. It is our hope that given the
generated data, the model can uncover the prototypes and mixing weights.

  1. P1 <- c(0.9, 0.9, 0.9, 0.1, 0.1)
  2. P2 <- c(0.1, 0.1, 0.9, 0.9, 0.9)
  3. prototypes <- list(P1, P2)
  4. weights <- c(0.25, 0.75)
  5. x <- t(replicate(1000, {
  6. ## pick a random prototype
  7. i <- sample(1:2, size = 1, prob = weights)
  8. ## sample bits from the chosen prototype
  9. sapply(prototypes[[i]], function(p) rbinom(1, 1, p))
  10. }))
  11. head(x)
  12. ## [,1] [,2] [,3] [,4] [,5]
  13. ## [1,] 1 1 1 0 1
  14. ## [2,] 0 0 1 1 1
  15. ## [3,] 0 0 1 1 1
  16. ## [4,] 0 0 1 1 1
  17. ## [5,] 0 1 1 0 1
  18. ## [6,] 0 0 1 1 1

Training a BMM

To train a BMM model invoke the BMM method passing the binary matrix of
data, the number of clusters to model, the maximum number of EM updates,
and whether to print training information to the console:

  1. set.seed(1234)
  2. res <- BMM(data = x, K = 2L, max.iter = 20L, verbose = 1L)
  3. ## 0 | -4155.3022
  4. ## 1 | -3454.8290
  5. ## 2 | -3454.7546
  6. ## 3 | -3454.2607
  7. ## 4 | -3450.9914
  8. ## 5 | -3429.9209
  9. ## 6 | -3313.7880
  10. ## 7 | -2956.5027
  11. ## 8 | -2593.1706
  12. ## 9 | -2449.3920
  13. ## 10 | -2371.0938
  14. ## 11 | -2325.7450
  15. ## 12 | -2303.7729
  16. ## 13 | -2292.8449
  17. ## 14 | -2286.8891
  18. ## 15 | -2283.4879
  19. ## 16 | -2281.5120
  20. ## 17 | -2280.3530
  21. ## 18 | -2279.6672
  22. ## 19 | -2279.2582

At each iteration, the model will print the log likelihood if the
verbose option is requested. The model will stop training after
convergence or the max number of iterations is reached, whichever
happens first. A model converges when the log likelihood no longer
updates.

Once the model finishes training it will return a list with three
elements:

  1. prototypes - contains a K x ncol(data) matrix where each row
    represents one of the learned prototoypes.
  2. pis - A numeric vector of length K containing the mixing weights
  3. cluster - An integer vector of length nrow(data) indicating which
    prototype was most likely to have generated each data point.
  1. ## Multivariate Bernoulli prototypes
  2. print(round(res$prototypes, 2))
  3. ## [,1] [,2] [,3] [,4] [,5]
  4. ## [1,] 0.11 0.12 0.92 0.89 0.90
  5. ## [2,] 0.87 0.92 0.90 0.14 0.14
  6. ## mixing weights
  7. print(round(res$pis, 2))
  8. ## [1] 0.74 0.26

These results are very similar to the prototypes and mixing weights used
to generate the dataset!

MNIST

What about something a little more interesting? The following example
trains a BMM on a sample of the famous MNIST, hand-written digit
dataset.

  1. url <- "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
  2. tmp <- tempfile()
  3. download.file(url, tmp)
  4. to.read <- gzfile(tmp, open = "rb")
  5. ## file header info
  6. readBin(to.read, what=integer(), n=4, endian="big")
  7. ## [1] 2051 10000 28 28
  8. images <- sapply(seq.int(10000), function(x) {
  9. readBin(to.read,integer(), size=1, n=28*28, endian="big")
  10. })
  11. close(to.read)
  12. d <- t(images)
  13. d <- (d < 0) + 0L ## binarize
  14. res <- BMM(d, K=12L, max.iter = 50L, verbose = 0L)

Visualizing Prototypes

With the BMM trained on the MNIST data we can now visualize the
prototypes to see what the model uncovered:

  1. par(mfrow=c(4,3))
  2. par(mar=c(0,0,0,0))
  3. for (i in seq.int(10)) {
  4. image(matrix(res$prototypes[i,], 28, 28)[,28:1], axes=F)
  5. }
  6. par(mfrow=c(1,1))