library(tidyverse)
library(magrittr)
library(keras)
library(tidymodels) # Or only load the 'rsample' and recipes on its own
# I load the classical creditcard fraud example
data <- read_csv('creditcard.csv')
Rows: 284807 Columns: 31
── Column specification ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
dbl (31): Time, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, V21, V22, V23, V24, V25, V26, ...

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
data %>% glimpse()
Rows: 284,807
Columns: 31
$ Time   <dbl> 0, 0, 1, 1, 2, 2, 4, 7, 7, 9, 10, 10, 10, 11, 12, 12, 12, 13, 14, 15, 16, 17, 18, 18, 22, 22, 23, 23, 23, 23, 24, 25, 26, 26…
$ V1     <dbl> -1.3598071, 1.1918571, -1.3583541, -0.9662717, -1.1582331, -0.4259659, 1.2296576, -0.6442694, -0.8942861, -0.3382618, 1.4490…
$ V2     <dbl> -0.07278117, 0.26615071, -1.34016307, -0.18522601, 0.87773675, 0.96052304, 0.14100351, 1.41796355, 0.28615720, 1.11959338, -…
$ V3     <dbl> 2.53634674, 0.16648011, 1.77320934, 1.79299334, 1.54871785, 1.14110934, 0.04537077, 1.07438038, -0.11319221, 1.04436655, 0.9…
$ V4     <dbl> 1.37815522, 0.44815408, 0.37977959, -0.86329128, 0.40303393, -0.16825208, 1.20261274, -0.49219902, -0.27152613, -0.22218728,…
$ V5     <dbl> -0.338320770, 0.060017649, -0.503198133, -0.010308880, -0.407193377, 0.420986881, 0.191880989, 0.948934095, 2.669598660, 0.4…
$ V6     <dbl> 0.46238778, -0.08236081, 1.80049938, 1.24720317, 0.09592146, -0.02972755, 0.27270812, 0.42811846, 3.72181806, -0.24676110, -…
$ V7     <dbl> 0.239598554, -0.078802983, 0.791460956, 0.237608940, 0.592940745, 0.476200949, -0.005159003, 1.120631358, 0.370145128, 0.651…
$ V8     <dbl> 0.098697901, 0.085101655, 0.247675787, 0.377435875, -0.270532677, 0.260314333, 0.081212940, -3.807864239, 0.851084443, 0.069…
$ V9     <dbl> 0.36378697, -0.25542513, -1.51465432, -1.38702406, 0.81773931, -0.56867138, 0.46495999, 0.61537473, -0.39204759, -0.73672732…
$ V10    <dbl> 0.09079417, -0.16697441, 0.20764287, -0.05495192, 0.75307443, -0.37140720, -0.09925432, 1.24937618, -0.41043043, -0.36684564…
$ V11    <dbl> -0.551599533, 1.612726661, 0.624501459, -0.226487264, -0.822842878, 1.341261980, -1.416907243, -0.619467796, -0.705116587, 1…
$ V12    <dbl> -0.61780086, 1.06523531, 0.06608369, 0.17822823, 0.53819555, 0.35989384, -0.15382583, 0.29147435, -0.11045226, 0.83638957, -…
$ V13    <dbl> -0.99138985, 0.48909502, 0.71729273, 0.50775687, 1.34585159, -0.35809065, -0.75106272, 1.75796421, -0.28625363, 1.00684351, …
$ V14    <dbl> -0.31116935, -0.14377230, -0.16594592, -0.28792375, -1.11966983, -0.13713370, 0.16737196, -1.32386522, 0.07435536, -0.443522…
$ V15    <dbl> 1.468176972, 0.635558093, 2.345864949, -0.631418118, 0.175121130, 0.517616807, 0.050143594, 0.686132504, -0.328783050, 0.150…
$ V16    <dbl> -0.470400525, 0.463917041, -2.890083194, -1.059647245, -0.451449183, 0.401725896, -0.443586798, -0.076126999, -0.210077268, …
$ V17    <dbl> 0.207971242, -0.114804663, 1.109969379, -0.684092786, -0.237033239, -0.058132823, 0.002820512, -1.222127345, -0.499767969, -…
$ V18    <dbl> 0.02579058, -0.18336127, -0.12135931, 1.96577500, -0.03819479, 0.06865315, -0.61198734, -0.35822157, 0.11876486, 0.47667726,…
$ V19    <dbl> 0.40399296, -0.14578304, -2.26185710, -1.23262197, 0.80348692, -0.03319379, -0.04557504, 0.32450473, 0.57032817, 0.45177296,…
$ V20    <dbl> 0.251412098, -0.069083135, 0.524979725, -0.208037781, 0.408542360, 0.084967672, -0.219632553, -0.156741852, 0.052735669, 0.2…
$ V21    <dbl> -0.018306778, -0.225775248, 0.247998153, -0.108300452, -0.009430697, -0.208253515, -0.167716266, 1.943465340, -0.073425100, …
$ V22    <dbl> 0.277837576, -0.638671953, 0.771679402, 0.005273597, 0.798278495, -0.559824796, -0.270709726, -1.015454710, -0.268091632, -0…
$ V23    <dbl> -0.110473910, 0.101288021, 0.909412262, -0.190320519, -0.137458080, -0.026397668, -0.154103787, 0.057503530, -0.204232670, -…
$ V24    <dbl> 0.0669280749, -0.3398464755, -0.6892809565, -1.1755753319, 0.1412669838, -0.3714265832, -0.7800554150, -0.6497090056, 1.0115…
$ V25    <dbl> 0.1285393583, 0.1671704044, -0.3276418337, 0.6473760346, -0.2060095876, -0.2327938167, 0.7501369358, -0.4152665662, 0.373204…
$ V26    <dbl> -0.18911484, 0.12589453, -0.13909657, -0.22192884, 0.50229222, 0.10591478, -0.25723685, -0.05163430, -0.38415731, 0.09419883…
$ V27    <dbl> 0.133558377, -0.008983099, -0.055352794, 0.062722849, 0.219422230, 0.253844225, 0.034507430, -1.206921081, 0.011747356, 0.24…
$ V28    <dbl> -0.021053053, 0.014724169, -0.059751841, 0.061457629, 0.215153147, 0.081080257, 0.005167769, -1.085339188, 0.142404330, 0.08…
$ Amount <dbl> 149.62, 2.69, 378.66, 123.50, 69.99, 3.67, 4.99, 40.80, 93.20, 3.68, 7.80, 9.99, 121.50, 27.50, 58.80, 15.99, 12.99, 0.89, 4…
$ Class  <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
# credit card fraud
data %>% count(Class)

