{ "cells": [ { "cell_type": "markdown", "id": "41147902", "metadata": {}, "source": [ "# Integrating with the Scikit: Pipeline and Gridsearch\n", "\n", "ITEA implementations inherits scikits' base classes. This means that we can integrate the ITEA with methods like Pipeline and Gridsearch. In this notebook, we'll show some examples on how to take advantage of that to tune an predictor." ] }, { "cell_type": "code", "execution_count": 1, "id": "d3847fde", "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "from scipy import stats\n", "from sklearn.model_selection import train_test_split\n", "from IPython.display import display\n", "\n", "from sklearn.pipeline import Pipeline\n", "\n", "from sklearn.feature_selection import SelectKBest\n", "from sklearn.feature_selection import mutual_info_regression\n", "\n", "from sklearn import datasets\n", "from sklearn.model_selection import GridSearchCV\n", "\n", "# Importing the halving gridsearch algorithm\n", "from sklearn.experimental import enable_halving_search_cv\n", "from sklearn.model_selection import HalvingGridSearchCV\n", "\n", "from itea.regression import ITEA_regressor\n", "from itea.inspection import *\n", "\n", "import warnings\n", "warnings.filterwarnings(action='ignore', module=r'itea')" ] }, { "cell_type": "markdown", "id": "7eeb718c", "metadata": {}, "source": [ "## Loading the data\n", "\n", "First, let's load the data, and split it into a training and testing partition. The training partition will be used for the training and validation process, and only after obtaining a final method will we perform the training with this data and the test with the test partition." ] }, { "cell_type": "code", "execution_count": 2, "id": "abce135b", "metadata": {}, "outputs": [], "source": [ "boston_data = datasets.load_boston() \n", "X, y = boston_data['data'], boston_data['target']\n", "labels = boston_data['feature_names']\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.33, random_state=42)" ] }, { "cell_type": "markdown", "id": "2f280150", "metadata": {}, "source": [ "## Inspectioning the data\n", "\n", "Let's look at some descriptive statistics for the variables.\n", "\n", "Suppose that, to reduce the complexity of the final model, we are interested in obtaining a subset of these variables." ] }, { "cell_type": "code", "execution_count": 3, "id": "b72c9693", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | CRIM | \n", "ZN | \n", "INDUS | \n", "CHAS | \n", "NOX | \n", "RM | \n", "AGE | \n", "DIS | \n", "RAD | \n", "TAX | \n", "PTRATIO | \n", "B | \n", "LSTAT | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "506.000000 | \n", "
mean | \n", "3.613524 | \n", "11.363636 | \n", "11.136779 | \n", "0.069170 | \n", "0.554695 | \n", "6.284634 | \n", "68.574901 | \n", "3.795043 | \n", "9.549407 | \n", "408.237154 | \n", "18.455534 | \n", "356.674032 | \n", "12.653063 | \n", "
std | \n", "8.601545 | \n", "23.322453 | \n", "6.860353 | \n", "0.253994 | \n", "0.115878 | \n", "0.702617 | \n", "28.148861 | \n", "2.105710 | \n", "8.707259 | \n", "168.537116 | \n", "2.164946 | \n", "91.294864 | \n", "7.141062 | \n", "
min | \n", "0.006320 | \n", "0.000000 | \n", "0.460000 | \n", "0.000000 | \n", "0.385000 | \n", "3.561000 | \n", "2.900000 | \n", "1.129600 | \n", "1.000000 | \n", "187.000000 | \n", "12.600000 | \n", "0.320000 | \n", "1.730000 | \n", "
25% | \n", "0.082045 | \n", "0.000000 | \n", "5.190000 | \n", "0.000000 | \n", "0.449000 | \n", "5.885500 | \n", "45.025000 | \n", "2.100175 | \n", "4.000000 | \n", "279.000000 | \n", "17.400000 | \n", "375.377500 | \n", "6.950000 | \n", "
50% | \n", "0.256510 | \n", "0.000000 | \n", "9.690000 | \n", "0.000000 | \n", "0.538000 | \n", "6.208500 | \n", "77.500000 | \n", "3.207450 | \n", "5.000000 | \n", "330.000000 | \n", "19.050000 | \n", "391.440000 | \n", "11.360000 | \n", "
75% | \n", "3.677083 | \n", "12.500000 | \n", "18.100000 | \n", "0.000000 | \n", "0.624000 | \n", "6.623500 | \n", "94.075000 | \n", "5.188425 | \n", "24.000000 | \n", "666.000000 | \n", "20.200000 | \n", "396.225000 | \n", "16.955000 | \n", "
max | \n", "88.976200 | \n", "100.000000 | \n", "27.740000 | \n", "1.000000 | \n", "0.871000 | \n", "8.780000 | \n", "100.000000 | \n", "12.126500 | \n", "24.000000 | \n", "711.000000 | \n", "22.000000 | \n", "396.900000 | \n", "37.970000 | \n", "