Preamble

# Clear workspace
rm(list=ls()); graphics.off() 
### Load standardpackages
library(tidyverse) # Collection of all the good stuff like dplyr, ggplot2 ect.
library(magrittr) # For extra-piping operators (eg. %<>%)
library(skimr) # For nice data summaries

# Load our main predictive tool
library(tidymodels)
library(keras)
# Test tensorflow
tensorflow::tf$constant(c("Hellow Tensorflow", "Hello Keras"))
2021-11-09 10:21:20.585438: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-11-09 10:21:20.587140: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
tf.Tensor([b'Hellow Tensorflow' b'Hello Keras'], shape=(2,), dtype=string)

Part 1: Simple neural network

Load the data

listings <- read_csv('http://data.insideairbnb.com/denmark/hovedstaden/copenhagen/2021-07-20/data/listings.csv.gz')
Rows: 9949 Columns: 74
── Column specification ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
chr  (23): listing_url, name, description, neighborhood_overview, picture_url, host_url, host_name, host_location, host_about, host_response_ti...
dbl  (37): id, scrape_id, host_id, host_listings_count, host_total_listings_count, latitude, longitude, accommodates, bedrooms, beds, minimum_n...
lgl   (9): host_is_superhost, host_has_profile_pic, host_identity_verified, neighbourhood_group_cleansed, bathrooms, calendar_updated, has_avai...
date  (5): last_scraped, host_since, calendar_last_scraped, first_review, last_review

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Preprocessing

data <- listings %>%
  rename(y = price) %>%
  relocate(y) 
data %<>% 
  mutate(across(is_character, ~ifelse(.x == "", NA, .x))) 
select_var <- c('y', 
                'room_type',
                'neighbourhood_cleansed',
                'accommodates',
                'bedrooms',
                'review_scores_rating',
                'number_of_reviews',
                'host_is_superhost',
                'host_identity_verified') 

select_cat_room <- c('Entire home/apt', 'Private room')

Split the data

data_split <- initial_split(data, prop = 0.80, strata = y)

data_train <- data_split  %>%  training()
data_test <- data_split %>% testing()

Define Recipe

data_recipe <- data_train %>%
  recipe(y ~.) %>%
  # Case specific variable selection
  step_select(all_of(select_var)) %>%   
  # Case specific mutation
  step_mutate_at(where(is_logical), fn = as.numeric) %>%              # Logical columns sometimes make problems with KNN inpute
  step_mutate(y = y %>% as.character() %>% parse_number()) %>%
  # Filtering
  step_filter(percent_rank(y) <= 0.95, skip = TRUE) %>% # filter y outliers
  step_naomit(y) %>%
  step_filter(room_type %in% select_cat_room) %>%                     # filter room type
  step_filter(number_of_reviews >= 1) %>%                             # filtern number of reviews
  # Standard preprocessing
  # step_impute_knn(all_predictors()) %>%                             # Inpute missing variables
  step_naomit(everything()) %>%                                       # When above doesnt work
  step_normalize(all_numeric_predictors()) %>%                        # Centers all numeric variables to mean = 0 & sd = 1
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%            # Dummies of all categorical
  prep()                                                              # Prep only if no workflows are needed later

Prepare input data

x_train <- juice(data_recipe) %>% select(-y) %>% as.matrix()
x_test <- bake(data_recipe, new_data = data_test) %>% select(-y) %>% as.matrix()
y_train <- juice(data_recipe) %>% pull(y)   
y_test <- bake(data_recipe, new_data = data_test) %>% pull(y) 

Neural Network

Define the network

model_base
Model
Model: "sequential_8"
___________________________________________________________________________________________________________________________________________________
Layer (type)                                                      Output Shape                                              Param #                
===================================================================================================================================================
dense_21 (Dense)                                                  (None, 64)                                                1408                   
___________________________________________________________________________________________________________________________________________________
dense_20 (Dense)                                                  (None, 64)                                                4160                   
___________________________________________________________________________________________________________________________________________________
dense_19 (Dense)                                                  (None, 1)                                                 65                     
===================================================================================================================================================
Total params: 5,633
Trainable params: 5,633
Non-trainable params: 0
___________________________________________________________________________________________________________________________________________________
model_base %>% 
  compile(
    optimizer = optimizer_adam(),
    loss = "mse",
    metrics = list("mean_squared_error")
  )

Run

set.seed(1337)
hist_base <- model_base  %>% fit(x = x_train, 
                                         y = y_train, 
                                         epochs = 20, # How often shall we re-run the model on the whole sample
                                         batch_size = 128, # How many observations should be included in every batch
                                         validation_split = 0.25, # If we want to do a  cross-validation in the training
                                         # callbacks = callback_tensorboard("logs/run_a")
                                         )
Epoch 1/20

 1/35 [..............................] - ETA: 0s - loss: 78690.2969 - mean_squared_error: 78690.2969
26/35 [=====================>........] - ETA: 0s - loss: 85269.6172 - mean_squared_error: 85269.6172
35/35 [==============================] - 0s 2ms/step - loss: 86694.8750 - mean_squared_error: 86694.8750

35/35 [==============================] - 0s 7ms/step - loss: 86694.8750 - mean_squared_error: 86694.8750 - val_loss: 80476.4375 - val_mean_squared_error: 80476.4297
Epoch 2/20

 1/35 [..............................] - ETA: 0s - loss: 95371.9609 - mean_squared_error: 95371.9609
35/35 [==============================] - 0s 1ms/step - loss: 86196.8438 - mean_squared_error: 86196.8438

35/35 [==============================] - 0s 2ms/step - loss: 86196.8438 - mean_squared_error: 86196.8438 - val_loss: 80484.8594 - val_mean_squared_error: 80484.8594
Epoch 3/20

 1/35 [..............................] - ETA: 0s - loss: 90720.6172 - mean_squared_error: 90720.6172
35/35 [==============================] - 0s 1ms/step - loss: 85844.5391 - mean_squared_error: 85844.5391

35/35 [==============================] - 0s 2ms/step - loss: 85844.5391 - mean_squared_error: 85844.5391 - val_loss: 80130.9062 - val_mean_squared_error: 80130.9062
Epoch 4/20

 1/35 [..............................] - ETA: 0s - loss: 104652.4844 - mean_squared_error: 104652.4844
27/35 [======================>.......] - ETA: 0s - loss: 84807.9531 - mean_squared_error: 84807.9531  
35/35 [==============================] - 0s 2ms/step - loss: 85495.7344 - mean_squared_error: 85495.7344

35/35 [==============================] - 0s 3ms/step - loss: 85495.7344 - mean_squared_error: 85495.7344 - val_loss: 79894.5312 - val_mean_squared_error: 79894.5312
Epoch 5/20

 1/35 [..............................] - ETA: 0s - loss: 80899.4922 - mean_squared_error: 80899.4922
33/35 [===========================>..] - ETA: 0s - loss: 85652.1406 - mean_squared_error: 85652.1406
35/35 [==============================] - 0s 2ms/step - loss: 85215.5938 - mean_squared_error: 85215.5938

35/35 [==============================] - 0s 3ms/step - loss: 85215.5938 - mean_squared_error: 85215.5938 - val_loss: 79563.3516 - val_mean_squared_error: 79563.3516
Epoch 6/20

 1/35 [..............................] - ETA: 0s - loss: 99321.5312 - mean_squared_error: 99321.5312
35/35 [==============================] - 0s 1ms/step - loss: 85013.0078 - mean_squared_error: 85013.0078

35/35 [==============================] - 0s 4ms/step - loss: 85013.0078 - mean_squared_error: 85013.0078 - val_loss: 79480.3672 - val_mean_squared_error: 79480.3672
Epoch 7/20

 1/35 [..............................] - ETA: 0s - loss: 56176.6562 - mean_squared_error: 56176.6562
32/35 [==========================>...] - ETA: 0s - loss: 85122.3281 - mean_squared_error: 85122.3281
35/35 [==============================] - 0s 2ms/step - loss: 84847.5234 - mean_squared_error: 84847.5234

35/35 [==============================] - 0s 3ms/step - loss: 84847.5234 - mean_squared_error: 84847.5234 - val_loss: 79348.1484 - val_mean_squared_error: 79348.1484
Epoch 8/20

 1/35 [..............................] - ETA: 0s - loss: 58779.6719 - mean_squared_error: 58779.6719
31/35 [=========================>....] - ETA: 0s - loss: 84233.9297 - mean_squared_error: 84233.9297
35/35 [==============================] - 0s 2ms/step - loss: 84698.2266 - mean_squared_error: 84698.2266

35/35 [==============================] - 0s 3ms/step - loss: 84698.2266 - mean_squared_error: 84698.2266 - val_loss: 79121.8047 - val_mean_squared_error: 79121.8047
Epoch 9/20

 1/35 [..............................] - ETA: 0s - loss: 67994.3281 - mean_squared_error: 67994.3281
35/35 [==============================] - 0s 1ms/step - loss: 84587.1328 - mean_squared_error: 84587.1328