Take away: Credit card fraud is extremely rare

EDA

data %>% skim()
── Data Summary ────────────────────────
                           Values    
Name                       Piped data
Number of rows             284807    
Number of columns          31        
_______________________              
Column type frequency:               
  numeric                  31        
________________________             
Group variables            None      

── Variable type: numeric ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   skim_variable n_missing complete_rate      mean         sd      p0        p25         p50         p75      p100 hist 
 1 Time                  0             1  9.48e+ 4 47488.        0    54202.     84692       139320.     172792    ▃▇▅▆▇
 2 V1                    0             1  1.17e-15     1.96    -56.4     -0.920      0.0181       1.32        2.45 ▁▁▁▁▇
 3 V2                    0             1  3.12e-16     1.65    -72.7     -0.599      0.0655       0.804      22.1  ▁▁▁▇▁
 4 V3                    0             1 -1.36e-15     1.52    -48.3     -0.890      0.180        1.03        9.38 ▁▁▁▁▇
 5 V4                    0             1  2.11e-15     1.42     -5.68    -0.849     -0.0198       0.743      16.9  ▂▇▁▁▁
 6 V5                    0             1  9.80e-16     1.38   -114.      -0.692     -0.0543       0.612      34.8  ▁▁▁▇▁
 7 V6                    0             1  1.51e-15     1.33    -26.2     -0.768     -0.274        0.399      73.3  ▁▇▁▁▁
 8 V7                    0             1 -5.42e-16     1.24    -43.6     -0.554      0.0401       0.570     121.   ▁▇▁▁▁
 9 V8                    0             1  1.03e-16     1.19    -73.2     -0.209      0.0224       0.327      20.0  ▁▁▁▇▁
