Neural Nets (pt.3), Interpretability and Convolutional Neural Networks

[1]:
# Ref: https://medium.com/biaslyai/pytorch-introduction-to-neural-network-feedforward-neural-network-model-e7231cff47cb

'''
Packages

conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
conda install -c anaconda scikit-learn
conda install pandas
'''

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from scipy import stats
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[5]:
#What are we trying to beat
data = pd.read_csv('./tree_height/txt/eu_x_y_height_predictors.txt',  sep=" ")
# data = pd.read_csv('./tree_height/txt/eu_x_y_height.txt',  sep=" ")
print(data.shape)
print(data.head())

y_true = data['h']
y_pred = data['forest_height']

slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(y_pred, y_true)

fig,ax=plt.subplots(1,1,figsize=(5,5))
ax.scatter(y_pred, y_true)
ax.set_xlabel('Prediction')
ax.set_ylabel('True')
ax.set_title('slope: {:.4f}, r_value: {:.4f}'.format(slope, r_value))
plt.show()
(66522, 22)
          x          y      h     minSVVI     medSVVI     maxSVVI  treecover  \
0  6.894317  49.482459   2.73  176.799316  233.790283  128.065918         61
1  7.023274  49.510552  10.75  495.294678  555.553955  680.410400         36
2  7.394650  49.590488  21.20  172.081299  122.683594  231.771484        100
3  7.396895  49.590968  20.00  278.890625   12.374756   25.781250         83
4  7.397643  49.591128  24.23  305.592651   26.431641   20.360352         79

   forest_height  CHELSA_bio18  CHELSA_bio4  ...  northness  eastness  \
0             19          2251         6208  ...  -0.041041 -0.107523
1             11          2336         6144  ...   0.060015 -0.031948
2             20          2067         6194  ...   0.104165  0.049155
3             26          2067         6194  ...   0.073083  0.041050
4             26          2067         6194  ...   0.042943  0.064307

    elevation  outlet_dist_dw_basin        cti  BLDFIE  ORCDRC  CECSOL  \
0  344.212067                781109 -288081519    1526      12      13
1  367.109863                800307 -155512427    1479      16      11
2  402.408142                624259 -343142396    1531       9      15
3  384.229584                624227 -304025319    1524      10      14
4  382.408264                624209 -316507082    1524      10      14

        SBIO4      SBIO3
0  482.899994  22.400000
1  452.500000  18.900000
2  460.100006  19.200001
3  460.100006  19.200001
4  460.100006  19.200001

[5 rows x 22 columns]
../_images/CASESTUDY_NNs_pt3_SHAP_2_1.png
[6]:
tree_height = data['h'].to_numpy()
data = data.drop('h', 1)
data = data.drop('forest_height',1)
print(data.head())
          x          y     minSVVI     medSVVI     maxSVVI  treecover  \
0  6.894317  49.482459  176.799316  233.790283  128.065918         61
1  7.023274  49.510552  495.294678  555.553955  680.410400         36
2  7.394650  49.590488  172.081299  122.683594  231.771484        100
3  7.396895  49.590968  278.890625   12.374756   25.781250         83
4  7.397643  49.591128  305.592651   26.431641   20.360352         79

   CHELSA_bio18  CHELSA_bio4  dev-magnitude  convergence  northness  eastness  \
0          2251         6208       0.774696    -5.247141  -0.041041 -0.107523
1          2336         6144      -1.088770    44.009472   0.060015 -0.031948
2          2067         6194       0.892866    -2.490510   0.104165  0.049155
3          2067         6194       0.893981     4.893103   0.073083  0.041050
4          2067         6194       0.856278    -4.858290   0.042943  0.064307

    elevation  outlet_dist_dw_basin        cti  BLDFIE  ORCDRC  CECSOL  \
0  344.212067                781109 -288081519    1526      12      13
1  367.109863                800307 -155512427    1479      16      11
2  402.408142                624259 -343142396    1531       9      15
3  384.229584                624227 -304025319    1524      10      14
4  382.408264                624209 -316507082    1524      10      14

        SBIO4      SBIO3
0  482.899994  22.400000
1  452.500000  18.900000
2  460.100006  19.200001
3  460.100006  19.200001
4  460.100006  19.200001
/var/folders/96/dnthcv6d22j1gtb3t_m_txr00000gn/T/ipykernel_84192/3688197143.py:2: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only.
  data = data.drop('h', 1)
/var/folders/96/dnthcv6d22j1gtb3t_m_txr00000gn/T/ipykernel_84192/3688197143.py:3: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only.
  data = data.drop('forest_height',1)
[7]:
#Explore the raw data
# data = data.to_numpy()
n_plots_x = int(np.ceil(np.sqrt(data.shape[1])))
n_plots_y = int(np.floor(np.sqrt(data.shape[1])))
fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=(10, 10), dpi=100, facecolor='w', edgecolor='k')
ax=ax.ravel()
for idx in range(data.shape[1]-1):
    ax[idx].hist(data.iloc[:, idx].to_numpy().flatten())
    ax[idx].set_title(data.columns[idx])
    # plt.scatter(to_np(times)[:extrapolation_points],obs_print[:extrapolation_points,0]*scaling_factor,label='Data',c='blue')
    # ax[idx].scatter(time_to_print[:],obs_to_print[:,idx],label='Data',c='blue', alpha=0.5)
    # ax[idx].set_xlabel("Time")
    # ax[idx].set_ylabel("dim"+str(idx))
    #plt.scatter(to_np(times)[extrapolation_points:],obs_print[extrapolation_points:,0,0],label='Data extr',c='red')
    # ax[idx].legend()
    # timestr = time.strftime("%Y%m%d-%H%M%S")
fig.tight_layout()
../_images/CASESTUDY_NNs_pt3_SHAP_4_0.png
[8]:
from sklearn.preprocessing import QuantileTransformer
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
qt = QuantileTransformer(
    n_quantiles=500, output_distribution="normal", random_state=0
)

# X_train, X_test = train_test_split(data, test_size=0.5)
# data = qt.fit_transform(data)
tree_height = qt.fit_transform(tree_height.reshape(-1,1))
scaler_tree = MinMaxScaler()
tree_height = scaler_tree.fit_transform(tree_height).squeeze()
tree_height = ((tree_height-tree_height.min())/(tree_height.max()-tree_height.min()))*2-1
# data[:,2] = data[:,2]/np.quantile(np.abs(data[:,2]),0.99)
tree_height = tree_height/tree_height.max()
plt.hist(tree_height,50)
[8]:
(array([1.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00,
        0.000e+00, 0.000e+00, 1.000e+00, 0.000e+00, 6.500e+01, 1.240e+02,
        2.050e+02, 3.430e+02, 3.760e+02, 8.450e+02, 1.115e+03, 1.770e+03,
        2.175e+03, 2.876e+03, 3.571e+03, 4.227e+03, 4.866e+03, 5.195e+03,
        5.512e+03, 5.437e+03, 5.233e+03, 4.862e+03, 4.215e+03, 3.589e+03,
        2.862e+03, 2.204e+03, 1.663e+03, 1.150e+03, 8.020e+02, 5.080e+02,
        3.100e+02, 1.950e+02, 1.890e+02, 3.400e+01, 1.000e+00, 0.000e+00,
        0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00,
        0.000e+00, 1.000e+00]),
 array([-1.  , -0.96, -0.92, -0.88, -0.84, -0.8 , -0.76, -0.72, -0.68,
        -0.64, -0.6 , -0.56, -0.52, -0.48, -0.44, -0.4 , -0.36, -0.32,
        -0.28, -0.24, -0.2 , -0.16, -0.12, -0.08, -0.04,  0.  ,  0.04,
         0.08,  0.12,  0.16,  0.2 ,  0.24,  0.28,  0.32,  0.36,  0.4 ,
         0.44,  0.48,  0.52,  0.56,  0.6 ,  0.64,  0.68,  0.72,  0.76,
         0.8 ,  0.84,  0.88,  0.92,  0.96,  1.  ]),
 <BarContainer object of 50 artists>)
../_images/CASESTUDY_NNs_pt3_SHAP_5_1.png
[9]:
#Normalize the data
scaler_data = MinMaxScaler()
data_transformed = scaler_data.fit_transform(data)

n_plots_x = int(np.ceil(np.sqrt(data.shape[1])))
n_plots_y = int(np.floor(np.sqrt(data.shape[1])))
fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=(10, 10), dpi=100, facecolor='w', edgecolor='k')
ax=ax.ravel()
for idx in range(data.shape[1]-1):
    ax[idx].hist(data_transformed[:,idx].flatten())
    ax[idx].set_title(data.columns[idx])
    # plt.scatter(to_np(times)[:extrapolation_points],obs_print[:extrapolation_points,0]*scaling_factor,label='Data',c='blue')
    # ax[idx].scatter(time_to_print[:],obs_to_print[:,idx],label='Data',c='blue', alpha=0.5)
    # ax[idx].set_xlabel("Time")
    # ax[idx].set_ylabel("dim"+str(idx))
    #plt.scatter(to_np(times)[extrapolation_points:],obs_print[extrapolation_points:,0,0],label='Data extr',c='red')
    # ax[idx].legend()
    # timestr = time.strftime("%Y%m%d-%H%M%S")
fig.tight_layout()
../_images/CASESTUDY_NNs_pt3_SHAP_6_0.png
[10]:
# Let's use all the data as one big minibatch

#Split the data
X_train, X_test, y_train, y_test = train_test_split(data_transformed,tree_height, test_size=0.30, random_state=0)
X_train = torch.FloatTensor(X_train)
y_train = torch.FloatTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.FloatTensor(y_test)
print('X_train.shape: {}, X_test.shape: {}, y_train.shape: {}, y_test.shape: {}'.format(X_train.shape, X_test.shape, y_train.shape, y_test.shape))
print('X_train.min: {}, X_test.shape: {}, y_train.shape: {}, y_test.shape: {}'.format(X_train.min(), X_test.min(), y_train.min(), y_test.min()))
print('X_train.max: {}, X_test.shape: {}, y_train.shape: {}, y_test.shape: {}'.format(X_train.max(), X_test.max(), y_train.max(), y_test.max()))
X_train.shape: torch.Size([46565, 20]), X_test.shape: torch.Size([19957, 20]), y_train.shape: torch.Size([46565]), y_test.shape: torch.Size([19957])
X_train.min: 0.0, X_test.shape: 0.0, y_train.shape: -0.6420077085494995, y_test.shape: -1.0
X_train.max: 1.0, X_test.shape: 1.0, y_train.shape: 1.0, y_test.shape: 0.57742840051651
[11]:
# Create the model
class Perceptron(torch.nn.Module):
    def __init__(self,input_size, output_size):
        super(Perceptron, self).__init__()
        self.fc = nn.Linear(input_size,output_size)
        self.relu = torch.nn.ReLU() # instead of Heaviside step fn
        self.sigmoid = torch.nn.Sigmoid()
        self.tanh = torch.nn.Tanh()
    def forward(self, x):
        # print('x.shape: ',x.shape)
        # print('x: ',x)
        output = self.tanh(self.fc(x))
        # print('output.shape: ',output.shape)
        # print('output: ',output)
        # output = self.tanh(output) # instead of Heaviside step fn
        # print('output: ',output)
        return output
[12]:
# model.train()
epoch = 2000
hid_dim_range = [1]#,256,512]
lr_range = [0.5,0.1,0.01]#,0.05,0.001]

for hid_dim in hid_dim_range:
    for lr in lr_range:
        print('\nhid_dim: {}, lr: {}'.format(hid_dim, lr))
        if 'model' in globals():
            print('Deleting previous model')
            del model, criterion, optimizer
        model = Perceptron(data.shape[1], hid_dim)
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr = lr)

        all_loss_train=[]
        all_loss_val=[]
        for epoch in range(epoch):
            model.train()
            optimizer.zero_grad()
            # Forward pass
            y_pred = model(X_train)
            # Compute Loss
            loss = criterion(y_pred.squeeze(), y_train)

            # Backward pass
            loss.backward()
            optimizer.step()

            all_loss_train.append(loss.item())

            model.eval()
            with torch.no_grad():
                y_pred = model(X_test)
                # Compute Loss
                loss = criterion(y_pred.squeeze(), y_test)
                all_loss_val.append(loss.item())

                if epoch%100==0:
                    y_pred = y_pred.detach().numpy().squeeze()
                    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(y_pred, y_test)
                    print('Epoch {}, train_loss: {:.4f}, val_loss: {:.4f}, r_value: {:.4f}'.format(epoch,all_loss_train[-1],all_loss_val[-1],r_value))

        fig,ax=plt.subplots(1,2,figsize=(10,5))
        ax[0].plot(all_loss_train)
        ax[0].plot(all_loss_val)

        ax[1].scatter(y_pred, y_test)
        ax[1].set_xlabel('Prediction')
        ax[1].set_ylabel('True')
        ax[1].set_title('slope: {:.4f}, r_value: {:.4f}'.format(slope, r_value))
        plt.show()

hid_dim: 1, lr: 0.5
Epoch 0, train_loss: 0.0532, val_loss: 0.1638, r_value: -0.0623
Epoch 100, train_loss: 0.1954, val_loss: 0.7130, r_value: 0.4214
Epoch 200, train_loss: 0.2813, val_loss: 0.6098, r_value: 0.3794
Epoch 300, train_loss: 0.3224, val_loss: 0.4925, r_value: 0.3421
Epoch 400, train_loss: 0.3879, val_loss: 0.4200, r_value: 0.3869
Epoch 500, train_loss: 0.4980, val_loss: 0.2536, r_value: 0.4959
Epoch 600, train_loss: 0.4884, val_loss: 0.2459, r_value: 0.5049
Epoch 700, train_loss: 0.4477, val_loss: 0.2981, r_value: 0.5021
Epoch 800, train_loss: 0.4495, val_loss: 0.2921, r_value: 0.5036
Epoch 900, train_loss: 0.4519, val_loss: 0.2866, r_value: 0.5047
Epoch 1000, train_loss: 0.4544, val_loss: 0.2813, r_value: 0.5056
Epoch 1100, train_loss: 0.4571, val_loss: 0.2761, r_value: 0.5063
Epoch 1200, train_loss: 0.4599, val_loss: 0.2711, r_value: 0.5068
Epoch 1300, train_loss: 0.4627, val_loss: 0.2662, r_value: 0.5074
Epoch 1400, train_loss: 0.4655, val_loss: 0.2615, r_value: 0.5078
Epoch 1500, train_loss: 0.4678, val_loss: 0.2575, r_value: 0.5082
Epoch 1600, train_loss: 0.4692, val_loss: 0.2550, r_value: 0.5085
Epoch 1700, train_loss: 0.4696, val_loss: 0.2541, r_value: 0.5087
Epoch 1800, train_loss: 0.4694, val_loss: 0.2539, r_value: 0.5089
Epoch 1900, train_loss: 0.4691, val_loss: 0.2541, r_value: 0.5090
../_images/CASESTUDY_NNs_pt3_SHAP_9_1.png

hid_dim: 1, lr: 0.1
Deleting previous model
Epoch 0, train_loss: 0.1608, val_loss: 0.0436, r_value: 0.0382
Epoch 100, train_loss: 0.0303, val_loss: 0.0298, r_value: 0.4364
Epoch 200, train_loss: 0.0281, val_loss: 0.0277, r_value: 0.4957
Epoch 300, train_loss: 0.0274, val_loss: 0.0270, r_value: 0.5121
Epoch 400, train_loss: 0.0271, val_loss: 0.0267, r_value: 0.5195
Epoch 500, train_loss: 0.0269, val_loss: 0.0265, r_value: 0.5238
Epoch 600, train_loss: 0.0268, val_loss: 0.0264, r_value: 0.5267
Epoch 700, train_loss: 0.0267, val_loss: 0.0263, r_value: 0.5288
Epoch 800, train_loss: 0.0266, val_loss: 0.0263, r_value: 0.5303
Epoch 900, train_loss: 0.0266, val_loss: 0.0262, r_value: 0.5315
Epoch 1000, train_loss: 0.0266, val_loss: 0.0262, r_value: 0.5325
Epoch 1100, train_loss: 0.0265, val_loss: 0.0261, r_value: 0.5333
Epoch 1200, train_loss: 0.0265, val_loss: 0.0261, r_value: 0.5339
Epoch 1300, train_loss: 0.0265, val_loss: 0.0261, r_value: 0.5344
Epoch 1400, train_loss: 0.0264, val_loss: 0.0261, r_value: 0.5348
Epoch 1500, train_loss: 0.0264, val_loss: 0.0261, r_value: 0.5351
Epoch 1600, train_loss: 0.0264, val_loss: 0.0261, r_value: 0.5354
Epoch 1700, train_loss: 0.0264, val_loss: 0.0261, r_value: 0.5357
Epoch 1800, train_loss: 0.0264, val_loss: 0.0260, r_value: 0.5359
Epoch 1900, train_loss: 0.0264, val_loss: 0.0260, r_value: 0.5360
../_images/CASESTUDY_NNs_pt3_SHAP_9_3.png

hid_dim: 1, lr: 0.01
Deleting previous model
Epoch 0, train_loss: 0.0635, val_loss: 0.0618, r_value: -0.2940
Epoch 100, train_loss: 0.0521, val_loss: 0.0517, r_value: -0.2424
Epoch 200, train_loss: 0.0485, val_loss: 0.0482, r_value: -0.1854
Epoch 300, train_loss: 0.0455, val_loss: 0.0452, r_value: -0.1220
Epoch 400, train_loss: 0.0430, val_loss: 0.0426, r_value: -0.0548
Epoch 500, train_loss: 0.0408, val_loss: 0.0405, r_value: 0.0133
Epoch 600, train_loss: 0.0390, val_loss: 0.0387, r_value: 0.0791
Epoch 700, train_loss: 0.0374, val_loss: 0.0371, r_value: 0.1400
Epoch 800, train_loss: 0.0361, val_loss: 0.0358, r_value: 0.1946
Epoch 900, train_loss: 0.0350, val_loss: 0.0347, r_value: 0.2422
Epoch 1000, train_loss: 0.0340, val_loss: 0.0337, r_value: 0.2829
Epoch 1100, train_loss: 0.0332, val_loss: 0.0329, r_value: 0.3175
Epoch 1200, train_loss: 0.0324, val_loss: 0.0321, r_value: 0.3466
Epoch 1300, train_loss: 0.0318, val_loss: 0.0315, r_value: 0.3711
Epoch 1400, train_loss: 0.0313, val_loss: 0.0310, r_value: 0.3916
Epoch 1500, train_loss: 0.0308, val_loss: 0.0305, r_value: 0.4090
Epoch 1600, train_loss: 0.0303, val_loss: 0.0300, r_value: 0.4237
Epoch 1700, train_loss: 0.0300, val_loss: 0.0297, r_value: 0.4362
Epoch 1800, train_loss: 0.0296, val_loss: 0.0293, r_value: 0.4469
Epoch 1900, train_loss: 0.0294, val_loss: 0.0290, r_value: 0.4561
../_images/CASESTUDY_NNs_pt3_SHAP_9_5.png
[13]:
# Try with FF
class Feedforward(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Feedforward, self).__init__()
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.fc2 = torch.nn.Linear(self.hidden_size, self.hidden_size)
        # self.fc3 = torch.nn.Linear(self.hidden_size, self.hidden_size)
        # self.fc4 = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc3 = torch.nn.Linear(self.hidden_size, 1)
        self.sigmoid = torch.nn.Sigmoid()
        self.tanh = torch.nn.Tanh()
    def forward(self, x):
        hidden = self.relu(self.fc1(x))
        hidden = self.relu(self.fc2(hidden))
        # hidden = self.relu(self.fc3(hidden))
        # hidden = self.relu(self.fc4(hidden))
        output = self.tanh(self.fc3(hidden))

        return output
[14]:
# model.train()
epoch = 2000
hid_dim_range = [128,256,512]
lr_range = [0.75,0.5,0.1]

for hid_dim in hid_dim_range:
    for lr in lr_range:
        print('\nhid_dim: {}, lr: {}'.format(hid_dim, lr))
        if 'model' in globals():
            print('Deleting previous model')
            del model, criterion, optimizer
        model = Feedforward(data.shape[1], hid_dim)
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr = lr)

        all_loss_train=[]
        all_loss_val=[]
        for epoch in range(epoch):
            model.train()
            optimizer.zero_grad()
            # Forward pass
            y_pred = model(X_train)
            # Compute Loss
            loss = criterion(y_pred.squeeze(), y_train)

            # Backward pass
            loss.backward()
            optimizer.step()

            all_loss_train.append(loss.item())

            model.eval()
            with torch.no_grad():
                y_pred = model(X_test)
                # Compute Loss
                loss = criterion(y_pred.squeeze(), y_test)
                all_loss_val.append(loss.item())

                if epoch%100==0:
                    y_pred = y_pred.detach().numpy().squeeze()
                    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(y_pred, y_test)
                    print('Epoch {}, train_loss: {:.4f}, val_loss: {:.4f}, r_value: {:.4f}'.format(epoch,all_loss_train[-1],all_loss_val[-1],r_value))

        fig,ax=plt.subplots(1,2,figsize=(10,5))
        ax[0].plot(np.log10(all_loss_train))
        ax[0].plot(np.log10(all_loss_val))

        ax[1].scatter(y_pred, y_test)
        ax[1].set_xlabel('Prediction')
        ax[1].set_ylabel('True')
        ax[1].set_title('slope: {:.4f}, r_value: {:.4f}'.format(slope, r_value))
        plt.show()