35/35 [==============================] - 0s 2ms/step - loss: 84587.1328 - mean_squared_error: 84587.1328 - val_loss: 79213.5391 - val_mean_squared_error: 79213.5391
Epoch 10/20

 1/35 [..............................] - ETA: 0s - loss: 63739.6641 - mean_squared_error: 63739.6641
31/35 [=========================>....] - ETA: 0s - loss: 84693.8672 - mean_squared_error: 84693.8672
35/35 [==============================] - 0s 2ms/step - loss: 84506.0938 - mean_squared_error: 84506.0938

35/35 [==============================] - 0s 3ms/step - loss: 84506.0938 - mean_squared_error: 84506.0938 - val_loss: 78935.0156 - val_mean_squared_error: 78935.0156
Epoch 11/20

 1/35 [..............................] - ETA: 0s - loss: 82951.9375 - mean_squared_error: 82951.9375
35/35 [==============================] - 0s 2ms/step - loss: 84398.8438 - mean_squared_error: 84398.8438

35/35 [==============================] - 0s 3ms/step - loss: 84398.8438 - mean_squared_error: 84398.8438 - val_loss: 78729.4453 - val_mean_squared_error: 78729.4453
Epoch 12/20

 1/35 [..............................] - ETA: 0s - loss: 104418.5156 - mean_squared_error: 104418.5156
26/35 [=====================>........] - ETA: 0s - loss: 85409.9453 - mean_squared_error: 85409.9453  
35/35 [==============================] - 0s 2ms/step - loss: 84303.7891 - mean_squared_error: 84303.7891

35/35 [==============================] - 0s 3ms/step - loss: 84303.7891 - mean_squared_error: 84303.7891 - val_loss: 78799.2578 - val_mean_squared_error: 78799.2578
Epoch 13/20

 1/35 [..............................] - ETA: 0s - loss: 86560.5938 - mean_squared_error: 86560.5938
35/35 [==============================] - 0s 2ms/step - loss: 84233.9688 - mean_squared_error: 84233.9688

35/35 [==============================] - 0s 3ms/step - loss: 84233.9688 - mean_squared_error: 84233.9688 - val_loss: 78412.1562 - val_mean_squared_error: 78412.1562
Epoch 14/20

 1/35 [..............................] - ETA: 0s - loss: 106577.2031 - mean_squared_error: 106577.2031
26/35 [=====================>........] - ETA: 0s - loss: 82931.8750 - mean_squared_error: 82931.8750  
35/35 [==============================] - 0s 2ms/step - loss: 84193.9375 - mean_squared_error: 84193.9375

35/35 [==============================] - 0s 3ms/step - loss: 84193.9375 - mean_squared_error: 84193.9375 - val_loss: 78637.9219 - val_mean_squared_error: 78637.9219
Epoch 15/20

 1/35 [..............................] - ETA: 0s - loss: 81914.0625 - mean_squared_error: 81914.0625
29/35 [=======================>......] - ETA: 0s - loss: 86892.6797 - mean_squared_error: 86892.6797
35/35 [==============================] - 0s 2ms/step - loss: 84125.5859 - mean_squared_error: 84125.5859

35/35 [==============================] - 0s 3ms/step - loss: 84125.5859 - mean_squared_error: 84125.5859 - val_loss: 78433.1719 - val_mean_squared_error: 78433.1719
Epoch 16/20

 1/35 [..............................] - ETA: 0s - loss: 68793.9219 - mean_squared_error: 68793.9219
35/35 [==============================] - 0s 1ms/step - loss: 84082.6719 - mean_squared_error: 84082.6719

35/35 [==============================] - 0s 3ms/step - loss: 84082.6719 - mean_squared_error: 84082.6719 - val_loss: 78446.3125 - val_mean_squared_error: 78446.3125
Epoch 17/20

 1/35 [..............................] - ETA: 0s - loss: 85168.6641 - mean_squared_error: 85168.6641
27/35 [======================>.......] - ETA: 0s - loss: 85189.4141 - mean_squared_error: 85189.4141
35/35 [==============================] - 0s 2ms/step - loss: 84059.6484 - mean_squared_error: 84059.6484

35/35 [==============================] - 0s 3ms/step - loss: 84059.6484 - mean_squared_error: 84059.6484 - val_loss: 78875.4375 - val_mean_squared_error: 78875.4375
Epoch 18/20

 1/35 [..............................] - ETA: 0s - loss: 90175.8359 - mean_squared_error: 90175.8359
35/35 [==============================] - 0s 1ms/step - loss: 84037.7422 - mean_squared_error: 84037.7422

35/35 [==============================] - 0s 3ms/step - loss: 84037.7422 - mean_squared_error: 84037.7422 - val_loss: 78126.9531 - val_mean_squared_error: 78126.9531
Epoch 19/20

 1/35 [..............................] - ETA: 0s - loss: 97419.5938 - mean_squared_error: 97419.5938
27/35 [======================>.......] - ETA: 0s - loss: 83898.7891 - mean_squared_error: 83898.7891
35/35 [==============================] - 0s 3ms/step - loss: 84004.3672 - mean_squared_error: 84004.3672

35/35 [==============================] - 0s 4ms/step - loss: 84004.3672 - mean_squared_error: 84004.3672 - val_loss: 78183.5469 - val_mean_squared_error: 78183.5469
Epoch 20/20

 1/35 [..............................] - ETA: 0s - loss: 70705.5469 - mean_squared_error: 70705.5469
35/35 [==============================] - 0s 1ms/step - loss: 83978.6875 - mean_squared_error: 83978.6875

35/35 [==============================] - 0s 3ms/step - loss: 83978.6875 - mean_squared_error: 83978.6875 - val_loss: 78434.6328 - val_mean_squared_error: 78434.6328

Evaluate

hist_base %>% plot(smooth = TRUE)
`geom_smooth()` using formula 'y ~ x'

model_base %>% evaluate(x_test, y_test) %>% sqrt()

 1/50 [..............................] - ETA: 0s - loss: 416146.2500 - mean_squared_error: 416146.2500
31/50 [=================>............] - ETA: 0s - loss: 422532.6562 - mean_squared_error: 422532.6562
50/50 [==============================] - 0s 1ms/step - loss: 318837.2188 - mean_squared_error: 318837.2188

50/50 [==============================] - 0s 1ms/step - loss: 318837.2188 - mean_squared_error: 318837.2188
              loss mean_squared_error 
          564.6567           564.6567 

Part 2: Play around

model_small <- keras_model_sequential() %>% 
  # architecture
  layer_dense(units = 21, activation = "relu", input_shape = ncol(x_train)) %>% 
  layer_dense(units = 1, activation = "linear") %>% 
  # compile
  compile(optimizer = optimizer_adam(), loss = "mse", metrics = list("mean_squared_error")
  )
model_big
Model
Model: "sequential_11"
___________________________________________________________________________________________________________________________________________________
Layer (type)                                                      Output Shape                                              Param #                
===================================================================================================================================================
dense_29 (Dense)                                                  (None, 512)                                               11264                  
___________________________________________________________________________________________________________________________________________________
dense_28 (Dense)                                                  (None, 512)                                               262656                 
___________________________________________________________________________________________________________________________________________________
dense_27 (Dense)                                                  (None, 512)                                               262656                 
___________________________________________________________________________________________________________________________________________________
dense_26 (Dense)                                                  (None, 512)                                               262656                 
___________________________________________________________________________________________________________________________________________________
dense_25 (Dense)                                                  (None, 512)                                               262656                 
___________________________________________________________________________________________________________________________________________________
dense_24 (Dense)                                                  (None, 1)                                                 513                    
===================================================================================================================================================
Total params: 1,062,401
Trainable params: 1,062,401
Non-trainable params: 0
___________________________________________________________________________________________________________________________________________________
set.seed(1337)
hist_small <- model_small %>% fit(x = x_train, y = y_train, validation_split = 0.25, 
                                        epochs = 20, # How often shall we re-run the model on the whole sample
                                        batch_size = 128 # How many observations should be included in every batch
                                        )
Epoch 1/20

 1/35 [..............................] - ETA: 10s - loss: 1118101.5000 - mean_squared_error: 1118101.5000
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0014s vs `on_train_batch_end` time: 0.0021s). Check your callbacks.

29/35 [=======================>......] - ETA: 0s - loss: 1017630.3578 - mean_squared_error: 1017630.3578 
35/35 [==============================] - 0s 2ms/step - loss: 1012669.7982 - mean_squared_error: 1012669.7982

35/35 [==============================] - 1s 7ms/step - loss: 1011964.6875 - mean_squared_error: 1011964.6875 - val_loss: 793234.6250 - val_mean_squared_error: 793234.6250
Epoch 2/20

 1/35 [..............................] - ETA: 0s - loss: 1204456.5000 - mean_squared_error: 1204456.5000
 2/35 [>.............................] - ETA: 2s - loss: 1120856.2500 - mean_squared_error: 1120856.2500
35/35 [==============================] - 0s 4ms/step - loss: 999784.4179 - mean_squared_error: 999784.4179