10 V9                    0             1 -2.42e-15     1.10    -13.4     -0.643     -0.0514       0.597      15.6  ▁▁▇▁▁
11 V10                   0             1  2.23e-15     1.09    -24.6     -0.535     -0.0929       0.454      23.7  ▁▁▇▁▁
12 V11                   0             1  1.71e-15     1.02     -4.80    -0.762     -0.0328       0.740      12.0  ▁▇▁▁▁
13 V12                   0             1 -1.24e-15     0.999   -18.7     -0.406      0.140        0.618       7.85 ▁▁▁▇▁
14 V13                   0             1  8.35e-16     0.995    -5.79    -0.649     -0.0136       0.663       7.13 ▁▃▇▁▁
15 V14                   0             1  1.23e-15     0.959   -19.2     -0.426      0.0506       0.493      10.5  ▁▁▁▇▁
16 V15                   0             1  4.84e-15     0.915    -4.50    -0.583      0.0481       0.649       8.88 ▁▇▂▁▁
17 V16                   0             1  1.43e-15     0.876   -14.1     -0.468      0.0664       0.523      17.3  ▁▁▇▁▁
18 V17                   0             1 -3.78e-16     0.849   -25.2     -0.484     -0.0657       0.400       9.25 ▁▁▁▇▁
19 V18                   0             1  9.76e-16     0.838    -9.50    -0.499     -0.00364      0.501       5.04 ▁▁▂▇▁
20 V19                   0             1  1.04e-15     0.814    -7.21    -0.456      0.00373      0.459       5.59 ▁▁▇▂▁
21 V20                   0             1  6.41e-16     0.771   -54.5     -0.212     -0.0625       0.133      39.4  ▁▁▇▁▁
22 V21                   0             1  1.69e-16     0.735   -34.8     -0.228     -0.0295       0.186      27.2  ▁▁▇▁▁
23 V22                   0             1 -3.38e-16     0.726   -10.9     -0.542      0.00678      0.529      10.5  ▁▁▇▁▁
24 V23                   0             1  2.67e-16     0.624   -44.8     -0.162     -0.0112       0.148      22.5  ▁▁▁▇▁
25 V24                   0             1  4.47e-15     0.606    -2.84    -0.355      0.0410       0.440       4.58 ▁▇▆▁▁
26 V25                   0             1  5.11e-16     0.521   -10.3     -0.317      0.0166       0.351       7.52 ▁▁▇▂▁
27 V26                   0             1  1.68e-15     0.482    -2.60    -0.327     -0.0521       0.241       3.52 ▁▆▇▁▁
28 V27                   0             1 -3.67e-16     0.404   -22.6     -0.0708     0.00134      0.0910     31.6  ▁▁▇▁▁
29 V28                   0             1 -1.23e-16     0.330   -15.4     -0.0530     0.0112       0.0783     33.8  ▁▇▁▁▁
30 Amount                0             1  8.83e+ 1   250.        0        5.6       22           77.2     25691.   ▇▁▁▁▁
31 Class                 0             1  1.73e- 3     0.0415    0        0          0            0           1    ▇▁▁▁▁

Otherwise all variables seem very well behaved.

library(ggridges)
data %>%
  pivot_longer(cols = !Class, names_to = 'feature', values_to = 'value') %>%
  ggplot(aes(y = feature  %>% as.factor(), 
             fill = Class %>% as.factor(), 
             x = value %>% percent_rank())) +
  geom_density_ridges() + 
  labs(title = 'Density Distribution by Outcome', x = 'Percntage Rank', y = 'Feature') +
  theme(legend.position = 'bottom')
