Learning Mixtures of Product Distributions

- Jon Feldman
- Columbia University

Ryan ODonnell IAS

Rocco Servedio Columbia University

Learning Distributions

- There is a an unknown distribution P over Rn, or

maybe just over 0,1n. - An algorithm gets access to random samples from

P. - In time polynomial in n/e it should output a

hypothesis distribution Q which (w.h.p.) is

e-close to P. - Technical details later.

Learning Distributions

- R
- 0
- Hopeless in general!

Learning Classes of Distributions

Learning Distributions

- Since this is hopeless in general one assumes

that P comes from class of distributions C. - We speak of whether C is polynomial-time

learnable or not this means that there is one

algorithm that learns every P in C. - Some easily learnable classes
- C Gaussians over Rn
- C Product distributions over 0,1n

Learning product distributions over 0,1n

- E.g. n 3. Samples
- 0 1 0
- 0 1 1
- 0 1 1
- 1 1 1
- 0 1 0
- 0 1 1
- 0 1 0
- 0 1 0
- 1 1 1
- 0 0 0
- Hypothesis .2 .9 .5

Mixtures of product distributions

- Fix k 2 and let p1 p2 pk 1.
- The p-mixture of distributions P 1, , P k is
- Draw i according to mixture weights pi.
- Draw from P i.
- In the case of product distributions over 0,1n
- p1 µ1 µ1 µ1 µ1
- p2 µ2 µ2 µ2 µ2
- pk µk µk µk µk

1

2

3

n

n

1

2

3

n

3

2

1

Learning mixture example

- E.g. n 4. Samples 1 1 0 0
- 0 0 0 1
- 0 1 0 1
- 0 1 1 0
- 0 0 0 1
- 1 1 1 0
- 0 1 0 1
- 0 0 1 1
- 1 1 1 0
- 1 0 1 0
- True distribution
- 60 .8 .8 .6 .2
- 40 .2 .4 .3 .8

Prior work

- KMRRSS94 learned in time poly(n/e, 2k) in the

special case that there is a number p lt ½ such

that every µi is either p or 1-p. - FM99 learned mixtures of 2 product

distributions over 0,1n in polynomial time

(with a few minor technical deficiencies). - CGG98 learned a generalization of 2 product

distributions over 0,1n, no deficiencies. - The latter two leave mixtures of 3 as an open

problem there is a qualitative difference

between 2 3. FM99 also leaves open learning

mixes of Gaussians, other Rn distributions.

j

Our results

- A poly(n/e) time algorithm learning a mixture of

k product distributions over 0,1n for any

constant k. - Evidence that getting a poly(n/e) algorithm for k

?(1) even in the case where µs are in 0, ½,

1 will be very hard (if possible). - Generalizations
- Let C 1, , C n be nice classes of

distributions over R (definable in terms of

O(1) moments) Algorithm learns mixture of O(1)

distributions in C 1 C n. - Only pairwise independence of coords is used

Technical definitions

- When is a hypothesis distribution Q e-close to

the target distribution P ? - L1 distance? ? P(x) Q(x).
- KL divergence KL(P Q) ? P (x) logP

(x)/Q(x). - Getting a KL-close hypothesis is more stringent
- fact L1 O(KL½).
- We learn under KL divergence, which leads to some

technical advantages (and some technical

difficulties).

Learning distributions summary

- Learning a class of distributions C.
- Let P be any distribution in the class.
- Given e and d gt 0.
- Get samples and do poly(n/e, log(1/d)) much work.
- With probability at least 1-d output a hypothesis

Q which satisfies KL(P Q) lt e.

Some intuition for k 2

- Idea Find two coordinates j and j' to key

off. - Suppose you notice that the bits in coords j and

j' are very frequently different. - Then probably most of the 01 examples come

from one mixture and most of the 10 examples

come from the other mixture - Use this separation to estimate all other means.

More details for the intuition

- Suppose you somehow know the following three

things - The mixture weights are 60 / 40.
- There are j and j' such that means satisfy
- pj pj'
- qj qj'
- The values pj, pj', qj, qj' themselves.

gt e.

More details for the intuition

- Main algorithmic idea
- For each coord m, estimate (to within e2) the

correlation between j m and j' m. - corr(j, m) (.6 pj) pm (.4 qj) qm
- corr(j', m) (.6 pj') pm (.4 qj') qm
- Solve this system of equations for pm, qm. Done!

- Since the determinant is gt e, any error in

correlation estimation error does not blow up too

much.

Two questions

- 1. This assumes that there is some 22 submatrix

which is far from singular. In general, no

reason to believe this is the case. - But if not, then one set of means is very nearly

a multiple of the other set problem becomes very

easy. - 2. How did we know p1, p2? How did we know

which j and j' were good? How did we know the 4

means pj, pj', qj, qj'?

Guessing

- Just guess. I.e., try all possibilities.
- Guess if the 2 n matrix is essentially rank 1

or not. - Guess p1, p2 to within e2. (Time 1/e4.)
- Guess correct j, j'. (Time n2.)
- Guess pj, pj', qj, qj' to within e2. (Time

1/e8.) - Solve the system of equations in every case.
- Time poly(n/e).

Checking guesses

- After this we get a whole bunch of candidate

hypotheses. - When we get lucky and make all the right guesses,

the resulting candidate hypothesis will be a good

one say, will be e-close in KL to the truth. - Can we pick the (or, a) candidate hypothesis

which is KL-close to the truth? I.e., can we

guess and check? - Yes use a Maximum Likelihood test

Checking with ML

- Suppose Q is a candidate hypothesis for P.
- Estimate its log likelihood
- log ?x ? S Q(x)
- Sx ? S log Q(x)
- S Elog Q (x)
- S ? P (x) log Q (x)
- S ? P log P KL(P Q ) .

Checking with ML contd

- By Chernoff bounds, if we take enough samples,

all candidate hypotheses Q will have their

estimated log-likelihoods close to their

expectations. - Any KL-close Q will look very good in the ML

test. - Anything which looks good in the ML test is

KL-close. - Thus assuming there is an e-close candidate

hypothesis among guesses, we find an O(e)-close

candidate hypothesis. - I.e., we can guess and check.

Overview of the algorithm

- We now give the precise algorithm for learning a

mixture of k product distributions, along with

intuition for why it works. - Intuitively
- Estimate all the pairwise correlations of bits.
- Guess a number of parameters of the mixture

distn. - Use guesses, correlation estimates to solve for

remaining parameters. - Show that whenever guesses are close, the

resulting parameter estimations give a

close-in-KL candidate hypothesis. - Check candidates with ML algorithm, pick best one.

The algorithm

- 1. Estimate all pairwise correlations corr(j, j')

to within (e/n)k. (Time (n/e)k.) - Note corr(j, j') Si 1..k pi µi µi
- ? µj , µj' ?,
- where µj ( (pi)½ µi )i 1..k
- 2. Guess all pi to within (e/n)k. (Time

(n/e)k2.) - Now it suffices to estimate all vectors µj, j

1 n.

j

j'

j

Mixtures of product distributions

- Fix k 2 and let p1 p2 pk 1.
- The p-mixture of distributions P 1, , P k is
- Draw i according to mixture weights pi.
- Draw from P i.
- In the case of product distributions over 0,1n
- p1 µ1 µ1 µ1 µ1
- p2 µ2 µ2 µ2 µ2
- pk µk µk µk µk

1

2

3

n

n

1

2

3

n

3

2

1

Guessing matrices from most of their Gram

matrices

- Let A be the k n matrix of µ is.
- A
- After estimating all correlations, we know all

dot products of distinct columns of A to high

accuracy. - Goal determine all entries of A, making only

O(1) guesses.

j

µ1

µ2

µn

Two remarks

- This is the final problem, where all the main

action and technical challenge lies. Note that

all we ever do with the samples is estimate

pairwise correlations. - If we knew the dot products of the columns of A

with themselves, wed have the whole matrix ATA.

That would be great we could just factor it and

recover A exactly. Unfortunately, there

doesnt seem to be any way to get at these

quantities Si 1..k pi (µi)2.

j

Keying off a nonsingular submatrix

- Idea find a nonsingular k k matrix to key

off. - As before, the usual case is that A has full

rank. - Then there is a k k nonsingular submatrix AJ.
- Guess this matrix (time nk) and all its entries

to within (e/n)k (time (n/e)k3 final running

time). - Now use this submatrix and correlation estimates

to find all other entries of A - for all m, AJT Am corr(m, j)

(j ? J)

Non-full rank case

- But what if A is not full rank? (Or in actual

analysis, if A is extremely close to being rank

deficient.) A genuine problem. - Then A has some perpendicular space of dimension

0 lt d k, spanned by some orthonormal vectors

u1, , ud. - Guess d and the vectors u1, , ud.
- Now adjoin these columns to A getting a full rank

matrix. - A' A u1 u2 ud

Non-full rank case contd

- Now A' has full rank and we can do the full rank

case! - Why do we still know all pairwise dot products of

A's columns? - Dot product of us with A columns are 0!
- Dot product of us with each other is 1. (Dont

need this.) - 4. Guess a k k submatrix of A' and all its

entries. Use these to solve for all other

entries.

The actual analysis

- The actual analysis of this algorithm is quite

delicate. - Theres some linear algebra numerical analysis

ideas. - The main issue is The degree to which A is

essentially of rank k d is similar to the

degree to which all guessed vectors u really do

have dot product 0 with As original columns. - The key is to find a large multiplicative gap

between As singular values, and treat its

location as the essential rank of A. - This is where the necessary accuracy (e/n)k comes

in.

Can we learn a mixture of ?(1)?

- Claim Let T be a decision tree on 0,1n with k

leaves. Then the uniform distribution over the

inputs which make T output 1 is a mixture of at

most k product distributions. - Indeed, all product distributions have means 0,

½, or 1.

x1

0

1

x2

x3

2/3 0, 0, ½, ½, ½, 1/3 1, 1, 0, ½, ½,

0

0

1

1

x2

1

0

0

0

1

0

1

Learning DTs under uniform

- Cor If one can learn a mixture of k product

distributions over 0,1n (even 0/½/1 ones) in

poly(n) time, one can PAC-learn k-leaf decision

trees under uniform in poly(n) time. - PAC-learning ?(1)-size DTs under uniform is an

extremely notorious problem - easier than learning ?(1)-term DNF under uniform,

a 20-year-old problem - essentially equivalent to learning ?(1)-juntas

under uniform worth 1000 from A. Blum to solve

Generalizations

- We gave an algorithm that guessed the means of an

unknown mixture of k product distributions. - What assumptions did we really need?
- pairwise independence of coords
- means fell in a bounded range -poly(n), poly(n)
- 1-d distributions (and pairwise products of same)

are samplable can find true correlations by

estimation - the means defined the 1-d distributions
- The last of these is rarely true. But

Higher moments

- Suppose we ran the algorithm and got N guesses

for the means of all the distributions. - Now run the algorithm again, but whenever you get

the point ?x1, , xn?, treat it as ?x12, , xn2?. - You will get N guesses for the second moments!
- Cross product the two lists, get N2 guesses for

the ?mean, second moment? pairs. - Guess and check, as always.

Generalizations

- Let C 1, , C n be families of distributions on R

which have the following niceness properties - means bounded in -poly(n), poly(n)
- sharp tail bounds / samplability
- defined by O(1) moments, closeness in moments ?

closeness in KL - more technical concerns
- Should be able to learn O(1)-mixtures from C 1

C n in same time. - Definitely can learn mixtures of axis-aligned

Gaussians, mixtures of distributions on

O(1)-sized sets.

Open questions

- Quantify some nice properties of families of

distributions over R which this algorithm can

learn. - Simplify algorithm
- Simpler analysis?
- Faster? nk2 ? nk ? nlog k ???
- Specific fast results for k 2, 3.
- Solve other distribution-learning problems.