35/35 [==============================] - 0s 5ms/step - loss: 999430.9531 - mean_squared_error: 999430.9531 - val_loss: 793080.9375 - val_mean_squared_error: 793080.9375
Epoch 3/20

 1/35 [..............................] - ETA: 0s - loss: 1084434.2500 - mean_squared_error: 1084434.2500
35/35 [==============================] - 0s 1ms/step - loss: 1010827.3625 - mean_squared_error: 1010827.3625

35/35 [==============================] - 0s 2ms/step - loss: 1010160.8628 - mean_squared_error: 1010160.8628 - val_loss: 792927.3750 - val_mean_squared_error: 792927.3750
Epoch 4/20

 1/35 [..............................] - ETA: 0s - loss: 1068159.0000 - mean_squared_error: 1068159.0000
35/35 [==============================] - 0s 1ms/step - loss: 992924.8036 - mean_squared_error: 992924.8036

35/35 [==============================] - 0s 3ms/step - loss: 992749.2882 - mean_squared_error: 992749.2882 - val_loss: 792777.3750 - val_mean_squared_error: 792777.3750
Epoch 5/20

 1/35 [..............................] - ETA: 0s - loss: 1003191.1250 - mean_squared_error: 1003191.1250
35/35 [==============================] - 0s 1ms/step - loss: 992015.3143 - mean_squared_error: 992015.3143

35/35 [==============================] - 0s 2ms/step - loss: 991858.7691 - mean_squared_error: 991858.7691 - val_loss: 792623.7500 - val_mean_squared_error: 792623.7500
Epoch 6/20

 1/35 [..............................] - ETA: 0s - loss: 930388.2500 - mean_squared_error: 930388.2500
35/35 [==============================] - 0s 1ms/step - loss: 976977.9446 - mean_squared_error: 976977.9446

35/35 [==============================] - 0s 3ms/step - loss: 977232.7604 - mean_squared_error: 977232.7604 - val_loss: 792470.9375 - val_mean_squared_error: 792470.9375
Epoch 7/20

 1/35 [..............................] - ETA: 0s - loss: 894342.0625 - mean_squared_error: 894342.0625
35/35 [==============================] - 0s 1ms/step - loss: 973522.0000 - mean_squared_error: 973522.0000

35/35 [==============================] - 0s 2ms/step - loss: 973866.5330 - mean_squared_error: 973866.5330 - val_loss: 792317.8750 - val_mean_squared_error: 792317.8750
Epoch 8/20

 1/35 [..............................] - ETA: 0s - loss: 995597.8750 - mean_squared_error: 995597.8750
35/35 [==============================] - 0s 1ms/step - loss: 985814.8536 - mean_squared_error: 985814.8536

35/35 [==============================] - 0s 2ms/step - loss: 985811.5990 - mean_squared_error: 985811.5990 - val_loss: 792166.9375 - val_mean_squared_error: 792166.9375
Epoch 9/20

 1/35 [..............................] - ETA: 0s - loss: 1016391.1875 - mean_squared_error: 1016391.1875
26/35 [=====================>........] - ETA: 0s - loss: 1000521.3774 - mean_squared_error: 1000521.3774
35/35 [==============================] - 0s 2ms/step - loss: 996883.3946 - mean_squared_error: 996883.3946

35/35 [==============================] - 0s 3ms/step - loss: 996566.3750 - mean_squared_error: 996566.3750 - val_loss: 792015.4375 - val_mean_squared_error: 792015.4375
Epoch 10/20

 1/35 [..............................] - ETA: 0s - loss: 1120095.5000 - mean_squared_error: 1120095.5000
35/35 [==============================] - 0s 1ms/step - loss: 993183.9393 - mean_squared_error: 993183.9393

35/35 [==============================] - 0s 2ms/step - loss: 992963.3472 - mean_squared_error: 992963.3472 - val_loss: 791863.4375 - val_mean_squared_error: 791863.4375
Epoch 11/20

 1/35 [..............................] - ETA: 0s - loss: 987956.8750 - mean_squared_error: 987956.8750
35/35 [==============================] - 0s 1ms/step - loss: 976953.0500 - mean_squared_error: 976953.0500

35/35 [==============================] - 0s 2ms/step - loss: 977177.0104 - mean_squared_error: 977177.0104 - val_loss: 791711.6875 - val_mean_squared_error: 791711.6875
Epoch 12/20

 1/35 [..............................] - ETA: 0s - loss: 970964.4375 - mean_squared_error: 970964.4375
35/35 [==============================] - 0s 1ms/step - loss: 1000611.6946 - mean_squared_error: 1000611.6946

35/35 [==============================] - 0s 2ms/step - loss: 1000172.1771 - mean_squared_error: 1000172.1771 - val_loss: 791558.6250 - val_mean_squared_error: 791558.6250
Epoch 13/20

 1/35 [..............................] - ETA: 0s - loss: 958742.8750 - mean_squared_error: 958742.8750
33/35 [===========================>..] - ETA: 0s - loss: 993045.2595 - mean_squared_error: 993045.2595
35/35 [==============================] - 0s 2ms/step - loss: 992586.6036 - mean_squared_error: 992586.6036

35/35 [==============================] - 0s 3ms/step - loss: 992363.6667 - mean_squared_error: 992363.6667 - val_loss: 791408.4375 - val_mean_squared_error: 791408.4375
Epoch 14/20

 1/35 [..............................] - ETA: 0s - loss: 1032647.3750 - mean_squared_error: 1032647.3750
35/35 [==============================] - 0s 1ms/step - loss: 970981.7036 - mean_squared_error: 970981.7036

35/35 [==============================] - 0s 3ms/step - loss: 971352.5972 - mean_squared_error: 971352.5972 - val_loss: 791255.3125 - val_mean_squared_error: 791255.3125
Epoch 15/20

 1/35 [..............................] - ETA: 0s - loss: 911472.0625 - mean_squared_error: 911472.0625
35/35 [==============================] - 0s 1ms/step - loss: 984113.8018 - mean_squared_error: 984113.8018

35/35 [==============================] - 0s 2ms/step - loss: 984113.6181 - mean_squared_error: 984113.6181 - val_loss: 791102.7500 - val_mean_squared_error: 791102.7500
Epoch 16/20

 1/35 [..............................] - ETA: 0s - loss: 1109388.0000 - mean_squared_error: 1109388.0000
28/35 [=======================>......] - ETA: 0s - loss: 1018607.9777 - mean_squared_error: 1018607.9777
35/35 [==============================] - 0s 2ms/step - loss: 1011464.6750 - mean_squared_error: 1011464.6750

35/35 [==============================] - 0s 3ms/step - loss: 1010698.4410 - mean_squared_error: 1010698.4410 - val_loss: 790950.1875 - val_mean_squared_error: 790950.1875
Epoch 17/20

 1/35 [..............................] - ETA: 0s - loss: 905676.1875 - mean_squared_error: 905676.1875
35/35 [==============================] - 0s 1ms/step - loss: 963638.1339 - mean_squared_error: 963638.1339

35/35 [==============================] - 0s 3ms/step - loss: 964194.1024 - mean_squared_error: 964194.1024 - val_loss: 790801.8750 - val_mean_squared_error: 790801.8750
Epoch 18/20

 1/35 [..............................] - ETA: 0s - loss: 827635.1250 - mean_squared_error: 827635.1250
35/35 [==============================] - 0s 1ms/step - loss: 984754.6411 - mean_squared_error: 984754.6411

35/35 [==============================] - 0s 2ms/step - loss: 984717.7604 - mean_squared_error: 984717.7604 - val_loss: 790648.9375 - val_mean_squared_error: 790648.9375
Epoch 19/20

 1/35 [..............................] - ETA: 0s - loss: 979348.5000 - mean_squared_error: 979348.5000
31/35 [=========================>....] - ETA: 0s - loss: 974619.5665 - mean_squared_error: 974619.5665
35/35 [==============================] - 0s 2ms/step - loss: 975346.0250 - mean_squared_error: 975346.0250

35/35 [==============================] - 0s 3ms/step - loss: 975564.1667 - mean_squared_error: 975564.1667 - val_loss: 790498.4375 - val_mean_squared_error: 790498.4375
Epoch 20/20

 1/35 [..............................] - ETA: 0s - loss: 1013968.1250 - mean_squared_error: 1013968.1250
34/35 [============================>.] - ETA: 0s - loss: 999464.3603 - mean_squared_error: 999464.3603  
35/35 [==============================] - 0s 2ms/step - loss: 998993.1750 - mean_squared_error: 998993.1750

35/35 [==============================] - 0s 3ms/step - loss: 998548.1667 - mean_squared_error: 998548.1667 - val_loss: 790345.5625 - val_mean_squared_error: 790345.5625
set.seed(1337)
hist_big <- model_big %>% fit(x = x_train, y = y_train, validation_split = 0.25, 
                                        epochs = 20, # How often shall we re-run the model on the whole sample
                                        batch_size = 125 # How many observations should be included in every batch
                                        )
Epoch 1/20

 1/36 [..............................] - ETA: 17s - loss: 1132958.6250 - mean_squared_error: 1132958.6250
 9/36 [======>.......................] - ETA: 0s - loss: 1010205.9375 - mean_squared_error: 1010205.9375 