Picking joint bandwidth of 0.0309

Some differences in feature distributions indeed indicate possible prediction of outcome

Preprocessing

# We don't use the amount and time as info
data %<>% rename(y = Class) %>%
  select(-Time, -Amount) %>%
  drop_na()
data_split <- initial_split(data, prop = 0.75, strata = y)

data_train <- data_split  %>%  training()
data_test <- data_split %>% testing()
data_recipe <- data_train %>%
  recipe(y ~.) %>%
  step_normalize(all_numeric(), -all_outcomes()) %>%
  prep()
x_train <- juice(data_recipe) %>% select(-y) %>% as.matrix()
y_train <- juice(data_recipe) %>% pull(y)   

x_test <- bake(data_recipe, new_data = data_test) %>% select(-y) %>% as.matrix()
y_test <- bake(data_recipe, new_data = data_test) %>% pull(y) 

Model

input_layer <- 
  layer_input(shape = ncol(x_train)) 
Loaded Tensorflow version 2.7.0
encoder <- 
  input_layer %>% 
  layer_dense(units = 64, activation = "relu") %>% 
  layer_batch_normalization() %>% 
  layer_dropout(rate = 0.1) %>% 
  layer_dense(units = 32, activation = "relu") %>%
  layer_dropout(rate = 0.1) %>%
  layer_dense(units = 16, activation = "relu") %>%
  layer_dense(units = 8) # n dimensions for the output layer
2021-11-26 13:28:28.007256: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
decoder <- 
  encoder %>% 
  layer_dense(units = 16, activation = "relu") %>% 
  layer_dropout(rate = 0.1) %>% 
  layer_dense(units = 32, activation = "relu") %>%
  layer_dropout(rate = 0.1) %>%
  layer_dense(units = 64, activation = "relu") %>%
  layer_dropout(rate = 0.1) %>%
  layer_dense(units = ncol(x_train)) # n dimensions for the original n variables
model <- keras_model(inputs = input_layer, outputs = decoder)
model %>% compile(
  loss = 'mse',
  optimizer= optimizer_adam(),
  metrics = c('mse')
)
model %>% summary()
Model: "model"
_____________________________________________________________________________________________________________________________________________
 Layer (type)                                                  Output Shape                                            Param #               
=============================================================================================================================================
 input_1 (InputLayer)                                          [(None, 28)]                                            0                     
                                                                                                                                             
 dense_3 (Dense)                                               (None, 64)                                              1856                  
                                                                                                                                             
 batch_normalization (BatchNormalization)                      (None, 64)                                              256                   
                                                                                                                                             
 dropout_1 (Dropout)                                           (None, 64)                                              0                     
                                                                                                                                             
 dense_2 (Dense)                                               (None, 32)                                              2080                  
                                                                                                                                             
 dropout (Dropout)                                             (None, 32)                                              0                     
                                                                                                                                             
 dense_1 (Dense)                                               (None, 16)                                              528                   
                                                                                                                                             
 dense (Dense)                                                 (None, 8)                                               136                   
                                                                                                                                             
 dense_11 (Dense)                                              (None, 16)                                              144                   
                                                                                                                                             
 dropout_7 (Dropout)                                           (None, 16)                                              0                     
                                                                                                                                             
 dense_10 (Dense)                                              (None, 32)                                              544                   
                                                                                                                                             
 dropout_6 (Dropout)                                           (None, 32)                                              0                     
                                                                                                                                             
 dense_9 (Dense)                                               (None, 64)                                              2112                  
                                                                                                                                             
 dropout_5 (Dropout)                                           (None, 64)                                              0                     
                                                                                                                                             
 dense_8 (Dense)                                               (None, 28)                                              1820                  
                                                                                                                                             
=============================================================================================================================================
Total params: 9,476
Trainable params: 9,348
Non-trainable params: 128
_____________________________________________________________________________________________________________________________________________

