21  Xây dựng mô hình GAM trong R

https://m-clark.github.io/generalized-additive-models/

21.1 Import dataset

Câu hỏi nghiên cứu: Đánh giá ảnh hưởng của các biến Income, Edu, Health đến kết quả học tập Overall.

# https://raw.githubusercontent.com/m-clark/generalized-additive-models/master/data/pisasci2006.csv
pisa <- read.csv('data_raw_1/pisasci2006.csv')
pisa
                    Country Overall Issues Explain Evidence Interest Support Income Health   Edu   HDI
1                   Albania      NA     NA      NA       NA       NA      NA  0.599  0.886 0.716 0.724
2                 Argentina     391    395     386      385      567     506  0.678  0.868 0.786 0.773
3                 Australia     527    535     520      531      465     487  0.826  0.965 0.978 0.920
4                   Austria     511    505     516      505      507     515  0.835  0.944 0.824 0.866
5                Azerbaijan     382    353     412      344      612     542  0.566  0.780    NA    NA
6                   Belgium     510    515     503      516      503     492  0.831  0.935 0.868 0.877
7                    Brazil     390    398     390      378      592     519  0.637  0.818 0.646 0.695
8                  Bulgaria     434    427     444      417      523     527  0.663  0.829 0.778 0.753
9                    Canada     534    532     531      542      469     501  0.840  0.951 0.902 0.897
10                    Chile     438    444     432      440      591     564  0.673  0.923 0.764 0.780
11                  ChinaHK     542    528     549      542      536     529  0.853  0.966 0.763 0.857
12                 Colombia     388    402     379      383      644     546  0.616  0.829 0.624 0.683
13                  Croatia     493    494     492      490      535     514  0.724  0.878 0.762 0.785
14           Czech Republic     513    500     527      501      489     485  0.763  0.893 0.927 0.858
15                  Denmark     496    493     501      489      463     483  0.838  0.914 0.911 0.887
16                  Estonia     531    516     541      531      502     497  0.739  0.838 0.918 0.829
17                  Finland     563    555     566      567      448     479  0.827  0.931 0.879 0.878
18                   France     495    499     481      511      520     507  0.819  0.955 0.850 0.873
19                  Germany     516    510     519      515      513     518  0.831  0.939 0.929 0.898
20                   Greece     473    469     476      465      549     533  0.791  0.937 0.861 0.861
21                  Hungary     504    483     518      497      522     512  0.733  0.842 0.855 0.808
22                  Iceland     491    494     488      491      466     491  0.832  0.964 0.893 0.895
23                Indonesia     393    393     395      386      608     521  0.484  0.748 0.535 0.579
24                  Ireland     508    516     505      506      481     484  0.837  0.932 0.945 0.904
25                   Israel     454    457     443      460      509     512  0.786  0.952 0.901 0.877
26                    Italy     475    474     480      467      529     511  0.810  0.963 0.832 0.866
27                    Japan     531    522     527      544      512     468  0.825  0.986 0.869 0.891
28                   Jordan     422    409     438      405      609     555  0.552  0.832 0.679 0.678
29               Kazakhstan      NA     NA      NA       NA       NA      NA  0.635  0.717 0.826 0.721
30                    Korea     522    519     512      538      486     495     NA     NA    NA    NA
31          Kyrgyz Republic     322    321     334      288      580     502  0.407  0.735 0.712 0.598
32                   Latvia     490    489     486      491      504     494  0.711  0.820 0.855 0.793
33            Liechtenstein     522    522     516      535      504     524  0.941  0.931    NA    NA
34                Lithuania     488    476     494      487      544     541  0.717  0.813 0.871 0.798
35               Luxembourg     486    483     483      492      515     522  0.901  0.931 0.764 0.863
36              Macao-China     511    490     520      512      524     521     NA     NA    NA    NA
37                   Mexico     410    421     406      402      611     536  0.695  0.880 0.684 0.748
38               Montenegro     412    401     417      407      561     529  0.647  0.852 0.802 0.762
39              Netherlands     525    533     522      526      452     447  0.848  0.942 0.903 0.897
40              New Zealand     530    536     522      537      461     470  0.781  0.943 0.991 0.901
41                   Norway     487    489     495      473      472     485  0.884  0.948 0.993 0.940
42                   Panama      NA     NA      NA       NA       NA      NA  0.646  0.872 0.735 0.745
43                     Peru      NA     NA      NA       NA       NA      NA  0.591  0.832 0.688 0.697
44                   Poland     498    483     506      494      501     513  0.710  0.872 0.812 0.795
45                 Portugal     474    486     469      472      571     538  0.765  0.919 0.704 0.791
46                    Qatar     349    352     356      324      565     520  0.914  0.909 0.655 0.816
47                  Romania     418    409     426      407      591     540  0.658  0.831 0.792 0.757
48       Russian Federation     479    463     483      481      541     508  0.691  0.736 0.773 0.733
49                   Serbia     436    431     441      425      523     520  0.642  0.848 0.770 0.749
50           Shanghai-China      NA     NA      NA       NA       NA      NA     NA     NA    NA    NA
51                Singapore      NA     NA      NA       NA       NA      NA  0.876  0.951 0.720 0.843
52          Slovak Republic     488    475     501      478      522     497  0.734  0.859 0.864 0.817
53                 Slovenia     519    517     523      516      505     502  0.788  0.916 0.877 0.858
54                    Spain     488    489     490      485      534     529  0.805  0.950 0.836 0.862
55                   Sweden     503    499     510      496      454     471  0.836  0.956 0.904 0.898
56              Switzerland     512    515     508      519      504     510  0.857  0.970 0.856 0.893
57 Taiwan Province of China     532    509     545      532      533     546     NA     NA    NA    NA
58                 Thailand     421    413     420      423      642     569  0.603  0.842 0.569 0.661
59      Trinidad and Tobago      NA     NA      NA       NA       NA      NA  0.769  0.772 0.685 0.741
60                  Tunisia     386    384     383      382      590     534  0.597  0.846 0.608 0.675
61                   Turkey     424    427     423      417      540     563  0.679  0.828 0.562 0.681
62     United Arab Emirates      NA     NA      NA       NA       NA      NA  0.909  0.878 0.686 0.818
63           United Kingdom     515    514     517      514      464     470  0.833  0.934 0.798 0.853
64            United States     489    492     486      489      480     490  0.872  0.911 0.930 0.904
65                  Uruguay     428    429     423      429      567     510  0.658  0.884 0.740 0.755

21.2 Summary dữ liệu cho các cột numeric

# devtools::install_github('m-clark/tidyext')
library(tidyverse)
library(tidyext)
library(gt)

pisa |> tidyext::num_by(vars(-Country)) |> gt::gt()
Variable N Mean SD Min Q1 Median Q3 Max % Missing
Overall 57 473.1 54.6 322.0 428.0 489.0 513.0 563.0 12
Issues 57 469.9 53.9 321.0 427.0 489.0 514.0 555.0 12
Explain 57 475.0 54.0 334.0 432.0 490.0 517.0 566.0 12
Evidence 57 469.8 61.7 288.0 423.0 489.0 515.0 567.0 12
Interest 57 528.2 49.8 448.0 501.0 522.0 565.0 644.0 12
Support 57 512.2 26.1 447.0 494.0 512.0 529.0 569.0 12
Income 61 0.7 0.1 0.4 0.7 0.8 0.8 0.9 6
Health 61 0.9 0.1 0.7 0.8 0.9 0.9 1.0 6
Edu 59 0.8 0.1 0.5 0.7 0.8 0.9 1.0 9
HDI 59 0.8 0.1 0.6 0.7 0.8 0.9 0.9 9

21.3 Pairs plot

library(GGally)

better_smooth <- function(data, mapping, ptcol, ptalpha=1, ptsize=1, linecol, ...) {
  p <- ggplot(data = data, mapping = mapping) +
    geom_point(color = ptcol, alpha=ptalpha, size = ptsize) +
    geom_smooth(color = linecol, ...)
  p
}

p <- GGally::ggpairs(
  pisa[, -c(1, 3:5)],
  lower = list(
    continuous = GGally::wrap(
      better_smooth,
      ptalpha = .25,
      ptcol = '#D55E00',
      ptsize = 1,
      linecol = '#03b3ff',
      method = 'loess',
      se = FALSE,
      lwd = .5
    )
  ),
  diag = list(continuous = GGally::wrap(
    'densityDiag', color = 'gray50', lwd = .5
  )),
  # upper=list(continuous=GGally::wrap(better_corr)),
  axisLabels = "none"
)

p

21.4 Fitting model GAM

library(gamRR)
library(nlme)
library(mgcv)
mod_gam2 = gam(Overall ~ s(Income) + s(Edu) + s(Health), data = pisa)
summary(mod_gam2)

Family: gaussian 
Link function: identity 

Formula:
Overall ~ s(Income) + s(Edu) + s(Health)

Parametric coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)  471.154      2.772     170   <2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Approximate significance of smooth terms:
            edf Ref.df     F  p-value    
s(Income) 7.593  8.415 8.826 1.29e-06 ***
s(Edu)    6.204  7.178 3.308  0.00771 ** 
s(Health) 1.000  1.000 2.736  0.10679    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

R-sq.(adj) =  0.863   Deviance explained = 90.3%
GCV = 573.83  Scale est. = 399.5     n = 52
plot(mod_gam2)

21.5 Thể hiện ảnh hưởng của từng biến x đầu vào lên biến y đầu ra

library(ggeffects)
library(gratia)
plot(ggeffects::ggpredict(mod_gam2), facets = TRUE)

gratia::draw(mod_gam2)

21.6 Cách tính relative risk trong model GAM

Code này tham khảo theo package gamRR. Khi phân tích source code ta thấy relative risk được tính như sau:

1/ Từ model GAM đã fitting ở trên mod_gam2

2/ Ta chọn 1 vector tham chiếu, ở đây ví dụ là ref = c(Income = pisa$Income[1], Edu = pisa$Edu[1], Health = pisa$Health[1]) nghĩa là ngay dòng đầu tiên của dataset. Vai trò của vector tham chiếu này (1 điểm data point) là giúp làm mẫu số cho công thức tính relative risk.

3/ Ta chọn 1 biến x để kiểm tra relative risk, như trong trường hợp này là biến Income.

4/ Khi đó ta tạo dataframe với cột Income sẽ thay đổi theo dataset ban đầu nhưng hai cột EduHealth sẽ không đổi (chính là điểm datapoint tham chiếu ở bước 2)

5/ Thực hiện predict để tìm giá trị y predict cho điểm tham chiếu (ở bước 2).

6/ Thực hiện predict để tìm các giá trị y predict cho dataframe (ở bước 4)

7/ Relative risk là tỷ số giữa các giá trị y predict cho dataframe (ở bước 4) chia cho giá trị y predict (ở bước 2), trong trường hợp này là relative risk của biến Income.

Toàn bộ quy trình này được thể hiện qua đồ thị như bên dưới. Lưu ý là các thông tin sau chỉ là nhận định của cá nhân mình, để hiểu rõ thuật toán bạn sẽ liên hệ với tác giả package theo thông tin sau.

https://cran.r-project.org/web/packages/gamRR/index.html

gamRR::gamRR <- function (fit, ref, est, data, n.points = 10, plot = TRUE, ylim = NULL) 
{
  ref = data.frame(t(ref))
  form = as.character(fit$formula)
  x.list = strsplit(form[3], "\\+")[[1]]
  x.list = gsub(" ", "", x.list)
  x.list = sapply(strsplit(x.list, "\\,"), "[", 1)
  x.list = gsub("s\\(", "", x.list)
  x.list = gsub("as.factor\\(", "", x.list)
  x.list = gsub("factor\\(", "", x.list)
  x.list = gsub("offset\\(", "", x.list)
  x.list = gsub("log\\(", "", x.list)
  x.list = gsub("\\)", "", x.list)
  if (length(names(ref)) != length(x.list)) {
    stop("The number of variables in the 'ref' argument is not equal to those in the model!")
  }
  if (any(!(names(ref) %in% x.list))) {
    stop("Some variables in the 'ref' argument are not in the model!")
  }
  for (i in 1:length(x.list)) {
    data = data[!is.na(data[, x.list[i]]), ]
  }
  rrref = predict(fit, type = "response", newdata = ref)
  ndata = matrix(rep(0, nrow(data) * length(names(ref))), ncol = length(names(ref)))
  ndata = data.frame(ndata)
  names(ndata) = names(ref)
  # 1 biến chạy, giữ nguyên các biến còn lại
  ndata[, match(est, names(ndata))] = data[, match(est, names(data))]
  ndata[, -match(est, names(ref))] = ref[, -match(est, names(ref))]
  rr = predict(fit, type = "response", newdata = ndata)
  # relative risk là kết quả y đầu ra (tiên lượng với 1 biến x chạy, và các biến kia giữ nguyên) chia cho điểm tham chiếu
  rr = as.numeric(rr)/as.numeric(rrref)
  ref_no_est = names(ref)[-match(est, names(ref))]
  i = 1
  ndata = matrix(rep(0, nrow(data) * length(names(ref))), ncol = length(names(ref)))
  ndata = data.frame(ndata)
  names(ndata) = names(ref)
  ndata[, match(est, names(ndata))] = data[, match(est, names(data))]
  for (j in 1:length(ref_no_est)) {
    ndata[, ref_no_est[j]] = data[i, ref_no_est[j]]
  }
  rrn = predict(fit, type = "response", newdata = ndata)/as.numeric(rrref)
  for (i in 2:nrow(data)) {
    ndata = matrix(rep(0, nrow(data) * length(names(ref))), 
                   ncol = length(names(ref)))
    ndata = data.frame(ndata)
    names(ndata) = names(ref)
    ndata[, match(est, names(ndata))] = data[, match(est, 
                                                     names(data))]
    for (j in 1:length(ref_no_est)) {
      ndata[, ref_no_est[j]] = data[i, ref_no_est[j]]
    }
    rrn = cbind(rrn, predict(fit, type = "response", newdata = ndata)/as.numeric(rrref))
  }
  se = apply(rrn, 1, FUN = "sd")/sqrt(nrow(data) - 1)
  u = rr + 1.96 * se
  l = rr - 1.96 * se
  xy = data.frame(x = data[, est], rr = rr, u = u, l = l)
  xy = xy[order(xy$x), ]
  rangE = range(data[, est])
  est.seq = seq(from = rangE[1], to = rangE[2], length.out = n.points)
  seq.ind = which(abs(est.seq - as.numeric(ref[est])) == min(abs(est.seq - 
                                                                   as.numeric(ref[est]))))
  est.seq[seq.ind] = as.numeric(ref[est])
  nxy = matrix(rep(0, n.points * 4), ncol = 4)
  nxy = data.frame(nxy)
  names(nxy) = c("x", "rr", "u", "l")
  for (i in 1:n.points) {
    ind = which(abs(xy$x - est.seq[i]) == min(abs(xy$x - 
                                                    est.seq[i])))
    nxy[i, ] = xy[ind, ]
  }
  nxy[seq.ind, 2:4] = 1
  if (plot) {
    if (is.null(ylim)) {
      ylim = c(min(xy$l), max(xy$u))
    }
    plot(spline(nxy$x, nxy$rr, xmax = as.numeric(ref[, est])), 
         type = "l", xlim = c(min(nxy$x), max(nxy$x)), ylim = ylim, 
         xlab = est, ylab = "RR")
    lines(spline(nxy$x, nxy$l, xmax = as.numeric(ref[, est])), 
          lty = 2)
    lines(spline(nxy$x, nxy$u, xmax = as.numeric(ref[, est])), 
          lty = 2)
    lines(spline(nxy$x, nxy$rr, xmin = as.numeric(ref[, est])), 
          lty = 1)
    lines(spline(nxy$x, nxy$l, xmin = as.numeric(ref[, est])), 
          lty = 2)
    lines(spline(nxy$x, nxy$u, xmin = as.numeric(ref[, est])), 
          lty = 2)
  }
  return(nxy)
}
gamRR::gamRR(fit = mod_gam2,
             ref = c(Income = pisa$Income[1],
                     Edu = pisa$Edu[1],
                     Health = pisa$Health[1]),
             est = "Income",
             data = pisa,
             n.points = 10,
             plot = TRUE,
             ylim = NULL)

       x        rr         u         l
1  0.407 0.6838944 0.6944088 0.6733801
2  0.484 0.8713265 0.8818409 0.8608122
3  0.552 0.9970502 1.0075646 0.9865359
4  0.599 1.0000000 1.0000000 1.0000000
5  0.635 0.9529522 0.9634665 0.9424379
6  0.691 1.0115245 1.0220389 1.0010102
7  0.739 1.1412184 1.1517328 1.1307041
8  0.805 1.1676452 1.1781595 1.1571308
9  0.857 1.1974125 1.2079269 1.1868982
10 0.914 0.9669727 0.9774871 0.9564584
gamRR::gamRR(fit = mod_gam2,
             ref = c(Income = pisa$Income[1],
                     Edu = pisa$Edu[1],
                     Health = pisa$Health[1]),
             est = "Edu",
             data = pisa,
             n.points = 10,
             plot = TRUE,
             ylim = NULL)

       x        rr         u         l
1  0.535 0.9878204 1.0124750 0.9631658
2  0.569 0.9450332 0.9696878 0.9203786
3  0.646 0.8924675 0.9171221 0.8678129
4  0.688 0.9458709 0.9705255 0.9212163
5  0.716 1.0000000 1.0000000 1.0000000
6  0.792 1.0169171 1.0415717 0.9922625
7  0.836 1.0009779 1.0256325 0.9763233
8  0.893 1.0115830 1.0362376 0.9869284
9  0.945 1.0165534 1.0412080 0.9918988
10 0.993 1.0657664 1.0904210 1.0411118
gamRR::gamRR(fit = mod_gam2,
             ref = c(Income = pisa$Income[1],
                     Edu = pisa$Edu[1],
                     Health = pisa$Health[1]),
             est = "Health",
             data = pisa,
             n.points = 10,
             plot = TRUE,
             ylim = NULL)

       x        rr         u         l
1  0.717 1.0740251 1.1081999 1.0398504
2  0.748 1.0604466 1.0946213 1.0262718
3  0.772 1.0499341 1.0841089 1.0157594
4  0.813 1.0319754 1.0661501 0.9978006
5  0.838 1.0210249 1.0551996 0.9868501
6  0.868 1.0078843 1.0420591 0.9737096
7  0.886 1.0000000 1.0000000 1.0000000
8  0.923 0.9837933 1.0179681 0.9496186
9  0.956 0.9693387 1.0035135 0.9351640
10 0.986 0.9561981 0.9903729 0.9220234

22 Tài liệu tham khảo

  1. https://rpubs.com/HeatWave2019/572700

  2. https://rpubs.com/Huyen_Nguyen_Rosie/1055522

  3. https://stat.ethz.ch/pipermail/r-help/2003-May/033804.html

  4. https://stats.stackexchange.com/questions/33327/confidence-interval-for-gam-model

  5. https://www.researchgate.net/post/How_can_I_calculate_the_confidence_Interval_of_the_relative_risk_when_using_the_GAM_model_and_the_R_language_My_modelgam_DeathsHeat_Cold_sTM