17/36 [=============>................] - ETA: 0s - loss: 924545.7757 - mean_squared_error: 924545.7757  
25/36 [===================>..........] - ETA: 0s - loss: 826064.5625 - mean_squared_error: 826064.5625
32/36 [=========================>....] - ETA: 0s - loss: 753842.0986 - mean_squared_error: 753842.0986
36/36 [==============================] - 1s 7ms/step - loss: 718449.4479 - mean_squared_error: 718449.4479

36/36 [==============================] - 1s 17ms/step - loss: 710459.2019 - mean_squared_error: 710459.2019 - val_loss: 98967.3594 - val_mean_squared_error: 98967.3594
Epoch 2/20

 1/36 [..............................] - ETA: 0s - loss: 112504.0000 - mean_squared_error: 112504.0000
 6/36 [====>.........................] - ETA: 0s - loss: 119477.3086 - mean_squared_error: 119477.3086
 9/36 [======>.......................] - ETA: 0s - loss: 116543.5998 - mean_squared_error: 116543.5998
14/36 [==========>...................] - ETA: 0s - loss: 112157.3867 - mean_squared_error: 112157.3867
21/36 [================>.............] - ETA: 0s - loss: 108900.2664 - mean_squared_error: 108900.2664
27/36 [=====================>........] - ETA: 0s - loss: 106878.9777 - mean_squared_error: 106878.9777
33/36 [==========================>...] - ETA: 0s - loss: 105177.8426 - mean_squared_error: 105177.8426
36/36 [==============================] - 0s 10ms/step - loss: 104439.7218 - mean_squared_error: 104439.7218

36/36 [==============================] - 0s 13ms/step - loss: 104222.5046 - mean_squared_error: 104222.5046 - val_loss: 83020.8906 - val_mean_squared_error: 83020.8906
Epoch 3/20

 1/36 [..............................] - ETA: 0s - loss: 96959.4688 - mean_squared_error: 96959.4688
 5/36 [===>..........................] - ETA: 0s - loss: 96509.1703 - mean_squared_error: 96509.1703
11/36 [========>.....................] - ETA: 0s - loss: 93673.7749 - mean_squared_error: 93673.7749
17/36 [=============>................] - ETA: 0s - loss: 91831.5671 - mean_squared_error: 91831.5671
23/36 [==================>...........] - ETA: 0s - loss: 90584.7592 - mean_squared_error: 90584.7592
29/36 [=======================>......] - ETA: 0s - loss: 89913.2648 - mean_squared_error: 89913.2648
34/36 [===========================>..] - ETA: 0s - loss: 89673.8603 - mean_squared_error: 89673.8603
36/36 [==============================] - 0s 10ms/step - loss: 89528.3971 - mean_squared_error: 89528.3971

36/36 [==============================] - 0s 12ms/step - loss: 89462.7016 - mean_squared_error: 89462.7016 - val_loss: 75445.9844 - val_mean_squared_error: 75445.9844
Epoch 4/20

 1/36 [..............................] - ETA: 0s - loss: 129676.3047 - mean_squared_error: 129676.3047
 7/36 [====>.........................] - ETA: 0s - loss: 99355.8013 - mean_squared_error: 99355.8013  
13/36 [=========>....................] - ETA: 0s - loss: 94012.9459 - mean_squared_error: 94012.9459
20/36 [===============>..............] - ETA: 0s - loss: 92356.2437 - mean_squared_error: 92356.2437
26/36 [====================>.........] - ETA: 0s - loss: 91806.7831 - mean_squared_error: 91806.7831
30/36 [========================>.....] - ETA: 0s - loss: 91703.5268 - mean_squared_error: 91703.5268
36/36 [==============================] - 0s 9ms/step - loss: 91429.4360 - mean_squared_error: 91429.4360

36/36 [==============================] - 0s 11ms/step - loss: 91411.0800 - mean_squared_error: 91411.0800 - val_loss: 92452.1562 - val_mean_squared_error: 92452.1562
Epoch 5/20

 1/36 [..............................] - ETA: 0s - loss: 99583.6094 - mean_squared_error: 99583.6094
 7/36 [====>.........................] - ETA: 0s - loss: 85441.1864 - mean_squared_error: 85441.1864
12/36 [=========>....................] - ETA: 0s - loss: 87941.9805 - mean_squared_error: 87941.9805
16/36 [============>.................] - ETA: 0s - loss: 88072.4795 - mean_squared_error: 88072.4795
22/36 [=================>............] - ETA: 0s - loss: 87983.9595 - mean_squared_error: 87983.9595
29/36 [=======================>......] - ETA: 0s - loss: 87622.0469 - mean_squared_error: 87622.0469
36/36 [==============================] - 0s 10ms/step - loss: 87151.5992 - mean_squared_error: 87151.5992

36/36 [==============================] - 0s 12ms/step - loss: 87106.4367 - mean_squared_error: 87106.4367 - val_loss: 79835.0781 - val_mean_squared_error: 79835.0781
Epoch 6/20

 1/36 [..............................] - ETA: 0s - loss: 99662.6016 - mean_squared_error: 99662.6016
 7/36 [====>.........................] - ETA: 0s - loss: 85237.4888 - mean_squared_error: 85237.4888
12/36 [=========>....................] - ETA: 0s - loss: 85346.2409 - mean_squared_error: 85346.2409
18/36 [==============>...............] - ETA: 0s - loss: 84704.2552 - mean_squared_error: 84704.2552
24/36 [===================>..........] - ETA: 0s - loss: 84276.3522 - mean_squared_error: 84276.3522
30/36 [========================>.....] - ETA: 0s - loss: 84062.9781 - mean_squared_error: 84062.9781
36/36 [==============================] - 0s 9ms/step - loss: 84055.5592 - mean_squared_error: 84055.5592

36/36 [==============================] - 0s 11ms/step - loss: 84035.2679 - mean_squared_error: 84035.2679 - val_loss: 83711.4688 - val_mean_squared_error: 83711.4688
Epoch 7/20

 1/36 [..............................] - ETA: 0s - loss: 105663.1328 - mean_squared_error: 105663.1328
 5/36 [===>..........................] - ETA: 0s - loss: 94856.1047 - mean_squared_error: 94856.1047  
11/36 [========>.....................] - ETA: 0s - loss: 89780.5994 - mean_squared_error: 89780.5994
16/36 [============>.................] - ETA: 0s - loss: 88211.0054 - mean_squared_error: 88211.0054
20/36 [===============>..............] - ETA: 0s - loss: 87301.7777 - mean_squared_error: 87301.7777
26/36 [====================>.........] - ETA: 0s - loss: 86039.9820 - mean_squared_error: 86039.9820
32/36 [=========================>....] - ETA: 0s - loss: 85379.3687 - mean_squared_error: 85379.3687
36/36 [==============================] - 0s 10ms/step - loss: 85287.2357 - mean_squared_error: 85287.2357

36/36 [==============================] - 0s 12ms/step - loss: 85291.5367 - mean_squared_error: 85291.5367 - val_loss: 86435.3906 - val_mean_squared_error: 86435.3906
Epoch 8/20

 1/36 [..............................] - ETA: 0s - loss: 88560.3750 - mean_squared_error: 88560.3750
 7/36 [====>.........................] - ETA: 0s - loss: 83392.1830 - mean_squared_error: 83392.1830
14/36 [==========>...................] - ETA: 0s - loss: 80944.8566 - mean_squared_error: 80944.8566
21/36 [================>.............] - ETA: 0s - loss: 80552.8333 - mean_squared_error: 80552.8333
28/36 [======================>.......] - ETA: 0s - loss: 81390.0675 - mean_squared_error: 81390.0675
33/36 [==========================>...] - ETA: 0s - loss: 82049.0220 - mean_squared_error: 82049.0220
36/36 [==============================] - 0s 9ms/step - loss: 82282.1814 - mean_squared_error: 82282.1814

36/36 [==============================] - 0s 12ms/step - loss: 82351.1334 - mean_squared_error: 82351.1334 - val_loss: 80796.7188 - val_mean_squared_error: 80796.7188
Epoch 9/20

 1/36 [..............................] - ETA: 0s - loss: 80361.5938 - mean_squared_error: 80361.5938
 7/36 [====>.........................] - ETA: 0s - loss: 71820.1484 - mean_squared_error: 71820.1484
13/36 [=========>....................] - ETA: 0s - loss: 73698.5415 - mean_squared_error: 73698.5415
20/36 [===============>..............] - ETA: 0s - loss: 75648.2773 - mean_squared_error: 75648.2773
27/36 [=====================>........] - ETA: 0s - loss: 76982.2532 - mean_squared_error: 76982.2532
34/36 [===========================>..] - ETA: 0s - loss: 78236.2750 - mean_squared_error: 78236.2750
36/36 [==============================] - 0s 8ms/step - loss: 78479.9655 - mean_squared_error: 78479.9655