Run the model

tensorboard("logs/run_a")

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

TensorBoard 2.7.0 at http://127.0.0.1:4273/ (Press CTRL+C to quit)
Started TensorBoard at http://127.0.0.1:4273 
history <-model %>%
  keras::fit(x_train,
             x_train,
             epochs = 100,
             verbose = FALSE,
             validation_data = list(x_test, x_test),
             callbacks = list(callback_tensorboard("logs/run_a"),
                              callback_early_stopping(patience = 5))
             )
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0013s vs `on_train_batch_end` time: 0.0032s). Check your callbacks.
history %>% plot()
`geom_smooth()` using formula 'y ~ x'

Evaluate Reconstruction

# We only evaluate on 'normal' cases, since this is what we want to reconstruct best
loss <- model %>% evaluate(x = x_test[y_test == 0,], y = x_test[y_test == 0,])

   1/2222 [..............................] - ETA: 42s - loss: 0.2940 - mse: 0.2940
  57/2222 [..............................] - ETA: 1s - loss: 0.3410 - mse: 0.3410 
 108/2222 [>.............................] - ETA: 2s - loss: 0.3724 - mse: 0.3724
 176/2222 [=>............................] - ETA: 1s - loss: 0.3383 - mse: 0.3383
 227/2222 [==>...........................] - ETA: 1s - loss: 0.3213 - mse: 0.3213
 297/2222 [===>..........................] - ETA: 1s - loss: 0.3139 - mse: 0.3139
 363/2222 [===>..........................] - ETA: 1s - loss: 0.3276 - mse: 0.3276
 434/2222 [====>.........................] - ETA: 1s - loss: 0.3229 - mse: 0.3229
 507/2222 [=====>........................] - ETA: 1s - loss: 0.3194 - mse: 0.3194
 587/2222 [======>.......................] - ETA: 1s - loss: 0.3195 - mse: 0.3195
 663/2222 [=======>......................] - ETA: 1s - loss: 0.3189 - mse: 0.3189
 739/2222 [========>.....................] - ETA: 1s - loss: 0.3152 - mse: 0.3152
 820/2222 [==========>...................] - ETA: 1s - loss: 0.3103 - mse: 0.3103
 888/2222 [==========>...................] - ETA: 0s - loss: 0.3066 - mse: 0.3066
 955/2222 [===========>..................] - ETA: 0s - loss: 0.3043 - mse: 0.3043
1024/2222 [============>.................] - ETA: 0s - loss: 0.3017 - mse: 0.3017
1092/2222 [=============>................] - ETA: 0s - loss: 0.3004 - mse: 0.3004
1152/2222 [==============>...............] - ETA: 0s - loss: 0.3007 - mse: 0.3007
1194/2222 [===============>..............] - ETA: 0s - loss: 0.3045 - mse: 0.3045
1254/2222 [===============>..............] - ETA: 0s - loss: 0.3087 - mse: 0.3087
1316/2222 [================>.............] - ETA: 0s - loss: 0.3082 - mse: 0.3082
1380/2222 [=================>............] - ETA: 0s - loss: 0.3090 - mse: 0.3090
1451/2222 [==================>...........] - ETA: 0s - loss: 0.3095 - mse: 0.3095
1523/2222 [===================>..........] - ETA: 0s - loss: 0.3112 - mse: 0.3112
1604/2222 [====================>.........] - ETA: 0s - loss: 0.3123 - mse: 0.3123
1678/2222 [=====================>........] - ETA: 0s - loss: 0.3135 - mse: 0.3135
1738/2222 [======================>.......] - ETA: 0s - loss: 0.3188 - mse: 0.3188
1808/2222 [=======================>......] - ETA: 0s - loss: 0.3200 - mse: 0.3200
1873/2222 [========================>.....] - ETA: 0s - loss: 0.3206 - mse: 0.3206
1932/2222 [=========================>....] - ETA: 0s - loss: 0.3209 - mse: 0.3209
1983/2222 [=========================>....] - ETA: 0s - loss: 0.3197 - mse: 0.3197
2057/2222 [==========================>...] - ETA: 0s - loss: 0.3203 - mse: 0.3203
2134/2222 [===========================>..] - ETA: 0s - loss: 0.3195 - mse: 0.3195
2214/2222 [============================>.] - ETA: 0s - loss: 0.3214 - mse: 0.3214
2222/2222 [==============================] - 2s 751us/step - loss: 0.3212 - mse: 0.3212

