Loading¶

In [29]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from mpl_toolkits.basemap import Basemap
from sklearn.impute import KNNImputer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from imblearn.over_sampling import SMOTE
from imblearn.combine import SMOTETomek
import xgboost as xgb
from sklearn.metrics import roc_curve,roc_auc_score
#load_df = pd.read_csv("Killed_and_Seriously_Injured.csv")
load_df = pd.read_csv("allfilter_injury_data2.csv")
load_df.head()
Out[29]:
X Y OBJECTID INDEX_ ACCNUM DATE TIME STREET1 STREET2 OFFSET ... SPEEDING AG_DRIV REDLIGHT ALCOHOL DISABILITY HOOD_158 NEIGHBOURHOOD_158 HOOD_140 NEIGHBOURHOOD_140 DIVISION
0 642702.4974 4.855938e+06 20 3363207 882024.0 2006/01/07 10:00:00+00 2325 STEELES AVE E NINTH LINE ST NaN ... NaN NaN NaN NaN NaN 144 Morningside Heights 131 Rouge (131) D42
1 616144.1868 4.841944e+06 32 3363869 882497.0 2006/01/08 10:00:00+00 1828 ISLINGTON AVE GOLFDOWN DR NaN ... NaN Yes NaN NaN NaN 5 Elms-Old Rexdale 5 Elms-Old Rexdale (5) D23
2 638249.2383 4.847699e+06 35 3363416 882174.0 2006/01/09 10:00:00+00 1435 KENNEDY RD GLAMORGAN AVE NaN ... NaN NaN NaN NaN NaN 126 Dorset Park 126 Dorset Park (126) D41
3 636288.2909 4.842392e+06 43 3363879 882501.0 2006/01/11 10:00:00+00 1120 BARTLEY DR JINNAH CRT NaN ... Yes Yes NaN NaN NaN 43 Victoria Village 43 Victoria Village (43) D55
4 638765.5901 4.848810e+06 63 3371161 886230.0 2006/01/21 10:00:00+00 1829 MIDLAND AVE GOODLAND GT NaN ... Yes Yes NaN NaN NaN 128 Agincourt South-Malvern West 128 Agincourt South-Malvern West (128) D42

5 rows × 54 columns

ETL: persons to incidents¶

In [30]:
# fatal_rows = (load_df['ACCLASS'] == 'Fatal') & (load_df['INJURY'] == 'Fatal')
# df_fatal = load_df.loc[fatal_rows]
# # df_fatal = df_fatal.drop_duplicates(subset=['ACCNUM'])
# no_fatal_row = (load_df['ACCLASS'] == 'Non-Fatal Injury')
# df_non_fatal = load_df.loc[no_fatal_row]
# df_non_fatal = df_non_fatal.drop_duplicates(subset=['ACCNUM'])
# df_final = pd.concat([df_fatal, df_non_fatal], ignore_index=True)
# df_final.to_csv('allfilter_injury_data2.csv', index=False)

EDA: exploring data initially¶