hid_dim: 128, lr: 0.75
Deleting previous model
Epoch 0, train_loss: 0.0403, val_loss: 0.0977, r_value: 0.2438
Epoch 100, train_loss: 0.0345, val_loss: 0.0356, r_value: 0.4551
Epoch 200, train_loss: 0.0327, val_loss: 0.0337, r_value: 0.4872
Epoch 300, train_loss: 0.0322, val_loss: 0.0327, r_value: 0.5180
Epoch 400, train_loss: 0.0316, val_loss: 0.0318, r_value: 0.5361
Epoch 500, train_loss: 0.0307, val_loss: 0.0308, r_value: 0.5470
Epoch 600, train_loss: 0.0300, val_loss: 0.0300, r_value: 0.5541
Epoch 700, train_loss: 0.0294, val_loss: 0.0293, r_value: 0.5589
Epoch 800, train_loss: 0.0288, val_loss: 0.0287, r_value: 0.5628
Epoch 900, train_loss: 0.0284, val_loss: 0.0283, r_value: 0.5659
Epoch 1000, train_loss: 0.0280, val_loss: 0.0279, r_value: 0.5685
Epoch 1100, train_loss: 0.0277, val_loss: 0.0277, r_value: 0.5706
Epoch 1200, train_loss: 0.0275, val_loss: 0.0274, r_value: 0.5724
Epoch 1300, train_loss: 0.0273, val_loss: 0.0272, r_value: 0.5741
Epoch 1400, train_loss: 0.0271, val_loss: 0.0271, r_value: 0.5755
Epoch 1500, train_loss: 0.0269, val_loss: 0.0269, r_value: 0.5769
Epoch 1600, train_loss: 0.0268, val_loss: 0.0268, r_value: 0.5781
Epoch 1700, train_loss: 0.0267, val_loss: 0.0266, r_value: 0.5794
Epoch 1800, train_loss: 0.0265, val_loss: 0.0265, r_value: 0.5805
Epoch 1900, train_loss: 0.0265, val_loss: 0.0265, r_value: 0.5815
../_images/CASESTUDY_NNs_pt3_SHAP_11_1.png

hid_dim: 128, lr: 0.5
Deleting previous model
Epoch 0, train_loss: 0.0385, val_loss: 0.0446, r_value: 0.2170
Epoch 100, train_loss: 0.0314, val_loss: 0.0298, r_value: 0.4997
Epoch 200, train_loss: 0.0292, val_loss: 0.0285, r_value: 0.5248
Epoch 300, train_loss: 0.0283, val_loss: 0.0277, r_value: 0.5393
Epoch 400, train_loss: 0.0277, val_loss: 0.0273, r_value: 0.5491
Epoch 500, train_loss: 0.0273, val_loss: 0.0268, r_value: 0.5554
Epoch 600, train_loss: 0.0269, val_loss: 0.0265, r_value: 0.5596
Epoch 700, train_loss: 0.0265, val_loss: 0.0262, r_value: 0.5628
Epoch 800, train_loss: 0.0263, val_loss: 0.0259, r_value: 0.5653
Epoch 900, train_loss: 0.0260, val_loss: 0.0258, r_value: 0.5674
Epoch 1000, train_loss: 0.0259, val_loss: 0.0256, r_value: 0.5693
Epoch 1100, train_loss: 0.0257, val_loss: 0.0254, r_value: 0.5711
Epoch 1200, train_loss: 0.0255, val_loss: 0.0253, r_value: 0.5727
Epoch 1300, train_loss: 0.0254, val_loss: 0.0252, r_value: 0.5742
Epoch 1400, train_loss: 0.0253, val_loss: 0.0251, r_value: 0.5755
Epoch 1500, train_loss: 0.0252, val_loss: 0.0250, r_value: 0.5769
Epoch 1600, train_loss: 0.0252, val_loss: 0.0250, r_value: 0.5781
Epoch 1700, train_loss: 0.0251, val_loss: 0.0249, r_value: 0.5793
Epoch 1800, train_loss: 0.0250, val_loss: 0.0248, r_value: 0.5804
Epoch 1900, train_loss: 0.0250, val_loss: 0.0248, r_value: 0.5815
../_images/CASESTUDY_NNs_pt3_SHAP_11_3.png

hid_dim: 128, lr: 0.1
Deleting previous model
Epoch 0, train_loss: 0.0375, val_loss: 0.0367, r_value: -0.0293
Epoch 100, train_loss: 0.0315, val_loss: 0.0311, r_value: 0.4840
Epoch 200, train_loss: 0.0286, val_loss: 0.0282, r_value: 0.5040
Epoch 300, train_loss: 0.0274, val_loss: 0.0271, r_value: 0.5148
Epoch 400, train_loss: 0.0270, val_loss: 0.0266, r_value: 0.5230
Epoch 500, train_loss: 0.0267, val_loss: 0.0263, r_value: 0.5301
Epoch 600, train_loss: 0.0264, val_loss: 0.0261, r_value: 0.5359
Epoch 700, train_loss: 0.0262, val_loss: 0.0259, r_value: 0.5404
Epoch 800, train_loss: 0.0261, val_loss: 0.0258, r_value: 0.5439
Epoch 900, train_loss: 0.0260, val_loss: 0.0256, r_value: 0.5468
Epoch 1000, train_loss: 0.0259, val_loss: 0.0255, r_value: 0.5493
Epoch 1100, train_loss: 0.0258, val_loss: 0.0254, r_value: 0.5514
Epoch 1200, train_loss: 0.0257, val_loss: 0.0254, r_value: 0.5532
Epoch 1300, train_loss: 0.0256, val_loss: 0.0253, r_value: 0.5547
Epoch 1400, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5561
Epoch 1500, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5574
Epoch 1600, train_loss: 0.0254, val_loss: 0.0251, r_value: 0.5586
Epoch 1700, train_loss: 0.0254, val_loss: 0.0251, r_value: 0.5596
Epoch 1800, train_loss: 0.0253, val_loss: 0.0251, r_value: 0.5606
Epoch 1900, train_loss: 0.0253, val_loss: 0.0250, r_value: 0.5615
../_images/CASESTUDY_NNs_pt3_SHAP_11_5.png

hid_dim: 256, lr: 0.75
Deleting previous model
Epoch 0, train_loss: 0.0403, val_loss: 0.0921, r_value: 0.0286
Epoch 100, train_loss: 0.0349, val_loss: 0.0329, r_value: 0.4931
Epoch 200, train_loss: 0.0336, val_loss: 0.0321, r_value: 0.5201
Epoch 300, train_loss: 0.0325, val_loss: 0.0315, r_value: 0.5385
Epoch 400, train_loss: 0.0313, val_loss: 0.0306, r_value: 0.5499
Epoch 500, train_loss: 0.0303, val_loss: 0.0298, r_value: 0.5571
Epoch 600, train_loss: 0.0295, val_loss: 0.0291, r_value: 0.5621
Epoch 700, train_loss: 0.0289, val_loss: 0.0286, r_value: 0.5659
Epoch 800, train_loss: 0.0284, val_loss: 0.0281, r_value: 0.5691
Epoch 900, train_loss: 0.0280, val_loss: 0.0278, r_value: 0.5719
Epoch 1000, train_loss: 0.0278, val_loss: 0.0275, r_value: 0.5743
Epoch 1100, train_loss: 0.0275, val_loss: 0.0273, r_value: 0.5765
Epoch 1200, train_loss: 0.0273, val_loss: 0.0271, r_value: 0.5785
Epoch 1300, train_loss: 0.0272, val_loss: 0.0269, r_value: 0.5803
Epoch 1400, train_loss: 0.0270, val_loss: 0.0268, r_value: 0.5819
Epoch 1500, train_loss: 0.0268, val_loss: 0.0266, r_value: 0.5835
Epoch 1600, train_loss: 0.0267, val_loss: 0.0265, r_value: 0.5849
Epoch 1700, train_loss: 0.0266, val_loss: 0.0264, r_value: 0.5862
Epoch 1800, train_loss: 0.0265, val_loss: 0.0263, r_value: 0.5874
Epoch 1900, train_loss: 0.0264, val_loss: 0.0262, r_value: 0.5885
../_images/CASESTUDY_NNs_pt3_SHAP_11_7.png

hid_dim: 256, lr: 0.5
Deleting previous model
Epoch 0, train_loss: 0.0382, val_loss: 0.0552, r_value: 0.1675
Epoch 100, train_loss: 0.0382, val_loss: 0.0346, r_value: 0.5131
Epoch 200, train_loss: 0.0292, val_loss: 0.0283, r_value: 0.5328
Epoch 300, train_loss: 0.0282, val_loss: 0.0275, r_value: 0.5474
Epoch 400, train_loss: 0.0275, val_loss: 0.0270, r_value: 0.5569
Epoch 500, train_loss: 0.0269, val_loss: 0.0265, r_value: 0.5631
Epoch 600, train_loss: 0.0265, val_loss: 0.0261, r_value: 0.5675
Epoch 700, train_loss: 0.0261, val_loss: 0.0258, r_value: 0.5709
Epoch 800, train_loss: 0.0258, val_loss: 0.0256, r_value: 0.5737
Epoch 900, train_loss: 0.0256, val_loss: 0.0254, r_value: 0.5761
Epoch 1000, train_loss: 0.0254, val_loss: 0.0252, r_value: 0.5782
Epoch 1100, train_loss: 0.0253, val_loss: 0.0251, r_value: 0.5801
Epoch 1200, train_loss: 0.0252, val_loss: 0.0250, r_value: 0.5819
Epoch 1300, train_loss: 0.0250, val_loss: 0.0249, r_value: 0.5835
Epoch 1400, train_loss: 0.0249, val_loss: 0.0248, r_value: 0.5850
Epoch 1500, train_loss: 0.0248, val_loss: 0.0247, r_value: 0.5864
Epoch 1600, train_loss: 0.0248, val_loss: 0.0246, r_value: 0.5877
Epoch 1700, train_loss: 0.0247, val_loss: 0.0246, r_value: 0.5889
Epoch 1800, train_loss: 0.0246, val_loss: 0.0245, r_value: 0.5900
Epoch 1900, train_loss: 0.0246, val_loss: 0.0244, r_value: 0.5911
../_images/CASESTUDY_NNs_pt3_SHAP_11_9.png

