{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "impaired-stylus", "metadata": {}, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import seaborn as sns\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.metrics import (confusion_matrix, f1_score, make_scorer,\n", " roc_auc_score)\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.pipeline import Pipeline\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "accepting-wallet", "metadata": {}, "outputs": [], "source": [ "proj_path = Path(os.getcwd()).parent.absolute()\n", "data_file_paths = [proj_path/'data'/'raw'/f'Churn_Modelling_{country}.csv' for country in ['Spain', 'France']]" ] }, { "cell_type": "code", "execution_count": 3, "id": "continued-juvenile", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CustomerIdSurnameCreditScoreGeographyGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExited
015647311Hill608SpainFemale41183807.86101112542.580
115737888Mitchell850SpainFemale432125510.8211179084.100
215574012Chu645SpainMale448113755.78210149756.711
315737173Andrews497SpainMale2430.0021076390.010
415600882Scott635SpainFemale3570.0021165951.650
\n", "
" ], "text/plain": [ " CustomerId Surname CreditScore Geography Gender Age Tenure \\\n", "0 15647311 Hill 608 Spain Female 41 1 \n", "1 15737888 Mitchell 850 Spain Female 43 2 \n", "2 15574012 Chu 645 Spain Male 44 8 \n", "3 15737173 Andrews 497 Spain Male 24 3 \n", "4 15600882 Scott 635 Spain Female 35 7 \n", "\n", " Balance NumOfProducts HasCrCard IsActiveMember EstimatedSalary \\\n", "0 83807.86 1 0 1 112542.58 \n", "1 125510.82 1 1 1 79084.10 \n", "2 113755.78 2 1 0 149756.71 \n", "3 0.00 2 1 0 76390.01 \n", "4 0.00 2 1 1 65951.65 \n", "\n", " Exited \n", "0 0 \n", "1 0 \n", "2 1 \n", "3 0 \n", "4 0 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.concat([pd.read_csv(fpath) for fpath in data_file_paths])\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 4, "id": "bb0d9f24", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(7491, 13)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.shape" ] }, { "cell_type": "code", "execution_count": 5, "id": "unnecessary-roots", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CustomerId 0\n", "Surname 0\n", "CreditScore 0\n", "Geography 0\n", "Gender 0\n", "Age 0\n", "Tenure 0\n", "Balance 0\n", "NumOfProducts 0\n", "HasCrCard 0\n", "IsActiveMember 0\n", "EstimatedSalary 0\n", "Exited 0\n", "dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Are there missing values?\n", "df.isna().sum()" ] }, { "cell_type": "code", "execution_count": 6, "id": "b025ccce", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Geography\n", "France 5014\n", "Spain 2477\n", "Name: count, dtype: int64" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['Geography'].value_counts()" ] }, { "cell_type": "code", "execution_count": 7, "id": "scheduled-measure", "metadata": {}, "outputs": [], "source": [ "feat_cols = ['CreditScore', 'Age', 'Tenure', \n", " 'Balance', 'NumOfProducts', 'HasCrCard',\n", " 'IsActiveMember', 'EstimatedSalary']\n", "targ_col = 'Exited'" ] }, { "cell_type": "code", "execution_count": 8, "id": "junior-rating", "metadata": {}, "outputs": [], "source": [ "X, y = df[feat_cols], df[targ_col]" ] }, { "cell_type": "code", "execution_count": 9, "id": "c4d476a6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.16326258176478442" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.mean()" ] }, { "cell_type": "code", "execution_count": 10, "id": "historic-doubt", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)" ] }, { "cell_type": "code", "execution_count": 11, "id": "infinite-african", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('preprocessor', SimpleImputer()),\n",
       "                ('clf', RandomForestClassifier(max_depth=10, random_state=42))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('preprocessor', SimpleImputer()),\n", " ('clf', RandomForestClassifier(max_depth=10, random_state=42))])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_state = 42\n", "train_params = {'n_estimators': 100, 'max_depth': 10}\n", "\n", "clf = RandomForestClassifier(random_state=random_state, \n", " **train_params)\n", "model = Pipeline(\n", " steps=[(\"preprocessor\", SimpleImputer()), (\"clf\", clf)]\n", " )\n", "\n", "model.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 12, "id": "12944f94", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.5474137931034483, 0.8684512806155847)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_prob = model.predict_proba(X_test)\n", "y_pred = y_prob[:, 1] >= 0.5\n", "f1 = f1_score(y_test, y_pred)\n", "roc_auc = roc_auc_score(y_test, y_prob[:, 1])\n", "f1, roc_auc" ] }, { "cell_type": "code", "execution_count": 13, "id": "4675491c", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAGdCAYAAACPX3D5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjb0lEQVR4nO3da3hU1d338d9MDhMCEgIhE4jRCMpJJIkBYkC0aDRWRai1Ug+AUfGESJnqI1EhikpUlHJbQlGUFtv6kEc8lEcR1Ci3osFgMKiIIAKiSE4cEggwgZm5X9h77OyEw+CEiazv57r2C9bsvWdtLzC//P9rzdh8Pp9PAADAWPZwTwAAAIQXYQAAAMMRBgAAMBxhAAAAwxEGAAAwHGEAAADDEQYAADAcYQAAAMMRBgAAMFxkuCfwv9pk3BnuKQCtzs6Vs8I9BaBVimnhn16h/Jm079PW/++41YQBAABaDZtZhXOznhYAADRBZQAAACubLdwzOK4IAwAAWBnWJiAMAABgZVhlwKzoAwAAmqAyAACAFW0CAAAMR5sAAACYhMoAAABWtAkAADAcbQIAAGASKgMAAFjRJgAAwHC0CQAAgEmoDAAAYEWbAAAAwxnWJiAMAABgZVhlwKynBQAATVAZAADAyrDKAGEAAAAru1lrBsyKPgAAoAkqAwAAWNEmAADAcIZtLTQr+gAAgCaoDAAAYEWbAAAAw9EmAAAAJqEyAACAFW0CAAAMZ1ibgDAAAICVYZUBs54WAAA0QWUAAAAr2gQAABiONgEAADAJlQEAAKxoEwAAYDjaBAAAwCRUBgAAsDKsMkAYAADAyrA1A2ZFHwAA0ASVAQAArGgTAABgOMPaBIQBAACsDKsMmPW0AACgCSoDAABY0SYAAMBsNsPCAG0CAAAMR2UAAAAL0yoDhAEAAKzMygK0CQAAMB2VAQAALGgTAABgONPCAG0CAAAMR2UAAAAL0yoDhAEAACwIAwAAmM6sLMCaAQAATEdlAAAAC9oEAAAYzrQwQJsAAADDURkAAMDCtMoAYQAAAAvTwgBtAgAADEdlAAAAK7MKA1QGAACwstlsITuCVVRUpNTUVMXExCgrK0tlZWWHPX/mzJnq2bOn2rRpo5SUFE2cOFH79+8P6j0JAwAAtBLFxcVyuVwqKCjQqlWrlJaWptzcXFVXVzd7/osvvqhJkyapoKBAa9eu1fPPP6/i4mLdd999Qb0vYQAAAItwVQZmzJihsWPHKi8vT3369NGcOXMUGxurefPmNXv+Rx99pMGDB+vaa69VamqqLr74Yl1zzTVHrCZYEQYAALAIZRhwu92qr68PONxud5P3bGxsVHl5uXJycvxjdrtdOTk5Ki0tbXaegwYNUnl5uf+H/8aNG7V48WJdeumlQT0vYQAAACtb6I7CwkLFxcUFHIWFhU3esra2Vh6PR06nM2Dc6XSqsrKy2Wlee+21mjp1qs4991xFRUWpe/fu+tWvfkWbAACA1iQ/P191dXUBR35+fkjuvWzZMk2bNk2zZ8/WqlWr9Morr+iNN97Qww8/HNR92FoIAIBFKD90yOFwyOFwHPG8hIQERUREqKqqKmC8qqpKSUlJzV4zefJkjRo1SjfffLMk6ayzzlJDQ4NuueUW3X///bLbj+53fioDAABYhGMBYXR0tDIzM1VSUuIf83q9KikpUXZ2drPX7N27t8kP/IiICEmSz+c76vemMgAAQCvhcrk0ZswY9e/fXwMHDtTMmTPV0NCgvLw8SdLo0aOVnJzsX3MwbNgwzZgxQxkZGcrKytKGDRs0efJkDRs2zB8KjgZhAAAAi3B9N8HIkSNVU1OjKVOmqLKyUunp6VqyZIl/UeGWLVsCKgEPPPCAbDabHnjgAW3dulWdO3fWsGHD9Oijjwb1vjZfMHWEFtQm485wTwFodXaunBXuKQCtUkwL/yrb9dZXQnavH565MmT3aimsGQAAwHC0CQAAsDLsi4oIAwAAWIRrzUC40CYAAMBwVAYAALAwrTJAGAAAwIIwAACA6czKAqwZAADAdFQGAACwMK1NQGXgBHfr1efpqzce0s4Vf9L7L9yt/meeeshzIyPtyr/lEq1ZVKCdK/6kj4sn6aJBvQPOsdttmnLHZVr7+oPaUTpDaxYVaNLYS1r6MYCfZcGL/9SvL7pAAzLO0nW//50+/+yzw57/1tI3NfzySzQg4yz9dsQwffD+fwe8/s7bb+nWsTfqvEFZSjuzp75auzbg9a1bv1famT2bPd5a+mbInw+hF44vKgonwsAJ7KqLz9bjf/yNHn3mTWVf+7g+W79Vi2aPU+f4ds2e/+Adw3Tzb8+V64mXlPHbR/TcwuUqfmqs0nqe7D/njzdcpLFXDdHEx15S+pWP6IGn/yXXmBzdcc35x+uxgKAseXOxnnyiULfeMU4LXnpVPXv20u233qTt27c3e37Fp6s06Z4/6jdXXqXiha9p6AUX6g/jx+nrr9f7z9m3b68yMs7WH1x3N3uPpKQuKlm2POC4fdx4xcbG6txzz2uR5wR+DsLACeyu6y/QX1/5SH9ftEJfbazU+EcXaN/+Ro0Z0fxXYV57+UA98fxbWrr8S23eul1zX1qupR9+qQmjLvCfc05aN73+359pyfI12rJth159p0IlK746bMUBCKe/z/+rrrzqao34zW/V/fTT9UDBQ4qJidFrr7zc7Pn//McLGnTuEN1w483q1r277rzrD+rdp48WvPgP/znDrhih2+64U1mH+FrZiIgIJXTuHHC8W/KOLr7k14pt27ZFnhOhRWUAJ4SoyAhl9E7Rux+v84/5fD69+/E6Dex3WrPXREdFan/jgYCxffsbNSiju//PK1Zv1NCBPXX6KYmSpLN6JCs7vZve+vDLFngK4Oc50NiotV+u0TnZg/xjdrtd55wzSJ+t/rTZaz6rqNA55wT+kB80+Fx9VlFxzPP4cs0XWvfVWv3myquO+R44vkwLA0EvIKytrdW8efNUWlqqyspKSVJSUpIGDRqkG264QZ07dw75JBG8hPh2ioyMUPWO3QHj1dvr1TPV2ew175Su1V3XX6DlqzZo43e1Gjqwp4ZfkK6IiJ/+Mj/517fVvl2MVr/6gDwenyIibCooel0L3vykRZ8HOBY7d+2Ux+NRp06dAsY7deqkTZs2NntNbW2tOnVKaHJ+7fbaY57Hqy8vVLdu3ZWecfYx3wNoSUGFgZUrVyo3N1exsbHKyclRjx49JElVVVV6+umn9dhjj2np0qXq37//Ye/jdrvldrsDxnxej2z2iCCnj1C6e/pCzZ58jVa/Mlk+n08bv6/VC4tWaMzwc/znXHXx2fr9rwfohvvm68tvtqlfz2RNv/sqbaup0z///8dhnD3QOu3fv19vLn5dY2+7I9xTQTB+Gb/Qh0xQYWD8+PH63e9+pzlz5jQpffh8Pt12220aP368SktLD3ufwsJCPfTQQwFjEc4BiuoyMJjp4DBqd+7RwYMeJXY8KWA8sVN7VW6vP+Q1V7vmyhEdqU5xbfVDTZ0euWu4Nm39aaHVtD+M0JN/fVsvLS2XJK3Z8INO6dJR9+RdRBhAqxPfIV4RERFNFgtu375dCQkJzV6TkJCg7ZYqwPbt25XQqfnzj+Ttt5Zo3779GnbFiGO6HuHxSynvh0pQawZWr16tiRMnNvsfyWazaeLEiao4ir5afn6+6urqAo5IZ2YwU8ERHDjo0adrv9PQrJ7+MZvNpqEDe6jss02HvdbdeFA/1NQpMtKuERem6/VlP23DahMTLa/PG3C+x+uT3c7yE7Q+UdHR6t3nTH284qdfULxerz7+uFT90jKavaZfero+XrEiYGxF6Ufql55+THN47ZWX9auhF6hjx47HdD1wPARVGUhKSlJZWZl69erV7OtlZWVyOpvvR/8nh8Mhh8MRMEaLIPSe/se7mjt1lMq/3KJPvtisO68dqtg2Dr3wrx//R/fcw6P0Q3Wdpvx5kSRpQN9T1TWxg1av+17JiR10/62Xym63acbf3vHfc/H7n+vem3L13bad+vKbbUrvdbLuun6oXnhtRbNzAMJt1Jg8Tb7vXp15Zl/1Pauf/vH3+dq3b59G/OZKSdL9+f9HiYlOTZj4R0nSddeP1k03jNL8v83TeeedryVvLtaaL77Q5Aen+u9Zt2uXtm3bppqaaknS5s0/BuyEhAQl/Me6qS3ffqvyT1aq6C/PHq/HRYiYVhkIKgzcfffduuWWW1ReXq4LL7zQ/4O/qqpKJSUlmjt3rp588skWmSiCt/CtVUqIb6cpt18mZ6eT9Nm6rRo+rsi/qDAlqaO8Xp//fIcjSgXjLtdpyQnas9etpR+u0U2TX1Ddnn3+c1yPv6SCOy7Xf903Up3j22lbTZ2eX/ihpj3LB6mgdbrk15dq544dmj3radXW1qhnr96a/cxz6vTvNkHltm2y236qbKVnnK3CJ57UrKdn6s8zZ+iUU1M1889FOuOMHv5zlr33rqY8kO//8713T5Qk3XbHnbp93Hj/+GuvviynM0nZg89t6cdEiBmWBWTz+Xy+I5/2k+LiYv3pT39SeXm5PB6PpB/31GZmZsrlcunqq68+pom0ybjzmK4DTmQ7V84K9xSAVimmhT9M/4x7loTsXl9Pb/2f0hr0f86RI0dq5MiROnDggGprf1xkk5CQoKioqJBPDgAAtLxjzlZRUVHq0qVLKOcCAECrYFqbgG8tBADAwrQFhOwHAwDAcFQGAACwMKwwQBgAAMDKbjcrDdAmAADAcFQGAACwoE0AAIDh2E0AAACMQmUAAAALwwoDhAEAAKxMaxMQBgAAsDAtDLBmAAAAw1EZAADAwrDCAGEAAAAr2gQAAMAoVAYAALAwrDBAGAAAwIo2AQAAMAqVAQAALAwrDBAGAACwok0AAACMQmUAAAALwwoDhAEAAKxMaxMQBgAAsDAsC7BmAAAA01EZAADAgjYBAACGMywL0CYAAMB0VAYAALCgTQAAgOEMywK0CQAAMB2VAQAALGgTAABgONPCAG0CAAAMR2UAAAALwwoDhAEAAKxMaxMQBgAAsDAsC7BmAAAA01EZAADAgjYBAACGMywL0CYAAMB0VAYAALCwG1YaIAwAAGBhWBagTQAAQGtSVFSk1NRUxcTEKCsrS2VlZYc9f9euXRo3bpy6dOkih8OhHj16aPHixUG9J5UBAAAswrWboLi4WC6XS3PmzFFWVpZmzpyp3NxcrVu3TomJiU3Ob2xs1EUXXaTExEQtXLhQycnJ+vbbb9WhQ4eg3pcwAACAhT1MbYIZM2Zo7NixysvLkyTNmTNHb7zxhubNm6dJkyY1OX/evHnasWOHPvroI0VFRUmSUlNTg35f2gQAAFjYbLaQHW63W/X19QGH2+1u8p6NjY0qLy9XTk6Of8xutysnJ0elpaXNznPRokXKzs7WuHHj5HQ61bdvX02bNk0ejyeo5yUMAADQggoLCxUXFxdwFBYWNjmvtrZWHo9HTqczYNzpdKqysrLZe2/cuFELFy6Ux+PR4sWLNXnyZD311FN65JFHgpojbQIAACxCuWQgPz9fLpcrYMzhcITk3l6vV4mJiXr22WcVERGhzMxMbd26VdOnT1dBQcFR34cwAACAhU2hSwMOh+OofvgnJCQoIiJCVVVVAeNVVVVKSkpq9pouXbooKipKERER/rHevXursrJSjY2Nio6OPqo50iYAAKAViI6OVmZmpkpKSvxjXq9XJSUlys7ObvaawYMHa8OGDfJ6vf6x9evXq0uXLkcdBCTCAAAATdhtoTuC4XK5NHfuXM2fP19r167V7bffroaGBv/ugtGjRys/P99//u23364dO3ZowoQJWr9+vd544w1NmzZN48aNC+p9aRMAAGARrs8ZGDlypGpqajRlyhRVVlYqPT1dS5Ys8S8q3LJli+z2n36PT0lJ0dKlSzVx4kT169dPycnJmjBhgu69996g3tfm8/l8IX2SY9Qm485wTwFodXaunBXuKQCtUkwL/yo7fO4nIbvXv8b2D9m9WgqVAQAALEz7bgLCAAAAFqZ9ayELCAEAMByVAQAALAwrDBAGAACwCtdugnAhDAAAYGFYFmDNAAAApqMyAACAhWm7CQgDAABYmBUFaBMAAGA8KgMAAFiwmwAAAMMF+22Dv3S0CQAAMByVAQAALGgTAABgOMOyAG0CAABMR2UAAAAL2gQAABjOtN0EhAEAACxMqwywZgAAAMNRGQAAwMKsugBhAACAJkz71kLaBAAAGI7KAAAAFoYVBggDAABYsZsAAAAYhcoAAAAWhhUGCAMAAFixmwAAABiFygAAABaGFQYIAwAAWJm2m6DVhIELbxsT7ikArc7cjzeFewpAqzR+8Gkten/TeuimPS8AALBoNZUBAABaC9oEAAAYzm5WFqBNAACA6agMAABgYVplgDAAAICFaWsGaBMAAGA4KgMAAFjQJgAAwHCGdQloEwAAYDoqAwAAWJj2FcaEAQAALEwrmxMGAACwMKwwYFz4AQAAFlQGAACwYM0AAACGMywL0CYAAMB0VAYAALDgEwgBADCcaWsGaBMAAGA4KgMAAFgYVhggDAAAYGXamgHaBAAAGI7KAAAAFjaZVRogDAAAYGFam4AwAACAhWlhgDUDAAAYjsoAAAAWNsP2FhIGAACwoE0AAACMQmUAAAALw7oEhAEAAKz4oiIAAGAUKgMAAFiwgBAAAMPZbKE7glVUVKTU1FTFxMQoKytLZWVlR3XdggULZLPZNGLEiKDfkzAAAEArUVxcLJfLpYKCAq1atUppaWnKzc1VdXX1Ya/bvHmz7r77bg0ZMuSY3pcwAACAhV22kB3BmDFjhsaOHau8vDz16dNHc+bMUWxsrObNm3fIazwej6677jo99NBD6tat2zE+LwAACBDKNoHb7VZ9fX3A4Xa7m7xnY2OjysvLlZOT4x+z2+3KyclRaWnpIec6depUJSYm6qabbjrm5yUMAABgYbeF7igsLFRcXFzAUVhY2OQ9a2tr5fF45HQ6A8adTqcqKyubnefy5cv1/PPPa+7cuT/redlNAABAC8rPz5fL5QoYczgcP/u+u3fv1qhRozR37lwlJCT8rHsRBgAAsAjlhw45HI6j+uGfkJCgiIgIVVVVBYxXVVUpKSmpyfnffPONNm/erGHDhvnHvF6vJCkyMlLr1q1T9+7dj2qOtAkAALAIx9bC6OhoZWZmqqSkxD/m9XpVUlKi7OzsJuf36tVLn3/+uSoqKvzHFVdcoaFDh6qiokIpKSlH/d5UBgAAaCVcLpfGjBmj/v37a+DAgZo5c6YaGhqUl5cnSRo9erSSk5NVWFiomJgY9e3bN+D6Dh06SFKT8SMhDAAAYBGu7yYYOXKkampqNGXKFFVWVio9PV1LlizxLyrcsmWL7PbQF/VtPp/PF/K7HoPLn1kZ7ikArU5u35+3KAg4UY0ffFqL3n/eyi0hu9eNA04J2b1aCmsGAAAwHG0CAAAsTPtNmTAAAICFLUxrBsLFtPADAAAsqAwAAGBhVl2AMAAAQBPh2loYLoQBAAAszIoCrBkAAMB4VAYAALAwrEtAGAAAwIqthQAAwChUBgAAsDDtN2XCAAAAFrQJAACAUagMAABgYVZdgDAAAEATtAkAAIBRqAwAAGBh2m/KhAEAACxMaxMQBgAAsDArCphXCQEAABZUBgAAsDCsS0AYAADAym5Yo4A2AQAAhqMyAACABW0CAAAMZ6NNAAAATEJlAAAAC9oEAAAYjt0EAADAKFQGAACwoE0AAIDhCAMAABiOrYUAAMAoVAYAALCwm1UYIAwAAGBFmwAAABiFygAAABbsJgAAwHC0CQAAgFGoDAAAYMFuApxQLjszUVemJSm+TZQ2bd+rZz7covU1Dc2ee2GPTpo4tFvAWONBr658vtz/5w5tInVDVooyTm6vttERWlO5R88s/1Y/1Ltb9DmAUPusZJE+XbJQe+t2KiGlm8677g45u/U84nXrP16mt555TKdlZOuy8QX+8W/Kl+uLZYtVvflruRt2a+SDRep8SveWfAS0INoEOGEM6d5RN2en6P+W/6AJL6/Rph17NfWyHoqLOXQGbHAf1PUvfOo/bnxxdcDrD+SeoaT2Dj2ydIMmvPylqne79cjlPeWI5K8Sfjm+LvtvLS+eqwFXXK+RBbPUKaWbFs24X3vrdx32uvraSn34/55T1x59m7x2wL1fXc44U4N+d2MLzRpoOfwf/AQ24iynlq6t0TvravXdrv0qev9buQ96dVGvhENe45O0a9/BgON/dY1zqJeznWZ/sFlf1zRoa91+zf7gW0VH2nX+6R2PwxMBoVGx9BWded4l6jPkYnVMPlVDR49XZLRDaz9YeshrvF6P3nr2CWUNv17tOyc1eb3XoBwNvOI6pfTJaMmp4zix2UJ3/BIQBk5QkXabTu/cVhVb6/1jPkkV39erl7PdIa9rExWhedf201+vS9MDuafrlPgY/2tRET/+dWn0+ALuecDjU5+kk0L+DEBL8Bw8oOpvvw74oW2z23VynwxVfrP2kNetXPSiYk/qoD7nXXI8pokws4Xw+CUgDJyg2sdEKsJu0659BwLGd+07oPg2Uc1es7Vuv/5r2SY9vPRrPfXuRtltNk0f3lud2v54/ve79qt6t1tjBp6sttERirTb9Nu0JHVuF62Osc3fE2ht9u2ul8/rVZv2HQLGY9t30N66nc1e88P6L/TlB0s19IYJx2GGaA3sNlvIjl+CkIeB7777TjfeePiemdvtVn19fcDhOdAY6qkgSF9VNejdr7dr0/Z9+mLbbj361gbV7T+oX/dOlCR5vD49+tYGJcfFqDjvbL18U6b6JbfXJ1t2yevzHeHuwC9T4769evu56bpgzAS1OSku3NMBWkTIdxPs2LFD8+fP17x58w55TmFhoR566KGAsTMuu1k9ht0S6ukYq37/QXm8PnWwVAE6tInSTku14FA8Xp821u5VlziHf+yb2r266+U1iv13ZaB+/0E9NaK3vq5tfocC0Nq0Oam9bHa79lkWC+6t36XYuPgm59fVbNPu2iq9/vRPOwd8/w6/RTdfquunPae4xK4tOmccf7+M3+dDJ+gwsGjRosO+vnHjxiPeIz8/Xy6XK2Bs5AufBzsVHMZBr08bahqUltxeKzbvkvTjX+605PZ6fU3VUd3DbpNO7dhG5d/VNXltb6NHktS1vUOnd26rf3yyNVRTB1pURGSUEk89Q9+trVC3swdJknxer75fW6F+Fwxrcn58lxRdM3VOwNiKV+frwP59GnLNbWrXsfNxmTeOM8PSQNBhYMSIEbLZbP5k3BzbEXokDodDDocjYCwiKjrYqeAIXvu8ShN/dZq+rmnQ+uoGDT/LqZgou95ZVytJcg09TdsbDmh+2feSpN+f3VXrqvfohzq32jkidGVakhJPcmjp2hr/PQd3i1f9voOq3tOo1I5tdMvgU7Ri8059+n19s3MAWqP03Cv1znNPKjH1DDlP66nVb7+qg+796n3uxZKkt+dOV9v4Thp01Y2KjIpWp5NTA653xLaVpIDx/Xt2a/eOajXs2i5J2lX547+r2Lh4tY1jtw1at6DDQJcuXTR79mwNHz682dcrKiqUmZn5syeGn++Db3YoLiZS1/dPVnxslDbW7tWUxev92wU7t4uW9z8yXTtHhMafl6r42CjtcXu0oaZB97y2Vt/t2u8/p2NslG7OPkUd2kRq594Denf9di1Y9cPxfjTgZzlj4Pnat7tOZa/9XQ11O9U5pZuGTXzE3ybYvaNatiA/gm5TRalK5s3w/3npnEJJ0oArrlPWiFGhmzyOC9M+dMjmO9yv+M244oorlJ6erqlTpzb7+urVq5WRkSGv1xvURC5/ZmVQ5wMmyO176M+EAEw2fvBpLXr/so1N26PHamC31r/wNOjKwD333KOGhkMvFjv99NP13nvv/axJAQCA4yfoMDBkyJDDvt62bVudf/75xzwhAADCzawmAV9UBABAU4alAT6BEAAAw1EZAADAwrTdBIQBAAAsfiFfKRAyhAEAACwMywKsGQAAwHRUBgAAsDKsNEAYAADAwrQFhLQJAAAwHJUBAAAs2E0AAIDhDMsCtAkAAGhNioqKlJqaqpiYGGVlZamsrOyQ586dO1dDhgxRfHy84uPjlZOTc9jzD4UwAACAlS2ERxCKi4vlcrlUUFCgVatWKS0tTbm5uaqurm72/GXLlumaa67Re++9p9LSUqWkpOjiiy/W1q1bg3tcn8/nC26qLePyZ1aGewpAq5PbNyHcUwBapfGDT2vR+3/23Z6Q3atfSrujPjcrK0sDBgzQrFmzJEler1cpKSkaP368Jk2adMTrPR6P4uPjNWvWLI0ePfqo35fKAAAALcjtdqu+vj7gcLvdTc5rbGxUeXm5cnJy/GN2u105OTkqLS09qvfau3evDhw4oI4dOwY1R8IAAAAWNlvojsLCQsXFxQUchYWFTd6ztrZWHo9HTqczYNzpdKqysvKo5n3vvfeqa9euAYHiaLCbAAAAi1DuJsjPz5fL5QoYczgcIXyHHz322GNasGCBli1bppiYmKCuJQwAAGAVwjTgcDiO6od/QkKCIiIiVFVVFTBeVVWlpKSkw1775JNP6rHHHtM777yjfv36BT1H2gQAALQC0dHRyszMVElJiX/M6/WqpKRE2dnZh7zuiSee0MMPP6wlS5aof//+x/TeVAYAALAI13cTuFwujRkzRv3799fAgQM1c+ZMNTQ0KC8vT5I0evRoJScn+9ccPP7445oyZYpefPFFpaam+tcWtGvXTu3aHf0uBsIAAAAW4fo44pEjR6qmpkZTpkxRZWWl0tPTtWTJEv+iwi1btshu/6mo/5e//EWNjY266qqrAu5TUFCgBx988Kjfl88ZAFoxPmcAaF5Lf87Alz80hOxefbq2Ddm9WgqVAQAALEz7bgLCAAAAVoalAXYTAABgOCoDAABYhGs3QbgQBgAAsAjXboJwoU0AAIDhqAwAAGBhWGGAMAAAQBOGpQHCAAAAFqYtIGTNAAAAhqMyAACAhWm7CQgDAABYGJYFaBMAAGA6KgMAAFgZVhogDAAAYMFuAgAAYBQqAwAAWLCbAAAAwxmWBWgTAABgOioDAABYGVYaIAwAAGBh2m4CwgAAABamLSBkzQAAAIajMgAAgIVhhQHCAAAAVrQJAACAUagMAADQhFmlAcIAAAAWtAkAAIBRqAwAAGBhWGGAMAAAgBVtAgAAYBQqAwAAWPDdBAAAmM6sLEAYAADAyrAswJoBAABMR2UAAAAL03YTEAYAALAwbQEhbQIAAAxHZQAAACuzCgOEAQAArAzLArQJAAAwHZUBAAAs2E0AAIDh2E0AAACMQmUAAAAL09oEVAYAADAclQEAACyoDAAAAKNQGQAAwMK03QSEAQAALGgTAAAAo1AZAADAwrDCAGEAAIAmDEsDtAkAADAclQEAACzYTQAAgOHYTQAAAIxCZQAAAAvDCgOEAQAAmjAsDRAGAACwMG0BIWsGAAAwHJUBAAAsTNtNYPP5fL5wTwKth9vtVmFhofLz8+VwOMI9HaBV4N8FTnSEAQSor69XXFyc6urq1L59+3BPB2gV+HeBEx1rBgAAMBxhAAAAwxEGAAAwHGEAARwOhwoKClgkBfwH/l3gRMcCQgAADEdlAAAAwxEGAAAwHGEAAADDEQYAADAcYQB+RUVFSk1NVUxMjLKyslRWVhbuKQFh9f7772vYsGHq2rWrbDabXnvttXBPCWgRhAFIkoqLi+VyuVRQUKBVq1YpLS1Nubm5qq6uDvfUgLBpaGhQWlqaioqKwj0VoEWxtRCSpKysLA0YMECzZs2SJHm9XqWkpGj8+PGaNGlSmGcHhJ/NZtOrr76qESNGhHsqQMhRGYAaGxtVXl6unJwc/5jdbldOTo5KS0vDODMAwPFAGIBqa2vl8XjkdDoDxp1OpyorK8M0KwDA8UIYAADAcIQBKCEhQREREaqqqgoYr6qqUlJSUphmBQA4XggDUHR0tDIzM1VSUuIf83q9KikpUXZ2dhhnBgA4HiLDPQG0Di6XS2PGjFH//v01cOBAzZw5Uw0NDcrLywv31ICw2bNnjzZs2OD/86ZNm1RRUaGOHTvqlFNOCePMgNBiayH8Zs2apenTp6uyslLp6el6+umnlZWVFe5pAWGzbNkyDR06tMn4mDFj9Le//e34TwhoIYQBAAAMx5oBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcP8DgCSQtNkAFegAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "reports_dir = proj_path/'reports'\n", "reports_dir.mkdir(exist_ok=True)\n", "fig_dir = reports_dir/'figures'\n", "fig_dir.mkdir(exist_ok=True)\n", "\n", "cm = confusion_matrix(y_test, y_pred, normalize='true') \n", "sns.heatmap(cm, annot=True, cmap=plt.cm.Blues)\n", "plt.savefig(fig_dir/'cm.png')" ] }, { "cell_type": "code", "execution_count": 14, "id": "9d8adf04", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['CreditScore', 'Age', 'Tenure', 'Balance', 'NumOfProducts',\n", " 'HasCrCard', 'IsActiveMember', 'EstimatedSalary'], dtype=object)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out_feat_names = model[:-1].get_feature_names_out(feat_cols)\n", "out_feat_names" ] }, { "cell_type": "code", "execution_count": 15, "id": "d9262a4f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
WeightFeature
\n", " 0.2444\n", " \n", " ± 0.0285\n", " \n", " \n", " Age\n", "
\n", " 0.1816\n", " \n", " ± 0.0108\n", " \n", " \n", " NumOfProducts\n", "
\n", " 0.1365\n", " \n", " ± 0.0315\n", " \n", " \n", " IsActiveMember\n", "
\n", " 0.0326\n", " \n", " ± 0.0050\n", " \n", " \n", " Balance\n", "
\n", " 0.0215\n", " \n", " ± 0.0136\n", " \n", " \n", " Tenure\n", "
\n", " 0.0143\n", " \n", " ± 0.0118\n", " \n", " \n", " CreditScore\n", "
\n", " 0.0057\n", " \n", " ± 0.0124\n", " \n", " \n", " EstimatedSalary\n", "
\n", " 0.0002\n", " \n", " ± 0.0083\n", " \n", " \n", " HasCrCard\n", "
\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import eli5\n", "from sklearn.metrics import make_scorer\n", "from eli5.sklearn import PermutationImportance\n", "\n", "\n", "preprocessor = model.named_steps['preprocessor']\n", "clf = model.named_steps['clf']\n", "X_test_transformed = preprocessor.transform(X_test)\n", "\n", "perm = PermutationImportance(clf, scoring=make_scorer(f1_score), random_state=random_state).fit(X_test_transformed, y_test)\n", "eli5.show_weights(perm, feature_names=out_feat_names)" ] }, { "cell_type": "code", "execution_count": 16, "id": "6e44b74c", "metadata": {}, "outputs": [], "source": [ "feat_imp = zip(X_test.columns.tolist(), perm.feature_importances_)\n", "df_feat_imp = pd.DataFrame(feat_imp, \n", " columns=['feature', 'importance'])\n", "df_feat_imp = df_feat_imp.sort_values(by='importance', ascending=False)\n", "feat_importance_fpath = reports_dir/'feat_imp.csv'\n", "df_feat_imp.to_csv(feat_importance_fpath, index=False)" ] }, { "cell_type": "code", "execution_count": 17, "id": "dac3462a", "metadata": {}, "outputs": [], "source": [ "from joblib import dump\n", "\n", "models_dir = proj_path/'models'\n", "models_dir.mkdir(exist_ok=True)\n", "dump(model, models_dir/'clf-model.joblib');" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.10 ('.venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" }, "vscode": { "interpreter": { "hash": "060614c890ed22051a9be2360999a13d2882d827ad8c9dd21319e1709b800224" } } }, "nbformat": 4, "nbformat_minor": 5 }