2222/2222 [==============================] - 2s 751us/step - loss: 0.3212 - mse: 0.3212
loss
     loss       mse 
0.3211641 0.3211641 
# Generate the reconstruction error (MSE)
pred_train <- predict(model, x_train)
mse_train <- apply((x_train - pred_train)^2, 1, sum)

pred_test <- predict(model, x_test)
mse_test <- apply((x_test - pred_test)^2, 1, sum)
pred_eval_train <- tibble(
  mse = mse_train,
  class = y_train %>% factor()
)
pred_eval_train %>%
  ggplot(aes(x = class, y = mse, fill = class)) +
  geom_boxplot() + 
  ylim(0, 250) + 
  labs(title = 'Boxplot: Reconstruction error by class', subtitle = 'Training Data') +
  theme(legend.position = 'bottom')
Warning: Removed 171 rows containing non-finite values (stat_boxplot).

pred_eval_train  %>%
  ggplot(aes(x = mse, fill = class, alpha = 0.5)) +
  geom_density() + 
  xlim(0, 100) + 
  labs(title = 'Density Distribution: Reconstruction error by class', subtitle = 'Training Data') +
  theme(legend.position = 'bottom')
Warning: Removed 576 rows containing non-finite values (stat_density).

Looks good news. The rare event indeed has a higher reconstruction error.

Now, we just need to find the right cutoff value of the MSA (k), which will decide the classification.

# LEts define a function for it
find_optimal_k <- function(error, class, k_lower = 0, k_upper = 10, steps = 1){

  eval <- tibble()
  
for(k in seq(from = k_lower, to = k_upper, by = steps)) {
  
  pred_k <- tibble(
    truth = class %>% as.logical() %>% factor(levels = c('FALSE', 'TRUE')),
    pred = (error > k) %>% factor(levels = c('FALSE', 'TRUE'))
  )
  
  res_k <- tibble(
    k = k,
    accuracy = pred_k %>% yardstick::accuracy(truth = truth, estimate = pred) %>% pull(-1),
    precision = pred_k %>% yardstick::precision(truth = truth, estimate = pred) %>% pull(-1),
    recall = pred_k %>% yardstick::recall(truth = truth, estimate = pred) %>% pull(-1),
    specificity = pred_k %>% yardstick::specificity(truth = truth, estimate = pred) %>% pull(-1),
    f1 = pred_k %>% yardstick::f_meas(truth = truth, estimate = pred) %>% pull(-1)   
  )
  
  eval <- eval %>% 
    bind_rows(res_k)
  }
  
  return(eval)
    
}
eval <- find_optimal_k(class = y_train, error = mse_train, k_lower = 0, k_upper = 75, steps = 1)
Warning: While computing binary `precision()`, no predicted events were detected (i.e. `true_positive + false_positive = 0`). 
Precision is undefined in this case, and `NA` will be returned.
Note that 213232 true event(s) actually occured for the problematic event level, 'FALSE'.
Warning: While computing binary `precision()`, no predicted events were detected (i.e. `true_positive + false_positive = 0`). 
Precision is undefined in this case, and `NA` will be returned.
Note that 213232 true event(s) actually occured for the problematic event level, 'FALSE'.
eval %>%
  pivot_longer(cols = -k, names_to = 'metric', values_to = 'value') %>%
  ggplot(aes(x = k, y = value, col = metric)) +
  geom_line()