hid_dim: 256, lr: 0.1
Deleting previous model
Epoch 0, train_loss: 0.0383, val_loss: 0.0366, r_value: 0.0119
Epoch 100, train_loss: 0.0293, val_loss: 0.0289, r_value: 0.4922
Epoch 200, train_loss: 0.0274, val_loss: 0.0271, r_value: 0.5165
Epoch 300, train_loss: 0.0267, val_loss: 0.0264, r_value: 0.5302
Epoch 400, train_loss: 0.0263, val_loss: 0.0260, r_value: 0.5387
Epoch 500, train_loss: 0.0261, val_loss: 0.0257, r_value: 0.5446
Epoch 600, train_loss: 0.0259, val_loss: 0.0256, r_value: 0.5488
Epoch 700, train_loss: 0.0257, val_loss: 0.0254, r_value: 0.5519
Epoch 800, train_loss: 0.0256, val_loss: 0.0253, r_value: 0.5543
Epoch 900, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5562
Epoch 1000, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5578
Epoch 1100, train_loss: 0.0254, val_loss: 0.0251, r_value: 0.5593
Epoch 1200, train_loss: 0.0253, val_loss: 0.0251, r_value: 0.5606
Epoch 1300, train_loss: 0.0253, val_loss: 0.0250, r_value: 0.5618
Epoch 1400, train_loss: 0.0252, val_loss: 0.0250, r_value: 0.5630
Epoch 1500, train_loss: 0.0252, val_loss: 0.0249, r_value: 0.5641
Epoch 1600, train_loss: 0.0251, val_loss: 0.0249, r_value: 0.5651
Epoch 1700, train_loss: 0.0251, val_loss: 0.0248, r_value: 0.5662
Epoch 1800, train_loss: 0.0251, val_loss: 0.0249, r_value: 0.5671
Epoch 1900, train_loss: 0.0252, val_loss: 0.0249, r_value: 0.5680
../_images/CASESTUDY_NNs_pt3_SHAP_11_11.png

hid_dim: 512, lr: 0.75
Deleting previous model
Epoch 0, train_loss: 0.0458, val_loss: 0.6252, r_value: 0.0520
Epoch 100, train_loss: 0.0350, val_loss: 0.0329, r_value: 0.4944
Epoch 200, train_loss: 0.0348, val_loss: 0.0329, r_value: 0.5054
Epoch 300, train_loss: 0.0340, val_loss: 0.0328, r_value: 0.5199
Epoch 400, train_loss: 0.0331, val_loss: 0.0324, r_value: 0.5327
Epoch 500, train_loss: 0.0323, val_loss: 0.0318, r_value: 0.5429
Epoch 600, train_loss: 0.0315, val_loss: 0.0311, r_value: 0.5513
Epoch 700, train_loss: 0.0308, val_loss: 0.0304, r_value: 0.5579
Epoch 800, train_loss: 0.0301, val_loss: 0.0297, r_value: 0.5630
Epoch 900, train_loss: 0.0295, val_loss: 0.0292, r_value: 0.5672
Epoch 1000, train_loss: 0.0290, val_loss: 0.0287, r_value: 0.5707
Epoch 1100, train_loss: 0.0286, val_loss: 0.0283, r_value: 0.5737
Epoch 1200, train_loss: 0.0283, val_loss: 0.0280, r_value: 0.5764
Epoch 1300, train_loss: 0.0279, val_loss: 0.0276, r_value: 0.5787
Epoch 1400, train_loss: 0.0276, val_loss: 0.0274, r_value: 0.5807
Epoch 1500, train_loss: 0.0274, val_loss: 0.0271, r_value: 0.5826
Epoch 1600, train_loss: 0.0271, val_loss: 0.0269, r_value: 0.5842
Epoch 1700, train_loss: 0.0269, val_loss: 0.0267, r_value: 0.5857
Epoch 1800, train_loss: 0.0267, val_loss: 0.0265, r_value: 0.5871
Epoch 1900, train_loss: 0.0265, val_loss: 0.0263, r_value: 0.5883
../_images/CASESTUDY_NNs_pt3_SHAP_11_13.png

hid_dim: 512, lr: 0.5
Deleting previous model
Epoch 0, train_loss: 0.0367, val_loss: 0.0350, r_value: 0.3153
Epoch 100, train_loss: 0.0314, val_loss: 0.0296, r_value: 0.5187
Epoch 200, train_loss: 0.0292, val_loss: 0.0282, r_value: 0.5383
Epoch 300, train_loss: 0.0280, val_loss: 0.0274, r_value: 0.5516
Epoch 400, train_loss: 0.0272, val_loss: 0.0268, r_value: 0.5597
Epoch 500, train_loss: 0.0266, val_loss: 0.0263, r_value: 0.5652
Epoch 600, train_loss: 0.0262, val_loss: 0.0259, r_value: 0.5692
Epoch 700, train_loss: 0.0259, val_loss: 0.0256, r_value: 0.5724
Epoch 800, train_loss: 0.0257, val_loss: 0.0254, r_value: 0.5753
Epoch 900, train_loss: 0.0255, val_loss: 0.0253, r_value: 0.5777
Epoch 1000, train_loss: 0.0254, val_loss: 0.0251, r_value: 0.5798
Epoch 1100, train_loss: 0.0252, val_loss: 0.0250, r_value: 0.5817
Epoch 1200, train_loss: 0.0251, val_loss: 0.0249, r_value: 0.5833
Epoch 1300, train_loss: 0.0250, val_loss: 0.0248, r_value: 0.5848
Epoch 1400, train_loss: 0.0249, val_loss: 0.0248, r_value: 0.5862
Epoch 1500, train_loss: 0.0249, val_loss: 0.0247, r_value: 0.5875
Epoch 1600, train_loss: 0.0248, val_loss: 0.0246, r_value: 0.5887
Epoch 1700, train_loss: 0.0247, val_loss: 0.0245, r_value: 0.5898
Epoch 1800, train_loss: 0.0246, val_loss: 0.0245, r_value: 0.5909
Epoch 1900, train_loss: 0.0246, val_loss: 0.0244, r_value: 0.5919
../_images/CASESTUDY_NNs_pt3_SHAP_11_15.png

hid_dim: 512, lr: 0.1
Deleting previous model
Epoch 0, train_loss: 0.0402, val_loss: 0.0381, r_value: -0.0979
Epoch 100, train_loss: 0.0274, val_loss: 0.0270, r_value: 0.5218
Epoch 200, train_loss: 0.0266, val_loss: 0.0262, r_value: 0.5358
Epoch 300, train_loss: 0.0262, val_loss: 0.0258, r_value: 0.5440
Epoch 400, train_loss: 0.0260, val_loss: 0.0256, r_value: 0.5491
Epoch 500, train_loss: 0.0262, val_loss: 0.0259, r_value: 0.5529
Epoch 600, train_loss: 0.0260, val_loss: 0.0257, r_value: 0.5556
Epoch 700, train_loss: 0.0258, val_loss: 0.0254, r_value: 0.5579
Epoch 800, train_loss: 0.0257, val_loss: 0.0254, r_value: 0.5600
Epoch 900, train_loss: 0.0256, val_loss: 0.0253, r_value: 0.5617
Epoch 1000, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5633
Epoch 1100, train_loss: 0.0254, val_loss: 0.0251, r_value: 0.5648
Epoch 1200, train_loss: 0.0253, val_loss: 0.0251, r_value: 0.5663
Epoch 1300, train_loss: 0.0253, val_loss: 0.0250, r_value: 0.5676
Epoch 1400, train_loss: 0.0252, val_loss: 0.0249, r_value: 0.5689
Epoch 1500, train_loss: 0.0252, val_loss: 0.0249, r_value: 0.5701
Epoch 1600, train_loss: 0.0251, val_loss: 0.0248, r_value: 0.5713
Epoch 1700, train_loss: 0.0250, val_loss: 0.0248, r_value: 0.5724
Epoch 1800, train_loss: 0.0250, val_loss: 0.0247, r_value: 0.5735
Epoch 1900, train_loss: 0.0249, val_loss: 0.0247, r_value: 0.5745
../_images/CASESTUDY_NNs_pt3_SHAP_11_17.png
[15]:
# Try deeper FF
class Feedforward(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Feedforward, self).__init__()
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.fc2 = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.fc3 = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.fc4 = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc5 = torch.nn.Linear(self.hidden_size, 1)
        self.sigmoid = torch.nn.Sigmoid()
        self.tanh = torch.nn.Tanh()
    def forward(self, x):
        hidden = self.relu(self.fc1(x))
        hidden = self.relu(self.fc2(hidden))
        hidden = self.relu(self.fc3(hidden))
        hidden = self.relu(self.fc4(hidden))
        output = self.tanh(self.fc5(hidden))

        return output
[16]:
# model.train()
epoch = 2000
hid_dim_range = [128,256,512]
lr_range = [0.75,0.5,0.1,0.01,0.05]

for hid_dim in hid_dim_range:
    for lr in lr_range:
        print('\nhid_dim: {}, lr: {}'.format(hid_dim, lr))
        if 'model' in globals():
            print('Deleting previous model')
            del model, criterion, optimizer
        model = Feedforward(data.shape[1], hid_dim)
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr = lr)

        all_loss_train=[]
        all_loss_val=[]
        for epoch in range(epoch):
            model.train()
            optimizer.zero_grad()
            # Forward pass
            y_pred = model(X_train)
            # Compute Loss
            loss = criterion(y_pred.squeeze(), y_train)

            # Backward pass
            loss.backward()
            optimizer.step()

            all_loss_train.append(loss.item())

            model.eval()
            with torch.no_grad():
                y_pred = model(X_test)
                # Compute Loss
                loss = criterion(y_pred.squeeze(), y_test)
                all_loss_val.append(loss.item())

                if epoch%100==0:
                    y_pred = y_pred.detach().numpy().squeeze()
                    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(y_pred, y_test)
                    print('Epoch {}, train_loss: {:.4f}, val_loss: {:.4f}, r_value: {:.4f}'.format(epoch,all_loss_train[-1],all_loss_val[-1],r_value))

        fig,ax=plt.subplots(1,2,figsize=(10,5))
        ax[0].plot(np.log10(all_loss_train))
        ax[0].plot(np.log10(all_loss_val))

        ax[1].scatter(y_pred, y_test)
        ax[1].set_xlabel('Prediction')
        ax[1].set_ylabel('True')
        ax[1].set_title('slope: {:.4f}, r_value: {:.4f}'.format(slope, r_value))
        plt.show()

hid_dim: 128, lr: 0.75
Deleting previous model
Epoch 0, train_loss: 0.0373, val_loss: 0.0372, r_value: 0.1484
Epoch 100, train_loss: 0.0394, val_loss: 0.0394, r_value: 0.4557
Epoch 200, train_loss: 0.0359, val_loss: 0.0371, r_value: 0.4361
Epoch 300, train_loss: 0.0331, val_loss: 0.0344, r_value: 0.4529
Epoch 400, train_loss: 0.0324, val_loss: 0.0334, r_value: 0.4718
Epoch 500, train_loss: 0.0321, val_loss: 0.0329, r_value: 0.4873
Epoch 600, train_loss: 0.0320, val_loss: 0.0325, r_value: 0.5015
Epoch 700, train_loss: 0.0318, val_loss: 0.0321, r_value: 0.5146
Epoch 800, train_loss: 0.0315, val_loss: 0.0316, r_value: 0.5258
Epoch 900, train_loss: 0.0311, val_loss: 0.0311, r_value: 0.5345
Epoch 1000, train_loss: 0.0305, val_loss: 0.0305, r_value: 0.5408
Epoch 1100, train_loss: 0.0300, val_loss: 0.0299, r_value: 0.5457
Epoch 1200, train_loss: 0.0296, val_loss: 0.0295, r_value: 0.5498
Epoch 1300, train_loss: 0.0292, val_loss: 0.0291, r_value: 0.5530
Epoch 1400, train_loss: 0.0289, val_loss: 0.0287, r_value: 0.5557
Epoch 1500, train_loss: 0.0286, val_loss: 0.0284, r_value: 0.5579
Epoch 1600, train_loss: 0.0284, val_loss: 0.0282, r_value: 0.5597
Epoch 1700, train_loss: 0.0282, val_loss: 0.0280, r_value: 0.5613
Epoch 1800, train_loss: 0.0280, val_loss: 0.0278, r_value: 0.5627
Epoch 1900, train_loss: 0.0278, val_loss: 0.0276, r_value: 0.5638
../_images/CASESTUDY_NNs_pt3_SHAP_13_1.png

hid_dim: 128, lr: 0.5
Deleting previous model
Epoch 0, train_loss: 0.0392, val_loss: 0.0373, r_value: -0.1593
Epoch 100, train_loss: 0.0321, val_loss: 0.0316, r_value: 0.4797
Epoch 200, train_loss: 0.0324, val_loss: 0.0332, r_value: 0.4713
Epoch 300, train_loss: 0.0304, val_loss: 0.0310, r_value: 0.4947
Epoch 400, train_loss: 0.0294, val_loss: 0.0298, r_value: 0.5126
Epoch 500, train_loss: 0.0288, val_loss: 0.0289, r_value: 0.5264
Epoch 600, train_loss: 0.0282, val_loss: 0.0282, r_value: 0.5376
Epoch 700, train_loss: 0.0277, val_loss: 0.0276, r_value: 0.5458
Epoch 800, train_loss: 0.0273, val_loss: 0.0272, r_value: 0.5518
Epoch 900, train_loss: 0.0269, val_loss: 0.0268, r_value: 0.5563
Epoch 1000, train_loss: 0.0266, val_loss: 0.0264, r_value: 0.5597
Epoch 1100, train_loss: 0.0264, val_loss: 0.0261, r_value: 0.5624
Epoch 1200, train_loss: 0.0262, val_loss: 0.0259, r_value: 0.5645
Epoch 1300, train_loss: 0.0260, val_loss: 0.0258, r_value: 0.5662
Epoch 1400, train_loss: 0.0259, val_loss: 0.0256, r_value: 0.5676
Epoch 1500, train_loss: 0.0257, val_loss: 0.0255, r_value: 0.5688
Epoch 1600, train_loss: 0.0256, val_loss: 0.0254, r_value: 0.5698
Epoch 1700, train_loss: 0.0256, val_loss: 0.0254, r_value: 0.5707
Epoch 1800, train_loss: 0.0255, val_loss: 0.0253, r_value: 0.5715
Epoch 1900, train_loss: 0.0254, val_loss: 0.0253, r_value: 0.5722
../_images/CASESTUDY_NNs_pt3_SHAP_13_3.png

hid_dim: 128, lr: 0.1
Deleting previous model
Epoch 0, train_loss: 0.0389, val_loss: 0.0375, r_value: 0.0934
Epoch 100, train_loss: 0.0362, val_loss: 0.0358, r_value: 0.3076
Epoch 200, train_loss: 0.0356, val_loss: 0.0352, r_value: 0.3654
Epoch 300, train_loss: 0.0347, val_loss: 0.0343, r_value: 0.4075
Epoch 400, train_loss: 0.0333, val_loss: 0.0329, r_value: 0.4384
Epoch 500, train_loss: 0.0315, val_loss: 0.0311, r_value: 0.4633
Epoch 600, train_loss: 0.0295, val_loss: 0.0291, r_value: 0.4867
Epoch 700, train_loss: 0.0280, val_loss: 0.0277, r_value: 0.5066
Epoch 800, train_loss: 0.0271, val_loss: 0.0268, r_value: 0.5209
Epoch 900, train_loss: 0.0267, val_loss: 0.0263, r_value: 0.5305
Epoch 1000, train_loss: 0.0264, val_loss: 0.0260, r_value: 0.5375
Epoch 1100, train_loss: 0.0261, val_loss: 0.0258, r_value: 0.5424
Epoch 1200, train_loss: 0.0260, val_loss: 0.0257, r_value: 0.5458
Epoch 1300, train_loss: 0.0259, val_loss: 0.0256, r_value: 0.5483
Epoch 1400, train_loss: 0.0258, val_loss: 0.0255, r_value: 0.5503
Epoch 1500, train_loss: 0.0257, val_loss: 0.0254, r_value: 0.5520
Epoch 1600, train_loss: 0.0256, val_loss: 0.0254, r_value: 0.5534
Epoch 1700, train_loss: 0.0256, val_loss: 0.0253, r_value: 0.5546
Epoch 1800, train_loss: 0.0255, val_loss: 0.0253, r_value: 0.5557
Epoch 1900, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5568
../_images/CASESTUDY_NNs_pt3_SHAP_13_5.png

hid_dim: 128, lr: 0.01
Deleting previous model
Epoch 0, train_loss: 0.0376, val_loss: 0.0372, r_value: -0.1016
Epoch 100, train_loss: 0.0370, val_loss: 0.0366, r_value: -0.0717
Epoch 200, train_loss: 0.0370, val_loss: 0.0366, r_value: -0.0385
Epoch 300, train_loss: 0.0369, val_loss: 0.0365, r_value: -0.0043
Epoch 400, train_loss: 0.0369, val_loss: 0.0365, r_value: 0.0320
Epoch 500, train_loss: 0.0369, val_loss: 0.0365, r_value: 0.0708
Epoch 600, train_loss: 0.0368, val_loss: 0.0365, r_value: 0.1087
Epoch 700, train_loss: 0.0368, val_loss: 0.0364, r_value: 0.1422
Epoch 800, train_loss: 0.0368, val_loss: 0.0364, r_value: 0.1700
Epoch 900, train_loss: 0.0368, val_loss: 0.0364, r_value: 0.1937
Epoch 1000, train_loss: 0.0367, val_loss: 0.0364, r_value: 0.2148
Epoch 1100, train_loss: 0.0367, val_loss: 0.0363, r_value: 0.2338
Epoch 1200, train_loss: 0.0367, val_loss: 0.0363, r_value: 0.2511
Epoch 1300, train_loss: 0.0367, val_loss: 0.0363, r_value: 0.2669
Epoch 1400, train_loss: 0.0367, val_loss: 0.0363, r_value: 0.2814
Epoch 1500, train_loss: 0.0366, val_loss: 0.0362, r_value: 0.2946
Epoch 1600, train_loss: 0.0366, val_loss: 0.0362, r_value: 0.3067
Epoch 1700, train_loss: 0.0366, val_loss: 0.0362, r_value: 0.3179
Epoch 1800, train_loss: 0.0366, val_loss: 0.0362, r_value: 0.3282
Epoch 1900, train_loss: 0.0365, val_loss: 0.0361, r_value: 0.3377
../_images/CASESTUDY_NNs_pt3_SHAP_13_7.png

hid_dim: 128, lr: 0.05
Deleting previous model
Epoch 0, train_loss: 0.0372, val_loss: 0.0367, r_value: 0.0048
Epoch 100, train_loss: 0.0366, val_loss: 0.0362, r_value: 0.2854
Epoch 200, train_loss: 0.0362, val_loss: 0.0359, r_value: 0.3750
Epoch 300, train_loss: 0.0359, val_loss: 0.0356, r_value: 0.4130
Epoch 400, train_loss: 0.0356, val_loss: 0.0352, r_value: 0.4345
Epoch 500, train_loss: 0.0352, val_loss: 0.0348, r_value: 0.4492
Epoch 600, train_loss: 0.0347, val_loss: 0.0344, r_value: 0.4598
Epoch 700, train_loss: 0.0341, val_loss: 0.0338, r_value: 0.4685
Epoch 800, train_loss: 0.0334, val_loss: 0.0330, r_value: 0.4757
Epoch 900, train_loss: 0.0324, val_loss: 0.0321, r_value: 0.4821
Epoch 1000, train_loss: 0.0313, val_loss: 0.0310, r_value: 0.4884
Epoch 1100, train_loss: 0.0302, val_loss: 0.0298, r_value: 0.4936
Epoch 1200, train_loss: 0.0292, val_loss: 0.0288, r_value: 0.4989
Epoch 1300, train_loss: 0.0284, val_loss: 0.0280, r_value: 0.5043
Epoch 1400, train_loss: 0.0278, val_loss: 0.0274, r_value: 0.5098
Epoch 1500, train_loss: 0.0274, val_loss: 0.0271, r_value: 0.5151
Epoch 1600, train_loss: 0.0271, val_loss: 0.0268, r_value: 0.5200
Epoch 1700, train_loss: 0.0269, val_loss: 0.0266, r_value: 0.5246
Epoch 1800, train_loss: 0.0268, val_loss: 0.0264, r_value: 0.5285
Epoch 1900, train_loss: 0.0266, val_loss: 0.0262, r_value: 0.5320
../_images/CASESTUDY_NNs_pt3_SHAP_13_9.png