36/36 [==============================] - 0s 12ms/step - loss: 78589.0868 - mean_squared_error: 78589.0868 - val_loss: 77173.8047 - val_mean_squared_error: 77173.8047
Epoch 10/20

 1/36 [..............................] - ETA: 0s - loss: 81076.8984 - mean_squared_error: 81076.8984
 6/36 [====>.........................] - ETA: 0s - loss: 77942.1654 - mean_squared_error: 77942.1654
11/36 [========>.....................] - ETA: 0s - loss: 78017.2464 - mean_squared_error: 78017.2464
17/36 [=============>................] - ETA: 0s - loss: 78964.4095 - mean_squared_error: 78964.4095
23/36 [==================>...........] - ETA: 0s - loss: 79955.3618 - mean_squared_error: 79955.3618
30/36 [========================>.....] - ETA: 0s - loss: 80925.0677 - mean_squared_error: 80925.0677
35/36 [============================>.] - ETA: 0s - loss: 81190.4924 - mean_squared_error: 81190.4924
36/36 [==============================] - 0s 10ms/step - loss: 81240.0935 - mean_squared_error: 81240.0935

36/36 [==============================] - 0s 13ms/step - loss: 81287.0135 - mean_squared_error: 81287.0135 - val_loss: 81368.4219 - val_mean_squared_error: 81368.4219
Epoch 11/20

 1/36 [..............................] - ETA: 0s - loss: 69175.3828 - mean_squared_error: 69175.3828
 8/36 [=====>........................] - ETA: 0s - loss: 75239.2544 - mean_squared_error: 75239.2544
14/36 [==========>...................] - ETA: 0s - loss: 78576.2955 - mean_squared_error: 78576.2955
20/36 [===============>..............] - ETA: 0s - loss: 79695.7705 - mean_squared_error: 79695.7705
27/36 [=====================>........] - ETA: 0s - loss: 80025.0875 - mean_squared_error: 80025.0875
34/36 [===========================>..] - ETA: 0s - loss: 80527.9744 - mean_squared_error: 80527.9744
36/36 [==============================] - 0s 8ms/step - loss: 80706.0225 - mean_squared_error: 80706.0225

36/36 [==============================] - 0s 11ms/step - loss: 80785.7552 - mean_squared_error: 80785.7552 - val_loss: 77068.2031 - val_mean_squared_error: 77068.2031
Epoch 12/20

 1/36 [..............................] - ETA: 0s - loss: 109161.4062 - mean_squared_error: 109161.4062
 7/36 [====>.........................] - ETA: 0s - loss: 97841.7924 - mean_squared_error: 97841.7924  
13/36 [=========>....................] - ETA: 0s - loss: 91869.5974 - mean_squared_error: 91869.5974
20/36 [===============>..............] - ETA: 0s - loss: 88654.2922 - mean_squared_error: 88654.2922
27/36 [=====================>........] - ETA: 0s - loss: 87285.5576 - mean_squared_error: 87285.5576
33/36 [==========================>...] - ETA: 0s - loss: 86487.2176 - mean_squared_error: 86487.2176
36/36 [==============================] - 0s 9ms/step - loss: 86228.6962 - mean_squared_error: 86228.6962

36/36 [==============================] - 0s 12ms/step - loss: 86156.8351 - mean_squared_error: 86156.8351 - val_loss: 77294.3906 - val_mean_squared_error: 77294.3906
Epoch 13/20

 1/36 [..............................] - ETA: 0s - loss: 78759.9453 - mean_squared_error: 78759.9453
 7/36 [====>.........................] - ETA: 0s - loss: 75405.8069 - mean_squared_error: 75405.8069
14/36 [==========>...................] - ETA: 0s - loss: 76247.5028 - mean_squared_error: 76247.5028
21/36 [================>.............] - ETA: 0s - loss: 78046.2232 - mean_squared_error: 78046.2232
28/36 [======================>.......] - ETA: 0s - loss: 79268.9710 - mean_squared_error: 79268.9710
35/36 [============================>.] - ETA: 0s - loss: 80088.4719 - mean_squared_error: 80088.4719
36/36 [==============================] - 0s 8ms/step - loss: 80156.1678 - mean_squared_error: 80156.1678

36/36 [==============================] - 0s 10ms/step - loss: 80220.2044 - mean_squared_error: 80220.2044 - val_loss: 81422.3438 - val_mean_squared_error: 81422.3438
Epoch 14/20

 1/36 [..............................] - ETA: 0s - loss: 77028.0625 - mean_squared_error: 77028.0625
 7/36 [====>.........................] - ETA: 0s - loss: 79257.6049 - mean_squared_error: 79257.6049
14/36 [==========>...................] - ETA: 0s - loss: 81929.4096 - mean_squared_error: 81929.4096
20/36 [===============>..............] - ETA: 0s - loss: 82037.7324 - mean_squared_error: 82037.7324
26/36 [====================>.........] - ETA: 0s - loss: 81578.0328 - mean_squared_error: 81578.0328
33/36 [==========================>...] - ETA: 0s - loss: 81428.9856 - mean_squared_error: 81428.9856
36/36 [==============================] - 0s 9ms/step - loss: 81521.5597 - mean_squared_error: 81521.5597

36/36 [==============================] - 0s 10ms/step - loss: 81547.7430 - mean_squared_error: 81547.7430 - val_loss: 84582.8125 - val_mean_squared_error: 84582.8125
Epoch 15/20

 1/36 [..............................] - ETA: 0s - loss: 58902.4141 - mean_squared_error: 58902.4141
 8/36 [=====>........................] - ETA: 0s - loss: 78382.8984 - mean_squared_error: 78382.8984
16/36 [============>.................] - ETA: 0s - loss: 77969.8584 - mean_squared_error: 77969.8584
23/36 [==================>...........] - ETA: 0s - loss: 78424.5550 - mean_squared_error: 78424.5550
30/36 [========================>.....] - ETA: 0s - loss: 79054.6174 - mean_squared_error: 79054.6174
36/36 [==============================] - 0s 7ms/step - loss: 79503.0510 - mean_squared_error: 79503.0510

36/36 [==============================] - 0s 9ms/step - loss: 79554.3188 - mean_squared_error: 79554.3188 - val_loss: 82460.0859 - val_mean_squared_error: 82460.0859
Epoch 16/20

 1/36 [..............................] - ETA: 0s - loss: 92673.3672 - mean_squared_error: 92673.3672
 8/36 [=====>........................] - ETA: 0s - loss: 91195.2188 - mean_squared_error: 91195.2188
14/36 [==========>...................] - ETA: 0s - loss: 88890.5999 - mean_squared_error: 88890.5999
20/36 [===============>..............] - ETA: 0s - loss: 87230.9742 - mean_squared_error: 87230.9742
27/36 [=====================>........] - ETA: 0s - loss: 85907.4320 - mean_squared_error: 85907.4320
34/36 [===========================>..] - ETA: 0s - loss: 85361.1023 - mean_squared_error: 85361.1023
36/36 [==============================] - 0s 8ms/step - loss: 85258.5809 - mean_squared_error: 85258.5809

36/36 [==============================] - 0s 11ms/step - loss: 85216.0963 - mean_squared_error: 85216.0963 - val_loss: 86206.6953 - val_mean_squared_error: 86206.6953
Epoch 17/20

 1/36 [..............................] - ETA: 0s - loss: 68234.7578 - mean_squared_error: 68234.7578
 7/36 [====>.........................] - ETA: 0s - loss: 84021.3984 - mean_squared_error: 84021.3984
14/36 [==========>...................] - ETA: 0s - loss: 83481.3114 - mean_squared_error: 83481.3114
21/36 [================>.............] - ETA: 0s - loss: 82544.7839 - mean_squared_error: 82544.7839
28/36 [======================>.......] - ETA: 0s - loss: 81961.2628 - mean_squared_error: 81961.2628
35/36 [============================>.] - ETA: 0s - loss: 81903.3830 - mean_squared_error: 81903.3830
36/36 [==============================] - 0s 8ms/step - loss: 81884.4143 - mean_squared_error: 81884.4143

36/36 [==============================] - 0s 11ms/step - loss: 81866.4709 - mean_squared_error: 81866.4709 - val_loss: 82459.6328 - val_mean_squared_error: 82459.6328
Epoch 18/20

 1/36 [..............................] - ETA: 0s - loss: 67616.3906 - mean_squared_error: 67616.3906
 6/36 [====>.........................] - ETA: 0s - loss: 78221.9857 - mean_squared_error: 78221.9857
12/36 [=========>....................] - ETA: 0s - loss: 82245.8613 - mean_squared_error: 82245.8613
18/36 [==============>...............] - ETA: 0s - loss: 82960.5399 - mean_squared_error: 82960.5399
23/36 [==================>...........] - ETA: 0s - loss: 82947.1230 - mean_squared_error: 82947.1230
28/36 [======================>.......] - ETA: 0s - loss: 82859.9143 - mean_squared_error: 82859.9143
32/36 [=========================>....] - ETA: 0s - loss: 82725.8784 - mean_squared_error: 82725.8784
36/36 [==============================] - 0s 10ms/step - loss: 82544.1183 - mean_squared_error: 82544.1183

36/36 [==============================] - 0s 12ms/step - loss: 82497.6881 - mean_squared_error: 82497.6881 - val_loss: 81734.4297 - val_mean_squared_error: 81734.4297
Epoch 19/20

 1/36 [..............................] - ETA: 0s - loss: 59234.1406 - mean_squared_error: 59234.1406
 6/36 [====>.........................] - ETA: 0s - loss: 71102.0078 - mean_squared_error: 71102.0078
12/36 [=========>....................] - ETA: 0s - loss: 74259.6693 - mean_squared_error: 74259.6693
17/36 [=============>................] - ETA: 0s - loss: 75468.9637 - mean_squared_error: 75468.9637
23/36 [==================>...........] - ETA: 0s - loss: 76166.3801 - mean_squared_error: 76166.3801
30/36 [========================>.....] - ETA: 0s - loss: 77025.1253 - mean_squared_error: 77025.1253
36/36 [==============================] - 0s 9ms/step - loss: 77612.3322 - mean_squared_error: 77612.3322

36/36 [==============================] - 0s 11ms/step - loss: 77697.3729 - mean_squared_error: 77697.3729 - val_loss: 83758.5391 - val_mean_squared_error: 83758.5391
Epoch 20/20

 1/36 [..............................] - ETA: 0s - loss: 72726.1016 - mean_squared_error: 72726.1016
 4/36 [==>...........................] - ETA: 0s - loss: 77880.7793 - mean_squared_error: 77880.7793
10/36 [=======>......................] - ETA: 0s - loss: 77727.2992 - mean_squared_error: 77727.2992
16/36 [============>.................] - ETA: 0s - loss: 77002.8809 - mean_squared_error: 77002.8809
22/36 [=================>............] - ETA: 0s - loss: 77138.8857 - mean_squared_error: 77138.8857
28/36 [======================>.......] - ETA: 0s - loss: 77375.3513 - mean_squared_error: 77375.3513
34/36 [===========================>..] - ETA: 0s - loss: 77826.4113 - mean_squared_error: 77826.4113
36/36 [==============================] - 0s 10ms/step - loss: 77912.8976 - mean_squared_error: 77912.8976

36/36 [==============================] - 0s 12ms/step - loss: 77951.4191 - mean_squared_error: 77951.4191 - val_loss: 88360.6328 - val_mean_squared_error: 88360.6328
val_collected <- tibble(epoc = 1:20, 
       MSE =hist_base$metrics$val_mean_squared_error[1:20] %>% sqrt(),
       model = 'base') %>%
  bind_rows(
    tibble(epoc = 1:20, 
       MSE =hist_small$metrics$val_mean_squared_error[1:20] %>% sqrt(),
       model = 'small') 
  ) %>%
  bind_rows(
    tibble(epoc = 1:20, 
       MSE =hist_big$metrics$val_mean_squared_error[1:20] %>% sqrt(),
       model = 'big') 
  )
val_collected %>%
  ggplot(aes(x = epoc, y = MSE)) +
  geom_line(aes(col = model))

Part 3: Prevent overfitting

Dropout layers

model_dropout 
Model
Model: "sequential_13"
___________________________________________________________________________________________________________________________________________________
Layer (type)                                                      Output Shape                                              Param #                
===================================================================================================================================================
dense_35 (Dense)                                                  (None, 512)                                               11264                  
___________________________________________________________________________________________________________________________________________________
dropout_1 (Dropout)                                               (None, 512)                                               0                      
___________________________________________________________________________________________________________________________________________________
dense_34 (Dense)                                                  (None, 512)                                               262656                 
___________________________________________________________________________________________________________________________________________________
dropout (Dropout)                                                 (None, 512)                                               0                      
___________________________________________________________________________________________________________________________________________________
dense_33 (Dense)                                                  (None, 512)                                               262656                 
___________________________________________________________________________________________________________________________________________________
dense_32 (Dense)                                                  (None, 1)                                                 513                    
===================================================================================================================================================
Total params: 537,089
Trainable params: 537,089
Non-trainable params: 0
___________________________________________________________________________________________________________________________________________________
set.seed(1337)
hist_dropout <- model_dropout %>% fit(x = x_train, y = y_train, validation_split = 0.25, 
                                        epochs = 10, # How often shall we re-run the model on the whole sample
                                        batch_size = 128 # How many observations should be included in every batch
                                        )
Epoch 1/10

 1/35 [..............................] - ETA: 14s - loss: 1168946.7500 - mean_squared_error: 1168946.7500
13/35 [==========>...................] - ETA: 0s - loss: 1031809.2404 - mean_squared_error: 1031809.2404 
24/35 [===================>..........] - ETA: 0s - loss: 983349.3542 - mean_squared_error: 983349.3542  
35/35 [==============================] - 1s 5ms/step - loss: 920376.7661 - mean_squared_error: 920376.7661

35/35 [==============================] - 1s 11ms/step - loss: 914283.5556 - mean_squared_error: 914283.5556 - val_loss: 137108.5000 - val_mean_squared_error: 137108.5000
Epoch 2/10

 1/35 [..............................] - ETA: 0s - loss: 167483.7656 - mean_squared_error: 167483.7656
 9/35 [======>.......................] - ETA: 0s - loss: 198247.4288 - mean_squared_error: 198247.4288
20/35 [================>.............] - ETA: 0s - loss: 180302.7992 - mean_squared_error: 180302.7992
30/35 [========================>.....] - ETA: 0s - loss: 168197.5135 - mean_squared_error: 168197.5135
35/35 [==============================] - 0s 5ms/step - loss: 163254.4719 - mean_squared_error: 163254.4719

35/35 [==============================] - 0s 7ms/step - loss: 162367.0803 - mean_squared_error: 162367.0803 - val_loss: 100051.3594 - val_mean_squared_error: 100051.3594
Epoch 3/10

 1/35 [..............................] - ETA: 0s - loss: 108173.0000 - mean_squared_error: 108173.0000
 9/35 [======>.......................] - ETA: 0s - loss: 93066.1762 - mean_squared_error: 93066.1762  
18/35 [==============>...............] - ETA: 0s - loss: 94053.5208 - mean_squared_error: 94053.5208
28/35 [=======================>......] - ETA: 0s - loss: 94631.3002 - mean_squared_error: 94631.3002
35/35 [==============================] - 0s 6ms/step - loss: 94480.4042 - mean_squared_error: 94480.4042

35/35 [==============================] - 0s 7ms/step - loss: 94435.7094 - mean_squared_error: 94435.7094 - val_loss: 82477.8672 - val_mean_squared_error: 82477.8672
Epoch 4/10

 1/35 [..............................] - ETA: 0s - loss: 80697.1406 - mean_squared_error: 80697.1406
10/35 [=======>......................] - ETA: 0s - loss: 88597.9523 - mean_squared_error: 88597.9523
19/35 [===============>..............] - ETA: 0s - loss: 89256.8372 - mean_squared_error: 89256.8372
29/35 [=======================>......] - ETA: 0s - loss: 88850.4208 - mean_squared_error: 88850.4208
35/35 [==============================] - 0s 5ms/step - loss: 88584.1723 - mean_squared_error: 88584.1723

35/35 [==============================] - 0s 7ms/step - loss: 88558.1048 - mean_squared_error: 88558.1048 - val_loss: 81927.8750 - val_mean_squared_error: 81927.8750
Epoch 5/10

 1/35 [..............................] - ETA: 0s - loss: 80257.2812 - mean_squared_error: 80257.2812
10/35 [=======>......................] - ETA: 0s - loss: 84539.1672 - mean_squared_error: 84539.1672
20/35 [================>.............] - ETA: 0s - loss: 87007.3402 - mean_squared_error: 87007.3402
29/35 [=======================>......] - ETA: 0s - loss: 87163.6188 - mean_squared_error: 87163.6188
35/35 [==============================] - 0s 6ms/step - loss: 86979.8866 - mean_squared_error: 86979.8866

35/35 [==============================] - 0s 7ms/step - loss: 86962.9301 - mean_squared_error: 86962.9301 - val_loss: 80259.7109 - val_mean_squared_error: 80259.7109
Epoch 6/10

 1/35 [..............................] - ETA: 0s - loss: 65675.8047 - mean_squared_error: 65675.8047
10/35 [=======>......................] - ETA: 0s - loss: 82129.9488 - mean_squared_error: 82129.9488
20/35 [================>.............] - ETA: 0s - loss: 84882.3822 - mean_squared_error: 84882.3822
30/35 [========================>.....] - ETA: 0s - loss: 85281.3473 - mean_squared_error: 85281.3473
35/35 [==============================] - 0s 5ms/step - loss: 85423.7485 - mean_squared_error: 85423.7485

35/35 [==============================] - 0s 7ms/step - loss: 85427.7401 - mean_squared_error: 85427.7401 - val_loss: 79161.9375 - val_mean_squared_error: 79161.9375
Epoch 7/10

 1/35 [..............................] - ETA: 0s - loss: 120223.5469 - mean_squared_error: 120223.5469