Warning: Removed 2 row(s) containing missing values (geom_path).

Now test our selected k finally on the test set

k = 17
class_collected <- tibble(truth = y_test %>% as.logical() %>% as.factor(),
                          pred = as.factor(mse_test > k))
cm <- class_collected %>% conf_mat(truth =  truth, estimate = pred )
cm %>% autoplot(type = 'heatmap')

cm %>% summary()
---
title:  'M3 workshop: Autoencoder (R)'
author: "Daniel S. Hain (dsh@business.aau.dk)"
date: "Updated `r format(Sys.time(), '%B %d, %Y')`"
output:
  html_notebook:
    code_folding: show
    df_print: paged
    toc: true
    toc_depth: 2
    toc_float:
      collapsed: false
    theme: flatly
---

```{r setup, include=FALSE}
### Generic preamble
rm(list=ls())
Sys.setenv(LANG = "en") # For english language
options(scipen = 5) # To deactivate annoying scientific number notation
```

```{r}
library(tidyverse)
library(magrittr)
library(keras)
library(tidymodels) # Or only load the 'rsample' and recipes on its own
library(skimr)
```


```{r}
# I load the classical creditcard fraud example
data <- read_csv('creditcard.csv')
```

```{r}
data %>% glimpse()
```

```{r}
# credit card fraud
data %>% count(Class)
```

Take away: Credit card fraud is extremely rare

# EDA

```{r}
data %>% skim()
```


Otherwise all variables seem very well behaved.

```{r, fig.width=12.5, fig.height=9}
library(ggridges)
data %>%
  pivot_longer(cols = !Class, names_to = 'feature', values_to = 'value') %>%
  ggplot(aes(y = feature  %>% as.factor(), 
             fill = Class %>% as.factor(), 
             x = value %>% percent_rank())) +
  geom_density_ridges() + 
  labs(title = 'Density Distribution by Outcome', x = 'Percntage Rank', y = 'Feature') +
  theme(legend.position = 'bottom')
```

Some differences in feature distributions indeed indicate possible prediction of outcome

# Preprocessing

```{r}
# We don't use the amount and time as info
data %<>% rename(y = Class) %>%
  select(-Time, -Amount) %>%
  drop_na()
```

```{r}
data_split <- initial_split(data, prop = 0.75, strata = y)

data_train <- data_split  %>%  training()
data_test <- data_split %>% testing()
```

```{r}
data_recipe <- data_train %>%
  recipe(y ~.) %>%
  step_normalize(all_numeric(), -all_outcomes()) %>%
  prep()
```

```{r}
x_train <- juice(data_recipe) %>% select(-y) %>% as.matrix()
y_train <- juice(data_recipe) %>% pull(y)   

x_test <- bake(data_recipe, new_data = data_test) %>% select(-y) %>% as.matrix()
y_test <- bake(data_recipe, new_data = data_test) %>% pull(y) 
```

# Model

```{r}
input_layer <- 
  layer_input(shape = ncol(x_train)) 

encoder <- 
  input_layer %>% 
  layer_dense(units = 64, activation = "relu") %>% 
  layer_batch_normalization() %>% 
  layer_dropout(rate = 0.1) %>% 
  layer_dense(units = 32, activation = "relu") %>%
  layer_dropout(rate = 0.1) %>%
  layer_dense(units = 16, activation = "relu") %>%
  layer_dense(units = 8) # n dimensions for the output layer
```

```{r}
decoder <- 
  encoder %>% 
  layer_dense(units = 16, activation = "relu") %>% 
  layer_dropout(rate = 0.1) %>% 
  layer_dense(units = 32, activation = "relu") %>%
  layer_dropout(rate = 0.1) %>%
  layer_dense(units = 64, activation = "relu") %>%
  layer_dropout(rate = 0.1) %>%
  layer_dense(units = ncol(x_train)) # n dimensions for the original n variables
```


```{r}
model <- keras_model(inputs = input_layer, outputs = decoder)
```

```{r}
model %>% compile(
  loss = 'mse',
  optimizer= optimizer_adam(),
  metrics = c('mse')
)
```

```{r}
model %>% summary()
```

# Run the model

```{r}
# for model monitoring
tensorboard("logs/run_a")
```

```{r}
history <-model %>%
  keras::fit(x_train,
             x_train,
             epochs = 100,
             verbose = FALSE,
             validation_data = list(x_test, x_test),
             callbacks = list(callback_tensorboard("logs/run_a"),
                              callback_early_stopping(patience = 5))
             )
```

```{r}
history %>% plot()
```

# Evaluate Reconstruction

```{r}
# We only evaluate on 'normal' cases, since this is what we want to reconstruct best
loss <- model %>% evaluate(x = x_test[y_test == 0,], y = x_test[y_test == 0,])
```

```{r}
loss
```


```{r}
# Generate the reconstruction error (MSE)
pred_train <- predict(model, x_train)
mse_train <- apply((x_train - pred_train)^2, 1, sum)

pred_test <- predict(model, x_test)
mse_test <- apply((x_test - pred_test)^2, 1, sum)
```


```{r}
pred_eval_train <- tibble(
  mse = mse_train,
  class = y_train %>% factor()
)
```

```{r}
pred_eval_train %>%
  ggplot(aes(x = class, y = mse, fill = class)) +
  geom_boxplot() + 
  ylim(0, 250) + 
  labs(title = 'Boxplot: Reconstruction error by class', subtitle = 'Training Data') +
  theme(legend.position = 'bottom')
```

```{r}
pred_eval_train  %>%
  ggplot(aes(x = mse, fill = class, alpha = 0.5)) +
  geom_density() + 
  xlim(0, 100) + 
  labs(title = 'Density Distribution: Reconstruction error by class', subtitle = 'Training Data') +
  theme(legend.position = 'bottom')
```

Looks  good news. The rare event indeed has a higher reconstruction error.

Now, we just need to find the right cutoff value of the MSA (k), which will decide the classification.

```{r}
# LEts define a function for it
find_optimal_k <- function(error, class, k_lower = 0, k_upper = 10, steps = 1){

  eval <- tibble()
  
for(k in seq(from = k_lower, to = k_upper, by = steps)) {
  
  pred_k <- tibble(
    truth = class %>% as.logical() %>% factor(levels = c('FALSE', 'TRUE')),
    pred = (error > k) %>% factor(levels = c('FALSE', 'TRUE'))
  )
  
  res_k <- tibble(
    k = k,
    accuracy = pred_k %>% yardstick::accuracy(truth = truth, estimate = pred) %>% pull(-1),
    precision = pred_k %>% yardstick::precision(truth = truth, estimate = pred) %>% pull(-1),
    recall = pred_k %>% yardstick::recall(truth = truth, estimate = pred) %>% pull(-1),
    specificity = pred_k %>% yardstick::specificity(truth = truth, estimate = pred) %>% pull(-1),
    f1 = pred_k %>% yardstick::f_meas(truth = truth, estimate = pred) %>% pull(-1)   
  )
  
  eval <- eval %>% 
    bind_rows(res_k)
  }
  
  return(eval)
    
}
```


```{r}
eval <- find_optimal_k(class = y_train, error = mse_train, k_lower = 0, k_upper = 75, steps = 1)
```


```{r}
eval %>%
  pivot_longer(cols = -k, names_to = 'metric', values_to = 'value') %>%
  ggplot(aes(x = k, y = value, col = metric)) +
  geom_line()
```

Now test our selected k finally on the test set

```{r}
k = 17
class_collected <- tibble(truth = y_test %>% as.logical() %>% as.factor(),
                          pred = as.factor(mse_test > k))
```

```{r}
cm <- class_collected %>% conf_mat(truth =  truth, estimate = pred )
```

```{r}
cm %>% autoplot(type = 'heatmap')
```

```{r}
cm %>% summary()
```