hid_dim: 256, lr: 0.75
Deleting previous model
Epoch 0, train_loss: 0.0410, val_loss: 0.0446, r_value: 0.2588
Epoch 100, train_loss: 0.0395, val_loss: 0.0378, r_value: 0.4959
Epoch 200, train_loss: 0.0361, val_loss: 0.0336, r_value: 0.5012
Epoch 300, train_loss: 0.0346, val_loss: 0.0325, r_value: 0.5152
Epoch 400, train_loss: 0.0337, val_loss: 0.0321, r_value: 0.5261
Epoch 500, train_loss: 0.0329, val_loss: 0.0317, r_value: 0.5365
Epoch 600, train_loss: 0.0321, val_loss: 0.0312, r_value: 0.5454
Epoch 700, train_loss: 0.0313, val_loss: 0.0306, r_value: 0.5518
Epoch 800, train_loss: 0.0306, val_loss: 0.0300, r_value: 0.5560
Epoch 900, train_loss: 0.0299, val_loss: 0.0294, r_value: 0.5589
Epoch 1000, train_loss: 0.0294, val_loss: 0.0290, r_value: 0.5611
Epoch 1100, train_loss: 0.0290, val_loss: 0.0286, r_value: 0.5628
Epoch 1200, train_loss: 0.0286, val_loss: 0.0283, r_value: 0.5643
Epoch 1300, train_loss: 0.0284, val_loss: 0.0280, r_value: 0.5657
Epoch 1400, train_loss: 0.0281, val_loss: 0.0278, r_value: 0.5669
Epoch 1500, train_loss: 0.0279, val_loss: 0.0276, r_value: 0.5681
Epoch 1600, train_loss: 0.0277, val_loss: 0.0274, r_value: 0.5692
Epoch 1700, train_loss: 0.0275, val_loss: 0.0273, r_value: 0.5702
Epoch 1800, train_loss: 0.0274, val_loss: 0.0271, r_value: 0.5711
Epoch 1900, train_loss: 0.0272, val_loss: 0.0270, r_value: 0.5720
../_images/CASESTUDY_NNs_pt3_SHAP_13_11.png

hid_dim: 256, lr: 0.5
Deleting previous model
Epoch 0, train_loss: 0.0394, val_loss: 0.0372, r_value: 0.2825
Epoch 100, train_loss: 0.0401, val_loss: 0.0367, r_value: 0.4948
Epoch 200, train_loss: 0.0319, val_loss: 0.0306, r_value: 0.5148
Epoch 300, train_loss: 0.0303, val_loss: 0.0292, r_value: 0.5296
Epoch 400, train_loss: 0.0293, val_loss: 0.0285, r_value: 0.5403
Epoch 500, train_loss: 0.0285, val_loss: 0.0279, r_value: 0.5483
Epoch 600, train_loss: 0.0279, val_loss: 0.0273, r_value: 0.5540
Epoch 700, train_loss: 0.0274, val_loss: 0.0269, r_value: 0.5582
Epoch 800, train_loss: 0.0270, val_loss: 0.0265, r_value: 0.5614
Epoch 900, train_loss: 0.0267, val_loss: 0.0263, r_value: 0.5639
Epoch 1000, train_loss: 0.0264, val_loss: 0.0260, r_value: 0.5659
Epoch 1100, train_loss: 0.0262, val_loss: 0.0259, r_value: 0.5675
Epoch 1200, train_loss: 0.0260, val_loss: 0.0257, r_value: 0.5689
Epoch 1300, train_loss: 0.0259, val_loss: 0.0256, r_value: 0.5701
Epoch 1400, train_loss: 0.0257, val_loss: 0.0255, r_value: 0.5712
Epoch 1500, train_loss: 0.0256, val_loss: 0.0254, r_value: 0.5722
Epoch 1600, train_loss: 0.0255, val_loss: 0.0253, r_value: 0.5730
Epoch 1700, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5738
Epoch 1800, train_loss: 0.0254, val_loss: 0.0252, r_value: 0.5745
Epoch 1900, train_loss: 0.0253, val_loss: 0.0251, r_value: 0.5753
../_images/CASESTUDY_NNs_pt3_SHAP_13_13.png

hid_dim: 256, lr: 0.1
Deleting previous model
Epoch 0, train_loss: 0.0391, val_loss: 0.0376, r_value: -0.0073
Epoch 100, train_loss: 0.0358, val_loss: 0.0354, r_value: 0.5005
Epoch 200, train_loss: 0.0344, val_loss: 0.0340, r_value: 0.4993
Epoch 300, train_loss: 0.0321, val_loss: 0.0317, r_value: 0.5011
Epoch 400, train_loss: 0.0293, val_loss: 0.0289, r_value: 0.5079
Epoch 500, train_loss: 0.0275, val_loss: 0.0271, r_value: 0.5183
Epoch 600, train_loss: 0.0268, val_loss: 0.0264, r_value: 0.5285
Epoch 700, train_loss: 0.0264, val_loss: 0.0260, r_value: 0.5368
Epoch 800, train_loss: 0.0261, val_loss: 0.0258, r_value: 0.5430
Epoch 900, train_loss: 0.0259, val_loss: 0.0256, r_value: 0.5477
Epoch 1000, train_loss: 0.0258, val_loss: 0.0254, r_value: 0.5513
Epoch 1100, train_loss: 0.0257, val_loss: 0.0253, r_value: 0.5541
Epoch 1200, train_loss: 0.0256, val_loss: 0.0252, r_value: 0.5564
Epoch 1300, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5582
Epoch 1400, train_loss: 0.0254, val_loss: 0.0251, r_value: 0.5598
Epoch 1500, train_loss: 0.0253, val_loss: 0.0250, r_value: 0.5611
Epoch 1600, train_loss: 0.0253, val_loss: 0.0250, r_value: 0.5623
Epoch 1700, train_loss: 0.0252, val_loss: 0.0249, r_value: 0.5634
Epoch 1800, train_loss: 0.0252, val_loss: 0.0249, r_value: 0.5643
Epoch 1900, train_loss: 0.0251, val_loss: 0.0249, r_value: 0.5652
../_images/CASESTUDY_NNs_pt3_SHAP_13_15.png

hid_dim: 256, lr: 0.01
Deleting previous model
Epoch 0, train_loss: 0.0403, val_loss: 0.0397, r_value: -0.1768
Epoch 100, train_loss: 0.0369, val_loss: 0.0365, r_value: -0.0038
Epoch 200, train_loss: 0.0368, val_loss: 0.0364, r_value: 0.2312
Epoch 300, train_loss: 0.0367, val_loss: 0.0363, r_value: 0.3537
Epoch 400, train_loss: 0.0365, val_loss: 0.0362, r_value: 0.4118
Epoch 500, train_loss: 0.0364, val_loss: 0.0361, r_value: 0.4425
Epoch 600, train_loss: 0.0363, val_loss: 0.0359, r_value: 0.4619
Epoch 700, train_loss: 0.0362, val_loss: 0.0358, r_value: 0.4750
Epoch 800, train_loss: 0.0361, val_loss: 0.0357, r_value: 0.4821
Epoch 900, train_loss: 0.0360, val_loss: 0.0356, r_value: 0.4869
Epoch 1000, train_loss: 0.0359, val_loss: 0.0355, r_value: 0.4904
Epoch 1100, train_loss: 0.0358, val_loss: 0.0354, r_value: 0.4929
Epoch 1200, train_loss: 0.0356, val_loss: 0.0353, r_value: 0.4948
Epoch 1300, train_loss: 0.0355, val_loss: 0.0352, r_value: 0.4962
Epoch 1400, train_loss: 0.0354, val_loss: 0.0350, r_value: 0.4973
Epoch 1500, train_loss: 0.0353, val_loss: 0.0349, r_value: 0.4982
Epoch 1600, train_loss: 0.0351, val_loss: 0.0348, r_value: 0.4989
Epoch 1700, train_loss: 0.0350, val_loss: 0.0346, r_value: 0.4997
Epoch 1800, train_loss: 0.0348, val_loss: 0.0345, r_value: 0.5004
Epoch 1900, train_loss: 0.0347, val_loss: 0.0343, r_value: 0.5009
../_images/CASESTUDY_NNs_pt3_SHAP_13_17.png

hid_dim: 256, lr: 0.05
Deleting previous model
Epoch 0, train_loss: 0.0375, val_loss: 0.0370, r_value: -0.0756
Epoch 100, train_loss: 0.0364, val_loss: 0.0360, r_value: 0.4232
Epoch 200, train_loss: 0.0358, val_loss: 0.0354, r_value: 0.4511
Epoch 300, train_loss: 0.0352, val_loss: 0.0348, r_value: 0.4615
Epoch 400, train_loss: 0.0345, val_loss: 0.0341, r_value: 0.4652
Epoch 500, train_loss: 0.0335, val_loss: 0.0332, r_value: 0.4679
Epoch 600, train_loss: 0.0324, val_loss: 0.0321, r_value: 0.4715
Epoch 700, train_loss: 0.0312, val_loss: 0.0308, r_value: 0.4764
Epoch 800, train_loss: 0.0301, val_loss: 0.0297, r_value: 0.4826
Epoch 900, train_loss: 0.0291, val_loss: 0.0287, r_value: 0.4898
Epoch 1000, train_loss: 0.0284, val_loss: 0.0280, r_value: 0.4973
Epoch 1100, train_loss: 0.0279, val_loss: 0.0275, r_value: 0.5044
Epoch 1200, train_loss: 0.0275, val_loss: 0.0271, r_value: 0.5108
Epoch 1300, train_loss: 0.0272, val_loss: 0.0269, r_value: 0.5165
Epoch 1400, train_loss: 0.0270, val_loss: 0.0267, r_value: 0.5215
Epoch 1500, train_loss: 0.0268, val_loss: 0.0265, r_value: 0.5259
Epoch 1600, train_loss: 0.0267, val_loss: 0.0263, r_value: 0.5300
Epoch 1700, train_loss: 0.0265, val_loss: 0.0262, r_value: 0.5336
Epoch 1800, train_loss: 0.0264, val_loss: 0.0260, r_value: 0.5368
Epoch 1900, train_loss: 0.0263, val_loss: 0.0259, r_value: 0.5397
../_images/CASESTUDY_NNs_pt3_SHAP_13_19.png

hid_dim: 512, lr: 0.75
Deleting previous model
Epoch 0, train_loss: 0.0370, val_loss: 0.0364, r_value: 0.2444
Epoch 100, train_loss: 0.0374, val_loss: 0.0387, r_value: 0.4517
Epoch 200, train_loss: 0.0335, val_loss: 0.0350, r_value: 0.4642
Epoch 300, train_loss: 0.0327, val_loss: 0.0337, r_value: 0.4954
Epoch 400, train_loss: 0.0321, val_loss: 0.0328, r_value: 0.5171
Epoch 500, train_loss: 0.0316, val_loss: 0.0319, r_value: 0.5317
Epoch 600, train_loss: 0.0309, val_loss: 0.0311, r_value: 0.5424
Epoch 700, train_loss: 0.0303, val_loss: 0.0304, r_value: 0.5502
Epoch 800, train_loss: 0.0298, val_loss: 0.0298, r_value: 0.5559
Epoch 900, train_loss: 0.0293, val_loss: 0.0292, r_value: 0.5602
Epoch 1000, train_loss: 0.0288, val_loss: 0.0288, r_value: 0.5636
Epoch 1100, train_loss: 0.0285, val_loss: 0.0284, r_value: 0.5664
Epoch 1200, train_loss: 0.0281, val_loss: 0.0280, r_value: 0.5686
Epoch 1300, train_loss: 0.0278, val_loss: 0.0278, r_value: 0.5705
Epoch 1400, train_loss: 0.0276, val_loss: 0.0275, r_value: 0.5721
Epoch 1500, train_loss: 0.0273, val_loss: 0.0273, r_value: 0.5735
Epoch 1600, train_loss: 0.0271, val_loss: 0.0271, r_value: 0.5747
Epoch 1700, train_loss: 0.0270, val_loss: 0.0269, r_value: 0.5757
Epoch 1800, train_loss: 0.0268, val_loss: 0.0268, r_value: 0.5767
Epoch 1900, train_loss: 0.0267, val_loss: 0.0267, r_value: 0.5776
../_images/CASESTUDY_NNs_pt3_SHAP_13_21.png

hid_dim: 512, lr: 0.5
Deleting previous model
Epoch 0, train_loss: 0.0374, val_loss: 0.0367, r_value: 0.2427
Epoch 100, train_loss: 0.0350, val_loss: 0.0334, r_value: 0.5088
Epoch 200, train_loss: 0.0317, val_loss: 0.0302, r_value: 0.5213
Epoch 300, train_loss: 0.0302, val_loss: 0.0290, r_value: 0.5350
Epoch 400, train_loss: 0.0292, val_loss: 0.0283, r_value: 0.5456
Epoch 500, train_loss: 0.0284, val_loss: 0.0277, r_value: 0.5536
Epoch 600, train_loss: 0.0278, val_loss: 0.0272, r_value: 0.5595
Epoch 700, train_loss: 0.0272, val_loss: 0.0268, r_value: 0.5638
Epoch 800, train_loss: 0.0268, val_loss: 0.0264, r_value: 0.5671
Epoch 900, train_loss: 0.0264, val_loss: 0.0261, r_value: 0.5696
Epoch 1000, train_loss: 0.0261, val_loss: 0.0258, r_value: 0.5716
Epoch 1100, train_loss: 0.0259, val_loss: 0.0256, r_value: 0.5733
Epoch 1200, train_loss: 0.0257, val_loss: 0.0255, r_value: 0.5747
Epoch 1300, train_loss: 0.0256, val_loss: 0.0254, r_value: 0.5759
Epoch 1400, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5771
Epoch 1500, train_loss: 0.0254, val_loss: 0.0252, r_value: 0.5782
Epoch 1600, train_loss: 0.0253, val_loss: 0.0251, r_value: 0.5792
Epoch 1700, train_loss: 0.0252, val_loss: 0.0250, r_value: 0.5802
Epoch 1800, train_loss: 0.0251, val_loss: 0.0250, r_value: 0.5811
Epoch 1900, train_loss: 0.0251, val_loss: 0.0249, r_value: 0.5820
../_images/CASESTUDY_NNs_pt3_SHAP_13_23.png

hid_dim: 512, lr: 0.1
Deleting previous model
Epoch 0, train_loss: 0.0369, val_loss: 0.0365, r_value: 0.0973
Epoch 100, train_loss: 0.0347, val_loss: 0.0343, r_value: 0.5013
Epoch 200, train_loss: 0.0319, val_loss: 0.0315, r_value: 0.5091
Epoch 300, train_loss: 0.0285, val_loss: 0.0282, r_value: 0.5197
Epoch 400, train_loss: 0.0268, val_loss: 0.0265, r_value: 0.5315
Epoch 500, train_loss: 0.0263, val_loss: 0.0259, r_value: 0.5407
Epoch 600, train_loss: 0.0260, val_loss: 0.0256, r_value: 0.5471
Epoch 700, train_loss: 0.0258, val_loss: 0.0254, r_value: 0.5514
Epoch 800, train_loss: 0.0256, val_loss: 0.0253, r_value: 0.5545
Epoch 900, train_loss: 0.0255, val_loss: 0.0252, r_value: 0.5569
Epoch 1000, train_loss: 0.0254, val_loss: 0.0251, r_value: 0.5588
Epoch 1100, train_loss: 0.0254, val_loss: 0.0251, r_value: 0.5605
Epoch 1200, train_loss: 0.0253, val_loss: 0.0250, r_value: 0.5619
Epoch 1300, train_loss: 0.0252, val_loss: 0.0250, r_value: 0.5631
Epoch 1400, train_loss: 0.0252, val_loss: 0.0249, r_value: 0.5642
Epoch 1500, train_loss: 0.0251, val_loss: 0.0249, r_value: 0.5652
Epoch 1600, train_loss: 0.0251, val_loss: 0.0248, r_value: 0.5661
Epoch 1700, train_loss: 0.0251, val_loss: 0.0248, r_value: 0.5669
Epoch 1800, train_loss: 0.0250, val_loss: 0.0248, r_value: 0.5677
Epoch 1900, train_loss: 0.0250, val_loss: 0.0247, r_value: 0.5685
../_images/CASESTUDY_NNs_pt3_SHAP_13_25.png

hid_dim: 512, lr: 0.01
Deleting previous model
Epoch 0, train_loss: 0.0385, val_loss: 0.0381, r_value: -0.1635
Epoch 100, train_loss: 0.0368, val_loss: 0.0364, r_value: 0.2001
Epoch 200, train_loss: 0.0366, val_loss: 0.0363, r_value: 0.3878
Epoch 300, train_loss: 0.0365, val_loss: 0.0361, r_value: 0.4541
Epoch 400, train_loss: 0.0363, val_loss: 0.0359, r_value: 0.4754
Epoch 500, train_loss: 0.0361, val_loss: 0.0358, r_value: 0.4826
Epoch 600, train_loss: 0.0360, val_loss: 0.0356, r_value: 0.4844
Epoch 700, train_loss: 0.0358, val_loss: 0.0354, r_value: 0.4844
Epoch 800, train_loss: 0.0356, val_loss: 0.0353, r_value: 0.4839
Epoch 900, train_loss: 0.0355, val_loss: 0.0351, r_value: 0.4832
Epoch 1000, train_loss: 0.0353, val_loss: 0.0349, r_value: 0.4825
Epoch 1100, train_loss: 0.0351, val_loss: 0.0347, r_value: 0.4820
Epoch 1200, train_loss: 0.0349, val_loss: 0.0345, r_value: 0.4818
Epoch 1300, train_loss: 0.0347, val_loss: 0.0343, r_value: 0.4821
Epoch 1400, train_loss: 0.0345, val_loss: 0.0341, r_value: 0.4825
Epoch 1500, train_loss: 0.0343, val_loss: 0.0339, r_value: 0.4829
Epoch 1600, train_loss: 0.0340, val_loss: 0.0336, r_value: 0.4835
Epoch 1700, train_loss: 0.0338, val_loss: 0.0334, r_value: 0.4842
Epoch 1800, train_loss: 0.0335, val_loss: 0.0331, r_value: 0.4851
Epoch 1900, train_loss: 0.0332, val_loss: 0.0328, r_value: 0.4862
../_images/CASESTUDY_NNs_pt3_SHAP_13_27.png

hid_dim: 512, lr: 0.05
Deleting previous model
Epoch 0, train_loss: 0.0393, val_loss: 0.0382, r_value: -0.0864
Epoch 100, train_loss: 0.0361, val_loss: 0.0357, r_value: 0.5113
Epoch 200, train_loss: 0.0352, val_loss: 0.0349, r_value: 0.5169
Epoch 300, train_loss: 0.0342, val_loss: 0.0338, r_value: 0.5148
Epoch 400, train_loss: 0.0328, val_loss: 0.0324, r_value: 0.5136
Epoch 500, train_loss: 0.0311, val_loss: 0.0307, r_value: 0.5143
Epoch 600, train_loss: 0.0293, val_loss: 0.0290, r_value: 0.5162
Epoch 700, train_loss: 0.0280, val_loss: 0.0276, r_value: 0.5198
Epoch 800, train_loss: 0.0272, val_loss: 0.0269, r_value: 0.5244
Epoch 900, train_loss: 0.0268, val_loss: 0.0264, r_value: 0.5292
Epoch 1000, train_loss: 0.0265, val_loss: 0.0262, r_value: 0.5339
Epoch 1100, train_loss: 0.0263, val_loss: 0.0260, r_value: 0.5382
Epoch 1200, train_loss: 0.0262, val_loss: 0.0258, r_value: 0.5419
Epoch 1300, train_loss: 0.0260, val_loss: 0.0257, r_value: 0.5450
Epoch 1400, train_loss: 0.0259, val_loss: 0.0256, r_value: 0.5475
Epoch 1500, train_loss: 0.0258, val_loss: 0.0255, r_value: 0.5497
Epoch 1600, train_loss: 0.0257, val_loss: 0.0254, r_value: 0.5516
Epoch 1700, train_loss: 0.0257, val_loss: 0.0254, r_value: 0.5533
Epoch 1800, train_loss: 0.0256, val_loss: 0.0253, r_value: 0.5547
Epoch 1900, train_loss: 0.0256, val_loss: 0.0253, r_value: 0.5560
../_images/CASESTUDY_NNs_pt3_SHAP_13_29.png

We have tested SGD extensively. Now let’s see what other optimizers have to offer (BACK TO SLIDES)

[17]:
# Try ADAM optimizer
epoch = 500
hid_dim_range = [128,256,512]
lr_range = [0.1,0.01,0.05,0.001]

for hid_dim in hid_dim_range:
    for lr in lr_range:
        print('\nhid_dim: {}, lr: {}'.format(hid_dim, lr))
        if 'model' in globals():
            print('Deleting previous model')
            del model, criterion, optimizer
        model = Feedforward(data.shape[1], hid_dim)
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

        all_loss_train=[]
        all_loss_val=[]
        for epoch in range(epoch):
            model.train()
            optimizer.zero_grad()
            # Forward pass
            y_pred = model(X_train)
            # Compute Loss
            loss = criterion(y_pred.squeeze(), y_train)

            # Backward pass
            loss.backward()
            optimizer.step()

            all_loss_train.append(loss.item())

            model.eval()
            with torch.no_grad():
                y_pred = model(X_test)
                # Compute Loss
                loss = criterion(y_pred.squeeze(), y_test)
                all_loss_val.append(loss.item())

                if epoch%100==0:
                    y_pred = y_pred.detach().numpy().squeeze()
                    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(y_pred, y_test)
                    print('Epoch {}, train_loss: {:.4f}, val_loss: {:.4f}, r_value: {:.4f}'.format(epoch,all_loss_train[-1],all_loss_val[-1],r_value))

        fig,ax=plt.subplots(1,2,figsize=(10,5))
        ax[0].plot(np.log10(all_loss_train))
        ax[0].plot(np.log10(all_loss_val))

        ax[1].scatter(y_pred, y_test)
        ax[1].set_xlabel('Prediction')
        ax[1].set_ylabel('True')
        ax[1].set_title('slope: {:.4f}, r_value: {:.4f}'.format(slope, r_value))
        plt.show()

hid_dim: 128, lr: 0.1
Deleting previous model
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:170: RuntimeWarning: invalid value encountered in double_scalars
  slope = ssxym / ssxm
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:187: RuntimeWarning: divide by zero encountered in double_scalars
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
Epoch 0, train_loss: 0.0447, val_loss: 1.0370, r_value: 0.0000
Epoch 100, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 200, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 300, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 400, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
../_images/CASESTUDY_NNs_pt3_SHAP_15_3.png

hid_dim: 128, lr: 0.01
Deleting previous model
Epoch 0, train_loss: 0.0407, val_loss: 0.1065, r_value: 0.2456
Epoch 100, train_loss: 0.0250, val_loss: 0.0247, r_value: 0.5683
Epoch 200, train_loss: 0.0240, val_loss: 0.0239, r_value: 0.5913
Epoch 300, train_loss: 0.0232, val_loss: 0.0233, r_value: 0.6031
Epoch 400, train_loss: 0.0238, val_loss: 0.0241, r_value: 0.6070
../_images/CASESTUDY_NNs_pt3_SHAP_15_5.png

hid_dim: 128, lr: 0.05
Deleting previous model
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:170: RuntimeWarning: invalid value encountered in double_scalars
  slope = ssxym / ssxm
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:187: RuntimeWarning: divide by zero encountered in double_scalars
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
Epoch 0, train_loss: 0.0415, val_loss: 1.0370, r_value: 0.0000
Epoch 100, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 200, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 300, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 400, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
../_images/CASESTUDY_NNs_pt3_SHAP_15_9.png

hid_dim: 128, lr: 0.001
Deleting previous model
Epoch 0, train_loss: 0.0371, val_loss: 0.0364, r_value: 0.3117
Epoch 100, train_loss: 0.0241, val_loss: 0.0239, r_value: 0.5877
Epoch 200, train_loss: 0.0230, val_loss: 0.0231, r_value: 0.6083
Epoch 300, train_loss: 0.0224, val_loss: 0.0228, r_value: 0.6166
Epoch 400, train_loss: 0.0218, val_loss: 0.0224, r_value: 0.6219
../_images/CASESTUDY_NNs_pt3_SHAP_15_11.png

hid_dim: 256, lr: 0.1
Deleting previous model
Epoch 0, train_loss: 0.0370, val_loss: 1.0370, r_value: -0.1892
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:170: RuntimeWarning: invalid value encountered in double_scalars
  slope = ssxym / ssxm
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:187: RuntimeWarning: divide by zero encountered in double_scalars
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
Epoch 100, train_loss: 1.0372, val_loss: 1.0361, r_value: 0.0000
Epoch 200, train_loss: 1.0372, val_loss: 1.0361, r_value: 0.0000
Epoch 300, train_loss: 1.0372, val_loss: 1.0361, r_value: 0.0000
Epoch 400, train_loss: 1.0372, val_loss: 1.0361, r_value: 0.0000
../_images/CASESTUDY_NNs_pt3_SHAP_15_15.png

hid_dim: 256, lr: 0.01
Deleting previous model
Epoch 0, train_loss: 0.0399, val_loss: 0.8095, r_value: 0.1103
Epoch 100, train_loss: 0.0244, val_loss: 0.0242, r_value: 0.5838
Epoch 200, train_loss: 0.0239, val_loss: 0.0235, r_value: 0.5981
Epoch 300, train_loss: 0.0232, val_loss: 0.0235, r_value: 0.6064
Epoch 400, train_loss: 0.0226, val_loss: 0.0231, r_value: 0.6115
../_images/CASESTUDY_NNs_pt3_SHAP_15_17.png

hid_dim: 256, lr: 0.05
Deleting previous model
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:170: RuntimeWarning: invalid value encountered in double_scalars
  slope = ssxym / ssxm
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:187: RuntimeWarning: divide by zero encountered in double_scalars
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
Epoch 0, train_loss: 0.0370, val_loss: 1.0370, r_value: 0.0000
Epoch 100, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 200, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 300, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 400, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
../_images/CASESTUDY_NNs_pt3_SHAP_15_21.png

hid_dim: 256, lr: 0.001
Deleting previous model
Epoch 0, train_loss: 0.0372, val_loss: 0.0371, r_value: 0.4011
Epoch 100, train_loss: 0.0238, val_loss: 0.0238, r_value: 0.5994
Epoch 200, train_loss: 0.0225, val_loss: 0.0229, r_value: 0.6171
Epoch 300, train_loss: 0.0214, val_loss: 0.0224, r_value: 0.6265
Epoch 400, train_loss: 0.0206, val_loss: 0.0221, r_value: 0.6294
../_images/CASESTUDY_NNs_pt3_SHAP_15_23.png

hid_dim: 512, lr: 0.1
Deleting previous model
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:170: RuntimeWarning: invalid value encountered in double_scalars
  slope = ssxym / ssxm
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:187: RuntimeWarning: divide by zero encountered in double_scalars
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
Epoch 0, train_loss: 0.0370, val_loss: 1.0370, r_value: 0.0000
Epoch 100, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 200, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 300, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 400, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
../_images/CASESTUDY_NNs_pt3_SHAP_15_27.png

hid_dim: 512, lr: 0.01
Deleting previous model
Epoch 0, train_loss: 0.0399, val_loss: 1.0361, r_value: 0.1849
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:170: RuntimeWarning: invalid value encountered in double_scalars
  slope = ssxym / ssxm
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:187: RuntimeWarning: divide by zero encountered in double_scalars
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
Epoch 100, train_loss: 1.0372, val_loss: 1.0361, r_value: 0.0000
Epoch 200, train_loss: 1.0372, val_loss: 1.0361, r_value: 0.0000
Epoch 300, train_loss: 1.0372, val_loss: 1.0361, r_value: 0.0000
Epoch 400, train_loss: 1.0372, val_loss: 1.0361, r_value: 0.0000
../_images/CASESTUDY_NNs_pt3_SHAP_15_31.png

hid_dim: 512, lr: 0.05
Deleting previous model
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:170: RuntimeWarning: invalid value encountered in double_scalars
  slope = ssxym / ssxm
/Users/antonio/opt/anaconda3/envs/geo_comp/lib/python3.10/site-packages/scipy/stats/_stats_mstats_common.py:187: RuntimeWarning: divide by zero encountered in double_scalars
  slope_stderr = np.sqrt((1 - r**2) * ssym / ssxm / df)
Epoch 0, train_loss: 0.0377, val_loss: 1.0370, r_value: 0.0000
Epoch 100, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 200, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 300, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
Epoch 400, train_loss: 1.0366, val_loss: 1.0370, r_value: 0.0000
../_images/CASESTUDY_NNs_pt3_SHAP_15_35.png

hid_dim: 512, lr: 0.001
Deleting previous model
Epoch 0, train_loss: 0.0368, val_loss: 0.0361, r_value: 0.4858
Epoch 100, train_loss: 0.0243, val_loss: 0.0246, r_value: 0.5991
Epoch 200, train_loss: 0.0221, val_loss: 0.0226, r_value: 0.6204
Epoch 300, train_loss: 0.0212, val_loss: 0.0221, r_value: 0.6294
Epoch 400, train_loss: 0.0199, val_loss: 0.0220, r_value: 0.6329
../_images/CASESTUDY_NNs_pt3_SHAP_15_37.png

ADAM seems to be quite effective for training in this setup if compared to other tests we did. Now let’s see how we can add a degree of interpretability to the model.

[18]:
import shap

test_images= X_test[100:110]
test_labels=y_test[100:110]

explainer = shap.DeepExplainer(model, X_test)
shap_values = explainer.shap_values(test_images)
Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
[19]:
model.eval()
with torch.no_grad():
    y_pred = model(test_images)
    after_train = criterion(y_pred.squeeze(), test_labels)
    print('Test loss after Training' , after_train.item())

    y_pred = y_pred.detach().numpy().squeeze()
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(y_pred, test_labels)

    fig,ax=plt.subplots()
    ax.scatter(y_pred, test_labels)
    ax.set_xlabel('Prediction')
    ax.set_ylabel('True')
    ax.set_title('slope: {:.3f}, r_value: {:.3f}'.format(slope, r_value))
Test loss after Training 0.025363940745592117
../_images/CASESTUDY_NNs_pt3_SHAP_18_1.png
[38]:
print(test_labels.mean())
print(y_pred.mean())
tensor(0.0070)
-0.048042446
[20]:
#SHAP global interpretation
# The summary plot shows the most important features and the magnitude of their impact on the model. It is the global interpretation.

print(shap_values.shape)
shap.summary_plot(shap_values, plot_type = 'bar', feature_names = data.columns, max_display=21)
(10, 20)
../_images/CASESTUDY_NNs_pt3_SHAP_20_1.png
[70]:
# SHAP Local Interpretation

# The force plot is good to see where the “output value” fits in relation to the “base value”. We also see which features have a positive
#  (red) or negative (blue) impact on the prediction and the magnitude of the impact.
shap.initjs()
print('output for the first sample: f(x)= ',y_pred[0])
print(explainer.expected_value[0])
shap.force_plot(explainer.expected_value[0], shap_values[2], test_images[2,:].numpy(), feature_names = data.columns)
output for the first sample: f(x)=  0.0084513
-0.013382731
[70]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[22]:
shap.force_plot(explainer.expected_value[0], shap_values, test_images.numpy(), link="logit",feature_names = data.columns)
[22]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
[36]:
'''
The waterfall plot also allows us to see the amplitude and the nature of the impact of a feature.
It also allows seeing the order of importance of the features and the values taken by each feature for the sample.
'''
plt.figure(facecolor='w')
shap.plots._waterfall.waterfall_legacy(explainer.expected_value[0], shap_values[2], feature_names = data.columns)
../_images/CASESTUDY_NNs_pt3_SHAP_23_0.png
[ ]: