library(tidyverse)
library(magrittr)
library(keras)
library(tidymodels) # Or only load the 'rsample' and recipes on its own
# I load the classical creditcard fraud example
data <- read_csv('creditcard.csv')
Rows: 284807 Columns: 31
── Column specification ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
dbl (31): Time, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, V21, V22, V23, V24, V25, V26, ...
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
data %>% glimpse()
Rows: 284,807
Columns: 31
$ Time <dbl> 0, 0, 1, 1, 2, 2, 4, 7, 7, 9, 10, 10, 10, 11, 12, 12, 12, 13, 14, 15, 16, 17, 18, 18, 22, 22, 23, 23, 23, 23, 24, 25, 26, 26…
$ V1 <dbl> -1.3598071, 1.1918571, -1.3583541, -0.9662717, -1.1582331, -0.4259659, 1.2296576, -0.6442694, -0.8942861, -0.3382618, 1.4490…
$ V2 <dbl> -0.07278117, 0.26615071, -1.34016307, -0.18522601, 0.87773675, 0.96052304, 0.14100351, 1.41796355, 0.28615720, 1.11959338, -…
$ V3 <dbl> 2.53634674, 0.16648011, 1.77320934, 1.79299334, 1.54871785, 1.14110934, 0.04537077, 1.07438038, -0.11319221, 1.04436655, 0.9…
$ V4 <dbl> 1.37815522, 0.44815408, 0.37977959, -0.86329128, 0.40303393, -0.16825208, 1.20261274, -0.49219902, -0.27152613, -0.22218728,…
$ V5 <dbl> -0.338320770, 0.060017649, -0.503198133, -0.010308880, -0.407193377, 0.420986881, 0.191880989, 0.948934095, 2.669598660, 0.4…
$ V6 <dbl> 0.46238778, -0.08236081, 1.80049938, 1.24720317, 0.09592146, -0.02972755, 0.27270812, 0.42811846, 3.72181806, -0.24676110, -…
$ V7 <dbl> 0.239598554, -0.078802983, 0.791460956, 0.237608940, 0.592940745, 0.476200949, -0.005159003, 1.120631358, 0.370145128, 0.651…
$ V8 <dbl> 0.098697901, 0.085101655, 0.247675787, 0.377435875, -0.270532677, 0.260314333, 0.081212940, -3.807864239, 0.851084443, 0.069…
$ V9 <dbl> 0.36378697, -0.25542513, -1.51465432, -1.38702406, 0.81773931, -0.56867138, 0.46495999, 0.61537473, -0.39204759, -0.73672732…
$ V10 <dbl> 0.09079417, -0.16697441, 0.20764287, -0.05495192, 0.75307443, -0.37140720, -0.09925432, 1.24937618, -0.41043043, -0.36684564…
$ V11 <dbl> -0.551599533, 1.612726661, 0.624501459, -0.226487264, -0.822842878, 1.341261980, -1.416907243, -0.619467796, -0.705116587, 1…
$ V12 <dbl> -0.61780086, 1.06523531, 0.06608369, 0.17822823, 0.53819555, 0.35989384, -0.15382583, 0.29147435, -0.11045226, 0.83638957, -…
$ V13 <dbl> -0.99138985, 0.48909502, 0.71729273, 0.50775687, 1.34585159, -0.35809065, -0.75106272, 1.75796421, -0.28625363, 1.00684351, …
$ V14 <dbl> -0.31116935, -0.14377230, -0.16594592, -0.28792375, -1.11966983, -0.13713370, 0.16737196, -1.32386522, 0.07435536, -0.443522…
$ V15 <dbl> 1.468176972, 0.635558093, 2.345864949, -0.631418118, 0.175121130, 0.517616807, 0.050143594, 0.686132504, -0.328783050, 0.150…
$ V16 <dbl> -0.470400525, 0.463917041, -2.890083194, -1.059647245, -0.451449183, 0.401725896, -0.443586798, -0.076126999, -0.210077268, …
$ V17 <dbl> 0.207971242, -0.114804663, 1.109969379, -0.684092786, -0.237033239, -0.058132823, 0.002820512, -1.222127345, -0.499767969, -…
$ V18 <dbl> 0.02579058, -0.18336127, -0.12135931, 1.96577500, -0.03819479, 0.06865315, -0.61198734, -0.35822157, 0.11876486, 0.47667726,…
$ V19 <dbl> 0.40399296, -0.14578304, -2.26185710, -1.23262197, 0.80348692, -0.03319379, -0.04557504, 0.32450473, 0.57032817, 0.45177296,…
$ V20 <dbl> 0.251412098, -0.069083135, 0.524979725, -0.208037781, 0.408542360, 0.084967672, -0.219632553, -0.156741852, 0.052735669, 0.2…
$ V21 <dbl> -0.018306778, -0.225775248, 0.247998153, -0.108300452, -0.009430697, -0.208253515, -0.167716266, 1.943465340, -0.073425100, …
$ V22 <dbl> 0.277837576, -0.638671953, 0.771679402, 0.005273597, 0.798278495, -0.559824796, -0.270709726, -1.015454710, -0.268091632, -0…
$ V23 <dbl> -0.110473910, 0.101288021, 0.909412262, -0.190320519, -0.137458080, -0.026397668, -0.154103787, 0.057503530, -0.204232670, -…
$ V24 <dbl> 0.0669280749, -0.3398464755, -0.6892809565, -1.1755753319, 0.1412669838, -0.3714265832, -0.7800554150, -0.6497090056, 1.0115…
$ V25 <dbl> 0.1285393583, 0.1671704044, -0.3276418337, 0.6473760346, -0.2060095876, -0.2327938167, 0.7501369358, -0.4152665662, 0.373204…
$ V26 <dbl> -0.18911484, 0.12589453, -0.13909657, -0.22192884, 0.50229222, 0.10591478, -0.25723685, -0.05163430, -0.38415731, 0.09419883…
$ V27 <dbl> 0.133558377, -0.008983099, -0.055352794, 0.062722849, 0.219422230, 0.253844225, 0.034507430, -1.206921081, 0.011747356, 0.24…
$ V28 <dbl> -0.021053053, 0.014724169, -0.059751841, 0.061457629, 0.215153147, 0.081080257, 0.005167769, -1.085339188, 0.142404330, 0.08…
$ Amount <dbl> 149.62, 2.69, 378.66, 123.50, 69.99, 3.67, 4.99, 40.80, 93.20, 3.68, 7.80, 9.99, 121.50, 27.50, 58.80, 15.99, 12.99, 0.89, 4…
$ Class <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
# credit card fraud
data %>% count(Class)
Take away: Credit card fraud is extremely rare
data %>% skim()
── Data Summary ────────────────────────
Values
Name Piped data
Number of rows 284807
Number of columns 31
_______________________
Column type frequency:
numeric 31
________________________
Group variables None
── Variable type: numeric ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────
skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
1 Time 0 1 9.48e+ 4 47488. 0 54202. 84692 139320. 172792 ▃▇▅▆▇
2 V1 0 1 1.17e-15 1.96 -56.4 -0.920 0.0181 1.32 2.45 ▁▁▁▁▇
3 V2 0 1 3.12e-16 1.65 -72.7 -0.599 0.0655 0.804 22.1 ▁▁▁▇▁
4 V3 0 1 -1.36e-15 1.52 -48.3 -0.890 0.180 1.03 9.38 ▁▁▁▁▇
5 V4 0 1 2.11e-15 1.42 -5.68 -0.849 -0.0198 0.743 16.9 ▂▇▁▁▁
6 V5 0 1 9.80e-16 1.38 -114. -0.692 -0.0543 0.612 34.8 ▁▁▁▇▁
7 V6 0 1 1.51e-15 1.33 -26.2 -0.768 -0.274 0.399 73.3 ▁▇▁▁▁
8 V7 0 1 -5.42e-16 1.24 -43.6 -0.554 0.0401 0.570 121. ▁▇▁▁▁
9 V8 0 1 1.03e-16 1.19 -73.2 -0.209 0.0224 0.327 20.0 ▁▁▁▇▁
10 V9 0 1 -2.42e-15 1.10 -13.4 -0.643 -0.0514 0.597 15.6 ▁▁▇▁▁
11 V10 0 1 2.23e-15 1.09 -24.6 -0.535 -0.0929 0.454 23.7 ▁▁▇▁▁
12 V11 0 1 1.71e-15 1.02 -4.80 -0.762 -0.0328 0.740 12.0 ▁▇▁▁▁
13 V12 0 1 -1.24e-15 0.999 -18.7 -0.406 0.140 0.618 7.85 ▁▁▁▇▁
14 V13 0 1 8.35e-16 0.995 -5.79 -0.649 -0.0136 0.663 7.13 ▁▃▇▁▁
15 V14 0 1 1.23e-15 0.959 -19.2 -0.426 0.0506 0.493 10.5 ▁▁▁▇▁
16 V15 0 1 4.84e-15 0.915 -4.50 -0.583 0.0481 0.649 8.88 ▁▇▂▁▁
17 V16 0 1 1.43e-15 0.876 -14.1 -0.468 0.0664 0.523 17.3 ▁▁▇▁▁
18 V17 0 1 -3.78e-16 0.849 -25.2 -0.484 -0.0657 0.400 9.25 ▁▁▁▇▁
19 V18 0 1 9.76e-16 0.838 -9.50 -0.499 -0.00364 0.501 5.04 ▁▁▂▇▁
20 V19 0 1 1.04e-15 0.814 -7.21 -0.456 0.00373 0.459 5.59 ▁▁▇▂▁
21 V20 0 1 6.41e-16 0.771 -54.5 -0.212 -0.0625 0.133 39.4 ▁▁▇▁▁
22 V21 0 1 1.69e-16 0.735 -34.8 -0.228 -0.0295 0.186 27.2 ▁▁▇▁▁
23 V22 0 1 -3.38e-16 0.726 -10.9 -0.542 0.00678 0.529 10.5 ▁▁▇▁▁
24 V23 0 1 2.67e-16 0.624 -44.8 -0.162 -0.0112 0.148 22.5 ▁▁▁▇▁
25 V24 0 1 4.47e-15 0.606 -2.84 -0.355 0.0410 0.440 4.58 ▁▇▆▁▁
26 V25 0 1 5.11e-16 0.521 -10.3 -0.317 0.0166 0.351 7.52 ▁▁▇▂▁
27 V26 0 1 1.68e-15 0.482 -2.60 -0.327 -0.0521 0.241 3.52 ▁▆▇▁▁
28 V27 0 1 -3.67e-16 0.404 -22.6 -0.0708 0.00134 0.0910 31.6 ▁▁▇▁▁
29 V28 0 1 -1.23e-16 0.330 -15.4 -0.0530 0.0112 0.0783 33.8 ▁▇▁▁▁
30 Amount 0 1 8.83e+ 1 250. 0 5.6 22 77.2 25691. ▇▁▁▁▁
31 Class 0 1 1.73e- 3 0.0415 0 0 0 0 1 ▇▁▁▁▁
Otherwise all variables seem very well behaved.
library(ggridges)
data %>%
pivot_longer(cols = !Class, names_to = 'feature', values_to = 'value') %>%
ggplot(aes(y = feature %>% as.factor(),
fill = Class %>% as.factor(),
x = value %>% percent_rank())) +
geom_density_ridges() +
labs(title = 'Density Distribution by Outcome', x = 'Percntage Rank', y = 'Feature') +
theme(legend.position = 'bottom')
Picking joint bandwidth of 0.0309
Some differences in feature distributions indeed indicate possible prediction of outcome
# We don't use the amount and time as info
data %<>% rename(y = Class) %>%
select(-Time, -Amount) %>%
drop_na()
data_split <- initial_split(data, prop = 0.75, strata = y)
data_train <- data_split %>% training()
data_test <- data_split %>% testing()
data_recipe <- data_train %>%
recipe(y ~.) %>%
step_normalize(all_numeric(), -all_outcomes()) %>%
prep()
x_train <- juice(data_recipe) %>% select(-y) %>% as.matrix()
y_train <- juice(data_recipe) %>% pull(y)
x_test <- bake(data_recipe, new_data = data_test) %>% select(-y) %>% as.matrix()
y_test <- bake(data_recipe, new_data = data_test) %>% pull(y)
input_layer <-
layer_input(shape = ncol(x_train))
Loaded Tensorflow version 2.7.0
encoder <-
input_layer %>%
layer_dense(units = 64, activation = "relu") %>%
layer_batch_normalization() %>%
layer_dropout(rate = 0.1) %>%
layer_dense(units = 32, activation = "relu") %>%
layer_dropout(rate = 0.1) %>%
layer_dense(units = 16, activation = "relu") %>%
layer_dense(units = 8) # n dimensions for the output layer
2021-11-26 13:28:28.007256: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
decoder <-
encoder %>%
layer_dense(units = 16, activation = "relu") %>%
layer_dropout(rate = 0.1) %>%
layer_dense(units = 32, activation = "relu") %>%
layer_dropout(rate = 0.1) %>%
layer_dense(units = 64, activation = "relu") %>%
layer_dropout(rate = 0.1) %>%
layer_dense(units = ncol(x_train)) # n dimensions for the original n variables
model <- keras_model(inputs = input_layer, outputs = decoder)
model %>% compile(
loss = 'mse',
optimizer= optimizer_adam(),
metrics = c('mse')
)
model %>% summary()
Model: "model"
_____________________________________________________________________________________________________________________________________________
Layer (type) Output Shape Param #
=============================================================================================================================================
input_1 (InputLayer) [(None, 28)] 0
dense_3 (Dense) (None, 64) 1856
batch_normalization (BatchNormalization) (None, 64) 256
dropout_1 (Dropout) (None, 64) 0
dense_2 (Dense) (None, 32) 2080
dropout (Dropout) (None, 32) 0
dense_1 (Dense) (None, 16) 528
dense (Dense) (None, 8) 136
dense_11 (Dense) (None, 16) 144
dropout_7 (Dropout) (None, 16) 0
dense_10 (Dense) (None, 32) 544
dropout_6 (Dropout) (None, 32) 0
dense_9 (Dense) (None, 64) 2112
dropout_5 (Dropout) (None, 64) 0
dense_8 (Dense) (None, 28) 1820
=============================================================================================================================================
Total params: 9,476
Trainable params: 9,348
Non-trainable params: 128
_____________________________________________________________________________________________________________________________________________
tensorboard("logs/run_a")
NOTE: Using experimental fast data loading logic. To disable, pass
"--load_fast=false" and report issues on GitHub. More details:
https://github.com/tensorflow/tensorboard/issues/4784
TensorBoard 2.7.0 at http://127.0.0.1:4273/ (Press CTRL+C to quit)
Started TensorBoard at http://127.0.0.1:4273
history <-model %>%
keras::fit(x_train,
x_train,
epochs = 100,
verbose = FALSE,
validation_data = list(x_test, x_test),
callbacks = list(callback_tensorboard("logs/run_a"),
callback_early_stopping(patience = 5))
)
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0013s vs `on_train_batch_end` time: 0.0032s). Check your callbacks.
history %>% plot()
`geom_smooth()` using formula 'y ~ x'
# We only evaluate on 'normal' cases, since this is what we want to reconstruct best
loss <- model %>% evaluate(x = x_test[y_test == 0,], y = x_test[y_test == 0,])
1/2222 [..............................] - ETA: 42s - loss: 0.2940 - mse: 0.2940
57/2222 [..............................] - ETA: 1s - loss: 0.3410 - mse: 0.3410
108/2222 [>.............................] - ETA: 2s - loss: 0.3724 - mse: 0.3724
176/2222 [=>............................] - ETA: 1s - loss: 0.3383 - mse: 0.3383
227/2222 [==>...........................] - ETA: 1s - loss: 0.3213 - mse: 0.3213
297/2222 [===>..........................] - ETA: 1s - loss: 0.3139 - mse: 0.3139
363/2222 [===>..........................] - ETA: 1s - loss: 0.3276 - mse: 0.3276
434/2222 [====>.........................] - ETA: 1s - loss: 0.3229 - mse: 0.3229
507/2222 [=====>........................] - ETA: 1s - loss: 0.3194 - mse: 0.3194
587/2222 [======>.......................] - ETA: 1s - loss: 0.3195 - mse: 0.3195
663/2222 [=======>......................] - ETA: 1s - loss: 0.3189 - mse: 0.3189
739/2222 [========>.....................] - ETA: 1s - loss: 0.3152 - mse: 0.3152
820/2222 [==========>...................] - ETA: 1s - loss: 0.3103 - mse: 0.3103
888/2222 [==========>...................] - ETA: 0s - loss: 0.3066 - mse: 0.3066
955/2222 [===========>..................] - ETA: 0s - loss: 0.3043 - mse: 0.3043
1024/2222 [============>.................] - ETA: 0s - loss: 0.3017 - mse: 0.3017
1092/2222 [=============>................] - ETA: 0s - loss: 0.3004 - mse: 0.3004
1152/2222 [==============>...............] - ETA: 0s - loss: 0.3007 - mse: 0.3007
1194/2222 [===============>..............] - ETA: 0s - loss: 0.3045 - mse: 0.3045
1254/2222 [===============>..............] - ETA: 0s - loss: 0.3087 - mse: 0.3087
1316/2222 [================>.............] - ETA: 0s - loss: 0.3082 - mse: 0.3082
1380/2222 [=================>............] - ETA: 0s - loss: 0.3090 - mse: 0.3090
1451/2222 [==================>...........] - ETA: 0s - loss: 0.3095 - mse: 0.3095
1523/2222 [===================>..........] - ETA: 0s - loss: 0.3112 - mse: 0.3112
1604/2222 [====================>.........] - ETA: 0s - loss: 0.3123 - mse: 0.3123
1678/2222 [=====================>........] - ETA: 0s - loss: 0.3135 - mse: 0.3135
1738/2222 [======================>.......] - ETA: 0s - loss: 0.3188 - mse: 0.3188
1808/2222 [=======================>......] - ETA: 0s - loss: 0.3200 - mse: 0.3200
1873/2222 [========================>.....] - ETA: 0s - loss: 0.3206 - mse: 0.3206
1932/2222 [=========================>....] - ETA: 0s - loss: 0.3209 - mse: 0.3209
1983/2222 [=========================>....] - ETA: 0s - loss: 0.3197 - mse: 0.3197
2057/2222 [==========================>...] - ETA: 0s - loss: 0.3203 - mse: 0.3203
2134/2222 [===========================>..] - ETA: 0s - loss: 0.3195 - mse: 0.3195
2214/2222 [============================>.] - ETA: 0s - loss: 0.3214 - mse: 0.3214
2222/2222 [==============================] - 2s 751us/step - loss: 0.3212 - mse: 0.3212
2222/2222 [==============================] - 2s 751us/step - loss: 0.3212 - mse: 0.3212
loss
loss mse
0.3211641 0.3211641
# Generate the reconstruction error (MSE)
pred_train <- predict(model, x_train)
mse_train <- apply((x_train - pred_train)^2, 1, sum)
pred_test <- predict(model, x_test)
mse_test <- apply((x_test - pred_test)^2, 1, sum)
pred_eval_train <- tibble(
mse = mse_train,
class = y_train %>% factor()
)
pred_eval_train %>%
ggplot(aes(x = class, y = mse, fill = class)) +
geom_boxplot() +
ylim(0, 250) +
labs(title = 'Boxplot: Reconstruction error by class', subtitle = 'Training Data') +
theme(legend.position = 'bottom')
Warning: Removed 171 rows containing non-finite values (stat_boxplot).
pred_eval_train %>%
ggplot(aes(x = mse, fill = class, alpha = 0.5)) +
geom_density() +
xlim(0, 100) +
labs(title = 'Density Distribution: Reconstruction error by class', subtitle = 'Training Data') +
theme(legend.position = 'bottom')
Warning: Removed 576 rows containing non-finite values (stat_density).
Looks good news. The rare event indeed has a higher reconstruction error.
Now, we just need to find the right cutoff value of the MSA (k), which will decide the classification.
# LEts define a function for it
find_optimal_k <- function(error, class, k_lower = 0, k_upper = 10, steps = 1){
eval <- tibble()
for(k in seq(from = k_lower, to = k_upper, by = steps)) {
pred_k <- tibble(
truth = class %>% as.logical() %>% factor(levels = c('FALSE', 'TRUE')),
pred = (error > k) %>% factor(levels = c('FALSE', 'TRUE'))
)
res_k <- tibble(
k = k,
accuracy = pred_k %>% yardstick::accuracy(truth = truth, estimate = pred) %>% pull(-1),
precision = pred_k %>% yardstick::precision(truth = truth, estimate = pred) %>% pull(-1),
recall = pred_k %>% yardstick::recall(truth = truth, estimate = pred) %>% pull(-1),
specificity = pred_k %>% yardstick::specificity(truth = truth, estimate = pred) %>% pull(-1),
f1 = pred_k %>% yardstick::f_meas(truth = truth, estimate = pred) %>% pull(-1)
)
eval <- eval %>%
bind_rows(res_k)
}
return(eval)
}
eval <- find_optimal_k(class = y_train, error = mse_train, k_lower = 0, k_upper = 75, steps = 1)
Warning: While computing binary `precision()`, no predicted events were detected (i.e. `true_positive + false_positive = 0`).
Precision is undefined in this case, and `NA` will be returned.
Note that 213232 true event(s) actually occured for the problematic event level, 'FALSE'.
Warning: While computing binary `precision()`, no predicted events were detected (i.e. `true_positive + false_positive = 0`).
Precision is undefined in this case, and `NA` will be returned.
Note that 213232 true event(s) actually occured for the problematic event level, 'FALSE'.
eval %>%
pivot_longer(cols = -k, names_to = 'metric', values_to = 'value') %>%
ggplot(aes(x = k, y = value, col = metric)) +
geom_line()
Warning: Removed 2 row(s) containing missing values (geom_path).
Now test our selected k finally on the test set
k = 17
class_collected <- tibble(truth = y_test %>% as.logical() %>% as.factor(),
pred = as.factor(mse_test > k))
cm <- class_collected %>% conf_mat(truth = truth, estimate = pred )
cm %>% autoplot(type = 'heatmap')
cm %>% summary()