{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "impaired-stylus",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.metrics import (confusion_matrix, f1_score, make_scorer,\n",
" roc_auc_score)\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import Pipeline\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "accepting-wallet",
"metadata": {},
"outputs": [],
"source": [
"proj_path = Path(os.getcwd()).parent.absolute()\n",
"data_file_paths = [proj_path/'data'/'raw'/f'Churn_Modelling_{country}.csv' for country in ['Spain', 'France']]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "continued-juvenile",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" CustomerId \n",
" Surname \n",
" CreditScore \n",
" Geography \n",
" Gender \n",
" Age \n",
" Tenure \n",
" Balance \n",
" NumOfProducts \n",
" HasCrCard \n",
" IsActiveMember \n",
" EstimatedSalary \n",
" Exited \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 15647311 \n",
" Hill \n",
" 608 \n",
" Spain \n",
" Female \n",
" 41 \n",
" 1 \n",
" 83807.86 \n",
" 1 \n",
" 0 \n",
" 1 \n",
" 112542.58 \n",
" 0 \n",
" \n",
" \n",
" 1 \n",
" 15737888 \n",
" Mitchell \n",
" 850 \n",
" Spain \n",
" Female \n",
" 43 \n",
" 2 \n",
" 125510.82 \n",
" 1 \n",
" 1 \n",
" 1 \n",
" 79084.10 \n",
" 0 \n",
" \n",
" \n",
" 2 \n",
" 15574012 \n",
" Chu \n",
" 645 \n",
" Spain \n",
" Male \n",
" 44 \n",
" 8 \n",
" 113755.78 \n",
" 2 \n",
" 1 \n",
" 0 \n",
" 149756.71 \n",
" 1 \n",
" \n",
" \n",
" 3 \n",
" 15737173 \n",
" Andrews \n",
" 497 \n",
" Spain \n",
" Male \n",
" 24 \n",
" 3 \n",
" 0.00 \n",
" 2 \n",
" 1 \n",
" 0 \n",
" 76390.01 \n",
" 0 \n",
" \n",
" \n",
" 4 \n",
" 15600882 \n",
" Scott \n",
" 635 \n",
" Spain \n",
" Female \n",
" 35 \n",
" 7 \n",
" 0.00 \n",
" 2 \n",
" 1 \n",
" 1 \n",
" 65951.65 \n",
" 0 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" CustomerId Surname CreditScore Geography Gender Age Tenure \\\n",
"0 15647311 Hill 608 Spain Female 41 1 \n",
"1 15737888 Mitchell 850 Spain Female 43 2 \n",
"2 15574012 Chu 645 Spain Male 44 8 \n",
"3 15737173 Andrews 497 Spain Male 24 3 \n",
"4 15600882 Scott 635 Spain Female 35 7 \n",
"\n",
" Balance NumOfProducts HasCrCard IsActiveMember EstimatedSalary \\\n",
"0 83807.86 1 0 1 112542.58 \n",
"1 125510.82 1 1 1 79084.10 \n",
"2 113755.78 2 1 0 149756.71 \n",
"3 0.00 2 1 0 76390.01 \n",
"4 0.00 2 1 1 65951.65 \n",
"\n",
" Exited \n",
"0 0 \n",
"1 0 \n",
"2 1 \n",
"3 0 \n",
"4 0 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.concat([pd.read_csv(fpath) for fpath in data_file_paths])\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bb0d9f24",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7491, 13)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "unnecessary-roots",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CustomerId 0\n",
"Surname 0\n",
"CreditScore 0\n",
"Geography 0\n",
"Gender 0\n",
"Age 0\n",
"Tenure 0\n",
"Balance 0\n",
"NumOfProducts 0\n",
"HasCrCard 0\n",
"IsActiveMember 0\n",
"EstimatedSalary 0\n",
"Exited 0\n",
"dtype: int64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Are there missing values?\n",
"df.isna().sum()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b025ccce",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Geography\n",
"France 5014\n",
"Spain 2477\n",
"Name: count, dtype: int64"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['Geography'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "scheduled-measure",
"metadata": {},
"outputs": [],
"source": [
"feat_cols = ['CreditScore', 'Age', 'Tenure', \n",
" 'Balance', 'NumOfProducts', 'HasCrCard',\n",
" 'IsActiveMember', 'EstimatedSalary']\n",
"targ_col = 'Exited'"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "junior-rating",
"metadata": {},
"outputs": [],
"source": [
"X, y = df[feat_cols], df[targ_col]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c4d476a6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.16326258176478442"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.mean()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "historic-doubt",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "infinite-african",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Pipeline(steps=[('preprocessor', SimpleImputer()),\n",
" ('clf', RandomForestClassifier(max_depth=10, random_state=42))]) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"Pipeline(steps=[('preprocessor', SimpleImputer()),\n",
" ('clf', RandomForestClassifier(max_depth=10, random_state=42))])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"random_state = 42\n",
"train_params = {'n_estimators': 100, 'max_depth': 10}\n",
"\n",
"clf = RandomForestClassifier(random_state=random_state, \n",
" **train_params)\n",
"model = Pipeline(\n",
" steps=[(\"preprocessor\", SimpleImputer()), (\"clf\", clf)]\n",
" )\n",
"\n",
"model.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "12944f94",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.5474137931034483, 0.8684512806155847)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_prob = model.predict_proba(X_test)\n",
"y_pred = y_prob[:, 1] >= 0.5\n",
"f1 = f1_score(y_test, y_pred)\n",
"roc_auc = roc_auc_score(y_test, y_prob[:, 1])\n",
"f1, roc_auc"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "4675491c",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAGdCAYAAACPX3D5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjb0lEQVR4nO3da3hU1d338d9MDhMCEgIhE4jRCMpJJIkBYkC0aDRWRai1Ug+AUfGESJnqI1EhikpUlHJbQlGUFtv6kEc8lEcR1Ci3osFgMKiIIAKiSE4cEggwgZm5X9h77OyEw+CEiazv57r2C9bsvWdtLzC//P9rzdh8Pp9PAADAWPZwTwAAAIQXYQAAAMMRBgAAMBxhAAAAwxEGAAAwHGEAAADDEQYAADAcYQAAAMMRBgAAMFxkuCfwv9pk3BnuKQCtzs6Vs8I9BaBVimnhn16h/Jm079PW/++41YQBAABaDZtZhXOznhYAADRBZQAAACubLdwzOK4IAwAAWBnWJiAMAABgZVhlwKzoAwAAmqAyAACAFW0CAAAMR5sAAACYhMoAAABWtAkAADAcbQIAAGASKgMAAFjRJgAAwHC0CQAAgEmoDAAAYEWbAAAAwxnWJiAMAABgZVhlwKynBQAATVAZAADAyrDKAGEAAAAru1lrBsyKPgAAoAkqAwAAWNEmAADAcIZtLTQr+gAAgCaoDAAAYEWbAAAAw9EmAAAAJqEyAACAFW0CAAAMZ1ibgDAAAICVYZUBs54WAAA0QWUAAAAr2gQAABiONgEAADAJlQEAAKxoEwAAYDjaBAAAwCRUBgAAsDKsMkAYAADAyrA1A2ZFHwAA0ASVAQAArGgTAABgOMPaBIQBAACsDKsMmPW0AACgCSoDAABY0SYAAMBsNsPCAG0CAAAMR2UAAAAL0yoDhAEAAKzMygK0CQAAMB2VAQAALGgTAABgONPCAG0CAAAMR2UAAAAL0yoDhAEAACwIAwAAmM6sLMCaAQAATEdlAAAAC9oEAAAYzrQwQJsAAADDURkAAMDCtMoAYQAAAAvTwgBtAgAADEdlAAAAK7MKA1QGAACwstlsITuCVVRUpNTUVMXExCgrK0tlZWWHPX/mzJnq2bOn2rRpo5SUFE2cOFH79+8P6j0JAwAAtBLFxcVyuVwqKCjQqlWrlJaWptzcXFVXVzd7/osvvqhJkyapoKBAa9eu1fPPP6/i4mLdd999Qb0vYQAAAItwVQZmzJihsWPHKi8vT3369NGcOXMUGxurefPmNXv+Rx99pMGDB+vaa69VamqqLr74Yl1zzTVHrCZYEQYAALAIZRhwu92qr68PONxud5P3bGxsVHl5uXJycvxjdrtdOTk5Ki0tbXaegwYNUnl5uf+H/8aNG7V48WJdeumlQT0vYQAAACtb6I7CwkLFxcUFHIWFhU3esra2Vh6PR06nM2Dc6XSqsrKy2Wlee+21mjp1qs4991xFRUWpe/fu+tWvfkWbAACA1iQ/P191dXUBR35+fkjuvWzZMk2bNk2zZ8/WqlWr9Morr+iNN97Qww8/HNR92FoIAIBFKD90yOFwyOFwHPG8hIQERUREqKqqKmC8qqpKSUlJzV4zefJkjRo1SjfffLMk6ayzzlJDQ4NuueUW3X///bLbj+53fioDAABYhGMBYXR0tDIzM1VSUuIf83q9KikpUXZ2drPX7N27t8kP/IiICEmSz+c76vemMgAAQCvhcrk0ZswY9e/fXwMHDtTMmTPV0NCgvLw8SdLo0aOVnJzsX3MwbNgwzZgxQxkZGcrKytKGDRs0efJkDRs2zB8KjgZhAAAAi3B9N8HIkSNVU1OjKVOmqLKyUunp6VqyZIl/UeGWLVsCKgEPPPCAbDabHnjgAW3dulWdO3fWsGHD9Oijjwb1vjZfMHWEFtQm485wTwFodXaunBXuKQCtUkwL/yrb9dZXQnavH565MmT3aimsGQAAwHC0CQAAsDLsi4oIAwAAWIRrzUC40CYAAMBwVAYAALAwrTJAGAAAwIIwAACA6czKAqwZAADAdFQGAACwMK1NQGXgBHfr1efpqzce0s4Vf9L7L9yt/meeeshzIyPtyr/lEq1ZVKCdK/6kj4sn6aJBvQPOsdttmnLHZVr7+oPaUTpDaxYVaNLYS1r6MYCfZcGL/9SvL7pAAzLO0nW//50+/+yzw57/1tI3NfzySzQg4yz9dsQwffD+fwe8/s7bb+nWsTfqvEFZSjuzp75auzbg9a1bv1famT2bPd5a+mbInw+hF44vKgonwsAJ7KqLz9bjf/yNHn3mTWVf+7g+W79Vi2aPU+f4ds2e/+Adw3Tzb8+V64mXlPHbR/TcwuUqfmqs0nqe7D/njzdcpLFXDdHEx15S+pWP6IGn/yXXmBzdcc35x+uxgKAseXOxnnyiULfeMU4LXnpVPXv20u233qTt27c3e37Fp6s06Z4/6jdXXqXiha9p6AUX6g/jx+nrr9f7z9m3b68yMs7WH1x3N3uPpKQuKlm2POC4fdx4xcbG6txzz2uR5wR+DsLACeyu6y/QX1/5SH9ftEJfbazU+EcXaN/+Ro0Z0fxXYV57+UA98fxbWrr8S23eul1zX1qupR9+qQmjLvCfc05aN73+359pyfI12rJth159p0IlK746bMUBCKe/z/+rrrzqao34zW/V/fTT9UDBQ4qJidFrr7zc7Pn//McLGnTuEN1w483q1r277rzrD+rdp48WvPgP/znDrhih2+64U1mH+FrZiIgIJXTuHHC8W/KOLr7k14pt27ZFnhOhRWUAJ4SoyAhl9E7Rux+v84/5fD69+/E6Dex3WrPXREdFan/jgYCxffsbNSiju//PK1Zv1NCBPXX6KYmSpLN6JCs7vZve+vDLFngK4Oc50NiotV+u0TnZg/xjdrtd55wzSJ+t/rTZaz6rqNA55wT+kB80+Fx9VlFxzPP4cs0XWvfVWv3myquO+R44vkwLA0EvIKytrdW8efNUWlqqyspKSVJSUpIGDRqkG264QZ07dw75JBG8hPh2ioyMUPWO3QHj1dvr1TPV2ew175Su1V3XX6DlqzZo43e1Gjqwp4ZfkK6IiJ/+Mj/517fVvl2MVr/6gDwenyIibCooel0L3vykRZ8HOBY7d+2Ux+NRp06dAsY7deqkTZs2NntNbW2tOnVKaHJ+7fbaY57Hqy8vVLdu3ZWecfYx3wNoSUGFgZUrVyo3N1exsbHKyclRjx49JElVVVV6+umn9dhjj2np0qXq37//Ye/jdrvldrsDxnxej2z2iCCnj1C6e/pCzZ58jVa/Mlk+n08bv6/VC4tWaMzwc/znXHXx2fr9rwfohvvm68tvtqlfz2RNv/sqbaup0z///8dhnD3QOu3fv19vLn5dY2+7I9xTQTB+Gb/Qh0xQYWD8+PH63e9+pzlz5jQpffh8Pt12220aP368SktLD3ufwsJCPfTQQwFjEc4BiuoyMJjp4DBqd+7RwYMeJXY8KWA8sVN7VW6vP+Q1V7vmyhEdqU5xbfVDTZ0euWu4Nm39aaHVtD+M0JN/fVsvLS2XJK3Z8INO6dJR9+RdRBhAqxPfIV4RERFNFgtu375dCQkJzV6TkJCg7ZYqwPbt25XQqfnzj+Ttt5Zo3779GnbFiGO6HuHxSynvh0pQawZWr16tiRMnNvsfyWazaeLEiao4ir5afn6+6urqAo5IZ2YwU8ERHDjo0adrv9PQrJ7+MZvNpqEDe6jss02HvdbdeFA/1NQpMtKuERem6/VlP23DahMTLa/PG3C+x+uT3c7yE7Q+UdHR6t3nTH284qdfULxerz7+uFT90jKavaZfero+XrEiYGxF6Ufql55+THN47ZWX9auhF6hjx47HdD1wPARVGUhKSlJZWZl69erV7OtlZWVyOpvvR/8nh8Mhh8MRMEaLIPSe/se7mjt1lMq/3KJPvtisO68dqtg2Dr3wrx//R/fcw6P0Q3Wdpvx5kSRpQN9T1TWxg1av+17JiR10/62Xym63acbf3vHfc/H7n+vem3L13bad+vKbbUrvdbLuun6oXnhtRbNzAMJt1Jg8Tb7vXp15Zl/1Pauf/vH3+dq3b59G/OZKSdL9+f9HiYlOTZj4R0nSddeP1k03jNL8v83TeeedryVvLtaaL77Q5Aen+u9Zt2uXtm3bppqaaknS5s0/BuyEhAQl/Me6qS3ffqvyT1aq6C/PHq/HRYiYVhkIKgzcfffduuWWW1ReXq4LL7zQ/4O/qqpKJSUlmjt3rp588skWmSiCt/CtVUqIb6cpt18mZ6eT9Nm6rRo+rsi/qDAlqaO8Xp//fIcjSgXjLtdpyQnas9etpR+u0U2TX1Ddnn3+c1yPv6SCOy7Xf903Up3j22lbTZ2eX/ihpj3LB6mgdbrk15dq544dmj3radXW1qhnr96a/cxz6vTvNkHltm2y236qbKVnnK3CJ57UrKdn6s8zZ+iUU1M1889FOuOMHv5zlr33rqY8kO//8713T5Qk3XbHnbp93Hj/+GuvviynM0nZg89t6cdEiBmWBWTz+Xy+I5/2k+LiYv3pT39SeXm5PB6PpB/31GZmZsrlcunqq68+pom0ybjzmK4DTmQ7V84K9xSAVimmhT9M/4x7loTsXl9Pb/2f0hr0f86RI0dq5MiROnDggGprf1xkk5CQoKioqJBPDgAAtLxjzlZRUVHq0qVLKOcCAECrYFqbgG8tBADAwrQFhOwHAwDAcFQGAACwMKwwQBgAAMDKbjcrDdAmAADAcFQGAACwoE0AAIDh2E0AAACMQmUAAAALwwoDhAEAAKxMaxMQBgAAsDAtDLBmAAAAw1EZAADAwrDCAGEAAAAr2gQAAMAoVAYAALAwrDBAGAAAwIo2AQAAMAqVAQAALAwrDBAGAACwok0AAACMQmUAAAALwwoDhAEAAKxMaxMQBgAAsDAsC7BmAAAA01EZAADAgjYBAACGMywL0CYAAMB0VAYAALCgTQAAgOEMywK0CQAAMB2VAQAALGgTAABgONPCAG0CAAAMR2UAAAALwwoDhAEAAKxMaxMQBgAAsDAsC7BmAAAA01EZAADAgjYBAACGMywL0CYAAMB0VAYAALCwG1YaIAwAAGBhWBagTQAAQGtSVFSk1NRUxcTEKCsrS2VlZYc9f9euXRo3bpy6dOkih8OhHj16aPHixUG9J5UBAAAswrWboLi4WC6XS3PmzFFWVpZmzpyp3NxcrVu3TomJiU3Ob2xs1EUXXaTExEQtXLhQycnJ+vbbb9WhQ4eg3pcwAACAhT1MbYIZM2Zo7NixysvLkyTNmTNHb7zxhubNm6dJkyY1OX/evHnasWOHPvroI0VFRUmSUlNTg35f2gQAAFjYbLaQHW63W/X19QGH2+1u8p6NjY0qLy9XTk6Of8xutysnJ0elpaXNznPRokXKzs7WuHHj5HQ61bdvX02bNk0ejyeo5yUMAADQggoLCxUXFxdwFBYWNjmvtrZWHo9HTqczYNzpdKqysrLZe2/cuFELFy6Ux+PR4sWLNXnyZD311FN65JFHgpojbQIAACxCuWQgPz9fLpcrYMzhcITk3l6vV4mJiXr22WcVERGhzMxMbd26VdOnT1dBQcFR34cwAACAhU2hSwMOh+OofvgnJCQoIiJCVVVVAeNVVVVKSkpq9pouXbooKipKERER/rHevXursrJSjY2Nio6OPqo50iYAAKAViI6OVmZmpkpKSvxjXq9XJSUlys7ObvaawYMHa8OGDfJ6vf6x9evXq0uXLkcdBCTCAAAATdhtoTuC4XK5NHfuXM2fP19r167V7bffroaGBv/ugtGjRys/P99//u23364dO3ZowoQJWr9+vd544w1NmzZN48aNC+p9aRMAAGARrs8ZGDlypGpqajRlyhRVVlYqPT1dS5Ys8S8q3LJli+z2n36PT0lJ0dKlSzVx4kT169dPycnJmjBhgu69996g3tfm8/l8IX2SY9Qm485wTwFodXaunBXuKQCtUkwL/yo7fO4nIbvXv8b2D9m9WgqVAQAALEz7bgLCAAAAFqZ9ayELCAEAMByVAQAALAwrDBAGAACwCtdugnAhDAAAYGFYFmDNAAAApqMyAACAhWm7CQgDAABYmBUFaBMAAGA8KgMAAFiwmwAAAMMF+22Dv3S0CQAAMByVAQAALGgTAABgOMOyAG0CAABMR2UAAAAL2gQAABjOtN0EhAEAACxMqwywZgAAAMNRGQAAwMKsugBhAACAJkz71kLaBAAAGI7KAAAAFoYVBggDAABYsZsAAAAYhcoAAAAWhhUGCAMAAFixmwAAABiFygAAABaGFQYIAwAAWJm2m6DVhIELbxsT7ikArc7cjzeFewpAqzR+8Gkten/TeuimPS8AALBoNZUBAABaC9oEAAAYzm5WFqBNAACA6agMAABgYVplgDAAAICFaWsGaBMAAGA4KgMAAFjQJgAAwHCGdQloEwAAYDoqAwAAWJj2FcaEAQAALEwrmxMGAACwMKwwYFz4AQAAFlQGAACwYM0AAACGMywL0CYAAMB0VAYAALDgEwgBADCcaWsGaBMAAGA4KgMAAFgYVhggDAAAYGXamgHaBAAAGI7KAAAAFjaZVRogDAAAYGFam4AwAACAhWlhgDUDAAAYjsoAAAAWNsP2FhIGAACwoE0AAACMQmUAAAALw7oEhAEAAKz4oiIAAGAUKgMAAFiwgBAAAMPZbKE7glVUVKTU1FTFxMQoKytLZWVlR3XdggULZLPZNGLEiKDfkzAAAEArUVxcLJfLpYKCAq1atUppaWnKzc1VdXX1Ya/bvHmz7r77bg0ZMuSY3pcwAACAhV22kB3BmDFjhsaOHau8vDz16dNHc+bMUWxsrObNm3fIazwej6677jo99NBD6tat2zE+LwAACBDKNoHb7VZ9fX3A4Xa7m7xnY2OjysvLlZOT4x+z2+3KyclRaWnpIec6depUJSYm6qabbjrm5yUMAABgYbeF7igsLFRcXFzAUVhY2OQ9a2tr5fF45HQ6A8adTqcqKyubnefy5cv1/PPPa+7cuT/redlNAABAC8rPz5fL5QoYczgcP/u+u3fv1qhRozR37lwlJCT8rHsRBgAAsAjlhw45HI6j+uGfkJCgiIgIVVVVBYxXVVUpKSmpyfnffPONNm/erGHDhvnHvF6vJCkyMlLr1q1T9+7dj2qOtAkAALAIx9bC6OhoZWZmqqSkxD/m9XpVUlKi7OzsJuf36tVLn3/+uSoqKvzHFVdcoaFDh6qiokIpKSlH/d5UBgAAaCVcLpfGjBmj/v37a+DAgZo5c6YaGhqUl5cnSRo9erSSk5NVWFiomJgY9e3bN+D6Dh06SFKT8SMhDAAAYBGu7yYYOXKkampqNGXKFFVWVio9PV1LlizxLyrcsmWL7PbQF/VtPp/PF/K7HoPLn1kZ7ikArU5u35+3KAg4UY0ffFqL3n/eyi0hu9eNA04J2b1aCmsGAAAwHG0CAAAsTPtNmTAAAICFLUxrBsLFtPADAAAsqAwAAGBhVl2AMAAAQBPh2loYLoQBAAAszIoCrBkAAMB4VAYAALAwrEtAGAAAwIqthQAAwChUBgAAsDDtN2XCAAAAFrQJAACAUagMAABgYVZdgDAAAEATtAkAAIBRqAwAAGBh2m/KhAEAACxMaxMQBgAAsDArCphXCQEAABZUBgAAsDCsS0AYAADAym5Yo4A2AQAAhqMyAACABW0CAAAMZ6NNAAAATEJlAAAAC9oEAAAYjt0EAADAKFQGAACwoE0AAIDhCAMAABiOrYUAAMAoVAYAALCwm1UYIAwAAGBFmwAAABiFygAAABbsJgAAwHC0CQAAgFGoDAAAYMFuApxQLjszUVemJSm+TZQ2bd+rZz7covU1Dc2ee2GPTpo4tFvAWONBr658vtz/5w5tInVDVooyTm6vttERWlO5R88s/1Y/1Ltb9DmAUPusZJE+XbJQe+t2KiGlm8677g45u/U84nXrP16mt555TKdlZOuy8QX+8W/Kl+uLZYtVvflruRt2a+SDRep8SveWfAS0INoEOGEM6d5RN2en6P+W/6AJL6/Rph17NfWyHoqLOXQGbHAf1PUvfOo/bnxxdcDrD+SeoaT2Dj2ydIMmvPylqne79cjlPeWI5K8Sfjm+LvtvLS+eqwFXXK+RBbPUKaWbFs24X3vrdx32uvraSn34/55T1x59m7x2wL1fXc44U4N+d2MLzRpoOfwf/AQ24iynlq6t0TvravXdrv0qev9buQ96dVGvhENe45O0a9/BgON/dY1zqJeznWZ/sFlf1zRoa91+zf7gW0VH2nX+6R2PwxMBoVGx9BWded4l6jPkYnVMPlVDR49XZLRDaz9YeshrvF6P3nr2CWUNv17tOyc1eb3XoBwNvOI6pfTJaMmp4zix2UJ3/BIQBk5QkXabTu/cVhVb6/1jPkkV39erl7PdIa9rExWhedf201+vS9MDuafrlPgY/2tRET/+dWn0+ALuecDjU5+kk0L+DEBL8Bw8oOpvvw74oW2z23VynwxVfrP2kNetXPSiYk/qoD7nXXI8pokws4Xw+CUgDJyg2sdEKsJu0659BwLGd+07oPg2Uc1es7Vuv/5r2SY9vPRrPfXuRtltNk0f3lud2v54/ve79qt6t1tjBp6sttERirTb9Nu0JHVuF62Osc3fE2ht9u2ul8/rVZv2HQLGY9t30N66nc1e88P6L/TlB0s19IYJx2GGaA3sNlvIjl+CkIeB7777TjfeePiemdvtVn19fcDhOdAY6qkgSF9VNejdr7dr0/Z9+mLbbj361gbV7T+oX/dOlCR5vD49+tYGJcfFqDjvbL18U6b6JbfXJ1t2yevzHeHuwC9T4769evu56bpgzAS1OSku3NMBWkTIdxPs2LFD8+fP17x58w55TmFhoR566KGAsTMuu1k9ht0S6ukYq37/QXm8PnWwVAE6tInSTku14FA8Xp821u5VlziHf+yb2r266+U1iv13ZaB+/0E9NaK3vq5tfocC0Nq0Oam9bHa79lkWC+6t36XYuPgm59fVbNPu2iq9/vRPOwd8/w6/RTdfquunPae4xK4tOmccf7+M3+dDJ+gwsGjRosO+vnHjxiPeIz8/Xy6XK2Bs5AufBzsVHMZBr08bahqUltxeKzbvkvTjX+605PZ6fU3VUd3DbpNO7dhG5d/VNXltb6NHktS1vUOnd26rf3yyNVRTB1pURGSUEk89Q9+trVC3swdJknxer75fW6F+Fwxrcn58lxRdM3VOwNiKV+frwP59GnLNbWrXsfNxmTeOM8PSQNBhYMSIEbLZbP5k3BzbEXokDodDDocjYCwiKjrYqeAIXvu8ShN/dZq+rmnQ+uoGDT/LqZgou95ZVytJcg09TdsbDmh+2feSpN+f3VXrqvfohzq32jkidGVakhJPcmjp2hr/PQd3i1f9voOq3tOo1I5tdMvgU7Ri8059+n19s3MAWqP03Cv1znNPKjH1DDlP66nVb7+qg+796n3uxZKkt+dOV9v4Thp01Y2KjIpWp5NTA653xLaVpIDx/Xt2a/eOajXs2i5J2lX547+r2Lh4tY1jtw1at6DDQJcuXTR79mwNHz682dcrKiqUmZn5syeGn++Db3YoLiZS1/dPVnxslDbW7tWUxev92wU7t4uW9z8yXTtHhMafl6r42CjtcXu0oaZB97y2Vt/t2u8/p2NslG7OPkUd2kRq594Denf9di1Y9cPxfjTgZzlj4Pnat7tOZa/9XQ11O9U5pZuGTXzE3ybYvaNatiA/gm5TRalK5s3w/3npnEJJ0oArrlPWiFGhmzyOC9M+dMjmO9yv+M244oorlJ6erqlTpzb7+urVq5WRkSGv1xvURC5/ZmVQ5wMmyO176M+EAEw2fvBpLXr/so1N26PHamC31r/wNOjKwD333KOGhkMvFjv99NP13nvv/axJAQCA4yfoMDBkyJDDvt62bVudf/75xzwhAADCzawmAV9UBABAU4alAT6BEAAAw1EZAADAwrTdBIQBAAAsfiFfKRAyhAEAACwMywKsGQAAwHRUBgAAsDKsNEAYAADAwrQFhLQJAAAwHJUBAAAs2E0AAIDhDMsCtAkAAGhNioqKlJqaqpiYGGVlZamsrOyQ586dO1dDhgxRfHy84uPjlZOTc9jzD4UwAACAlS2ERxCKi4vlcrlUUFCgVatWKS0tTbm5uaqurm72/GXLlumaa67Re++9p9LSUqWkpOjiiy/W1q1bg3tcn8/nC26qLePyZ1aGewpAq5PbNyHcUwBapfGDT2vR+3/23Z6Q3atfSrujPjcrK0sDBgzQrFmzJEler1cpKSkaP368Jk2adMTrPR6P4uPjNWvWLI0ePfqo35fKAAAALcjtdqu+vj7gcLvdTc5rbGxUeXm5cnJy/GN2u105OTkqLS09qvfau3evDhw4oI4dOwY1R8IAAAAWNlvojsLCQsXFxQUchYWFTd6ztrZWHo9HTqczYNzpdKqysvKo5n3vvfeqa9euAYHiaLCbAAAAi1DuJsjPz5fL5QoYczgcIXyHHz322GNasGCBli1bppiYmKCuJQwAAGAVwjTgcDiO6od/QkKCIiIiVFVVFTBeVVWlpKSkw1775JNP6rHHHtM777yjfv36BT1H2gQAALQC0dHRyszMVElJiX/M6/WqpKRE2dnZh7zuiSee0MMPP6wlS5aof//+x/TeVAYAALAI13cTuFwujRkzRv3799fAgQM1c+ZMNTQ0KC8vT5I0evRoJScn+9ccPP7445oyZYpefPFFpaam+tcWtGvXTu3aHf0uBsIAAAAW4fo44pEjR6qmpkZTpkxRZWWl0tPTtWTJEv+iwi1btshu/6mo/5e//EWNjY266qqrAu5TUFCgBx988Kjfl88ZAFoxPmcAaF5Lf87Alz80hOxefbq2Ddm9WgqVAQAALEz7bgLCAAAAVoalAXYTAABgOCoDAABYhGs3QbgQBgAAsAjXboJwoU0AAIDhqAwAAGBhWGGAMAAAQBOGpQHCAAAAFqYtIGTNAAAAhqMyAACAhWm7CQgDAABYGJYFaBMAAGA6KgMAAFgZVhogDAAAYMFuAgAAYBQqAwAAWLCbAAAAwxmWBWgTAABgOioDAABYGVYaIAwAAGBh2m4CwgAAABamLSBkzQAAAIajMgAAgIVhhQHCAAAAVrQJAACAUagMAADQhFmlAcIAAAAWtAkAAIBRqAwAAGBhWGGAMAAAgBVtAgAAYBQqAwAAWPDdBAAAmM6sLEAYAADAyrAswJoBAABMR2UAAAAL03YTEAYAALAwbQEhbQIAAAxHZQAAACuzCgOEAQAArAzLArQJAAAwHZUBAAAs2E0AAIDh2E0AAACMQmUAAAAL09oEVAYAADAclQEAACyoDAAAAKNQGQAAwMK03QSEAQAALGgTAAAAo1AZAADAwrDCAGEAAIAmDEsDtAkAADAclQEAACzYTQAAgOHYTQAAAIxCZQAAAAvDCgOEAQAAmjAsDRAGAACwMG0BIWsGAAAwHJUBAAAsTNtNYPP5fL5wTwKth9vtVmFhofLz8+VwOMI9HaBV4N8FTnSEAQSor69XXFyc6urq1L59+3BPB2gV+HeBEx1rBgAAMBxhAAAAwxEGAAAwHGEAARwOhwoKClgkBfwH/l3gRMcCQgAADEdlAAAAwxEGAAAwHGEAAADDEQYAADAcYQB+RUVFSk1NVUxMjLKyslRWVhbuKQFh9f7772vYsGHq2rWrbDabXnvttXBPCWgRhAFIkoqLi+VyuVRQUKBVq1YpLS1Nubm5qq6uDvfUgLBpaGhQWlqaioqKwj0VoEWxtRCSpKysLA0YMECzZs2SJHm9XqWkpGj8+PGaNGlSmGcHhJ/NZtOrr76qESNGhHsqQMhRGYAaGxtVXl6unJwc/5jdbldOTo5KS0vDODMAwPFAGIBqa2vl8XjkdDoDxp1OpyorK8M0KwDA8UIYAADAcIQBKCEhQREREaqqqgoYr6qqUlJSUphmBQA4XggDUHR0tDIzM1VSUuIf83q9KikpUXZ2dhhnBgA4HiLDPQG0Di6XS2PGjFH//v01cOBAzZw5Uw0NDcrLywv31ICw2bNnjzZs2OD/86ZNm1RRUaGOHTvqlFNOCePMgNBiayH8Zs2apenTp6uyslLp6el6+umnlZWVFe5pAWGzbNkyDR06tMn4mDFj9Le//e34TwhoIYQBAAAMx5oBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcP8DgCSQtNkAFegAAAAASUVORK5CYII=",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"reports_dir = proj_path/'reports'\n",
"reports_dir.mkdir(exist_ok=True)\n",
"fig_dir = reports_dir/'figures'\n",
"fig_dir.mkdir(exist_ok=True)\n",
"\n",
"cm = confusion_matrix(y_test, y_pred, normalize='true') \n",
"sns.heatmap(cm, annot=True, cmap=plt.cm.Blues)\n",
"plt.savefig(fig_dir/'cm.png')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9d8adf04",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['CreditScore', 'Age', 'Tenure', 'Balance', 'NumOfProducts',\n",
" 'HasCrCard', 'IsActiveMember', 'EstimatedSalary'], dtype=object)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_feat_names = model[:-1].get_feature_names_out(feat_cols)\n",
"out_feat_names"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d9262a4f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
"\n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
" \n",
" \n",
" \n",
" Weight \n",
" Feature \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" 0.2444\n",
" \n",
" ± 0.0285\n",
" \n",
" \n",
" \n",
" Age\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" 0.1816\n",
" \n",
" ± 0.0108\n",
" \n",
" \n",
" \n",
" NumOfProducts\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" 0.1365\n",
" \n",
" ± 0.0315\n",
" \n",
" \n",
" \n",
" IsActiveMember\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" 0.0326\n",
" \n",
" ± 0.0050\n",
" \n",
" \n",
" \n",
" Balance\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" 0.0215\n",
" \n",
" ± 0.0136\n",
" \n",
" \n",
" \n",
" Tenure\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" 0.0143\n",
" \n",
" ± 0.0118\n",
" \n",
" \n",
" \n",
" CreditScore\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" 0.0057\n",
" \n",
" ± 0.0124\n",
" \n",
" \n",
" \n",
" EstimatedSalary\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" 0.0002\n",
" \n",
" ± 0.0083\n",
" \n",
" \n",
" \n",
" HasCrCard\n",
" \n",
" \n",
" \n",
" \n",
" \n",
"
\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import eli5\n",
"from sklearn.metrics import make_scorer\n",
"from eli5.sklearn import PermutationImportance\n",
"\n",
"\n",
"preprocessor = model.named_steps['preprocessor']\n",
"clf = model.named_steps['clf']\n",
"X_test_transformed = preprocessor.transform(X_test)\n",
"\n",
"perm = PermutationImportance(clf, scoring=make_scorer(f1_score), random_state=random_state).fit(X_test_transformed, y_test)\n",
"eli5.show_weights(perm, feature_names=out_feat_names)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "6e44b74c",
"metadata": {},
"outputs": [],
"source": [
"feat_imp = zip(X_test.columns.tolist(), perm.feature_importances_)\n",
"df_feat_imp = pd.DataFrame(feat_imp, \n",
" columns=['feature', 'importance'])\n",
"df_feat_imp = df_feat_imp.sort_values(by='importance', ascending=False)\n",
"feat_importance_fpath = reports_dir/'feat_imp.csv'\n",
"df_feat_imp.to_csv(feat_importance_fpath, index=False)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "dac3462a",
"metadata": {},
"outputs": [],
"source": [
"from joblib import dump\n",
"\n",
"models_dir = proj_path/'models'\n",
"models_dir.mkdir(exist_ok=True)\n",
"dump(model, models_dir/'clf-model.joblib');"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.10 ('.venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
},
"vscode": {
"interpreter": {
"hash": "060614c890ed22051a9be2360999a13d2882d827ad8c9dd21319e1709b800224"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}