In [31]:
load_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5299 entries, 0 to 5298
Data columns (total 54 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   X                  5299 non-null   float64
 1   Y                  5299 non-null   float64
 2   OBJECTID           5299 non-null   int64  
 3   INDEX_             5299 non-null   int64  
 4   ACCNUM             4962 non-null   float64
 5   DATE               5299 non-null   object 
 6   TIME               5299 non-null   int64  
 7   STREET1            5299 non-null   object 
 8   STREET2            4804 non-null   object 
 9   OFFSET             732 non-null    object 
 10  ROAD_CLASS         5155 non-null   object 
 11  DISTRICT           5208 non-null   object 
 12  LATITUDE           5299 non-null   float64
 13  LONGITUDE          5299 non-null   float64
 14  ACCLOC             3452 non-null   object 
 15  TRAFFCTL           5269 non-null   object 
 16  VISIBILITY         5287 non-null   object 
 17  LIGHT              5297 non-null   object 
 18  RDSFCOND           5286 non-null   object 
 19  ACCLASS            5299 non-null   object 
 20  IMPACTYPE          5290 non-null   object 
 21  INVTYPE            5296 non-null   object 
 22  INVAGE             5299 non-null   object 
 23  INJURY             2487 non-null   object 
 24  FATAL_NO           864 non-null    float64
 25  INITDIR            4000 non-null   object 
 26  VEHTYPE            4816 non-null   object 
 27  MANOEUVER          3448 non-null   object 
 28  DRIVACT            3194 non-null   object 
 29  DRIVCOND           3192 non-null   object 
 30  PEDTYPE            672 non-null    object 
 31  PEDACT             673 non-null    object 
 32  PEDCOND            667 non-null    object 
 33  CYCLISTYPE         109 non-null    object 
 34  CYCACT             114 non-null    object 
 35  CYCCOND            113 non-null    object 
 36  PEDESTRIAN         2402 non-null   object 
 37  CYCLIST            623 non-null    object 
 38  AUTOMOBILE         4691 non-null   object 
 39  MOTORCYCLE         530 non-null    object 
 40  TRUCK              296 non-null    object 
 41  TRSN_CITY_VEH      282 non-null    object 
 42  EMERG_VEH          7 non-null      object 
 43  PASSENGER          1223 non-null   object 
 44  SPEEDING           660 non-null    object 
 45  AG_DRIV            2547 non-null   object 
 46  REDLIGHT           350 non-null    object 
 47  ALCOHOL            208 non-null    object 
 48  DISABILITY         145 non-null    object 
 49  HOOD_158           5299 non-null   object 
 50  NEIGHBOURHOOD_158  5299 non-null   object 
 51  HOOD_140           5299 non-null   object 
 52  NEIGHBOURHOOD_140  5299 non-null   object 
 53  DIVISION           5299 non-null   object 
dtypes: float64(6), int64(3), object(45)
memory usage: 2.2+ MB
In [32]:
print("\nMissing values:")
print(load_df.isnull().sum())
Missing values:
X                       0
Y                       0
OBJECTID                0
INDEX_                  0
ACCNUM                337
DATE                    0
TIME                    0
STREET1                 0
STREET2               495
OFFSET               4567
ROAD_CLASS            144
DISTRICT               91
LATITUDE                0
LONGITUDE               0
ACCLOC               1847
TRAFFCTL               30
VISIBILITY             12
LIGHT                   2
RDSFCOND               13
ACCLASS                 0
IMPACTYPE               9
INVTYPE                 3
INVAGE                  0
INJURY               2812
FATAL_NO             4435
INITDIR              1299
VEHTYPE               483
MANOEUVER            1851
DRIVACT              2105
DRIVCOND             2107
PEDTYPE              4627
PEDACT               4626
PEDCOND              4632
CYCLISTYPE           5190
CYCACT               5185
CYCCOND              5186
PEDESTRIAN           2897
CYCLIST              4676
AUTOMOBILE            608
MOTORCYCLE           4769
TRUCK                5003
TRSN_CITY_VEH        5017
EMERG_VEH            5292
PASSENGER            4076
SPEEDING             4639
AG_DRIV              2752
REDLIGHT             4949
ALCOHOL              5091
DISABILITY           5154
HOOD_158                0
NEIGHBOURHOOD_158       0
HOOD_140                0
NEIGHBOURHOOD_140       0
DIVISION                0
dtype: int64

Transfrom columns:¶

'TRAFFCTL', 'VISIBILITY', 'LIGHT', 'RDSFCOND', 'ACCLASS', 'IMPACTYPE', 'INVTYPE', 'INVAGE', 'VEHTYPE', 'ALCOHOL'

In [33]:
'''
    1: 'Small Vehicles',
    2: 'Trucks and Vans',
    3: 'Public Transit',
    4: 'Emergency and Unknown',
    5: 'Special Equipment',
    6: 'Off-Road',
    7: 'Bicycles and Mopeds',
    8: 'Motorcycles',
    9: 'Rickshaws',
    10: 'Others'
'''
load_df["VEHTYPE"] = load_df["VEHTYPE"].fillna('Other')

classification = {
    'Automobile, Station Wagon': 1,
    'Bicycle': 7,
    'Motorcycle': 8,
    'Pick Up Truck': 1,
    'Passenger Van': 1,
    'Taxi': 1,
    'Moped': 7,
    'Delivery Van': 2,
    'Truck - Open': 2,
    'Truck - Closed (Blazer, etc)': 2,
    'Truck - Dump': 2,
    'Truck-Tractor': 2,
    'Truck (other)': 2,
    'Truck - Tank': 2,
    'Tow Truck': 2,
    'Truck - Car Carrier': 2,
    'Municipal Transit Bus (TTC)': 3,
    'Street Car': 3,
    'Bus (Other) (Go Bus, Gray Coa': 3,
    'Intercity Bus': 3,
    'School Bus': 3,
    'Other': 10,
    'Unknown': 4,
    'Police Vehicle': 4,
    'Fire Vehicle': 4,
    'Other Emergency Vehicle': 4,
    'Construction Equipment': 5,
    'Rickshaw': 9,
    'Ambulance': 4,
    'Off Road - 2 Wheels': 6,
    'Off Road - 4 Wheels': 6,
    'Off Road - Other': 6
}


load_df['VEHTYPE'] = load_df['VEHTYPE'].map(classification)
load_df['VEHTYPE'].value_counts()
Out[33]:
VEHTYPE
1     2741
10    1923
8      351
7      127
2       96
3       54
4        5
6        1
9        1
Name: count, dtype: int64
In [34]:
'''
1Normal
2Impaired (includes inattentive, medical or physical disability, had been drinking, alcohol impairment, drug impairment)
3Other (includes other and fatigue)
'''
load_df["DRIVCOND"] = load_df["DRIVCOND"].fillna('Other')
drivcond_classification = {
    'Normal': 1,
    'Inattentive': 2,
    'Unknown': 2,
    'Medical or Physical Disability': 2,
    'Had Been Drinking': 2,
    'Ability Impaired, Alcohol Over .08': 2,
    'Ability Impaired, Alcohol': 2,
    'Other': 3,
    'Fatigue': 3,
    'Ability Impaired, Drugs': 2
}
load_df['DRIVCOND'] = load_df['DRIVCOND'].map(drivcond_classification)
load_df['DRIVCOND'].value_counts()
Out[34]:
DRIVCOND
3    2193
1    1757
2    1349
Name: count, dtype: int64
In [35]:
'''
1Infants and Young Children (0 to 9)
2Adolescents (10 to 19)
3Young Adults (20 to 34)
4Middle-Aged Adults (35 to 49)
5Older Adults (50 and above)
6Unknown
'''
load_df['INVAGE'].value_counts()

age_classification = {
    'unknown': 6,       # Category 6: Unknown
    '0 to 4': 1,        # Category 1: Infants and Young Children
    '5 to 9': 1,        # Category 1: Infants and Young Children
    '10 to 14': 2,      # Category 2: Adolescents
    '15 to 19': 2,      # Category 2: Adolescents
    '20 to 24': 3,      # Category 3: Young Adults
    '25 to 29': 3,      # Category 3: Young Adults
    '30 to 34': 3,      # Category 3: Young Adults
    '35 to 39': 4,      # Category 4: Middle-Aged Adults
    '40 to 44': 4,      # Category 4: Middle-Aged Adults
    '45 to 49': 4,      # Category 4: Middle-Aged Adults
    '50 to 54': 5,      # Category 5: Older Adults
    '55 to 59': 5,      # Category 5: Older Adults
    '60 to 64': 5,      # Category 5: Older Adults
    '65 to 69': 5,      # Category 5: Older Adults
    '70 to 74': 5,      # Category 5: Older Adults
    '75 to 79': 5,      # Category 5: Older Adults
    '80 to 84': 5,      # Category 5: Older Adults
    '85 to 89': 5,      # Category 5: Older Adults
    '90 to 94': 5,      # Category 5: Older Adults
    'Over 95': 5        # Category 5: Older Adults
}

# Apply classification to the DataFrame
load_df['INVAGE'] = load_df['INVAGE'].map(age_classification)
load_df['INVAGE'].value_counts()
Out[35]:
INVAGE
5    1799
3    1214
4    1048
6     994
2     200
1      44
Name: count, dtype: int64
In [36]:
'''
1: No Control (e.g., 'No Control')
2: Traffic Control Devices (e.g., 'Traffic Signal', 'Stop Sign', 'Pedestrian Crossover', etc.)
3: Other (e.g., 'Traffic Gate', 'School Guard', 'Police Control')
'''
load_df["TRAFFCTL"] = load_df["TRAFFCTL"].fillna('No Control')
load_df['TRAFFCTL'].value_counts()
traffic_control_classification = {
    'No Control': 1,
    'Traffic Signal': 2,
    'Stop Sign': 2,
    'Pedestrian Crossover': 2,
    'Traffic Controller': 2,
    'Yield Sign': 2,
    'Streetcar (Stop for)': 2,
    'Traffic Gate': 3,
    'School Guard': 3,
    'Police Control': 3
}
load_df['TRAFFCTL'] = load_df['TRAFFCTL'].map(traffic_control_classification)
load_df['TRAFFCTL'].value_counts()
Out[36]:
TRAFFCTL
2    2668
1    2627
3       4
Name: count, dtype: int64
In [37]:
'''
1: Clear (e.g., 'Clear')
2: Adverse Weather (e.g., 'Rain', 'Snow', 'Fog, Mist, Smoke, Dust', etc.)
3: Severe Weather (e.g., 'Strong wind')
'''
load_df["VISIBILITY"] = load_df["VISIBILITY"].fillna('Clear')
load_df['VISIBILITY'].value_counts()
# Define the classification
visibility_classification = {
    'Clear': 1,
    'Rain': 2,
    'Snow': 2,
    'Other': 2,
    'Fog, Mist, Smoke, Dust': 2,
    'Freezing Rain': 2,
    'Drifting Snow': 2,
    'Strong wind': 3
}

# Apply classification to the DataFrame
load_df['VISIBILITY'] = load_df['VISIBILITY'].map(visibility_classification)
load_df['VISIBILITY'].value_counts()
Out[37]:
VISIBILITY
1    4530
2     766
3       3
Name: count, dtype: int64
In [38]:
'''
1: Daylight (e.g., 'Daylight', 'Daylight, artificial')
2: Artificial Light (e.g., 'Dark, artificial', 'Dusk, artificial', 'Dawn, artificial')
3: Low Light (e.g., 'Dark', 'Dusk', 'Dawn', 'Other')
'''
load_df["LIGHT"] = load_df["LIGHT"].fillna('Other')
load_df['LIGHT'].value_counts()
light_classification = {
    'Daylight': 1,
    'Daylight, artificial': 1,
    'Dark': 3,
    'Dark, artificial': 2,
    'Dusk': 3,
    'Dusk, artificial': 2,
    'Dawn': 3,
    'Dawn, artificial': 2,
    'Other': 3
}

# Apply classification to the DataFrame
load_df['LIGHT'] = load_df['LIGHT'].map(light_classification)
load_df['LIGHT'].value_counts()
Out[38]:
LIGHT
1    3071
3    1318
2     910
Name: count, dtype: int64
In [39]:
'''
Dry (1)
Wet (2): Includes Wet and Spilled Liquid conditions.
Slushy/Other (3): Includes Slush and any other unspecified conditions.
Loose Surface (4): Includes Loose Snow, Packed Snow, and Loose Sand/Gravel.
Ice (5): Purely icy conditions.
'''
load_df["RDSFCOND"] = load_df["RDSFCOND"].fillna('Other')
load_df['RDSFCOND'].value_counts()
road_condition_classification = {
    'Dry': 1,                  # Category 1: Dry
    'Wet': 2,                  # Category 2: Wet
    'Slush': 3,                # Category 3: Slushy
    'Loose Snow': 4,           # Category 4: Loose Snow
    'Packed Snow': 4,          # Category 4: Packed Snow
    'Ice': 5,                  # Category 5: Ice
    'Loose Sand or Gravel': 4, # Category 4: Loose Sand/Gravel
    'Spilled liquid': 2,       # Category 2: Wet (Spilled Liquid)
    'Other': 3                 # Category 3: Slushy/Other
}

load_df['RDSFCOND'] = load_df['RDSFCOND'].map(road_condition_classification)
load_df['RDSFCOND'].value_counts()
Out[39]:
RDSFCOND
1    4201
2     921
3      95
4      63
5      19
Name: count, dtype: int64
In [40]:
'''
Drivers (1): Includes all types of drivers (e.g., Car Driver, Motorcycle Driver, Truck Driver).
Cyclists/Skaters (2): Includes Cyclists, Cyclist Passengers, and In-Line Skaters.
Passengers (3): Includes Car, Motorcycle, and Moped Passengers.
Pedestrians (4): Includes Pedestrians and those using Wheelchairs.
Vehicle & Property Owners (5): Includes Vehicle Owners and Other Property Owners.
Other/Special Cases (6): Includes Witnesses, Trailer Owners, and Other unspecified cases.
'''
load_df["INVTYPE"] = load_df["INVTYPE"].fillna('Other')

invtype_classification = {
    'Driver': 1,
    'Motorcycle Driver': 1,
    'Truck Driver': 1,
    'Moped Driver': 1,
    'Driver - Not Hit': 1,
    'Cyclist': 2,
    'In-Line Skater': 2,
    'Passenger': 3,
    'Motorcycle Passenger': 3,
    'Pedestrian': 4,
    'Wheelchair': 4,
    'Vehicle Owner': 5,
    'Other Property Owner': 5,
    'Other': 6
}
# Apply classification to the DataFrame
load_df['INVTYPE'] = load_df['INVTYPE'].map(invtype_classification)
load_df = load_df.dropna(subset=['INVTYPE'])

load_df['INVTYPE'].value_counts()
Out[40]:
INVTYPE
1    3253
5     759
4     676
3     463
2     118
6      30
Name: count, dtype: int64
In [41]:
'''
1: Collisions Involving Vulnerable Road Users (e.g., 'Pedestrian Collisions', 'Cyclist Collisions')
2: Vehicle-to-Vehicle Collisions (e.g., 'Turning Movement', 'Rear End', 'Angle', 'Sideswipe', 'Approaching')
3: Other (e.g., 'SMV Other', 'Other', 'SMV Unattended Vehicle')
'''
load_df["IMPACTYPE"] = load_df["IMPACTYPE"].fillna('Other')
load_df['IMPACTYPE'].value_counts()
impact_type_classification = {
    'Pedestrian Collisions': 1,
    'Cyclist Collisions': 1,
    'Turning Movement': 2,
    'Rear End': 2,
    'SMV Other': 2,
    'Angle': 2,
    'Approaching': 2,
    'Sideswipe': 2,
    'Other': 3,
    'SMV Unattended Vehicle': 3
}

# Apply classification to the DataFrame
load_df['IMPACTYPE'] = load_df['IMPACTYPE'].map(impact_type_classification)
load_df['IMPACTYPE'].value_counts()
Out[41]:
IMPACTYPE
1    2970
2    2210
3     119
Name: count, dtype: int64
In [42]:
load_df["ACCLASS"] = load_df["ACCLASS"].fillna('Non-Fatal Injury')
load_df['ACCLASS'].value_counts()
load_df["ACCLASS"] = (
    load_df["ACCLASS"].map(
        {"Non-Fatal Injury": 0,
         "Fatal": 1,
         "Property Damage O": 0
        }
    )
)
load_df["ACCLASS"].value_counts()
Out[42]:
ACCLASS
0    4325
1     974
Name: count, dtype: int64

map¶

In [43]:
def mapToronto(data_full):
    # Coordinates for Toronto, Canada
    llcrnrlat = 43.581024  # Lower left corner latitude
    urcrnrlat = 43.855457   # Upper right corner latitude
    llcrnrlon = -79.639219  # Lower left corner longitude
    urcrnrlon = -79.115218  # Upper right corner longitude

    # Initialize the Basemap
    m = Basemap(projection='merc', llcrnrlat=llcrnrlat, urcrnrlat=urcrnrlat,
                llcrnrlon=llcrnrlon, urcrnrlon=urcrnrlon, resolution='i')

    # Draw map details
    m.drawcountries()
    m.drawparallels(np.arange(-90, 91., 2.), labels=[1,0,0,0])
    m.drawmeridians(np.arange(-180, 181., 2.), labels=[0,0,0,1])

    # Extract data from dataframe
    lat = data_full['LATITUDE'].values
    lon = data_full['LONGITUDE'].values
    a_1 = data_full['ACCLASS'].values

    # Plot data
    m.scatter(lon, lat, latlon=True, c=a_1, s=50, linewidth=1, edgecolors='red', cmap='hot', alpha=1)

    # Add color bar
    cbar = m.colorbar()
    cbar.set_label('Fatality Count')

    # Add title
    plt.title("Toronto, Canada Fatalities", fontsize=30)
    plt.show()

# Set the style and size of the plot
sns.set(style="white", font_scale=1.5)
plt.figure(figsize=(20,20))

# Call the function to plot the map
mapToronto(load_df)
No description has been provided for this image
In [44]:
##Feature selection:delete alcohol
In [45]:
new_df = load_df[['TRAFFCTL', 'VISIBILITY', 'LIGHT', 'RDSFCOND','DRIVCOND', 'ACCLASS', 'IMPACTYPE', 'INVTYPE', 'INVAGE', 'VEHTYPE']]
print(new_df)
      TRAFFCTL  VISIBILITY  LIGHT  RDSFCOND  DRIVCOND  ACCLASS  IMPACTYPE  \
0            1           1      3         2         2        1          2   
1            2           1      3         1         3        1          1   
2            2           1      1         1         3        1          1   
3            1           1      1         2         3        1          2   
4            1           1      3         1         3        1          1   
...        ...         ...    ...       ...       ...      ...        ...   
5294         1           2      2         2         1        0          2   
5295         2           1      1         1         1        0          1   
5296         2           1      2         1         2        0          2   
5297         2           2      2         2         1        0          1   
5298         1           2      3         2         1        0          1   

      INVTYPE  INVAGE  VEHTYPE  
0           1       5        1  
1           4       2       10  
2           4       5       10  
3           3       2       10  
4           4       5       10  
...       ...     ...      ...  
5294        1       4        1  
5295        1       5        1  
5296        1       5        1  
5297        1       5        1  
5298        1       5        1  

[5299 rows x 10 columns]
In [46]:
sns.heatmap(
    new_df.corr(numeric_only=True),
    vmin=-1,
    vmax=1,
    cmap="coolwarm"
)
Out[46]:
<Axes: >
No description has been provided for this image

imbalanced data¶

In [47]:
new_df["ACCLASS"].value_counts(normalize=True).plot.bar()
Out[47]:
<Axes: xlabel='ACCLASS'>
No description has been provided for this image
In [48]:
X = new_df.drop(columns=['ACCLASS'])
y = new_df['ACCLASS']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

Apply SMOTETomek to the training data¶

In [49]:
smote_tomek = SMOTETomek(random_state=42)
X_train_resampled, y_train_resampled = smote_tomek.fit_resample(X_train, y_train)

# Show SMOTETomek result

# Plot class distribution before SMOTETomek
plt.figure(figsize=(10, 5))
sns.countplot(x=y_train)
plt.title('Class Distribution Before SMOTETomek')
plt.xlabel('Class')
plt.ylabel('Count')
plt.show()

# Plot class distribution after SMOTETomek
plt.figure(figsize=(10, 5))
sns.countplot(x=y_train_resampled)
plt.title('Class Distribution After SMOTETomek')
plt.xlabel('Class')
plt.ylabel('Count')
plt.show()
No description has been provided for this image
No description has been provided for this image

Scale the features¶

In [50]:
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_resampled)
X_test_scaled = scaler.transform(X_test)

Function to train and evaluate models¶

In [51]:
def train_and_evaluate(model, X_train, X_test, y_train, y_test, model_name):
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    
    # Calculate probabilities for ROC AUC
    if hasattr(model, "predict_proba"):
        y_pred_proba = model.predict_proba(X_test)[:, 1]
    else:
        y_pred_proba = model.decision_function(X_test)
    
    fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
    auc_score = roc_auc_score(y_test, y_pred_proba)
    
    print(f"Results for {model_name}:")
    print("Accuracy:", accuracy_score(y_test, y_pred))
    print("Classification Report:\n", classification_report(y_test, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
    print("AUC Score:", auc_score)
    print("\n" + "="*60 + "\n")
    
    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC curve (area = {auc_score:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic - {model_name}')
    plt.legend(loc="lower right")
    plt.show()

Define models¶

In [52]:
models = {
    "Logistic Regression": LogisticRegression(max_iter=1000),
    "Decision Tree": DecisionTreeClassifier(),
    "Random Forest": RandomForestClassifier(n_estimators=100),
    "Support Vector Machine": SVC(probability=True),
    "Neural Network": MLPClassifier(hidden_layer_sizes=(100,), max_iter=300),
    "XGBoost": xgb.XGBClassifier(use_label_encoder=False, eval_metric='mlogloss')
}

Train and evaluate each model¶

In [53]:
for name, model in models.items():
    if name in ["Logistic Regression", "Support Vector Machine", "Neural Network"]:
        train_and_evaluate(model, X_train_scaled, X_test_scaled, y_train_resampled, y_test, name)
    else:
        train_and_evaluate(model, X_train_resampled, X_test, y_train_resampled, y_test, name)
Results for Logistic Regression:
Accuracy: 0.7688679245283019
Classification Report:
               precision    recall  f1-score   support

           0       0.95      0.76      0.84       865
           1       0.43      0.81      0.56       195

    accuracy                           0.77      1060
   macro avg       0.69      0.78      0.70      1060
weighted avg       0.85      0.77      0.79      1060

Confusion Matrix:
 [[657 208]
 [ 37 158]]
AUC Score: 0.8238002074996295

============================================================

No description has been provided for this image
Results for Decision Tree:
Accuracy: 0.8188679245283019
Classification Report:
               precision    recall  f1-score   support

           0       0.95      0.82      0.88       865
           1       0.50      0.82      0.62       195

    accuracy                           0.82      1060
   macro avg       0.73      0.82      0.75      1060
weighted avg       0.87      0.82      0.83      1060

Confusion Matrix:
 [[709 156]
 [ 36 159]]
AUC Score: 0.8935111901585889

============================================================

No description has been provided for this image
Results for Random Forest:
Accuracy: 0.8235849056603773
Classification Report:
               precision    recall  f1-score   support

           0       0.95      0.83      0.88       865
           1       0.51      0.81      0.63       195

    accuracy                           0.82      1060
   macro avg       0.73      0.82      0.76      1060
weighted avg       0.87      0.82      0.84      1060

Confusion Matrix:
 [[715 150]
 [ 37 158]]
AUC Score: 0.9177708611234624

============================================================

No description has been provided for this image
Results for Support Vector Machine:
Accuracy: 0.8207547169811321
Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.81      0.88       865
           1       0.51      0.87      0.64       195

    accuracy                           0.82      1060
   macro avg       0.74      0.84      0.76      1060
weighted avg       0.88      0.82      0.84      1060

Confusion Matrix:
 [[700 165]
 [ 25 170]]
AUC Score: 0.9025018526752631

============================================================

No description has been provided for this image
Results for Neural Network:
Accuracy: 0.8330188679245283
Classification Report:
               precision    recall  f1-score   support

           0       0.96      0.83      0.89       865
           1       0.53      0.83      0.65       195

    accuracy                           0.83      1060
   macro avg       0.74      0.83      0.77      1060
weighted avg       0.88      0.83      0.85      1060

Confusion Matrix:
 [[722 143]
 [ 34 161]]
AUC Score: 0.9227538165110419

============================================================

No description has been provided for this image
Results for XGBoost:
Accuracy: 0.8292452830188679
Classification Report:
               precision    recall  f1-score   support

           0       0.96      0.83      0.89       865
           1       0.52      0.84      0.64       195

    accuracy                           0.83      1060
   macro avg       0.74      0.83      0.77      1060
weighted avg       0.88      0.83      0.84      1060

Confusion Matrix:
 [[716 149]
 [ 32 163]]
AUC Score: 0.9232666370238625

============================================================

No description has been provided for this image

Cross-validation evaluation for each model¶

In [54]:
print("\nCross-validation scores:\n")
for name, model in models.items():
    skf = StratifiedKFold(n_splits=5)
    cv_scores = cross_val_score(model, X, y, cv=skf)
    print(f"{name}: {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")
Cross-validation scores:

Logistic Regression: 0.795 (+/- 0.012)
Decision Tree: 0.870 (+/- 0.013)
Random Forest: 0.873 (+/- 0.013)
Support Vector Machine: 0.879 (+/- 0.015)
Neural Network: 0.885 (+/- 0.015)
XGBoost: 0.875 (+/- 0.012)

Final: All in one code shell: Randomized CV¶

In [55]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import xgboost as xgb
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score, RandomizedSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, roc_auc_score, ConfusionMatrixDisplay
from scipy.stats import uniform, randint
from imblearn.combine import SMOTETomek

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Apply SMOTETomek to the training data
smote_tomek = SMOTETomek(random_state=42)
X_train_resampled, y_train_resampled = smote_tomek.fit_resample(X_train, y_train)

# Show SMOTETomek result
plt.figure(figsize=(10, 5))
sns.countplot(x=y_train)
plt.title('Class Distribution Before SMOTETomek')
plt.xlabel('Class')
plt.ylabel('Count')
plt.show()

plt.figure(figsize=(10, 5))
sns.countplot(x=y_train_resampled)
plt.title('Class Distribution After SMOTETomek')
plt.xlabel('Class')
plt.ylabel('Count')
plt.show()

# Scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_resampled)
X_test_scaled = scaler.transform(X_test)

# Define models, including XGBoost
models = {
    "Logistic Regression": LogisticRegression(max_iter=1000),
    "Decision Tree": DecisionTreeClassifier(),
    "Random Forest": RandomForestClassifier(n_estimators=100),
    "Support Vector Machine": SVC(probability=True),
    "XGBoost": xgb.XGBClassifier(use_label_encoder=False, eval_metric='mlogloss')
}

# Define hyperparameter distributions
lr_params = {
    'C': uniform(loc=0.1, scale=10),
    'penalty': ['l1', 'l2']
}
dt_params = {
    'max_depth': randint(2, 20),
    'min_samples_split': randint(2, 10),
    'min_samples_leaf': randint(1, 10)
}
rf_params = {
    'n_estimators': randint(50, 200),
    'max_depth': randint(2, 20),
    'min_samples_split': randint(2, 10),
    'min_samples_leaf': randint(1, 10)
}
svm_params = {
    'C': uniform(loc=0.1, scale=10),
    'gamma': uniform(loc=0.001, scale=0.1)
}
xgb_params = {
    'n_estimators': randint(50, 200),
    'max_depth': randint(2, 10),
    'learning_rate': uniform(loc=0.01, scale=0.1)
}

# Function to train and evaluate models with Randomized Search CV
def train_and_evaluate(model, X_train, X_test, y_train, y_test, model_name, param_dist):
    # Randomized Search CV
    if model_name in models:
        rand_search = RandomizedSearchCV(model, param_dist, n_iter=50, cv=5, scoring='roc_auc', random_state=42)
        rand_search.fit(X_train, y_train)
        model = rand_search.best_estimator_
        print(f"Best {model_name} parameters: {rand_search.best_params_}")
    else:
        model.fit(X_train, y_train)
    
    y_pred = model.predict(X_test)
    
    # Calculate probabilities for ROC AUC
    if hasattr(model, "predict_proba"):
        y_pred_proba = model.predict_proba(X_test)[:, 1]
    else:
        y_pred_proba = model.decision_function(X_test)
    
    fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
    auc_score = roc_auc_score(y_test, y_pred_proba)
    
    print(f"Results for {model_name}:")
    print("Accuracy:", accuracy_score(y_test, y_pred))
    print("Classification Report:\n", classification_report(y_test, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
    print("AUC Score:", auc_score)
    print("\n" + "="*60 + "\n")
    
    # Plot ROC Curve
    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC curve (area = {auc_score:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic - {model_name}')
    plt.legend(loc="lower right")
    plt.show()
    
    # Plot Confusion Matrix
    cm = confusion_matrix(y_test, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=model.classes_)
    disp.plot(cmap=plt.cm.Blues, values_format='d')
    plt.title(f'Confusion Matrix - {model_name}')
    plt.show()

# Train and evaluate each model
for name, model in models.items():
    if name == "Logistic Regression":
        train_and_evaluate(model, X_train_scaled, X_test_scaled, y_train_resampled, y_test, name, lr_params)
    elif name == "Decision Tree":
        train_and_evaluate(model, X_train_resampled, X_test, y_train_resampled, y_test, name, dt_params)
    elif name == "Random Forest":
        train_and_evaluate(model, X_train_resampled, X_test, y_train_resampled, y_test, name, rf_params)
    elif name == "Support Vector Machine":
        train_and_evaluate(model, X_train_scaled, X_test_scaled, y_train_resampled, y_test, name, svm_params)
    elif name == "XGBoost":
        train_and_evaluate(model, X_train_resampled, X_test, y_train_resampled, y_test, name, xgb_params)

# Cross-validation evaluation for each model
print("\nCross-validation scores:\n")
for name, model in models.items():
    skf = StratifiedKFold(n_splits=5)
    cv_scores = cross_val_score(model, X, y, cv=skf)
    print(f"{name}: {cv_scores.mean():.3f} (+/- {cv_scores.std():.3f})")
No description has been provided for this image
No description has been provided for this image
Best Logistic Regression parameters: {'C': 0.10778765841014329, 'penalty': 'l2'}
Results for Logistic Regression:
Accuracy: 0.7566037735849057
Classification Report:
               precision    recall  f1-score   support

           0       0.95      0.74      0.83       865
           1       0.42      0.81      0.55       195

    accuracy                           0.76      1060
   macro avg       0.68      0.78      0.69      1060
weighted avg       0.85      0.76      0.78      1060

Confusion Matrix:
 [[644 221]
 [ 37 158]]
AUC Score: 0.8245531347265451

============================================================

D:\anaconda\Lib\site-packages\sklearn\model_selection\_validation.py:378: FitFailedWarning: 
90 fits failed out of a total of 250.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
90 fits failed with the following error:
Traceback (most recent call last):
  File "D:\anaconda\Lib\site-packages\sklearn\model_selection\_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "D:\anaconda\Lib\site-packages\sklearn\linear_model\_logistic.py", line 1162, in fit
    solver = _check_solver(self.solver, self.penalty, self.dual)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\anaconda\Lib\site-packages\sklearn\linear_model\_logistic.py", line 54, in _check_solver
    raise ValueError(
ValueError: Solver lbfgs supports only 'l2' or 'none' penalties, got l1 penalty.

  warnings.warn(some_fits_failed_message, FitFailedWarning)
D:\anaconda\Lib\site-packages\sklearn\model_selection\_search.py:952: UserWarning: One or more of the test scores are non-finite: [       nan 0.792131          nan        nan 0.79234336 0.79217958
 0.79221477 0.79233121 0.79221728 0.79297543 0.792131          nan
        nan        nan 0.79211047        nan        nan 0.79221477
 0.7922022         nan        nan 0.79220178 0.79221603 0.7922043
        nan 0.792211          nan 0.7922043  0.79235006 0.7922022
 0.79221477 0.79220178 0.7922022  0.79243677 0.79221603 0.7922022
 0.7922043  0.79218           nan        nan 0.7922022  0.7922022
 0.79220095        nan        nan        nan 0.79221854 0.7922022
        nan 0.7921222 ]
  warnings.warn(
No description has been provided for this image
No description has been provided for this image
Best Decision Tree parameters: {'max_depth': 17, 'min_samples_leaf': 7, 'min_samples_split': 8}
Results for Decision Tree:
Accuracy: 0.8264150943396227
Classification Report:
               precision    recall  f1-score   support

           0       0.96      0.83      0.89       865
           1       0.52      0.83      0.64       195

    accuracy                           0.83      1060
   macro avg       0.74      0.83      0.76      1060
weighted avg       0.88      0.83      0.84      1060

Confusion Matrix:
 [[714 151]
 [ 33 162]]
AUC Score: 0.920329035126723

============================================================

No description has been provided for this image
No description has been provided for this image
Best Random Forest parameters: {'max_depth': 15, 'min_samples_leaf': 2, 'min_samples_split': 3, 'n_estimators': 58}
Results for Random Forest:
Accuracy: 0.8320754716981132
Classification Report:
               precision    recall  f1-score   support

           0       0.95      0.83      0.89       865
           1       0.53      0.83      0.64       195

    accuracy                           0.83      1060
   macro avg       0.74      0.83      0.77      1060
weighted avg       0.88      0.83      0.84      1060

Confusion Matrix:
 [[721 144]
 [ 34 161]]
AUC Score: 0.9250037053505262

============================================================

No description has been provided for this image
No description has been provided for this image
Best Support Vector Machine parameters: {'C': 3.845401188473625, 'gamma': 0.09607143064099162}
Results for Support Vector Machine:
Accuracy: 0.8141509433962264
Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.80      0.88       865
           1       0.50      0.88      0.63       195

    accuracy                           0.81      1060
   macro avg       0.73      0.84      0.75      1060
weighted avg       0.88      0.81      0.83      1060

Confusion Matrix:
 [[692 173]
 [ 24 171]]
AUC Score: 0.8974714688009485

============================================================

No description has been provided for this image
No description has been provided for this image
Best XGBoost parameters: {'learning_rate': 0.06396921323890797, 'max_depth': 9, 'n_estimators': 173}
Results for XGBoost:
Accuracy: 0.8273584905660377
Classification Report:
               precision    recall  f1-score   support

           0       0.96      0.83      0.89       865
           1       0.52      0.83      0.64       195

    accuracy                           0.83      1060
   macro avg       0.74      0.83      0.76      1060
weighted avg       0.88      0.83      0.84      1060

Confusion Matrix:
 [[715 150]
 [ 33 162]]
AUC Score: 0.9235808507484808

============================================================

No description has been provided for this image
No description has been provided for this image
Cross-validation scores:

Logistic Regression: 0.795 (+/- 0.012)
Decision Tree: 0.869 (+/- 0.013)
Random Forest: 0.876 (+/- 0.012)
Support Vector Machine: 0.879 (+/- 0.015)
XGBoost: 0.875 (+/- 0.012)
In [58]:
import shap

explainer = shap.Explainer(xgb_model)

shap_values = explainer.shap_values(X_test.sample(50, random_state=12345))
In [59]:
shap_values
Out[59]:
array([[-1.93381105e-02, -6.62288442e-02, -2.49365702e-01,
        -1.38345852e-01,  1.97153926e-01,  1.91319436e-01,
        -7.21931934e-01, -3.89662886e+00, -2.00946021e+00],
       [-1.01129413e+00, -5.75815499e-01, -3.17831814e-01,
        -7.82127380e-01, -1.56952870e+00, -3.33567882e+00,
        -1.91243529e+00, -1.64480269e-01, -1.47690368e+00],
       [-3.51403266e-01,  6.27360195e-02, -2.15497445e-02,
         2.33477424e-03, -7.06985295e-01,  5.86404204e-02,
        -1.28654957e+00,  8.29563200e-01, -7.26504922e-01],
       [-1.02178597e+00, -1.47476268e+00, -7.49799848e-01,
        -5.49536765e-01,  4.76419717e-01, -3.06754565e+00,
        -2.19056702e+00, -9.27880049e-01, -1.25566316e+00],
       [-1.08964421e-01,  3.37765701e-02, -3.71000692e-02,
        -7.04841018e-02,  1.38014667e-02, -6.11864567e-01,
        -2.35663557e+00, -5.09058046e+00,  2.52143323e-01],
       [ 6.74568117e-02,  2.70688143e-02,  4.56348211e-01,
        -1.05761252e-01, -2.34828368e-01,  8.41898024e-01,
        -8.76760721e-01, -4.00064290e-01, -2.83433229e-01],
       [-3.28015029e-01,  1.09333448e-01,  8.92352700e-01,
         1.12779357e-01, -2.19668254e-01,  8.18482995e-01,
        -8.35878551e-01, -3.72705385e-02, -7.28291690e-01],
       [ 7.73408934e-02,  9.99717191e-02, -1.29771918e-01,
        -1.55586660e-01,  2.48188265e-02, -6.02049351e-01,
         2.21430254e+00,  9.27423835e-01,  7.14338899e-01],
       [-2.35508129e-01,  2.77866032e-02, -3.93418521e-01,
        -9.56036672e-02,  1.70172244e-01, -7.39671230e-01,
         2.38143229e+00,  7.38609076e-01,  8.41563344e-01],
       [ 3.10718834e-01,  3.15310173e-02, -3.55146468e-01,
        -1.05335794e-01,  9.44646001e-01,  9.44702566e-01,
        -8.99213552e-01, -9.91142318e-02, -5.11548102e-01],
       [-9.51097682e-02,  3.81136164e-02, -2.51929700e-01,
         1.65722221e-02,  2.11887375e-01, -9.84828174e-01,
        -3.18203539e-01,  1.13615823e+00,  3.54789734e-01],
       [-7.89452732e-01, -1.14122164e+00, -3.80998313e-01,
        -4.06694084e-01,  3.64301354e-01, -2.51861525e+00,
        -1.43628013e+00, -3.76160407e+00, -8.53261173e-01],
       [-4.62559648e-02,  3.76737826e-02, -9.40762311e-02,
        -5.38437441e-02, -7.34988078e-02,  1.46637648e-01,
        -2.65724611e+00, -5.09845924e+00,  2.13134125e-01],
       [-8.98990482e-02,  1.11257941e-01,  4.62407134e-02,
         4.46258225e-02,  7.07178712e-01,  9.51091826e-01,
        -7.94422209e-01,  6.47840381e-01, -3.02195460e-01],
       [-1.13522041e+00, -9.49010134e-01,  2.51475833e-02,
        -8.91322553e-01, -1.09084892e+00,  6.64149642e-01,
        -1.01224065e+00, -1.95666421e-02, -6.06009364e-01],
       [ 2.58693665e-01,  4.40996438e-02,  3.99653554e-01,
        -1.55484423e-01,  4.82581943e-01,  8.58417332e-01,
        -9.39561903e-01, -2.27498159e-01, -4.97091264e-01],
       [ 3.05514455e-01,  4.71714810e-02,  2.25309402e-01,
         1.22657724e-01,  5.78781784e-01,  9.59272027e-01,
        -1.04257691e+00, -2.45525539e-02, -2.44668916e-01],
       [ 7.25928620e-02,  5.28936982e-02,  1.90763354e-01,
        -1.04034469e-01,  3.05118468e-02, -5.32749951e-01,
        -2.29144049e+00, -5.48646832e+00,  3.26812178e-01],
       [-4.85669598e-02, -4.25708368e-02, -4.27143574e-01,
        -2.74695177e-02, -1.56923199e+00, -2.82138610e+00,
        -1.82957971e+00, -1.51607239e+00, -1.82401371e+00],
       [-7.86977291e-01,  2.10416913e-02, -3.64142030e-01,
        -5.99623919e-01, -7.41931200e-02,  1.46533474e-01,
        -5.72120249e-01,  1.22435011e-01,  1.10557839e-01],
       [ 3.04077893e-01,  3.04592792e-02,  7.10107625e-01,
        -1.75621390e-01, -1.11065888e+00,  5.24037719e-01,
        -1.47408879e+00, -2.61611968e-01, -9.77721393e-01],
       [ 2.74994951e-02,  5.25919124e-02, -4.75337915e-02,
        -8.74799341e-02, -1.66352272e-01,  1.32912964e-01,
        -2.73681140e+00, -5.07298803e+00,  1.23192310e-01],
       [-1.93771213e-01,  1.95712503e-02,  9.09444690e-02,
        -1.21797666e-01,  2.37964407e-01, -2.70368242e+00,
        -1.53256130e+00, -3.48969340e+00, -6.97732508e-01],
       [-4.10069585e-01, -4.03013796e-01,  7.24365264e-02,
         2.81927347e-01,  3.14463198e-01, -6.74239039e-01,
         2.88252163e+00,  7.30022550e-01,  1.42431295e+00],
       [-1.07249916e+00, -1.13756537e+00, -1.18762247e-01,
        -4.30933535e-02,  6.33755863e-01,  7.96018243e-01,
        -1.07325172e+00,  3.98037881e-01,  5.29460013e-02],
       [ 3.10718834e-01,  3.15310173e-02, -3.55146468e-01,
        -1.05335794e-01,  9.44646001e-01,  9.44702566e-01,
        -8.99213552e-01, -9.91142318e-02, -5.11548102e-01],
       [-1.59351639e-02,  1.26875058e-01, -2.26915643e-01,
        -2.92278882e-02, -7.29034066e-01,  4.05084997e-01,
        -1.15267050e+00, -4.07824516e-02,  1.67073321e+00],
       [-1.16919227e-01,  5.82024045e-02, -1.62954524e-01,
        -1.19079441e-01,  2.43242055e-01, -8.81494939e-01,
         2.84752679e+00,  2.63449073e-01,  4.03350890e-01],
       [-3.14670265e-01,  1.14170276e-02, -4.43615854e-01,
        -1.61466256e-01, -1.32154727e+00,  4.33280051e-01,
        -1.62493134e+00, -4.37575102e-01, -1.02052307e+00],
       [-2.35888943e-01,  8.52480382e-02, -3.90788078e-01,
         1.48218693e-02, -8.04140329e-01, -3.11676788e+00,
        -1.33595455e+00,  2.26336852e-01, -1.15486991e+00],
       [-6.94697201e-02,  1.34566948e-02, -2.48693392e-01,
        -9.73307714e-02,  2.06582591e-01, -3.44762087e-01,
        -2.32455134e+00, -4.94624043e+00, -2.00959131e-01],
       [ 1.51706666e-01,  6.80131391e-02, -1.90164283e-01,
        -2.33701020e-02,  4.23217416e-02,  2.54519045e-01,
         1.85870767e-01,  1.59819931e-01, -3.21910828e-02],
       [-3.19132179e-01,  6.33101240e-02, -2.49963343e-01,
         3.97904217e-02, -8.17891285e-02,  2.93164283e-01,
        -2.41855964e-01,  2.45835111e-01,  5.71214519e-02],
       [-1.42147943e-01,  9.40644741e-02, -4.91657436e-01,
        -1.87025324e-01, -1.53674548e-02, -4.11945701e-01,
         2.33720660e+00, -1.84624746e-01,  4.14770633e-01],
       [ 2.74994951e-02,  5.25919124e-02, -4.75337915e-02,
        -8.74799341e-02, -1.66352272e-01,  1.32912964e-01,
        -2.73681140e+00, -5.07298803e+00,  1.23192310e-01],
       [ 7.73408934e-02,  9.99717191e-02, -1.29771918e-01,
        -1.55586660e-01,  2.48188265e-02, -6.02049351e-01,
         2.21430254e+00,  9.27423835e-01,  7.14338899e-01],
       [-5.17478138e-02,  1.00478930e-02,  2.24559888e-01,
        -2.14167219e-02,  3.48142356e-01, -2.78471422e+00,
        -2.01706696e+00, -2.63107419e-01, -2.47578955e+00],
       [-8.99411663e-02,  3.87514420e-02, -6.75466210e-02,
        -4.90424484e-02, -6.49665892e-02,  1.34221569e-01,
        -2.80208254e+00, -5.16013336e+00,  9.94109511e-02],
       [-1.95364729e-02,  1.04852552e-02, -1.64784491e-01,
        -1.00198261e-01, -5.47570549e-02,  6.33303598e-02,
        -2.61272931e+00, -5.31503296e+00, -3.35513987e-02],
       [ 9.66485590e-02,  3.40564027e-02, -3.73806417e-01,
        -1.83075756e-01,  1.40359879e-01, -6.58663094e-01,
         2.55919194e+00,  6.32796943e-01,  7.43776858e-01],
       [-3.40972871e-01,  1.50461972e-01,  3.91101241e-01,
        -1.19617078e-02, -1.03763819e+00, -6.90448403e-01,
        -1.55245662e+00, -1.04290509e+00,  8.81740332e-01],
       [-8.99411663e-02,  3.87514420e-02, -6.75466210e-02,
        -4.90424484e-02, -6.49665892e-02,  1.34221569e-01,
        -2.80208254e+00, -5.16013336e+00,  9.94109511e-02],
       [-3.41409564e-01,  5.56582119e-03, -3.45922589e-01,
         4.66416776e-03, -1.39189041e+00, -3.02620101e+00,
        -1.75788784e+00, -9.71480608e-01, -1.72892499e+00],
       [-3.94320637e-01,  6.14018850e-02,  6.32896602e-01,
         2.18729861e-02,  7.87348330e-01,  9.17335033e-01,
        -9.28744018e-01,  3.17282200e-01, -3.49805593e-01],
       [ 7.73408934e-02,  9.99717191e-02, -1.29771918e-01,
        -1.55586660e-01,  2.48188265e-02, -6.02049351e-01,
         2.21430254e+00,  9.27423835e-01,  7.14338899e-01],
       [ 2.93549001e-01,  1.04285046e-01,  3.14950854e-01,
        -1.12519681e-01,  6.20082855e-01, -1.64532948e+00,
        -1.40471959e+00,  6.85630023e-01, -7.29807913e-01],
       [ 1.89440057e-01,  7.47012272e-02, -2.02477023e-01,
         4.91003655e-02,  3.18141371e-01, -3.66354656e+00,
        -2.30577159e+00, -3.63638878e-01, -1.27347314e+00],
       [ 4.36498821e-02,  8.58186558e-02,  2.13198364e-02,
        -1.09476358e-01,  1.95139512e-01, -5.55412769e-01,
         3.43275619e+00,  4.67949331e-01, -1.04716980e+00],
       [-8.36332738e-01, -7.53389418e-01, -4.51529622e-01,
        -5.14398634e-01,  5.66261970e-02, -2.61760235e+00,
        -2.20572138e+00, -4.91789162e-01, -1.64379048e+00],
       [ 7.73408934e-02,  9.99717191e-02, -1.29771918e-01,
        -1.55586660e-01,  2.48188265e-02, -6.02049351e-01,
         2.21430254e+00,  9.27423835e-01,  7.14338899e-01]], dtype=float32)
In [65]:
# Create the explainer
explainer = shap.TreeExplainer(xgb_model)
# Generate SHAP values (this returns an Explanation object)
shap_values = explainer(X_test)
# Now you can use the beeswarm plot
shap.plots.beeswarm(shap_values)
No description has been provided for this image
In [66]:
X_test.sample(50, random_state=12345).iloc[1]
Out[66]:
TRAFFCTL      2
VISIBILITY    2
LIGHT         3
RDSFCOND      2
DRIVCOND      1
IMPACTYPE     1
INVTYPE       1
INVAGE        3
VEHTYPE       1
Name: 1758, dtype: int64
In [88]:
shap.plots.waterfall(shap_values[1])
No description has been provided for this image

Soft Voting Classifier¶

In [69]:
from sklearn.ensemble import VotingClassifier

# Create and train Soft Voting Classifier with Decision Tree included
soft_voting_model = VotingClassifier(estimators=[
    ('Random Forest', RandomForestClassifier(
        n_estimators=63,
        max_depth=13,
        min_samples_leaf=2,
        min_samples_split=3
    )),
    ('XGBoost', xgb.XGBClassifier(
        use_label_encoder=False,
        eval_metric='mlogloss',
        learning_rate=0.035877998160001694,
        max_depth=9,
        n_estimators=181
    )),
    ('Decision Tree', DecisionTreeClassifier(
        max_depth=17,
        min_samples_leaf=7,
        min_samples_split=8
    ))
], voting='soft')

# Train the model
soft_voting_model.fit(X_train_resampled, y_train_resampled)

# Evaluate Soft Voting Classifier
print("Soft Voting Classifier Results:")
train_and_evaluate(soft_voting_model, X_train_resampled, X_test, y_train_resampled, y_test, "Soft Voting Classifier", {})
Soft Voting Classifier Results:
Results for Soft Voting Classifier:
Accuracy: 0.8301886792452831
Classification Report:
               precision    recall  f1-score   support

           0       0.96      0.83      0.89       865
           1       0.52      0.84      0.65       195

    accuracy                           0.83      1060
   macro avg       0.74      0.83      0.77      1060
weighted avg       0.88      0.83      0.84      1060

Confusion Matrix:
 [[716 149]
 [ 31 164]]
AUC Score: 0.9259819178894324

============================================================

No description has been provided for this image
No description has been provided for this image

only numeric_transformer, no need categorical_transformer¶

In [70]:
X.dtypes
Out[70]:
TRAFFCTL      int64
VISIBILITY    int64
LIGHT         int64
RDSFCOND      int64
DRIVCOND      int64
IMPACTYPE     int64
INVTYPE       int64
INVAGE        int64
VEHTYPE       int64
dtype: object
In [89]:
# from sklearn.impute import SimpleImputer
# numeric_transformer = Pipeline([
#     ("imputer", SimpleImputer(strategy="mean")),
#     ("scaler", StandardScaler()) 
# ])
from sklearn.impute import SimpleImputer
numeric_transformer = Pipeline([
    ("imputer", SimpleImputer(strategy="mean"))
])

Model deployment: pipeline and pickle¶

In [90]:
pipeline = Pipeline(steps=[
    ('preprocessor', numeric_transformer),
    ('voting', VotingClassifier(estimators=[
        ('Random Forest', RandomForestClassifier(
            n_estimators=63,
            max_depth=13,
            min_samples_leaf=2,
            min_samples_split=3
        )),
        ('XGBoost', xgb.XGBClassifier(
            use_label_encoder=False,
            eval_metric='mlogloss',
            learning_rate=0.035877998160001694,
            max_depth=9,
            n_estimators=181
        )),
        ('Decision Tree', DecisionTreeClassifier(
        max_depth=17,
        min_samples_leaf=7,
        min_samples_split=8
        ))
    ], voting='soft'))
])
In [91]:
pipeline.fit(X_train_resampled, y_train_resampled)
pipeline.score(X_test, y_test)
Out[91]:
0.8292452830188679
In [92]:
import joblib

joblib.dump(pipeline, "KSI_model_pipeline_voting_without_scaler.pkl")
Out[92]:
['KSI_model_pipeline_voting_without_scaler.pkl']
In [93]:
pipeline
Out[93]:
Pipeline(steps=[('preprocessor',
                 Pipeline(steps=[('imputer', SimpleImputer())])),
                ('voting',
                 VotingClassifier(estimators=[('Random Forest',
                                               RandomForestClassifier(max_depth=13,
                                                                      min_samples_leaf=2,
                                                                      min_samples_split=3,
                                                                      n_estimators=63)),
                                              ('XGBoost',
                                               XGBClassifier(base_score=None,
                                                             booster=None,
                                                             callbacks=None,
                                                             colsample_bylevel=None,
                                                             colsample_bynode=None,
                                                             colsample...
                                                             max_cat_threshold=None,
                                                             max_cat_to_onehot=None,
                                                             max_delta_step=None,
                                                             max_depth=9,
                                                             max_leaves=None,
                                                             min_child_weight=None,
                                                             missing=nan,
                                                             monotone_constraints=None,
                                                             multi_strategy=None,
                                                             n_estimators=181,
                                                             n_jobs=None,
                                                             num_parallel_tree=None,
                                                             random_state=None, ...)),
                                              ('Decision Tree',
                                               DecisionTreeClassifier(max_depth=17,
                                                                      min_samples_leaf=7,
                                                                      min_samples_split=8))],
                                  voting='soft'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('preprocessor',
                 Pipeline(steps=[('imputer', SimpleImputer())])),
                ('voting',
                 VotingClassifier(estimators=[('Random Forest',
                                               RandomForestClassifier(max_depth=13,
                                                                      min_samples_leaf=2,
                                                                      min_samples_split=3,
                                                                      n_estimators=63)),
                                              ('XGBoost',
                                               XGBClassifier(base_score=None,
                                                             booster=None,
                                                             callbacks=None,
                                                             colsample_bylevel=None,
                                                             colsample_bynode=None,
                                                             colsample...
                                                             max_cat_threshold=None,
                                                             max_cat_to_onehot=None,
                                                             max_delta_step=None,
                                                             max_depth=9,
                                                             max_leaves=None,
                                                             min_child_weight=None,
                                                             missing=nan,
                                                             monotone_constraints=None,
                                                             multi_strategy=None,
                                                             n_estimators=181,
                                                             n_jobs=None,
                                                             num_parallel_tree=None,
                                                             random_state=None, ...)),
                                              ('Decision Tree',
                                               DecisionTreeClassifier(max_depth=17,
                                                                      min_samples_leaf=7,
                                                                      min_samples_split=8))],
                                  voting='soft'))])
Pipeline(steps=[('imputer', SimpleImputer())])
SimpleImputer()
VotingClassifier(estimators=[('Random Forest',
                              RandomForestClassifier(max_depth=13,
                                                     min_samples_leaf=2,
                                                     min_samples_split=3,
                                                     n_estimators=63)),
                             ('XGBoost',
                              XGBClassifier(base_score=None, booster=None,
                                            callbacks=None,
                                            colsample_bylevel=None,
                                            colsample_bynode=None,
                                            colsample_bytree=None, device=None,
                                            early_stopping_rounds=None,
                                            enable_categorical=False,
                                            eval_metric=...
                                            max_cat_threshold=None,
                                            max_cat_to_onehot=None,
                                            max_delta_step=None, max_depth=9,
                                            max_leaves=None,
                                            min_child_weight=None, missing=nan,
                                            monotone_constraints=None,
                                            multi_strategy=None,
                                            n_estimators=181, n_jobs=None,
                                            num_parallel_tree=None,
                                            random_state=None, ...)),
                             ('Decision Tree',
                              DecisionTreeClassifier(max_depth=17,
                                                     min_samples_leaf=7,
                                                     min_samples_split=8))],
                 voting='soft')
RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3,
                       n_estimators=63)
XGBClassifier(base_score=None, booster=None, callbacks=None,
              colsample_bylevel=None, colsample_bynode=None,
              colsample_bytree=None, device=None, early_stopping_rounds=None,
              enable_categorical=False, eval_metric='mlogloss',
              feature_types=None, gamma=None, grow_policy=None,
              importance_type=None, interaction_constraints=None,
              learning_rate=0.035877998160001694, max_bin=None,
              max_cat_threshold=None, max_cat_to_onehot=None,
              max_delta_step=None, max_depth=9, max_leaves=None,
              min_child_weight=None, missing=nan, monotone_constraints=None,
              multi_strategy=None, n_estimators=181, n_jobs=None,
              num_parallel_tree=None, random_state=None, ...)
DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8)

Scoring of the "KSI_model_pipeline.pkl"¶

In [94]:
KSI_model_pipeline = joblib.load("KSI_model_pipeline_voting_without_scaler.pkl")
In [95]:
KSI_model_pipeline
Out[95]:
Pipeline(steps=[('preprocessor',
                 Pipeline(steps=[('imputer', SimpleImputer())])),
                ('voting',
                 VotingClassifier(estimators=[('Random Forest',
                                               RandomForestClassifier(max_depth=13,
                                                                      min_samples_leaf=2,
                                                                      min_samples_split=3,
                                                                      n_estimators=63)),
                                              ('XGBoost',
                                               XGBClassifier(base_score=None,
                                                             booster=None,
                                                             callbacks=None,
                                                             colsample_bylevel=None,
                                                             colsample_bynode=None,
                                                             colsample...
                                                             max_cat_threshold=None,
                                                             max_cat_to_onehot=None,
                                                             max_delta_step=None,
                                                             max_depth=9,
                                                             max_leaves=None,
                                                             min_child_weight=None,
                                                             missing=nan,
                                                             monotone_constraints=None,
                                                             multi_strategy=None,
                                                             n_estimators=181,
                                                             n_jobs=None,
                                                             num_parallel_tree=None,
                                                             random_state=None, ...)),
                                              ('Decision Tree',
                                               DecisionTreeClassifier(max_depth=17,
                                                                      min_samples_leaf=7,
                                                                      min_samples_split=8))],
                                  voting='soft'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('preprocessor',
                 Pipeline(steps=[('imputer', SimpleImputer())])),
                ('voting',
                 VotingClassifier(estimators=[('Random Forest',
                                               RandomForestClassifier(max_depth=13,
                                                                      min_samples_leaf=2,
                                                                      min_samples_split=3,
                                                                      n_estimators=63)),
                                              ('XGBoost',
                                               XGBClassifier(base_score=None,
                                                             booster=None,
                                                             callbacks=None,
                                                             colsample_bylevel=None,
                                                             colsample_bynode=None,
                                                             colsample...
                                                             max_cat_threshold=None,
                                                             max_cat_to_onehot=None,
                                                             max_delta_step=None,
                                                             max_depth=9,
                                                             max_leaves=None,
                                                             min_child_weight=None,
                                                             missing=nan,
                                                             monotone_constraints=None,
                                                             multi_strategy=None,
                                                             n_estimators=181,
                                                             n_jobs=None,
                                                             num_parallel_tree=None,
                                                             random_state=None, ...)),
                                              ('Decision Tree',
                                               DecisionTreeClassifier(max_depth=17,
                                                                      min_samples_leaf=7,
                                                                      min_samples_split=8))],
                                  voting='soft'))])
Pipeline(steps=[('imputer', SimpleImputer())])
SimpleImputer()
VotingClassifier(estimators=[('Random Forest',
                              RandomForestClassifier(max_depth=13,
                                                     min_samples_leaf=2,
                                                     min_samples_split=3,
                                                     n_estimators=63)),
                             ('XGBoost',
                              XGBClassifier(base_score=None, booster=None,
                                            callbacks=None,
                                            colsample_bylevel=None,
                                            colsample_bynode=None,
                                            colsample_bytree=None, device=None,
                                            early_stopping_rounds=None,
                                            enable_categorical=False,
                                            eval_metric=...
                                            max_cat_threshold=None,
                                            max_cat_to_onehot=None,
                                            max_delta_step=None, max_depth=9,
                                            max_leaves=None,
                                            min_child_weight=None, missing=nan,
                                            monotone_constraints=None,
                                            multi_strategy=None,
                                            n_estimators=181, n_jobs=None,
                                            num_parallel_tree=None,
                                            random_state=None, ...)),
                             ('Decision Tree',
                              DecisionTreeClassifier(max_depth=17,
                                                     min_samples_leaf=7,
                                                     min_samples_split=8))],
                 voting='soft')
RandomForestClassifier(max_depth=13, min_samples_leaf=2, min_samples_split=3,
                       n_estimators=63)
XGBClassifier(base_score=None, booster=None, callbacks=None,
              colsample_bylevel=None, colsample_bynode=None,
              colsample_bytree=None, device=None, early_stopping_rounds=None,
              enable_categorical=False, eval_metric='mlogloss',
              feature_types=None, gamma=None, grow_policy=None,
              importance_type=None, interaction_constraints=None,
              learning_rate=0.035877998160001694, max_bin=None,
              max_cat_threshold=None, max_cat_to_onehot=None,
              max_delta_step=None, max_depth=9, max_leaves=None,
              min_child_weight=None, missing=nan, monotone_constraints=None,
              multi_strategy=None, n_estimators=181, n_jobs=None,
              num_parallel_tree=None, random_state=None, ...)
DecisionTreeClassifier(max_depth=17, min_samples_leaf=7, min_samples_split=8)
In [96]:
#'TRAFFCTL', 'VISIBILITY', 'LIGHT', 'RDSFCOND', 'ACCLASS', 'IMPACTYPE', 'INVTYPE', 'INVAGE', 'VEHTYPE', 'ALCOHOL'
KSI_features = new_df[['TRAFFCTL', 'VISIBILITY', 'LIGHT', 'RDSFCOND', 'DRIVCOND','IMPACTYPE', 'INVTYPE', 'INVAGE', 'VEHTYPE']]
In [97]:
# KSI_model_pipeline.predict_proba(KSI_features)[:5]
In [98]:
ksi_to_score = pd.DataFrame({
    "TRAFFCTL": [1, 1, 2],
    "VISIBILITY": [1, 2, 1],
    "LIGHT": [3, 2, 3],
    "RDSFCOND": [2, 2, 1],
    'DRIVCOND': [2, 2, 1],
    "IMPACTYPE": [2, 2, 1],
    "INVTYPE": [1, 1, 2],
    "INVAGE": [2, 2, 1],
    "VEHTYPE": [1, 1, 10]
})

ksi_to_score.head()
Out[98]:
TRAFFCTL VISIBILITY LIGHT RDSFCOND DRIVCOND IMPACTYPE INVTYPE INVAGE VEHTYPE
0 1 1 3 2 2 2 1 2 1
1 1 2 2 2 2 2 1 2 1
2 2 1 3 1 1 1 2 1 10
In [99]:
pd.DataFrame({
    "predicted_prob": KSI_model_pipeline.predict_proba(ksi_to_score)[:, 1]
})
Out[99]:
predicted_prob
0 0.705179
1 0.931160
2 0.100457

ML Model Scoring App¶

In [7]:
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
import numpy as np
import pandas as pd
import joblib

### Setup dash app
app = dash.Dash(__name__)
app.title = 'Machine Learning Model Deployment'
server = app.server

### load ML pipeline (or model)
model_pipeline = joblib.load("KSI_model_pipeline_voting_without_scaler.pkl")

### App Layout 
app.layout = html.Div([
    dbc.Row([html.H3(children='Predict fatality in incidents')]),
    
    dbc.Row([
        dbc.Col(html.Label(children='Traffic Control Type:'), width={"order": "first"}),
        dbc.Col(dcc.Dropdown(
            id='TRAFFCTL',
            options=[
                {'label': 'No Control', 'value': 1},
                {'label': 'Traffic Control Devices', 'value': 2},
                {'label': 'Other', 'value': 3}
            ],
            value=1
        )) 
    ]),
    
    dbc.Row([
        dbc.Col(html.Label(children='Vehicle Type:'), width={"order": "first"}),
        dbc.Col(dcc.Dropdown(
            id='VEHTYPE',
            options=[
                {'label': 'Small Vehicles', 'value': 1},
                {'label': 'Trucks and Vans', 'value': 2},
                {'label': 'Public Transit', 'value': 3},
                {'label': 'Emergency and Unknown', 'value': 4},
                {'label': 'Special Equipment', 'value': 5},
                {'label': 'Off-Road', 'value': 6},
                {'label': 'Bicycles and Mopeds', 'value': 7},
                {'label': 'Motorcycles', 'value': 8},
                {'label': 'Rickshaws', 'value': 9},
                {'label': 'Others', 'value': 10}
            ],
            value=1
        ))
    ]),

    dbc.Row([
        dbc.Col(html.Label(children='Driver Condition:'), width={"order": "first"}),
        dbc.Col(dcc.Dropdown(
            id='DRIVCOND',
            options=[
                {'label': 'Normal', 'value': 1},
                {'label': 'Impaired', 'value': 2},
                {'label': 'Other', 'value': 3}
            ],
            value=1
        ))
    ]),

    dbc.Row([
        dbc.Col(html.Label(children='Involved Person Age Group:'), width={"order": "first"}),
        dbc.Col(dcc.Dropdown(
            id='INVAGE',
            options=[
                {'label': 'Infants and Young Children (0 to 9)', 'value': 1},
                {'label': 'Adolescents (10 to 19)', 'value': 2},
                {'label': 'Young Adults (20 to 34)', 'value': 3},
                {'label': 'Middle-Aged Adults (35 to 49)', 'value': 4},
                {'label': 'Older Adults (50 and above)', 'value': 5},
                {'label': 'Unknown', 'value': 6}
            ],
            value=1
        ))
    ]),

    dbc.Row([
        dbc.Col(html.Label(children='Visibility:'), width={"order": "first"}),
        dbc.Col(dcc.Dropdown(
            id='VISIBILITY',
            options=[
                {'label': 'Clear', 'value': 1},
                {'label': 'Adverse Weather', 'value': 2},
                {'label': 'Severe Weather', 'value': 3}
            ],
            value=1
        ))
    ]),

    dbc.Row([
        dbc.Col(html.Label(children='Light Condition:'), width={"order": "first"}),
        dbc.Col(dcc.Dropdown(
            id='LIGHT',
            options=[
                {'label': 'Daylight', 'value': 1},
                {'label': 'Artificial Light', 'value': 2},
                {'label': 'Low Light', 'value': 3}
            ],
            value=1
        ))
    ]),

    dbc.Row([
        dbc.Col(html.Label(children='Road Surface Condition:'), width={"order": "first"}),
        dbc.Col(dcc.Dropdown(
            id='RDSFCOND',
            options=[
                {'label': 'Dry', 'value': 1},
                {'label': 'Wet', 'value': 2},
                {'label': 'Slushy/Other', 'value': 3},
                {'label': 'Loose Surface', 'value': 4},
                {'label': 'Ice', 'value': 5}
            ],
            value=1
        ))
    ]),

    dbc.Row([
        dbc.Col(html.Label(children='Involved Person Type:'), width={"order": "first"}),
        dbc.Col(dcc.Dropdown(
            id='INVTYPE',
            options=[
                {'label': 'Drivers', 'value': 1},
                {'label': 'Cyclists/Skaters', 'value': 2},
                {'label': 'Passengers', 'value': 3},
                {'label': 'Pedestrians', 'value': 4},
                {'label': 'Vehicle & Property Owners', 'value': 5},
                {'label': 'Other/Special Cases', 'value': 6}
            ],
            value=1
        ))
    ]),

    dbc.Row([
        dbc.Col(html.Label(children='Impact Type:'), width={"order": "first"}),
        dbc.Col(dcc.Dropdown(
            id='IMPACTYPE',
            options=[
                {'label': 'Collisions Involving Vulnerable Road Users', 'value': 1},
                {'label': 'Vehicle-to-Vehicle Collisions', 'value': 2},
                {'label': 'Other', 'value': 3}
            ],
            value=1
        ))
    ]),

    dbc.Row([dbc.Button('Submit', id='submit-val', n_clicks=0, color="primary")]),
    html.Br(),
    dbc.Row([html.Div(id='prediction output')])
    
], style={'padding': '0px 0px 0px 150px', 'width': '50%'})

### Callback to produce the model output
@app.callback(
    Output('prediction output', 'children'), 
    Input('submit-val', 'n_clicks'),
    State('TRAFFCTL', 'value'),
    State('VISIBILITY', 'value'),
    State('LIGHT', 'value'),
    State('RDSFCOND', 'value'),
    State('DRIVCOND', 'value'),
    State('IMPACTYPE', 'value'),
    State('INVTYPE', 'value'),
    State('INVAGE', 'value'),
    State('VEHTYPE', 'value')
)
def update_output(n_clicks, traffctl, visibility, light, rdsfcond, drivcond, impactype, invtype, invage, vehtype):
    if n_clicks > 0:
        # Create a DataFrame with the input values in the correct order
        x = pd.DataFrame({
            "TRAFFCTL": [traffctl], 
            "VISIBILITY": [visibility], 
            "LIGHT": [light], 
            "RDSFCOND": [rdsfcond], 
            "DRIVCOND": [drivcond], 
            "IMPACTYPE": [impactype], 
            "INVTYPE": [invtype], 
            "INVAGE": [invage], 
            "VEHTYPE": [vehtype]
        })

        # Make the prediction using the loaded model pipeline
        prediction = model_pipeline.predict_proba(x)[0]

        # Determine if the incident is fatal or nonfatal
        if prediction[1] >= 0.5:
            incident_type = 'Fatal'
        else:
            incident_type = 'Nonfatal'

        # Format the output to show both the probability and the incident type
        #
        output = [
            f'The incident is predicted to be: {incident_type}',
            html.Br(),
            f'The probability of a fatality in this incident is {round(prediction[1], 2)}'
        ]
    else:
        output = 'Please submit the form to get a prediction.'

    return output
    
### Run the App 
if __name__ == '__main__':
    app.run(debug=True, port=8001, host='127.0.0.1')
In [ ]: