{ "cells": [ { "cell_type": "markdown", "metadata": { "nbpresent": { "id": "a20290f5-2a26-4622-9e8b-ce94cc1dcb1a" }, "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "# Estimation of tree height using GEDI dataset - Clean Data - Perceptron 2 - 2022" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Base on data quality flag select more reilable tree height." ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "nbpresent": { "id": "8f67df50-3050-47d2-a5d9-4329a61325fa" }, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import scipy\n", "from sklearn.metrics import r2_score\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**File storing tree hight (cm) obtained by 6 algorithms, with their associate quality flags.**\n", "The quality flags can be used to refine and select the best tree height estimation and use it as tree height observation.\n", "\n", "* a?_95: tree hight (cm) at 95 quintile, for each algorithm \n", "* min_rh_95: minimum value of tree hight (cm) ammong the 6 algorithms \n", "* max_rh_95: maximum value of tree hight (cm) ammong the 6 algorithms \n", "* BEAM: 1-4 coverage beam = lower power (worse) ; 5-8 power beam = higher power (better) \n", "* digital_elev: digital mdoel elevation \n", "* elev_low: elevation of center of lowest mode \n", "* qc_a?: quality_flag for six algorithms quality_flag = 1 (better); = 0 (worse) \n", "* se_a?: sensitivity for six algorithms sensitivity < 0.95 (worse); sensitivity > 0.95 (beter ) \n", "* deg_fg: (degrade_flag) not-degraded 0 (better) ; degraded > 0 (worse) \n", "* solar_ele: solar elevation. > 0 day (worse); < 0 night (better) " ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDXYa1_95a2_95a3_95a4_95a5_95a6_95min_rh_95max_rh_95BEAMdigital_elevelev_lowqc_a1qc_a2qc_a3qc_a4qc_a5qc_a6se_a1se_a2se_a3se_a4se_a5se_a6deg_fgsolar_ele
016.05000149.727499313931393139312031393139312031395410.0383.721531111110.9620.9840.9680.9620.9890.979017.7
126.05000249.922155102223039708725596152487255965290.02374.141100000000.9480.9900.9600.9480.9940.980043.7
236.05000248.60237738013363323621336134033213404440.0435.977811111110.9470.9750.9560.9470.9810.96800.2
346.05000948.151979315331423142312731383142312731532450.0422.005371111110.9300.9700.9430.9300.9780.9620-14.2
456.05001049.588410666422165133561127233356118370.02413.748300000000.9410.9830.9460.9410.9920.969022.1
566.05001448.608456787117911877611833183376118333420.0415.515811111110.9520.9790.9610.9520.9860.97500.2
\n", "
" ], "text/plain": [ " ID X Y a1_95 a2_95 a3_95 a4_95 a5_95 a6_95 \\\n", "0 1 6.050001 49.727499 3139 3139 3139 3120 3139 3139 \n", "1 2 6.050002 49.922155 1022 2303 970 872 5596 1524 \n", "2 3 6.050002 48.602377 380 1336 332 362 1336 1340 \n", "3 4 6.050009 48.151979 3153 3142 3142 3127 3138 3142 \n", "4 5 6.050010 49.588410 666 4221 651 33 5611 2723 \n", "5 6 6.050014 48.608456 787 1179 1187 761 1833 1833 \n", "\n", " min_rh_95 max_rh_95 BEAM digital_elev elev_low qc_a1 qc_a2 qc_a3 \\\n", "0 3120 3139 5 410.0 383.72153 1 1 1 \n", "1 872 5596 5 290.0 2374.14110 0 0 0 \n", "2 332 1340 4 440.0 435.97781 1 1 1 \n", "3 3127 3153 2 450.0 422.00537 1 1 1 \n", "4 33 5611 8 370.0 2413.74830 0 0 0 \n", "5 761 1833 3 420.0 415.51581 1 1 1 \n", "\n", " qc_a4 qc_a5 qc_a6 se_a1 se_a2 se_a3 se_a4 se_a5 se_a6 deg_fg \\\n", "0 1 1 1 0.962 0.984 0.968 0.962 0.989 0.979 0 \n", "1 0 0 0 0.948 0.990 0.960 0.948 0.994 0.980 0 \n", "2 1 1 1 0.947 0.975 0.956 0.947 0.981 0.968 0 \n", "3 1 1 1 0.930 0.970 0.943 0.930 0.978 0.962 0 \n", "4 0 0 0 0.941 0.983 0.946 0.941 0.992 0.969 0 \n", "5 1 1 1 0.952 0.979 0.961 0.952 0.986 0.975 0 \n", "\n", " solar_ele \n", "0 17.7 \n", "1 43.7 \n", "2 0.2 \n", "3 -14.2 \n", "4 22.1 \n", "5 0.2 " ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "height_6algorithms = pd.read_csv(\"tree_height/txt/eu_y_x_select_6algorithms_fullTable.txt\", sep=\" \", index_col=False)\n", "pd.set_option('display.max_columns',None)\n", "height_6algorithms.head(6)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "height_6algorithms_sel = height_6algorithms.loc[(height_6algorithms['BEAM'] > 4) \n", " & (height_6algorithms['qc_a1'] == 1)\n", " & (height_6algorithms['qc_a2'] == 1)\n", " & (height_6algorithms['qc_a3'] == 1) \n", " & (height_6algorithms['qc_a4'] == 1) \n", " & (height_6algorithms['qc_a5'] == 1) \n", " & (height_6algorithms['qc_a6'] == 1)\n", " & (height_6algorithms['se_a1'] > 0.95) \n", " & (height_6algorithms['se_a2'] > 0.95)\n", " & (height_6algorithms['se_a3'] > 0.95)\n", " & (height_6algorithms['se_a4'] > 0.95)\n", " & (height_6algorithms['se_a5'] > 0.95) \n", " & (height_6algorithms['se_a6'] > 0.95)\n", " & (height_6algorithms['deg_fg'] == 0) \n", " & (height_6algorithms['solar_ele'] < 0)]" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDXYa1_95a2_95a3_95a4_95a5_95a6_95min_rh_95max_rh_95BEAMdigital_elevelev_lowqc_a1qc_a2qc_a3qc_a4qc_a5qc_a6se_a1se_a2se_a3se_a4se_a5se_a6deg_fgsolar_ele
786.05001949.921613330332883296323638573292323638577320.0297.685331111110.9710.9880.9760.9710.9920.9840-33.9
11126.05003947.995344276227362740274738932736273638935390.0368.551211111110.9750.9900.9790.9750.9940.9870-37.3
14156.05004649.865317139825052509131628482505131628486340.0330.405641111110.9730.9900.9790.9730.9940.9860-18.2
15166.05004849.050020984943947958261794794326176300.0291.225981111110.9780.9910.9820.9780.9950.9880-35.4
16176.05004948.391359336233323336335144673336333244675530.0504.781221111110.9730.9880.9770.9730.9920.9840-5.1
.......................................................................................
126720712672089.94982949.216272216028162816210432992816210432998420.0386.445561111110.9800.9930.9840.9800.9950.9890-16.9
126721112672129.94985649.881190319031793179317138223179317138226380.0363.693481111110.9680.9860.9740.9680.9900.9820-35.1
126721612672179.94988049.873435206128282046202428282828202428287380.0361.068121111110.9670.9880.9740.9670.9930.9830-35.1
126722712672289.94995849.127182366230712603553531271935535316500.0493.527921111110.9730.9890.9780.9730.9930.9850-36.0
126723712672389.94999949.936763251324902490249424902490249025135360.0346.542271111110.9680.9880.9740.9680.9930.9830-32.2
\n", "

226892 rows × 28 columns

\n", "
" ], "text/plain": [ " ID X Y a1_95 a2_95 a3_95 a4_95 a5_95 \\\n", "7 8 6.050019 49.921613 3303 3288 3296 3236 3857 \n", "11 12 6.050039 47.995344 2762 2736 2740 2747 3893 \n", "14 15 6.050046 49.865317 1398 2505 2509 1316 2848 \n", "15 16 6.050048 49.050020 984 943 947 958 2617 \n", "16 17 6.050049 48.391359 3362 3332 3336 3351 4467 \n", "... ... ... ... ... ... ... ... ... \n", "1267207 1267208 9.949829 49.216272 2160 2816 2816 2104 3299 \n", "1267211 1267212 9.949856 49.881190 3190 3179 3179 3171 3822 \n", "1267216 1267217 9.949880 49.873435 2061 2828 2046 2024 2828 \n", "1267227 1267228 9.949958 49.127182 366 2307 1260 355 3531 \n", "1267237 1267238 9.949999 49.936763 2513 2490 2490 2494 2490 \n", "\n", " a6_95 min_rh_95 max_rh_95 BEAM digital_elev elev_low qc_a1 \\\n", "7 3292 3236 3857 7 320.0 297.68533 1 \n", "11 2736 2736 3893 5 390.0 368.55121 1 \n", "14 2505 1316 2848 6 340.0 330.40564 1 \n", "15 947 943 2617 6 300.0 291.22598 1 \n", "16 3336 3332 4467 5 530.0 504.78122 1 \n", "... ... ... ... ... ... ... ... \n", "1267207 2816 2104 3299 8 420.0 386.44556 1 \n", "1267211 3179 3171 3822 6 380.0 363.69348 1 \n", "1267216 2828 2024 2828 7 380.0 361.06812 1 \n", "1267227 2719 355 3531 6 500.0 493.52792 1 \n", "1267237 2490 2490 2513 5 360.0 346.54227 1 \n", "\n", " qc_a2 qc_a3 qc_a4 qc_a5 qc_a6 se_a1 se_a2 se_a3 se_a4 se_a5 \\\n", "7 1 1 1 1 1 0.971 0.988 0.976 0.971 0.992 \n", "11 1 1 1 1 1 0.975 0.990 0.979 0.975 0.994 \n", "14 1 1 1 1 1 0.973 0.990 0.979 0.973 0.994 \n", "15 1 1 1 1 1 0.978 0.991 0.982 0.978 0.995 \n", "16 1 1 1 1 1 0.973 0.988 0.977 0.973 0.992 \n", "... ... ... ... ... ... ... ... ... ... ... \n", "1267207 1 1 1 1 1 0.980 0.993 0.984 0.980 0.995 \n", "1267211 1 1 1 1 1 0.968 0.986 0.974 0.968 0.990 \n", "1267216 1 1 1 1 1 0.967 0.988 0.974 0.967 0.993 \n", "1267227 1 1 1 1 1 0.973 0.989 0.978 0.973 0.993 \n", "1267237 1 1 1 1 1 0.968 0.988 0.974 0.968 0.993 \n", "\n", " se_a6 deg_fg solar_ele \n", "7 0.984 0 -33.9 \n", "11 0.987 0 -37.3 \n", "14 0.986 0 -18.2 \n", "15 0.988 0 -35.4 \n", "16 0.984 0 -5.1 \n", "... ... ... ... \n", "1267207 0.989 0 -16.9 \n", "1267211 0.982 0 -35.1 \n", "1267216 0.983 0 -35.1 \n", "1267227 0.985 0 -36.0 \n", "1267237 0.983 0 -32.2 \n", "\n", "[226892 rows x 28 columns]" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "height_6algorithms_sel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Calculate the mean height excluidng the maximum and minimum values " ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "height_sel = pd.DataFrame({'ID' : height_6algorithms_sel['ID'] , \n", " 'hm_sel': (height_6algorithms_sel['a1_95'] + height_6algorithms_sel['a2_95'] + height_6algorithms_sel['a3_95'] + height_6algorithms_sel['a4_95'] \n", " + height_6algorithms_sel['a5_95'] + height_6algorithms_sel['a6_95'] - height_6algorithms_sel['min_rh_95'] - height_6algorithms_sel['max_rh_95']) / 400 } )" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDhm_sel
7832.9475
111227.4625
141522.2925
15169.5900
161733.4625
.........
1267207126720826.5200
1267211126721231.8175
1267216126721724.4075
1267227126722816.6300
1267237126723824.9100
\n", "

226892 rows × 2 columns

\n", "
" ], "text/plain": [ " ID hm_sel\n", "7 8 32.9475\n", "11 12 27.4625\n", "14 15 22.2925\n", "15 16 9.5900\n", "16 17 33.4625\n", "... ... ...\n", "1267207 1267208 26.5200\n", "1267211 1267212 31.8175\n", "1267216 1267217 24.4075\n", "1267227 1267228 16.6300\n", "1267237 1267238 24.9100\n", "\n", "[226892 rows x 2 columns]" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "height_sel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import raw data, extracted predictors and show the data distribution" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDXYhBLDFIE_WeigAverCECSOL_WeigAverCHELSA_bio18CHELSA_bio4convergencectidevmagnitudeeastnesselevforestheightglad_ard_SVVI_maxglad_ard_SVVI_medglad_ard_SVVI_minnorthnessORCDRC_WeigAveroutlet_dist_dw_basinSBIO3_Isothermality_5_15cmSBIO4_Temperature_Seasonality_5_15cmtreecover
016.05000149.7274993139.0015401321135893-10.486560-2380431201.1584170.069094353.98312423276.87109446.444092347.6654050.042500978040319.798992440.67221185
126.05000249.9221551454.751491121993591233.274361-208915344-1.7553410.269112267.51168819-49.52636719.552734-130.5417480.1827801677277720.889412457.75619585
236.05000248.602377853.50152117212459830.045293-1374797921.908780-0.016055389.7511602193.25732450.743652384.5224610.0362531489882020.695877481.87970062
346.05000948.1519793141.0015261625696130-33.654274-2672230720.9657870.067767380.20770327542.401367202.264160386.1567380.0051391583182419.375000479.41027885
456.05001049.5884102065.251547142108592327.493824-107809368-0.1626240.014065308.04278625136.048340146.835205198.1274410.0288471779696218.777500457.88006685
566.05001448.6084561246.5015151921246010-1.602039173842821.447979-0.018912364.52710018221.339844247.387207480.3879390.0427471489794519.398880474.33132962
676.05001648.5714012938.751520192169614727.856503-66516432-1.0739560.002280254.67959619125.25048887.865234160.6967770.0372541190842620.170450476.41452096
786.05001949.9216133294.751490121995591222.102139-297770784-1.4026330.309765294.92776526-86.729492-145.584229-190.0629880.2224351577278420.855963457.19540486
896.05002048.8226451623.501554181973613818.496584-25336536-0.8000160.010370240.49375922-51.470703-245.886719172.0747070.004428883913221.812290496.23111064
9106.05002449.8475221400.0015211521875886-5.660453-2786526081.477951-0.068720376.67114312277.297363273.141846-138.8959960.0988171376887321.137711466.97668570
\n", "
" ], "text/plain": [ " ID X Y h BLDFIE_WeigAver CECSOL_WeigAver \\\n", "0 1 6.050001 49.727499 3139.00 1540 13 \n", "1 2 6.050002 49.922155 1454.75 1491 12 \n", "2 3 6.050002 48.602377 853.50 1521 17 \n", "3 4 6.050009 48.151979 3141.00 1526 16 \n", "4 5 6.050010 49.588410 2065.25 1547 14 \n", "5 6 6.050014 48.608456 1246.50 1515 19 \n", "6 7 6.050016 48.571401 2938.75 1520 19 \n", "7 8 6.050019 49.921613 3294.75 1490 12 \n", "8 9 6.050020 48.822645 1623.50 1554 18 \n", "9 10 6.050024 49.847522 1400.00 1521 15 \n", "\n", " CHELSA_bio18 CHELSA_bio4 convergence cti devmagnitude eastness \\\n", "0 2113 5893 -10.486560 -238043120 1.158417 0.069094 \n", "1 1993 5912 33.274361 -208915344 -1.755341 0.269112 \n", "2 2124 5983 0.045293 -137479792 1.908780 -0.016055 \n", "3 2569 6130 -33.654274 -267223072 0.965787 0.067767 \n", "4 2108 5923 27.493824 -107809368 -0.162624 0.014065 \n", "5 2124 6010 -1.602039 17384282 1.447979 -0.018912 \n", "6 2169 6147 27.856503 -66516432 -1.073956 0.002280 \n", "7 1995 5912 22.102139 -297770784 -1.402633 0.309765 \n", "8 1973 6138 18.496584 -25336536 -0.800016 0.010370 \n", "9 2187 5886 -5.660453 -278652608 1.477951 -0.068720 \n", "\n", " elev forestheight glad_ard_SVVI_max glad_ard_SVVI_med \\\n", "0 353.983124 23 276.871094 46.444092 \n", "1 267.511688 19 -49.526367 19.552734 \n", "2 389.751160 21 93.257324 50.743652 \n", "3 380.207703 27 542.401367 202.264160 \n", "4 308.042786 25 136.048340 146.835205 \n", "5 364.527100 18 221.339844 247.387207 \n", "6 254.679596 19 125.250488 87.865234 \n", "7 294.927765 26 -86.729492 -145.584229 \n", "8 240.493759 22 -51.470703 -245.886719 \n", "9 376.671143 12 277.297363 273.141846 \n", "\n", " glad_ard_SVVI_min northness ORCDRC_WeigAver outlet_dist_dw_basin \\\n", "0 347.665405 0.042500 9 780403 \n", "1 -130.541748 0.182780 16 772777 \n", "2 384.522461 0.036253 14 898820 \n", "3 386.156738 0.005139 15 831824 \n", "4 198.127441 0.028847 17 796962 \n", "5 480.387939 0.042747 14 897945 \n", "6 160.696777 0.037254 11 908426 \n", "7 -190.062988 0.222435 15 772784 \n", "8 172.074707 0.004428 8 839132 \n", "9 -138.895996 0.098817 13 768873 \n", "\n", " SBIO3_Isothermality_5_15cm SBIO4_Temperature_Seasonality_5_15cm treecover \n", "0 19.798992 440.672211 85 \n", "1 20.889412 457.756195 85 \n", "2 20.695877 481.879700 62 \n", "3 19.375000 479.410278 85 \n", "4 18.777500 457.880066 85 \n", "5 19.398880 474.331329 62 \n", "6 20.170450 476.414520 96 \n", "7 20.855963 457.195404 86 \n", "8 21.812290 496.231110 64 \n", "9 21.137711 466.976685 70 " ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictors = pd.read_csv(\"tree_height/txt/eu_x_y_height_predictors_select.txt\", sep=\" \", index_col=False)\n", "pd.set_option('display.max_columns',None)\n", "# change column name\n", "predictors = predictors.rename({'dev-magnitude':'devmagnitude'} , axis='columns')\n", "predictors.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Merge the new height with the predictors table, using the ID as Primary Key" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "predictors_hm_sel = pd.merge( predictors , height_sel , left_on='ID' , right_on='ID' , how='right')" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDXYhBLDFIE_WeigAverCECSOL_WeigAverCHELSA_bio18CHELSA_bio4convergencectidevmagnitudeeastnesselevforestheightglad_ard_SVVI_maxglad_ard_SVVI_medglad_ard_SVVI_minnorthnessORCDRC_WeigAveroutlet_dist_dw_basinSBIO3_Isothermality_5_15cmSBIO4_Temperature_Seasonality_5_15cmtreecoverhm_sel
086.05001949.9216133294.751490121995591222.102139-297770784-1.4026330.309765294.92776526-86.729492-145.584229-190.0629880.2224351577278420.855963457.1954048632.9475
1126.05003947.9953442746.25152312261261813.549103-712799920.507727-0.021408322.92022726660.00610492.722168190.979736-0.0347871678480720.798000460.5012219727.4625
2156.05004649.8653172229.251517132191590131.054762-186807440-1.375050-0.126880291.41253771028.385498915.806396841.5861820.0246771676644419.941267454.1850895422.2925
3166.05004849.050020959.00152614208161009.933455-183562672-0.3828340.086874246.28801024-12.283691-58.179199174.2055660.0941751080573019.849365470.946533789.5900
4176.05004948.3913593346.2514891924865966-6.957157-2735226882.9897590.214769474.40908824125.5830086.154297128.1291500.0171641595019021.179420491.3983768533.4625
5196.05005349.877876529.0015311221845915-24.278454-3773352960.265329-0.248356335.53476025593.601074228.712402315.298340-0.1273651776471319.760756448.580811965.2900
\n", "
" ], "text/plain": [ " ID X Y h BLDFIE_WeigAver CECSOL_WeigAver \\\n", "0 8 6.050019 49.921613 3294.75 1490 12 \n", "1 12 6.050039 47.995344 2746.25 1523 12 \n", "2 15 6.050046 49.865317 2229.25 1517 13 \n", "3 16 6.050048 49.050020 959.00 1526 14 \n", "4 17 6.050049 48.391359 3346.25 1489 19 \n", "5 19 6.050053 49.877876 529.00 1531 12 \n", "\n", " CHELSA_bio18 CHELSA_bio4 convergence cti devmagnitude eastness \\\n", "0 1995 5912 22.102139 -297770784 -1.402633 0.309765 \n", "1 2612 6181 3.549103 -71279992 0.507727 -0.021408 \n", "2 2191 5901 31.054762 -186807440 -1.375050 -0.126880 \n", "3 2081 6100 9.933455 -183562672 -0.382834 0.086874 \n", "4 2486 5966 -6.957157 -273522688 2.989759 0.214769 \n", "5 2184 5915 -24.278454 -377335296 0.265329 -0.248356 \n", "\n", " elev forestheight glad_ard_SVVI_max glad_ard_SVVI_med \\\n", "0 294.927765 26 -86.729492 -145.584229 \n", "1 322.920227 26 660.006104 92.722168 \n", "2 291.412537 7 1028.385498 915.806396 \n", "3 246.288010 24 -12.283691 -58.179199 \n", "4 474.409088 24 125.583008 6.154297 \n", "5 335.534760 25 593.601074 228.712402 \n", "\n", " glad_ard_SVVI_min northness ORCDRC_WeigAver outlet_dist_dw_basin \\\n", "0 -190.062988 0.222435 15 772784 \n", "1 190.979736 -0.034787 16 784807 \n", "2 841.586182 0.024677 16 766444 \n", "3 174.205566 0.094175 10 805730 \n", "4 128.129150 0.017164 15 950190 \n", "5 315.298340 -0.127365 17 764713 \n", "\n", " SBIO3_Isothermality_5_15cm SBIO4_Temperature_Seasonality_5_15cm \\\n", "0 20.855963 457.195404 \n", "1 20.798000 460.501221 \n", "2 19.941267 454.185089 \n", "3 19.849365 470.946533 \n", "4 21.179420 491.398376 \n", "5 19.760756 448.580811 \n", "\n", " treecover hm_sel \n", "0 86 32.9475 \n", "1 97 27.4625 \n", "2 54 22.2925 \n", "3 78 9.5900 \n", "4 85 33.4625 \n", "5 96 5.2900 " ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictors_hm_sel.head(6)" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [], "source": [ "predictors_hm_sel = predictors_hm_sel.loc[(predictors['h'] < 5000) ]" ] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDXYhBLDFIE_WeigAverCECSOL_WeigAverCHELSA_bio18CHELSA_bio4convergencectidevmagnitudeeastnesselevforestheightglad_ard_SVVI_maxglad_ard_SVVI_medglad_ard_SVVI_minnorthnessORCDRC_WeigAveroutlet_dist_dw_basinSBIO3_Isothermality_5_15cmSBIO4_Temperature_Seasonality_5_15cmtreecoverhm_sel
1126.05003947.9953442746.25152312261261813.549103-712799920.507727-0.021408322.92022726660.00610492.722168190.979736-0.0347871678480720.798000460.5012219727.4625
2156.05004649.8653172229.251517132191590131.054762-186807440-1.375050-0.126880291.41253771028.385498915.806396841.5861820.0246771676644419.941267454.1850895422.2925
5196.05005349.877876529.0015311221845915-24.278454-3773352960.265329-0.248356335.53476025593.601074228.712402315.298340-0.1273651776471319.760756448.580811965.2900
8276.05008349.2814393921.25148813234559153.646593-2234992480.3833140.062349309.14260925862.362305263.612793249.693115-0.0688101678112017.538614463.28024310039.2125
9356.05011949.9286101765.0015101319765917-2.138205-393005696-1.467515-0.316702301.64956720248.561035164.831299237.283447-0.1444071977364721.324263465.0464788717.6500
...........................................................................
22687612671769.94963749.8874713266.50155814207464844.993805-1966916801.2605400.044238328.01916526-144.124023-145.522949-14.3176270.020914785722917.376682463.4213879832.6650
22687812671809.94965849.856387497.001515171956656134.034641-126274136-0.9590000.047194260.8890690681.798340657.745605642.0114750.050846585722521.177349573.086243124.9700
22687912671849.94968849.3628312344.00151715226464999.1731681609678720.645957-0.011381447.67804023282.11938579.83081110.1403810.014716782415718.283070471.1674198923.4400
22688112671899.94970149.1144582014.501529122608663227.137199104082784-0.481382-0.012974447.81439221187.93408290.763672168.5271000.0456021890789418.010750473.2279667220.1450
22688212671929.94973149.8931962891.0015661520886482-36.581142-2191424961.3487700.060537331.81591821-68.830078-160.90966811.658203-0.0460481185770516.320225450.4092717528.9100
\n", "

116199 rows × 24 columns

\n", "
" ], "text/plain": [ " ID X Y h BLDFIE_WeigAver \\\n", "1 12 6.050039 47.995344 2746.25 1523 \n", "2 15 6.050046 49.865317 2229.25 1517 \n", "5 19 6.050053 49.877876 529.00 1531 \n", "8 27 6.050083 49.281439 3921.25 1488 \n", "9 35 6.050119 49.928610 1765.00 1510 \n", "... ... ... ... ... ... \n", "226876 1267176 9.949637 49.887471 3266.50 1558 \n", "226878 1267180 9.949658 49.856387 497.00 1515 \n", "226879 1267184 9.949688 49.362831 2344.00 1517 \n", "226881 1267189 9.949701 49.114458 2014.50 1529 \n", "226882 1267192 9.949731 49.893196 2891.00 1566 \n", "\n", " CECSOL_WeigAver CHELSA_bio18 CHELSA_bio4 convergence cti \\\n", "1 12 2612 6181 3.549103 -71279992 \n", "2 13 2191 5901 31.054762 -186807440 \n", "5 12 2184 5915 -24.278454 -377335296 \n", "8 13 2345 5915 3.646593 -223499248 \n", "9 13 1976 5917 -2.138205 -393005696 \n", "... ... ... ... ... ... \n", "226876 14 2074 6484 4.993805 -196691680 \n", "226878 17 1956 6561 34.034641 -126274136 \n", "226879 15 2264 6499 9.173168 160967872 \n", "226881 12 2608 6632 27.137199 104082784 \n", "226882 15 2088 6482 -36.581142 -219142496 \n", "\n", " devmagnitude eastness elev forestheight glad_ard_SVVI_max \\\n", "1 0.507727 -0.021408 322.920227 26 660.006104 \n", "2 -1.375050 -0.126880 291.412537 7 1028.385498 \n", "5 0.265329 -0.248356 335.534760 25 593.601074 \n", "8 0.383314 0.062349 309.142609 25 862.362305 \n", "9 -1.467515 -0.316702 301.649567 20 248.561035 \n", "... ... ... ... ... ... \n", "226876 1.260540 0.044238 328.019165 26 -144.124023 \n", "226878 -0.959000 0.047194 260.889069 0 681.798340 \n", "226879 0.645957 -0.011381 447.678040 23 282.119385 \n", "226881 -0.481382 -0.012974 447.814392 21 187.934082 \n", "226882 1.348770 0.060537 331.815918 21 -68.830078 \n", "\n", " glad_ard_SVVI_med glad_ard_SVVI_min northness ORCDRC_WeigAver \\\n", "1 92.722168 190.979736 -0.034787 16 \n", "2 915.806396 841.586182 0.024677 16 \n", "5 228.712402 315.298340 -0.127365 17 \n", "8 263.612793 249.693115 -0.068810 16 \n", "9 164.831299 237.283447 -0.144407 19 \n", "... ... ... ... ... \n", "226876 -145.522949 -14.317627 0.020914 7 \n", "226878 657.745605 642.011475 0.050846 5 \n", "226879 79.830811 10.140381 0.014716 7 \n", "226881 90.763672 168.527100 0.045602 18 \n", "226882 -160.909668 11.658203 -0.046048 11 \n", "\n", " outlet_dist_dw_basin SBIO3_Isothermality_5_15cm \\\n", "1 784807 20.798000 \n", "2 766444 19.941267 \n", "5 764713 19.760756 \n", "8 781120 17.538614 \n", "9 773647 21.324263 \n", "... ... ... \n", "226876 857229 17.376682 \n", "226878 857225 21.177349 \n", "226879 824157 18.283070 \n", "226881 907894 18.010750 \n", "226882 857705 16.320225 \n", "\n", " SBIO4_Temperature_Seasonality_5_15cm treecover hm_sel \n", "1 460.501221 97 27.4625 \n", "2 454.185089 54 22.2925 \n", "5 448.580811 96 5.2900 \n", "8 463.280243 100 39.2125 \n", "9 465.046478 87 17.6500 \n", "... ... ... ... \n", "226876 463.421387 98 32.6650 \n", "226878 573.086243 12 4.9700 \n", "226879 471.167419 89 23.4400 \n", "226881 473.227966 72 20.1450 \n", "226882 450.409271 75 28.9100 \n", "\n", "[116199 rows x 24 columns]" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictors_hm_sel" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
XYhm_sel
16.05003947.99534427.4625
26.05004649.86531722.2925
56.05005349.8778765.2900
86.05008349.28143939.2125
96.05011949.92861017.6500
............
2268769.94963749.88747132.6650
2268789.94965849.8563874.9700
2268799.94968849.36283123.4400
2268819.94970149.11445820.1450
2268829.94973149.89319628.9100
\n", "

116199 rows × 3 columns

\n", "
" ], "text/plain": [ " X Y hm_sel\n", "1 6.050039 47.995344 27.4625\n", "2 6.050046 49.865317 22.2925\n", "5 6.050053 49.877876 5.2900\n", "8 6.050083 49.281439 39.2125\n", "9 6.050119 49.928610 17.6500\n", "... ... ... ...\n", "226876 9.949637 49.887471 32.6650\n", "226878 9.949658 49.856387 4.9700\n", "226879 9.949688 49.362831 23.4400\n", "226881 9.949701 49.114458 20.1450\n", "226882 9.949731 49.893196 28.9100\n", "\n", "[116199 rows x 3 columns]" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_y_hm_sel = predictors_hm_sel[[\"X\",\"Y\",\"hm_sel\"]]\n", "x_y_hm_sel" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [], "source": [ "#Normalize the data\n", "from sklearn.preprocessing import MinMaxScaler\n", "scaler = MinMaxScaler()\n", "data = scaler.fit_transform(x_y_hm_sel)" ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([4.4490e+03, 2.5580e+03, 2.3750e+03, 2.2710e+03, 2.7280e+03,\n", " 3.3330e+03, 4.4110e+03, 5.5060e+03, 6.5250e+03, 7.5690e+03,\n", " 8.9780e+03, 1.0236e+04, 1.1240e+04, 1.0795e+04, 9.7710e+03,\n", " 7.9490e+03, 5.6670e+03, 3.7270e+03, 2.4290e+03, 1.4340e+03,\n", " 8.8400e+02, 4.8700e+02, 2.8500e+02, 1.5300e+02, 1.0600e+02,\n", " 6.5000e+01, 4.1000e+01, 2.4000e+01, 2.9000e+01, 2.2000e+01,\n", " 1.6000e+01, 2.0000e+01, 1.5000e+01, 1.2000e+01, 1.5000e+01,\n", " 5.0000e+00, 9.0000e+00, 8.0000e+00, 7.0000e+00, 7.0000e+00,\n", " 7.0000e+00, 3.0000e+00, 2.0000e+00, 7.0000e+00, 4.0000e+00,\n", " 2.0000e+00, 4.0000e+00, 6.0000e+00, 1.0000e+00, 2.0000e+00]),\n", " array([0. , 0.02, 0.04, 0.06, 0.08, 0.1 , 0.12, 0.14, 0.16, 0.18, 0.2 ,\n", " 0.22, 0.24, 0.26, 0.28, 0.3 , 0.32, 0.34, 0.36, 0.38, 0.4 , 0.42,\n", " 0.44, 0.46, 0.48, 0.5 , 0.52, 0.54, 0.56, 0.58, 0.6 , 0.62, 0.64,\n", " 0.66, 0.68, 0.7 , 0.72, 0.74, 0.76, 0.78, 0.8 , 0.82, 0.84, 0.86,\n", " 0.88, 0.9 , 0.92, 0.94, 0.96, 0.98, 1. ]),\n", " )" ] }, "execution_count": 116, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#Inspect the ranges \n", "fig,ax = plt.subplots(1,3,figsize=(15,5))\n", "ax[0].hist(data[:,0],50)\n", "ax[1].hist(data[:,1],50)\n", "ax[2].hist(data[:,2],50)" ] }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X_train.shape: torch.Size([81339, 2]), X_test.shape: torch.Size([34860, 2]), y_train.shape: torch.Size([81339]), y_test.shape: torch.Size([34860])\n" ] } ], "source": [ "#Split the data\n", "X_train, X_test, y_train, y_test = train_test_split(data[:,:2], data[:,2], test_size=0.30, random_state=0)\n", "X_train = torch.FloatTensor(X_train)\n", "y_train = torch.FloatTensor(y_train)\n", "X_test = torch.FloatTensor(X_test)\n", "y_test = torch.FloatTensor(y_test)\n", "print('X_train.shape: {}, X_test.shape: {}, y_train.shape: {}, y_test.shape: {}'.format(X_train.shape, X_test.shape, y_train.shape, y_test.shape))" ] }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [], "source": [ "class Perceptron(torch.nn.Module):\n", " def __init__(self,input_size, output_size, use_activation_fn=False):\n", " super(Perceptron, self).__init__()\n", " self.fc = nn.Linear(input_size,output_size) # Initializes weights with uniform distribution centered in zero\n", " self.activation_fn = nn.ReLU() # instead of Heaviside step fn\n", " self.use_activation_fn = use_activation_fn # If we want to use an activation function\n", " def forward(self, x):\n", " output = self.fc(x)\n", " if self.use_activation_fn:\n", " output = self.activation_fn(output) # To add the non-linearity. Try training you Perceptron with and without the non-linearity\n", " return output" ] }, { "cell_type": "code", "execution_count": 143, "metadata": {}, "outputs": [], "source": [ "# Create percetron\n", "model = Perceptron(input_size=2, output_size=1 , use_activation_fn=True)\n", "criterion = torch.nn.MSELoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)" ] }, { "cell_type": "code", "execution_count": 144, "metadata": {}, "outputs": [], "source": [ "model.train()\n", "epoch = 5000\n", "all_loss=[]\n", "for epoch in range(epoch):\n", " optimizer.zero_grad()\n", " # Forward pass\n", " y_pred = model(X_train)\n", " # Compute Loss\n", " loss = criterion(y_pred.squeeze(), y_train)\n", " \n", " # Backward pass\n", " loss.backward()\n", " optimizer.step()\n", " \n", " all_loss.append(loss.item())" ] }, { "cell_type": "code", "execution_count": 145, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 145, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig,ax=plt.subplots()\n", "ax.plot(all_loss)" ] }, { "cell_type": "code", "execution_count": 146, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test loss after Training 0.009606563486158848\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model.eval()\n", "with torch.no_grad():\n", " y_pred = model(X_test)\n", " after_train = criterion(y_pred.squeeze(), y_test) \n", " print('Test loss after Training' , after_train.item())\n", "\n", " y_pred = y_pred.detach().numpy().squeeze()\n", " slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(y_pred, y_test)\n", "\n", " fig,ax=plt.subplots()\n", " ax.scatter(y_pred, y_test)\n", " ax.set_xlabel('Prediction')\n", " ax.set_ylabel('True')\n", " ax.set_title('slope: {:.3f}, r_value: {:.3f}'.format(slope, r_value))" ] } ], "metadata": { "anaconda-cloud": {}, "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 }