11/35 [========>.....................] - ETA: 0s - loss: 95439.5426 - mean_squared_error: 95439.5426  
20/35 [================>.............] - ETA: 0s - loss: 90958.8594 - mean_squared_error: 90958.8594
26/35 [=====================>........] - ETA: 0s - loss: 89659.8335 - mean_squared_error: 89659.8335
34/35 [============================>.] - ETA: 0s - loss: 88441.8842 - mean_squared_error: 88441.8842
35/35 [==============================] - 0s 6ms/step - loss: 88348.3058 - mean_squared_error: 88348.3058

35/35 [==============================] - 0s 8ms/step - loss: 88259.9262 - mean_squared_error: 88259.9262 - val_loss: 78538.6875 - val_mean_squared_error: 78538.6875
Epoch 8/10

 1/35 [..............................] - ETA: 0s - loss: 85055.4375 - mean_squared_error: 85055.4375
 9/35 [======>.......................] - ETA: 0s - loss: 87331.1875 - mean_squared_error: 87331.1875
18/35 [==============>...............] - ETA: 0s - loss: 86503.3798 - mean_squared_error: 86503.3798
29/35 [=======================>......] - ETA: 0s - loss: 85507.9507 - mean_squared_error: 85507.9507
35/35 [==============================] - 0s 6ms/step - loss: 85363.0690 - mean_squared_error: 85363.0690

35/35 [==============================] - 0s 8ms/step - loss: 85352.1714 - mean_squared_error: 85352.1714 - val_loss: 79789.1406 - val_mean_squared_error: 79789.1406
Epoch 9/10

 1/35 [..............................] - ETA: 0s - loss: 85436.9297 - mean_squared_error: 85436.9297
10/35 [=======>......................] - ETA: 0s - loss: 78208.7508 - mean_squared_error: 78208.7508
20/35 [================>.............] - ETA: 0s - loss: 81439.5441 - mean_squared_error: 81439.5441
29/35 [=======================>......] - ETA: 0s - loss: 82550.2099 - mean_squared_error: 82550.2099
35/35 [==============================] - 0s 6ms/step - loss: 82965.5569 - mean_squared_error: 82965.5569

35/35 [==============================] - 0s 9ms/step - loss: 83008.2307 - mean_squared_error: 83008.2307 - val_loss: 78430.8750 - val_mean_squared_error: 78430.8750
Epoch 10/10

 1/35 [..............................] - ETA: 0s - loss: 90665.5391 - mean_squared_error: 90665.5391
10/35 [=======>......................] - ETA: 0s - loss: 83433.8656 - mean_squared_error: 83433.8656
20/35 [================>.............] - ETA: 0s - loss: 84644.0352 - mean_squared_error: 84644.0352
29/35 [=======================>......] - ETA: 0s - loss: 84765.5913 - mean_squared_error: 84765.5913
35/35 [==============================] - 0s 6ms/step - loss: 84793.0518 - mean_squared_error: 84793.0518

35/35 [==============================] - 0s 7ms/step - loss: 84767.4818 - mean_squared_error: 84767.4818 - val_loss: 77337.3984 - val_mean_squared_error: 77337.3984
val_collected  %<>%
  bind_rows(
    tibble(epoc = 1:10, 
       MSE =hist_dropout$metrics$val_mean_squared_error[1:10],
       model = 'dropout') 
  )
val_collected %>%
  ggplot(aes(x = epoc, y = MSE)) +
  geom_line(aes(col = model))

pred_collected <- tibble(
  truth = y_train,
  base = mean(truth),
  model_base = model_base %>% predict(x = x_train),
  model_small = model_small %>% predict(x = x_train),
  model_big = model_big %>% predict(x = x_train),
  ) %>% 
  pivot_longer(cols = -truth,
               names_to = 'model',
               values_to = '.pred')
pred_collected %>%
  ggplot(aes(x = truth, y = .pred, color = model)) +
  geom_abline(lty = 2, color = "gray80", size = 1.5) +
  geom_point(alpha = 0.5, shape = 21) +
  labs(
    x = "Truth",
    y = "Predicted price",
    color = "Type of model"
  )

Early stop

library(tfruns)
# launch TensorBoard (data won't show up until after the first epoch)
tensorboard("logs/run_a")

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

TensorBoard 2.5.0 at http://127.0.0.1:7243/ (Press CTRL+C to quit)
Started TensorBoard at http://127.0.0.1:7243 
set.seed(1337)
hist_drop_stop <- model_dropout  %>% fit(x = x_train, 
                                         y = y_train, 
                                         epochs = 500, # How often shall we re-run the model on the whole sample
                                         batch_size = 128, # How many observations should be included in every batch
                                         validation_split = 0.25, # If we want to do a  cross-validation in the training
                                         callbacks = list(
                                           callback_model_checkpoint("logs_a/checkpoints.h5"),
                                           callback_tensorboard("logs/run_a"), 
                                           callback_early_stopping(patience = 3))
                                       )
Epoch 1/500

 1/35 [..............................] - ETA: 0s - loss: 86048.9062 - mean_squared_error: 86048.9062
 9/35 [======>.......................] - ETA: 0s - loss: 88406.9297 - mean_squared_error: 88406.9297
20/35 [================>.............] - ETA: 0s - loss: 86766.5938 - mean_squared_error: 86766.5938
31/35 [=========================>....] - ETA: 0s - loss: 84904.8125 - mean_squared_error: 84904.8125
35/35 [==============================] - 0s 5ms/step - loss: 84217.8594 - mean_squared_error: 84217.8594

35/35 [==============================] - 0s 7ms/step - loss: 84217.8594 - mean_squared_error: 84217.8594 - val_loss: 77731.1719 - val_mean_squared_error: 77731.1719
Epoch 2/500

 1/35 [..............................] - ETA: 0s - loss: 88228.3750 - mean_squared_error: 88228.3750
 9/35 [======>.......................] - ETA: 0s - loss: 80858.6562 - mean_squared_error: 80858.6562
19/35 [===============>..............] - ETA: 0s - loss: 82384.8125 - mean_squared_error: 82384.8125
30/35 [========================>.....] - ETA: 0s - loss: 81817.1094 - mean_squared_error: 81817.1094
35/35 [==============================] - 0s 6ms/step - loss: 84791.2031 - mean_squared_error: 84791.2031

35/35 [==============================] - 0s 7ms/step - loss: 84791.2031 - mean_squared_error: 84791.2031 - val_loss: 81453.6641 - val_mean_squared_error: 81453.6641
Epoch 3/500

 1/35 [..............................] - ETA: 0s - loss: 62494.8281 - mean_squared_error: 62494.8281
 9/35 [======>.......................] - ETA: 0s - loss: 82986.0703 - mean_squared_error: 82986.0703
19/35 [===============>..............] - ETA: 0s - loss: 85073.3125 - mean_squared_error: 85073.3125
27/35 [======================>.......] - ETA: 0s - loss: 83602.5703 - mean_squared_error: 83602.5703
35/35 [==============================] - 0s 6ms/step - loss: 84364.0469 - mean_squared_error: 84364.0469

35/35 [==============================] - 0s 8ms/step - loss: 84364.0469 - mean_squared_error: 84364.0469 - val_loss: 84640.5938 - val_mean_squared_error: 84640.5938
Epoch 4/500

 1/35 [..............................] - ETA: 0s - loss: 104754.6953 - mean_squared_error: 104754.6953
 7/35 [=====>........................] - ETA: 0s - loss: 83500.6250 - mean_squared_error: 83500.6250  
15/35 [===========>..................] - ETA: 0s - loss: 87389.6328 - mean_squared_error: 87389.6328
26/35 [=====================>........] - ETA: 0s - loss: 86199.2344 - mean_squared_error: 86199.2344
35/35 [==============================] - 0s 6ms/step - loss: 83708.6328 - mean_squared_error: 83708.6328

35/35 [==============================] - 0s 7ms/step - loss: 83708.6328 - mean_squared_error: 83708.6328 - val_loss: 82294.2500 - val_mean_squared_error: 82294.2500
tensorboard("logs/run_a")

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

TensorBoard 2.5.0 at http://127.0.0.1:7243/ (Press CTRL+C to quit)
Started TensorBoard at http://127.0.0.1:7243 
---
title: 'Neural Networks Exercise: Predicting Airbnb prices with simple ANNs (R)'
author: "Daniel S. Hain (dsh@business.aau.dk)"
date: "Updated `r format(Sys.time(), '%B %d, %Y')`"
output:
  html_notebook:
    code_folding: show
    df_print: paged
    toc: true
    toc_depth: 2
    toc_float:
      collapsed: false
    theme: flatly
---

```{r setup, include=FALSE}
### Generic preamble
rm(list=ls())
Sys.setenv(LANG = "en") # For english language
options(scipen = 5) # To deactivate annoying scientific number notation

### Knitr options
library(knitr) # For display of the markdown
knitr::opts_chunk$set(warning=FALSE,
                     message=FALSE,
                     comment=FALSE, 
                     fig.align="center"
                     )
```


## Preamble

```{r}
# Clear workspace
rm(list=ls()); graphics.off() 
```

```{r}
### Load standardpackages
library(tidyverse) # Collection of all the good stuff like dplyr, ggplot2 ect.
library(magrittr) # For extra-piping operators (eg. %<>%)
library(skimr) # For nice data summaries

# Load our main predictive tool
library(tidymodels)
library(keras)
```

```{r}
# Test tensorflow
tensorflow::tf$constant(c("Hellow Tensorflow", "Hello Keras"))
```


# Part 1: Simple neural network

## Load the data

```{r}
listings <- read_csv('http://data.insideairbnb.com/denmark/hovedstaden/copenhagen/2021-07-20/data/listings.csv.gz')
```


# Preprocessing

```{r}
data <- listings %>%
  rename(y = price) %>%
  relocate(y) 
```

```{r}
data %<>% 
  mutate(across(is_character, ~ifelse(.x == "", NA, .x))) 
```


```{r}
select_var <- c('y', 
                'room_type',
                'neighbourhood_cleansed',
                'accommodates',
                'bedrooms',
                'review_scores_rating',
                'number_of_reviews',
                'host_is_superhost',
                'host_identity_verified') 

select_cat_room <- c('Entire home/apt', 'Private room')
```

## Split the data

```{r}
data_split <- initial_split(data, prop = 0.80, strata = y)

data_train <- data_split  %>%  training()
data_test <- data_split %>% testing()
```

## Define Recipe

```{r}
data_recipe <- data_train %>%
  recipe(y ~.) %>%
  # Case specific variable selection
  step_select(all_of(select_var)) %>%   
  # Case specific mutation
  step_mutate_at(where(is_logical), fn = as.numeric) %>%              # Logical columns sometimes make problems with KNN inpute
  step_mutate(y = y %>% as.character() %>% parse_number()) %>%
  # Filtering
  step_filter(percent_rank(y) <= 0.95, skip = TRUE) %>% # filter y outliers
  step_naomit(y) %>%
  step_filter(room_type %in% select_cat_room) %>%                     # filter room type
  step_filter(number_of_reviews >= 1) %>%                             # filtern number of reviews
  # Standard preprocessing
  # step_impute_knn(all_predictors()) %>%                             # Inpute missing variables
  step_naomit(everything()) %>%                                       # When above doesnt work
  step_normalize(all_numeric_predictors()) %>%                        # Centers all numeric variables to mean = 0 & sd = 1
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%            # Dummies of all categorical
  prep()                                                              # Prep only if no workflows are needed later
```


## Prepare input data

```{r}
x_train <- juice(data_recipe) %>% select(-y) %>% as.matrix()
x_test <- bake(data_recipe, new_data = data_test) %>% select(-y) %>% as.matrix()
```

```{r}
y_train <- juice(data_recipe) %>% pull(y)   
y_test <- bake(data_recipe, new_data = data_test) %>% pull(y) 
```


# Neural Network 

### Define the network

```{r}
model_base <- keras_model_sequential() %>% 
  layer_dense(units = 64, activation = "relu", input_shape = ncol(x_train)) %>% 
  layer_dense(units = 64, activation = "relu") %>% 
  layer_dense(units = 1, activation = "linear") 
```

```{r}
model_base %>% 
  compile(
    optimizer = optimizer_adam(),
    loss = "mse",
    metrics = list("mean_squared_error")
  )
```

### Run

```{r}
set.seed(1337)
hist_base <- model_base  %>% fit(x = x_train, 
                                         y = y_train, 
                                         epochs = 20, # How often shall we re-run the model on the whole sample
                                         batch_size = 128, # How many observations should be included in every batch
                                         validation_split = 0.25, # If we want to do a  cross-validation in the training
                                         # callbacks = callback_tensorboard("logs/run_a")
                                         )
```

### Evaluate

```{r}
hist_base %>% plot(smooth = TRUE)
```

```{r}
model_base %>% evaluate(x_test, y_test) %>% sqrt()
```

# Part 2: Play around

```{r}
model_small <- keras_model_sequential() %>% 
  # architecture
  layer_dense(units = 21, activation = "relu", input_shape = ncol(x_train)) %>% 
  layer_dense(units = 1, activation = "linear") %>% 
  # compile
  compile(optimizer = optimizer_adam(), loss = "mse", metrics = list("mean_squared_error")
  )
```


```{r}
model_big <- keras_model_sequential() %>% 
  # architecture
  layer_dense(units = 512, activation = "relu", input_shape = ncol(x_train)) %>% 
  layer_dense(units = 512, activation = "relu") %>% 
  layer_dense(units = 512, activation = "relu") %>% 
  layer_dense(units = 512, activation = "relu") %>% 
  layer_dense(units = 512, activation = "relu") %>% 
  layer_dense(units = 1, activation = "linear") %>% 
  # compile
  compile(optimizer = optimizer_adam(), loss = "mse", metrics = list("mean_squared_error")
  )
```

```{r}
set.seed(1337)
hist_small <- model_small %>% fit(x = x_train, y = y_train, validation_split = 0.25, 
                                        epochs = 20, # How often shall we re-run the model on the whole sample
                                        batch_size = 128 # How many observations should be included in every batch
                                        )
```

```{r}
set.seed(1337)
hist_big <- model_big %>% fit(x = x_train, y = y_train, validation_split = 0.25, 
                                        epochs = 20, # How often shall we re-run the model on the whole sample
                                        batch_size = 125 # How many observations should be included in every batch
                                        )
```

```{r}
val_collected <- tibble(epoc = 1:20, 
       MSE =hist_base$metrics$val_mean_squared_error[1:20] %>% sqrt(),
       model = 'base') %>%
  bind_rows(
    tibble(epoc = 1:20, 
       MSE =hist_small$metrics$val_mean_squared_error[1:20] %>% sqrt(),
       model = 'small') 
  ) %>%
  bind_rows(
    tibble(epoc = 1:20, 
       MSE =hist_big$metrics$val_mean_squared_error[1:20] %>% sqrt(),
       model = 'big') 
  )
```


```{r}
val_collected %>%
  ggplot(aes(x = epoc, y = MSE)) +
  geom_line(aes(col = model))
```

# Part 3: Prevent overfitting

## Dropout layers

```{r}
model_dropout <- keras_model_sequential() %>% 
  # architecture
  layer_dense(units = 512, activation = "relu", input_shape = ncol(x_train)) %>% 
  layer_dropout(rate = 0.1) %>%
  layer_dense(units = 512, activation = "relu") %>% 
  layer_dropout(rate = 0.1) %>%
  layer_dense(units = 512, activation = "relu") %>% 
  layer_dense(units = 1, activation = "linear") %>% 
  # compile
  compile(optimizer = "adam", loss = "mse", metrics = list("mean_squared_error")
  )
```

```{r}
set.seed(1337)
hist_dropout <- model_dropout %>% fit(x = x_train, y = y_train, validation_split = 0.25, 
                                        epochs = 10, # How often shall we re-run the model on the whole sample
                                        batch_size = 128 # How many observations should be included in every batch
                                        )
```

```{r}
val_collected  %<>%
  bind_rows(
    tibble(epoc = 1:10, 
       MSE =hist_dropout$metrics$val_mean_squared_error[1:10],
       model = 'dropout') 
  )
```


```{r}
val_collected %>%
  ggplot(aes(x = epoc, y = MSE)) +
  geom_line(aes(col = model))
```


```{r}
pred_collected <- tibble(
  truth = y_train,
  base = mean(truth),
  model_base = model_base %>% predict(x = x_train),
  model_small = model_small %>% predict(x = x_train),
  model_big = model_big %>% predict(x = x_train),
  ) %>% 
  pivot_longer(cols = -truth,
               names_to = 'model',
               values_to = '.pred')
```

```{r}
pred_collected %>%
  ggplot(aes(x = truth, y = .pred, color = model)) +
  geom_abline(lty = 2, color = "gray80", size = 1.5) +
  geom_point(alpha = 0.5, shape = 21) +
  labs(
    x = "Truth",
    y = "Predicted price",
    color = "Type of model"
  )
```

## Early stop
```{r}
library(tfruns)
```

```{r}
# launch TensorBoard (data won't show up until after the first epoch)
tensorboard("logs/run_a")
```


```{r}
set.seed(1337)
hist_drop_stop <- model_dropout  %>% fit(x = x_train, 
                                         y = y_train, 
                                         epochs = 500, # How often shall we re-run the model on the whole sample
                                         batch_size = 128, # How many observations should be included in every batch
                                         validation_split = 0.25, # If we want to do a  cross-validation in the training
                                         callbacks = list(
                                           callback_model_checkpoint("logs_a/checkpoints.h5"),
                                           callback_tensorboard("logs/run_a"), 
                                           callback_early_stopping(patience = 3))
                                       )
```

```{r}
tensorboard("logs/run_a")
```

# Things to look

* https://blogs.rstudio.com/ai/posts/2019-04-16-autokeras/
* https://cran.r-project.org/web/packages/kerastuneR/vignettes/Introduction.html