{ "cells": [ { "cell_type": "markdown", "id": "032578d6-7f00-4bed-ba92-615a598e2963", "metadata": { "tags": [] }, "source": [ "# Florian Ellsäßer: Using a LSTM network and SHAP to determine the impact of drought and season on winter wheat\n", "\n", "[Presentation](http://spatial-ecology.net/docs/source/STUDENTSPROJECTS/Proj_2022_Matera/Using_a_LSTM_network_and_SHAP_to_determine_the_impact_of_drought_Florian_Ellsäßer.pdf) \n", "[Video Recording](https://youtu.be/EqxW-PfZsKY)" ] }, { "cell_type": "code", "execution_count": 1, "id": "b76be5c5", "metadata": {}, "outputs": [], "source": [ "import xarray as xr\n", "import pandas as pd\n", "import numpy as np\n", "import scipy \n", "import datetime as dt\n", "#import rioxarray as rio\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "a7cac6f2", "metadata": {}, "source": [ "## load the data" ] }, { "cell_type": "code", "execution_count": 2, "id": "63230794", "metadata": {}, "outputs": [], "source": [ "# load the data\n", "spei_data = xr.open_dataset(r'C:\\Users\\Florian\\Desktop\\Jupyter_Skripts\\05_ML_crop_vulnerability\\data/spei01_germany.nc')\n", "yield_data = xr.open_dataset(r'C:\\Users\\Florian\\Desktop\\Jupyter_Skripts\\05_ML_crop_vulnerability\\data/all_crops_productivity_gapfilled_detrended.nc4')\n", "yield_index_data = xr.open_dataset(r'C:\\Users\\Florian\\JLUbox\\Erntedaten_CROP_Projekt\\09_Final_output_files/all_crops_SHI.nc4')\n", "natural_areas_raster = xr.open_dataset('C:/Users/Florian/Desktop/E-OBS_data_25.0e/natural_areas_germany.nc')\n", "phen_data_winterwheat = pd.read_csv('C:/Users/Florian/Desktop/CROP_indices_data/phen_data_gapfilled.csv')" ] }, { "cell_type": "code", "execution_count": 3, "id": "1ee3f7f1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# visualize the data for a quick check\n", "spei_data.spei.sel(time='2016-05-16T00:00:00.000000000').plot(cmap='seismic_r')" ] }, { "cell_type": "code", "execution_count": 4, "id": "dbba4cf3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "yield_data.winter_wheat.sel(year=2018).plot()" ] }, { "cell_type": "code", "execution_count": 5, "id": "5d05b8b4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "yield_index_data.winter_wheat.sel(year=2016).plot(cmap='seismic_r')" ] }, { "cell_type": "code", "execution_count": 6, "id": "b005704b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "natural_areas_raster = natural_areas_raster.__xarray_dataarray_variable__\n", "natural_areas_raster = natural_areas_raster.where(natural_areas_raster != -9999.) \n", "natural_areas_raster.plot()" ] }, { "cell_type": "markdown", "id": "def0b186", "metadata": {}, "source": [ "## align the data properties" ] }, { "cell_type": "code", "execution_count": 7, "id": "41379aa7", "metadata": {}, "outputs": [], "source": [ "# get the same order of dimensions \n", "# rename year to time in yield data \n", "yield_data = yield_data.rename({'year':'time'})\n", "yield_index = yield_index_data.winter_wheat.rename({'year':'time'})\n", "\n", "# transpose dimensions of \n", "spei_data['spei'] = spei_data.spei.transpose('longitude','latitude','time')" ] }, { "cell_type": "code", "execution_count": 8, "id": "88e63741", "metadata": {}, "outputs": [], "source": [ "# create a function to round coordinates (this is taken from the crops package)\n", "def round_coordinates(in_array):\n", " \"\"\"\n", " This function rounds the coordinates two two decimals. This prevents errors\n", " due to weird roundings that sometimes appear\n", " \n", " Parameters:\n", " in_array (xarray.core.dataarray.DataArray) = input Data array e.g. Tmax\n", " \n", " Returns:\n", " in_array (xarray.core.dataarray.DataArray) = array with rounded coordinates\n", " \"\"\"\n", " \n", " # round longitude\n", " in_array.coords['longitude'] = np.round(in_array.coords['longitude'],2)\n", " # round latitude \n", " in_array.coords['latitude'] = np.round(in_array.coords['latitude'],2)\n", " \n", " return in_array" ] }, { "cell_type": "code", "execution_count": 9, "id": "eb178269", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Frozen(SortedKeysDict({'latitude': 79, 'longitude': 93, 'time': 32}))\n", "('latitude', 'longitude', 'time')\n", "Frozen(SortedKeysDict({'longitude': 90, 'latitude': 77, 'time': 852}))\n" ] } ], "source": [ "print(yield_data.dims)\n", "print(yield_index.dims)\n", "print(spei_data.dims)" ] }, { "cell_type": "code", "execution_count": 10, "id": "85802944", "metadata": {}, "outputs": [], "source": [ "spei_data = round_coordinates(spei_data)\n", "yield_data = round_coordinates(yield_data)\n", "yield_index = round_coordinates(yield_index)\n", "natural_areas_raster = round_coordinates(natural_areas_raster)" ] }, { "cell_type": "code", "execution_count": 11, "id": "df54e4f4", "metadata": {}, "outputs": [], "source": [ "# now dims are the same, adjust the extend of the lon and lat dims\n", "# get the dims of the smaller xarray\n", "min_lon = spei_data.longitude.min()\n", "min_lat = spei_data.latitude.min()\n", "max_lon = spei_data.longitude.max()\n", "max_lat = spei_data.latitude.max()\n", "#create a mask for later\n", "mask_lon_spei_data = (spei_data.longitude >= min_lon) & (spei_data.longitude <= max_lon)\n", "mask_lat_spei_data = (spei_data.latitude >= min_lat) & (spei_data.latitude <= max_lat)\n", "# now drop the coords that are too big\n", "yield_data = yield_data.where(mask_lon_spei_data.longitude, drop=True)\n", "yield_data = yield_data.where(mask_lat_spei_data.latitude, drop=True)\n", "\n", "yield_index = yield_index.where(mask_lon_spei_data.longitude, drop=True)\n", "yield_index = yield_index.where(mask_lat_spei_data.latitude, drop=True)\n", "\n", "natural_areas_raster = natural_areas_raster.where(mask_lon_spei_data.longitude, drop=True)\n", "natural_areas_raster = natural_areas_raster.where(mask_lat_spei_data.latitude, drop=True)\n", "\n", "# transpose again\n", "spei_data['spei'] = spei_data.spei.transpose('latitude','longitude','time')\n", "#yield_data = yield_data.transpose('longitude','latitude','time')" ] }, { "cell_type": "code", "execution_count": 12, "id": "ff6b301e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Frozen(SortedKeysDict({'latitude': 77, 'longitude': 90, 'time': 32}))\n", "('latitude', 'longitude', 'time')\n", "Frozen(SortedKeysDict({'latitude': 77, 'longitude': 90, 'time': 852}))\n" ] } ], "source": [ "print(yield_data.dims)\n", "print(yield_index.dims)\n", "print(spei_data.dims)" ] }, { "cell_type": "code", "execution_count": 13, "id": "0a6e15d0", "metadata": {}, "outputs": [], "source": [ "# concatenate the data into one set\n", "in_data = xr.concat([natural_areas_raster,yield_data.winter_wheat,yield_index,spei_data.spei], dim='time')" ] }, { "cell_type": "markdown", "id": "71d2c30b", "metadata": {}, "source": [ "## apply ufunc to go through the raster cells " ] }, { "cell_type": "code", "execution_count": null, "id": "f6196e76", "metadata": {}, "outputs": [], "source": [ "## use apply ufunc to go through the xarrays along the time coordinate\n", "# the purpose of this is to create a list of lists that can later be converted into a pandas data frame\n", "# for each harvest date, we want the previous 24 month of SPEI data\n", "test_list = []\n", "test_list.append(['harvest_year','harvest_month','natural_area','yield','index']+\n", " ['SPEI'+str(i) for i in range(1,25)])\n", "# first we create a yield_calculation class and a function to return the three best years\n", "class calcualte_impact:\n", " def __init__(self,in_time):\n", " self.in_time = in_time\n", " pass\n", " \n", " # now define a function that takes the yield data and returns the three best years \n", " def run_model(self, in_data):\n", " # first get the times\n", " in_time_index = self.in_time[1:33]\n", " in_time_yield = self.in_time[33:65]\n", " in_time_spei = self.in_time[65:]\n", " # the get the data\n", " natural_area = in_data[0]\n", " index_data = in_data[1:33]\n", " yield_data = in_data[33:65]\n", " spei_data = in_data[65:]\n", " \n", " if np.isnan(spei_data).all():\n", " # create a list with the length of total_years_in_data full of Nones\n", " result = [None] * len(index_data)\n", " return np.array(result)\n", " \n", " else:\n", " # create a pandas dataframe from in_time_yield and yield_data\n", " yield_df = pd.DataFrame({'time': in_time_yield, 'yield': yield_data, 'index':index_data}, \n", " columns=['time', 'yield','index'])\n", " # add a column for the harvest dates\n", " yield_df['natural_area'] = [natural_area]*len(index_data)\n", " harvest_date_list = []\n", " for year in in_time_yield:\n", " try: \n", " harvest_date = phen_data_winterwheat[(phen_data_winterwheat.natural_area_group_code == natural_area) &\n", " (phen_data_winterwheat.reference_year == int(year))]['start_date'].values[0]\n", " harvest_date_list.append(harvest_date)\n", " except: \n", " harvest_date_list.append(None)\n", " \n", " # add the harvest date to the df\n", " yield_df['harvest_date'] = harvest_date_list\n", " # convert the time strings to datetime objects\n", " yield_df['harvest_date'] = yield_df['harvest_date'].astype('datetime64[ns]')\n", " # convert the time strings to datetime objects\n", " yield_df['harvest_date'] = yield_df['harvest_date'].astype('datetime64[ns]')\n", " yield_df['harvest_date'] = pd.to_datetime(yield_df['harvest_date']).apply(lambda x: x.date())\n", " # create a column for years and month\n", " yield_df['harvest_year'] = pd.DatetimeIndex(yield_df['harvest_date']).year\n", " yield_df['harvest_month'] = pd.DatetimeIndex(yield_df['harvest_date']).month\n", " \n", " # create the same for in_time spei and spei data\n", " spei_df = pd.DataFrame({'time': in_time_spei, 'spei': spei_data}, columns=['time', 'spei'])\n", "\n", " # convert the time strings to datetime objects\n", " spei_df['time'] = spei_df['time'].astype('datetime64[ns]')\n", " spei_df['time'] = pd.to_datetime(spei_df['time']).apply(lambda x: x.date())\n", " # create a column for years and month\n", " spei_df['year'] = pd.DatetimeIndex(spei_df['time']).year\n", " spei_df['month'] = pd.DatetimeIndex(spei_df['time']).month\n", " \n", " # now form an output for every year: year, harvest_month, natural_area, index, yield, 24 SPEI\n", " for index, row in yield_df.iterrows(): \n", " \n", " # get the SPEI for each harvest month\n", " index_df = spei_df[(spei_df.year==yield_df['harvest_year'][index])&\n", " (spei_df.month==yield_df['harvest_month'][index])]\n", " \n", " \n", " try: \n", " new_index = index_df.index.values.astype('int')[0]\n", " # create a range of indices +23\n", " index_df_range = list(range(new_index,new_index+24)) \n", " spei_range = spei_df.iloc[index_df_range]['spei'].values\n", " except:\n", " spei_range = [None]*24\n", " \n", " \n", " # append all to a list\n", " test_list.append([yield_df['harvest_year'][index],\n", " yield_df['harvest_month'][index],\n", " yield_df['natural_area'][index],\n", " yield_df['index'][index],\n", " yield_df['yield'][index]]+\n", " # list(spei_range).reverse()\n", " list(spei_range)\n", " )\n", " return np.array(yield_df['harvest_month']) \n", " \n", "# now we create an object and give it the in_years of the yield data\n", "yield_object = calcualte_impact(in_data.time)\n", "\n", "# now we can apply ufunc and get the years with the highest productivity\n", "test_out = xr.apply_ufunc(yield_object.run_model, \n", " in_data,\n", " input_core_dims=[['time']],\n", " output_core_dims=[['year']], \n", " dask = 'parallelized', \n", " vectorize = True)" ] }, { "cell_type": "code", "execution_count": null, "id": "42d880e7", "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame(test_list)\n", "df.columns = df.iloc[0] \n", "df = df[1:]\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "a6564f67", "metadata": {}, "outputs": [], "source": [ "# revove all the rows with a Nan\n", "df.dropna(inplace=True)\n", "df = df.reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "77e93a6e", "metadata": {}, "outputs": [], "source": [ "df.to_csv('C:/Users/Florian/Desktop/Jupyter_Skripts/05_ML_crop_vulnerability/data/yield_spei_data.csv')" ] }, { "cell_type": "code", "execution_count": null, "id": "ae589cc4", "metadata": {}, "outputs": [], "source": [ "df" ] }, { "cell_type": "markdown", "id": "1b5f4d29", "metadata": {}, "source": [ "## ML part" ] }, { "cell_type": "code", "execution_count": 23, "id": "bf6ec909", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tensorboard\\compat\\tensorflow_stub\\dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" ] } ], "source": [ "import math\n", "from keras.models import Sequential\n", "from keras.layers import Dense\n", "from keras.layers import LSTM\n", "from sklearn.preprocessing import MinMaxScaler\n", "from sklearn.metrics import mean_squared_error\n", "import tensorflow as tf\n", "import time\n", "from sklearn.metrics import r2_score\n", "from sklearn.model_selection import train_test_split\n", "import keras\n", "from keras.callbacks import EarlyStopping\n", "from keras.models import load_model\n", "from keras.callbacks import ModelCheckpoint\n", "from tensorflow.keras import optimizers" ] }, { "cell_type": "code", "execution_count": 24, "id": "791b5442", "metadata": {}, "outputs": [], "source": [ "# on LSTMs: https://machinelearningmastery.com/gentle-introduction-long-short-term-memory-networks-experts/\n", "# create data similar to this:\n", "# based on this: https://machinelearningmastery.com/how-to-develop-lstm-models-for-time-series-forecasting/\n", "\n", "# X -> array([[10, 20, 30],\n", "# [20, 30, 40],\n", "# [30, 40, 50],\n", "# [40, 50, 60],\n", "# [50, 60, 70],\n", "# [60, 70, 80]])\n", "# y -> array([40, 50, 60, 70, 80, 90])" ] }, { "cell_type": "code", "execution_count": 25, "id": "fbbe03de", "metadata": {}, "outputs": [], "source": [ "# read the data again\n", "df = pd.read_csv('C:/Users/Florian/Desktop/Jupyter_Skripts/05_ML_crop_vulnerability/data/yield_spei_data.csv')" ] }, { "cell_type": "markdown", "id": "c4cb9271", "metadata": {}, "source": [ "### do with all data" ] }, { "cell_type": "code", "execution_count": 26, "id": "119de43b", "metadata": {}, "outputs": [], "source": [ "# create X and y data\n", "X = np.array(df[['SPEI1', 'SPEI2', 'SPEI3', 'SPEI4', 'SPEI5', 'SPEI6', 'SPEI7',\n", " 'SPEI8', 'SPEI9', 'SPEI10', 'SPEI11', 'SPEI12', 'SPEI13', 'SPEI14',\n", " 'SPEI15', 'SPEI16', 'SPEI17', 'SPEI18', 'SPEI19', 'SPEI20',\n", " 'SPEI21', 'SPEI22', 'SPEI23', 'SPEI24']].values.tolist())\n", "y = np.array(df['index'].values.tolist())" ] }, { "cell_type": "code", "execution_count": 27, "id": "fee1e9bc", "metadata": {}, "outputs": [], "source": [ "# create train and test data by splitting\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)" ] }, { "cell_type": "code", "execution_count": 28, "id": "8f629322", "metadata": {}, "outputs": [], "source": [ "# define model\n", "n_steps = 24 # --> month of SPEI\n", "n_features = 1\n", "model = Sequential()\n", "model.add(LSTM(50, activation='relu', input_shape=(n_steps, n_features)))\n", "model.add(Dense(1))\n", "model.compile(optimizer='adam', loss='mse')" ] }, { "cell_type": "code", "execution_count": 29, "id": "ee48c129", "metadata": {}, "outputs": [], "source": [ "# reshape from [samples, timesteps] into [samples, timesteps, features]\n", "X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], n_features))" ] }, { "cell_type": "code", "execution_count": 30, "id": "2593616a", "metadata": {}, "outputs": [], "source": [ "# fit model -> outcommented because it takes like 4h\n", "#model.fit(X_train, y_train, epochs=200, verbose=0)" ] }, { "cell_type": "code", "execution_count": 31, "id": "94c4c8a1", "metadata": {}, "outputs": [], "source": [ "# save the model\n", "#model.save('C:/Users/Florian/Desktop/Jupyter_Skripts/05_ML_crop_vulnerability/models/model_v001')" ] }, { "cell_type": "code", "execution_count": 32, "id": "a865b8ce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.\n", "\n" ] } ], "source": [ "# load the model again\n", "model = keras.models.load_model('C:/Users/Florian/Desktop/Jupyter_Skripts/05_ML_crop_vulnerability/models/model_v001')" ] }, { "cell_type": "code", "execution_count": 33, "id": "b01afa0c", "metadata": {}, "outputs": [], "source": [ "# test the model\n", "X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], n_features))\n", "y_pred = model.predict(X_test, verbose=0)\n", "# reformat the output\n", "y_pred_list = []\n", "for item in y_pred:\n", " y_pred_list.append(item[0])\n", "y_pred = y_pred_list" ] }, { "cell_type": "code", "execution_count": 34, "id": "9b433d8a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "n predictions: 45036\n", "slope: 1.0054949082304028\n", "intercept: 0.01712772477690782\n", "r: 0.9177724506004394\n", "r²: 0.8423062710811361\n", "p-value: 0.0\n", "standard error: 0.002050131923845098\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot it \n", "plt.scatter(y_test,y_pred, marker='.')\n", "m, b = np.polyfit(y_test,y_pred, 1)\n", "plt.plot(y_test, m*y_test + b, color='red')\n", "plt.xlabel('test values (harvest anomalies)')\n", "plt.ylabel('predicted values (harvest anomalies)')\n", "# check all other things\n", "n_pred = len(y_pred)\n", "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(y_pred, y_test)\n", "print('n predictions: ' + str(n_pred))\n", "print('slope: ' + str(slope))\n", "print('intercept: ' + str(intercept))\n", "print('r: ' + str(r_value))\n", "print('r²: ' + str(r_value**2))\n", "print('p-value: ' + str(p_value))\n", "print('standard error: ' +str(std_err))\n", "plt.savefig(r'C:\\Users\\Florian\\Desktop\\Jupyter_Skripts\\05_ML_crop_vulnerability/figures/scatterplot.png', format='png', dpi=300, bbox_inches='tight')" ] }, { "cell_type": "markdown", "id": "c06d16e6", "metadata": {}, "source": [ "## now implement shap for variable importance" ] }, { "cell_type": "code", "execution_count": 35, "id": "2cc8cee0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import shap" ] }, { "cell_type": "code", "execution_count": 36, "id": "a49a9dc1", "metadata": {}, "outputs": [], "source": [ "# create a smaller subset\n", "test_images= X_test[100:110]\n", "test_labels=y_test[100:110]" ] }, { "cell_type": "code", "execution_count": 37, "id": "850af659", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\shap\\explainers\\tf_utils.py:28: The name tf.keras.backend.get_session is deprecated. Please use tf.compat.v1.keras.backend.get_session instead.\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "keras is no longer supported, please use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From C:\\Users\\Florian\\anaconda3\\envs\\yield_ml_2\\lib\\site-packages\\shap\\explainers\\_deep\\deep_tf.py:631: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.where in 2.0, which has the same broadcast rule as np.where\n" ] } ], "source": [ "# fit the explainer\n", "explainer = shap.DeepExplainer(model, X_test[0:4999])\n", "shap_values = explainer.shap_values(test_images)" ] }, { "cell_type": "code", "execution_count": 38, "id": "2338fd0d", "metadata": {}, "outputs": [], "source": [ "# reformat output\n", "shap_list=[]\n", "for row in shap_values[0]:\n", " row_list = []\n", " for item in row:\n", " row_list.append(item[0])\n", " #print(row_list)\n", " shap_list.append(row_list)\n", "# make it numpy array \n", "shap_array = np.array(shap_list)" ] }, { "cell_type": "code", "execution_count": 39, "id": "adb81cd9", "metadata": {}, "outputs": [], "source": [ "# create a feature name list\n", "feature_names = ['SPEI-24','SPEI-23','SPEI-22','SPEI-21','SPEI-20','SPEI-19',\n", " 'SPEI-18','SPEI-17','SPEI-16','SPEI-15','SPEI-14','SPEI-13',\n", " 'SPEI-12','SPEI-11','SPEI-10','SPEI-9','SPEI-8','SPEI-7',\n", " 'SPEI-6','SPEI-5','SPEI-4','SPEI-3','SPEI-2','SPEI-1']" ] }, { "cell_type": "code", "execution_count": 40, "id": "0960ff28", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot and save the figure\n", "shap.summary_plot(shap_array, plot_type = 'bar', feature_names = np.array(feature_names), show=False)\n", "plt.savefig(r'C:\\Users\\Florian\\Desktop\\Jupyter_Skripts\\05_ML_crop_vulnerability/figures/barplot.png', format='png', dpi=300, bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 41, "id": "c4f4e3d7", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot and save the figure\n", "shap.summary_plot(shap_array,feature_names = np.array(feature_names),plot_type='violin', show=False)\n", "plt.savefig(r'C:\\Users\\Florian\\Desktop\\Jupyter_Skripts\\05_ML_crop_vulnerability/figures/violinplot.png', format='png', dpi=300, bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "1cccef92", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "b1a99ca4-e128-4306-b9da-6ad60cc3a82f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "celltoolbar": "Raw Cell Format", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }