-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathch8_markov-chain-monte-carlo.Rmd
531 lines (445 loc) · 17.7 KB
/
ch8_markov-chain-monte-carlo.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
---
title: "Chapter 8. Markov Chain Monte Carlo"
output: github_document
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE, comment = "#>", dpi = 500)
library(glue)
library(mustashe)
library(broom)
library(patchwork)
library(rethinking)
library(tidyverse)
library(conflicted)
conflict_prefer("filter", "dplyr")
conflict_prefer("select", "dplyr")
conflict_prefer("rename", "dplyr")
# To be able to use `map2stan()`:
conflict_prefer("collapse", "dplyr")
conflict_prefer("extract", "tidyr")
conflict_prefer("rstudent", "rethinking")
conflict_prefer("lag", "dplyr")
conflict_prefer("map", "purrr")
conflict_prefer("Position", "ggplot2")
theme_set(theme_minimal())
source_scripts()
set.seed(0)
```
- estimation of posterior probability distributions using a stochastic process called *Markov chain Monte Carlo (MCMC)** estimation
* sample directly from the posterior instead of approximating curves (like the quadratic approximation)
* allows for models that do not assume a multivariate normality
- can use generalized linear and multilevel models
- use *Stan* to fit these models
## 8.1 Good King Markov and His island kingdom
- a tale of the Good King Markov
* he was kings of a ring of 10 islands
* the second island had twice the population of the first, the third had three times the population of the first, etc.
* the king wanted to visit the islands in proportion to the population size, but didn't want to have to keep track of a schedule
* he would also only travel between adjacent islands
* he used the *Metropolis algorithm* to decide which island to visit next
1. the king decides to stay or travel by a coin flip
2. if the coin is heads, the king considers moving clockwise; if tails, he considers moving counter-clockwise; call this next island the "proposal" island
3. to decide whether not to move, the king collects the number of seashells in proportion to the population size of the proposed island, and collects the number of stones relative to the population of the current island
4. if there are more seashells, the king moves to the proposed island; else he discards the number of stones equal to the number of seashells and randomly selects from the remaining seashells and stones; if he selects a shell, he travels to the proposed island, else he stays
- below is a simulation of this process
```{r}
set.seed(0)
current_island <- 10
num_weeks <- 1e5
visited_islands <- rep(0, num_weeks)
flip_coin_to_decide_proposal_island <- function(x) {
a <- sample(c(0, 1), 1)
if (a == 0) {
y <- x - 1
} else {
y <- x + 1
}
if (y == 0) {
y <- 10
} else if (y == 11) {
y <- 1
}
return(y)
}
for (wk in seq(1, num_weeks)) {
visited_islands[[wk]] <- current_island
proposal_island <- flip_coin_to_decide_proposal_island(current_island)
if (proposal_island > current_island) {
current_island <- proposal_island
} else {
shells <- rep("shell", proposal_island)
stones <- rep("stone", current_island - proposal_island)
selection <- sample(unlist(c(shells, stones)), 1)
if (selection == "shell") {
current_island <- proposal_island
}
}
}
```
```{r}
tibble(wk = seq(1, num_weeks),
islands = visited_islands) %>%
ggplot(aes(x = factor(islands))) +
geom_bar(fill = "skyblue4") +
labs(x = "islands",
y = "number of weeks visited",
title = "Distribution of King Markov's visits")
```
```{r}
tibble(wk = seq(1, num_weeks),
islands = visited_islands) %>%
slice(1:400) %>%
ggplot(aes(x = wk, y = islands)) +
geom_line(color = "skyblue4", alpha = 0.5) +
geom_point(color = "skyblue4", size = 0.7) +
scale_y_continuous(breaks = c(1:10)) +
labs(x = "week number", y = "island",
title = "The path of King Markov",
subtitle = "Showing only the first 400 weeks")
```
- this algorithm still works if the king is equally likely to propose a move to any island from the current island
* still use the proportion of the islands' populations as the probability of moving
- at any point, the king only needs to know the population of the current island and the proposal island
## 8.2 Markov chain Monte Carlo
- the *Metropolis algorithm* used above is an example of a *Markov chain Monte Carlo*
* the goal is to draw samples from an unknown and complex target distribution
- "islands": the parameter values (they can be continuous, too)
- "population sizes": the posterior probabilities at each parameter value
- "week": samples taken from the joint posterior of the parameters
* we can use the samples from the Metropolis algorithm just like any other sampled distributions so far
- we will also cover *Gibbs sampling* and *Hamiltonian Monte Carlo*
### 8.2.1 Gibbs sampling
- can achieve a more efficient sampling procedure
* *adaptive proposals*: the distribution of proposed parameter values adjusts itself intelligently depending upon the parameter values at the moment
* the adaptive proposals depend on using particular combinations of prior distributions and likelihoods known as conjugate pairs
- these pairs have analytical solutions for the posterior of an individual parameter
- can use these solutions to make smart jumps around the joint posterior
- Gibbs sampling is used in BUGS (Bayesian inference Using Gibbs Sampling) and JAGS (Just Another Gibbs Sampler)
- limitations:
* don't always want to use the conjugate priors
* Gibbs sampling becomes very inefficient for models with hundreds or thousands of parameters
### 8.2.2 Hamiltonian Monte Carlo (HMC)
- more computationally expensive than Gibbs sampling and the Metropolis algorithm, but more efficient
* doesn't require as many samples to describe the posterior distribution
- an analogy using King Monty:
* king of a continuous stretch of land, not discrete islands
* wants to visit his citizens in proportion to their local density
* decides to travel back and forth across the length of the country, slowing down the vehicle when houses grow more dense, and speeding up when the houses are more sparse
* requires knowledge of how quickly the population is changing at the current location
- in this analogy:
* "royal vehicle": the current vector of parameter values
* the log-posterior forms a bowl with the MAP at the nadir
* the vehicle sweeps across the bowl, adjust the speed in proportion to the height on the bowl
- HMC does run a sort of physics simulation
* a vector of parameters gives the position of a friction-less particle
* the log-posterior provides the surface
- limitations:
* HMC requires continuous parameters
* HMC needs to be tuned to a particular model and its data
- the particle needs mass (so it can have momentum)
- other hyperparameters of MCMC is handled by Stan
## 8.3 Easy HMC: `map2stan`
- `map2stan()` provides an interface to Stan similar to how we have used `quap()`
* need to preprocess and variables transformations
* input data frame can only have columns used in the formulae
- example: look at terrain ruggedness data
```{r}
data("rugged")
dd <- as_tibble(rugged) %>%
mutate(log_gdp = log(rgdppc_2000)) %>%
filter(!is.na(rgdppc_2000))
```
- fit a model to predict log-GDP with terrain ruggedness, continent, and the interaction of the two
- first fit with `quap()` like before
```{r}
m8_1 <- quap(
alist(
log_gdp ~ dnorm(mu, sigma),
mu <- a + bR*rugged + bA*cont_africa + bAR*rugged*cont_africa,
a ~ dnorm(0, 100),
bR ~ dnorm(0, 10),
bA ~ dnorm(0, 10),
bAR ~ dnorm(0, 10),
sigma ~ dunif(0, 10)
),
data = dd
)
precis(m8_1)
```
### 8.3.1 Preparation
- must:
* do any transformations beforehand (e.g. logarithm, squared, etc.)
* reduce data frame to only used variables
```{r}
dd_trim <- dd %>%
select(log_gdp, rugged, cont_africa)
```
### 8.3.2 Estimation
- can now fit the model with HMC using `map2stan()`
```{r}
stash("m8_1stan", depends_on = "dd_trim", {
m8_1stan <- map2stan(
alist(
log_gdp ~ dnorm(mu, sigma),
mu <- a + bR*rugged + bA*cont_africa + bAR*rugged*cont_africa,
a ~ dnorm(0, 100),
bR ~ dnorm(0, 10),
bA ~ dnorm(0, 10),
bAR ~ dnorm(0, 10),
sigma ~ dcauchy(0, 2)
),
data = dd_trim
)
})
precis(m8_1stan)
```
- a half-Cauchy prior was used for $\sigma$
* a uniform distribution would work here, too
* it is a useful, "thick-tailed" probability
* related to the Student $t$ distribution
* can think of it as a weakly-regularizing prior for the standard deviation
- the estimates from this model are similar to those from the quadratic prior
- a few differences to note about the output of `precis()` (`summary()`):
* the *highest probability density intervals* (HPDI) are shown, not just percentiles intervals (PI), like before
* two new columns (they will be discussed further later):
- `n_eff`: a crude estimate of the number of independent samples that were collected
- `Rhat4`: an estimate of the convergence of the Markov chains (1 is good)
### 8.3.3 Sampling again, in parallel
- specific advice on the number of samples to run will be given later
- a compiled model can be resampled from again
* multiple chains can also be used and run in parallel
```{r}
stash("m8_1stan_4chains", {
m8_1stan_4chains <- map2stan(m8_1stan, chains = 4, cores = 4)
})
precis(m8_1stan_4chains)
```
### 8.3.4 Visualization
- can plot the samples
* interesting to see how Gaussian the distribution actually was
```{r}
post <- extract.samples(m8_1stan)
str(post)
```
```{r}
pairs(m8_1stan)
```
### 8.3.5 Using the samples
- we can use the samples just like before
* simulate predictions
* compute differences of parameters
* calculate DIC and WAIC,
* etc.
```{r}
show(m8_1stan)
WAIC(m8_1stan)
```
### 8.3.6 Checking the chain
- the MC is guaranteed to converge eventually
* need to check it actually did in the time we gave it
- use the *trace plot* to plot the samples in sequential order
* this is the first plot that should be looked at after fitting with MCMC
* grey region is the warm up ("adaptation") where the chain is learning to more efficiently sample from the posterior
- look for 2 characteristics of a "good" chain:
* *stationary*: the path staying within the posterior distribution
- the center of each path is relatively stable from start to end
* *good mixing*: each successive sample within each parameter is not highly correlated with the sample before it
- this is represented by a rapid zig-zap of the paths
```{r}
plot(m8_1stan)
```
- can access the raw Stan code in case we want to make specific changes not possible through `map2stan()`
```{r}
str_split(m8_1stan@model, "\n")
```
## 8.4 Care and feeding of your Markov chain
- it is not necessary to fully understand the MCMC, but some understanding of the process is necessary to be able to check if it worked
### 8.4.1 How many samples do you need?
- defaults: `iter = 2000` and `warmup = iter/2`
* gives 1000 warm-ups and 1000 samples
* this is a good place to start to make sure the model is defined correctly
- the number of samples needed for inference depends on many factors:
* the *effective* number of samples is the important part
- is an estimate of the number of independent samples from the posterior distribution
- chains can become *autocorrelated*
* what do we want to know?
- if we just want posterior means, not many samples are needed
- if we care about the shape of the posterior tails/extreme values, then need many samples
- Gaussian posterior distributions should need about 2000 samples, but skewed distributions likely need more
- for warm-up, we want as few as possible so more time is spent on sampling, but more warm-ups helps the MCMC sample more efficiently
* for Stan models, it is good to devote half of the total samples to warm-up
* for the very simple models we have fit so far, don't need very much warm-up
### 8.4.2 How many chains do you need?
- general workflow
1. when debugging a model, use one chain
2. run multiple chains to make sure the chains are working
3. for final model (whose samples will be use for inference), only one chain is really needed, but running multiple in parallel can help speed it up
- for typical regression models: "four short chains to check, one long chain for inference"
- when sampling is not working right, it is usually very obvious
* we will see some bad chains in the sections below
- Rhat tells us if the chains converged
* use it as a diagnostic signal of danger (when it is above 1.00), but not as a sign of safety (when it is 1.00)
### 8.4.3 Taming a wild chain
- a common problem is the model has a broad, flat posterior density
* happens most often with flat priors
* this causes the Markov chain to wander erratically
* a simple example with flat priors:
```{r}
y <- c(-1, 1)
stash("m8_2", depends_on = "y", {
m8_2 <- map2stan(
alist(
y ~ dnorm(mu, sigma),
mu <- alpha
),
data = list(y = y),
start = list(alpha = 0, sigma = 1),
chains = 2,
iter = 4000,
warmup = 1000
)
})
```
```{r}
precis(m8_2)
```
```{r}
plot(m8_2)
```
- even weakly informative priors can fix this problem
* the above example just has such little data with two y values and uninformative priors
* with little data, the priors become more important
- here's the same model with weak, but more informative priors
```{r}
stash("m8_3", depends_on = "y", {
m8_3 <- map2stan(
alist(
y ~ dnorm(mu, sigma),
mu <- alpha,
alpha ~ dnorm(1, 10),
sigma ~ dcauchy(0, 1)
),
data = list(y = y),
start = list(alpha = 0, sigma = 1),
chains = 2,
iter = 4e3,
warmup = 1e3
)
})
precis(m8_3)
plot(m8_3)
```
- below is a plot of the prior and posterior distributions used above
```{r}
post <- extract.samples(m8_3, n = 1e4, clean = FALSE)
post_tibble <- tibble(parameter = c("alpha", "sigma"),
value = c(list(post$alpha), list(post$sigma))) %>%
unnest(value) %>%
filter((parameter == "sigma" & value > 0 & value < 10) |
parameter == "alpha")
prior_tibble <- tibble(parameter = c("alpha", "sigma"),
x = rep(list(seq(-20, 20, 0.1)), each = 2),
value = c(
list(dnorm(seq(-20, 20, 0.1), 1, 10)),
list(dcauchy(seq(-20, 20, 0.1), 0, 1))
)) %>%
unnest(c(x, value)) %>%
filter((parameter == "sigma" & x > 0 & x < 10) | parameter == "alpha")
post_tibble %>%
ggplot() +
facet_wrap(~ parameter, scales = "free", nrow = 1) +
geom_density(aes(x = value), color = "dodgerblue3", size = 1.3) +
geom_line(data = prior_tibble, aes(x = x, y = value),
color = "darkorange1", size = 1.3, lty = 2) +
labs(x = "values",
y = "probability density",
title = "Priors and posteriors in a simple model",
subtitle = "For each parameter, the prior is in orange and posterior in blue.")
```
### 8.4.4. Non-identifiable parameters
- see the result of highly correlated predictors on a Markov chain
* produce a characteristic shape in the chain
* as an example, fit the following model without any priors on the $\alpha$'s.
$$
y_i \sim \text{Normal}(\mu, \sigma) \\
\mu = \alpha_1 + \alpha_2 \\
\sigma \sim \text{HalfCauchy}(0, 1)
$$
```{r}
y <- rnorm(100, 0, 1)
stash("m8_4", depends_on = "y", {
m8_4 <- map2stan(
alist(
y ~ dnorm(mu, sigma),
mu <- a1 + a2,
sigma ~ dcauchy(0, 1)
),
data = list(y = y),
start = list(a1 = 0, a2 = 0, sigma = 1),
chains = 2,
iter = 4e3,
warmup = 1e3
)
})
precis(m8_4, digits = 3)
plot(m8_4)
```
- the chains wander very far and never converge
* the `n_eff` and `Rhat` values are horrible
* these samples should not be used
- using weak priors helps a lot
```{r}
stash("m8_5", depends_on = "y", {
m8_5 <- map2stan(
alist(
y ~ dnorm(mu, sigma),
mu <- a1 + a2,
c(a1, a2) ~ dnorm(0, 10),
sigma ~ dcauchy(0, 1)
),
data = list(y = y),
start = list(a1 = 0, a2 = 0, sigma = 1),
chains = 2,
iter = 4e3,
warmup = 1e3
)
})
precis(m8_5, digits = 3)
plot(m8_5)
```
## 8.6 Practice
(I skipped to the problems that seemed interesting.)
**8H1. Run the model below and then inspect the posterior distribution and explain what it is accomplishing. Compare the samples for the parameters a and b. Can you explain the different trace plots, using what you know about the Cauchy distribution?**
```{r}
# Code from the question:
stash("mp", {
mp <- map2stan(
alist(
a ~ dnorm(0, 1),
b ~ dcauchy(0, 1)
),
data = list(y = 1),
start = list(a = 0, b = 0),
iter = 1e4,
warmup = 1e2,
WAIC = FALSE
)
})
###
```
```{r}
precis(mp, digits = 3)
plot(mp)
```
```{r}
post <- extract.samples(mp, clean = FALSE)
enframe(post) %>%
unnest(value) %>%
ggplot(aes(x = value, color = name)) +
geom_density(size = 1.3) +
scale_x_continuous(limits = c(-10, 10)) +
scale_color_brewer(palette = "Set1")
```
The posterior plots basically look like their priors.
Parameter `b` has fatter tails which are descriptive of the Cauchy compared to a Gaussian.
With only one data point, the priors will be very important in defining the posterior distributions.