Importing the Libraries¶
In [1]:
import os # For file and directory operations
import pandas as pd # For data manipulation and analysis
import kagglehub # For downloading datasets from Kaggle
import matplotlib.pyplot as plt # For plotting and visualization
import seaborn as sns # For advanced statistical data visualization
from kagglehub import KaggleDatasetAdapter # Enum for Kaggle dataset adapters
from sklearn.ensemble import RandomForestClassifier # Random Forest ML model
from sklearn.model_selection import train_test_split # For splitting data into train/test sets
from sklearn.preprocessing import LabelEncoder # For encoding categorical variables
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix # For model evaluation metrics
from sklearn.model_selection import GridSearchCV # For hyperparameter tuning with grid search
pd.set_option('display.max_columns', None) # Show all columns in DataFrame display
pd.set_option('display.width', 1000) # Set the display width for DataFrames
Loading the Data¶
In [2]:
# ======================================== Data Loading ======================================== #
file_path = "heart.csv" # File name from the Kaggle dataset
cache_file = "heart.csv" # Local cache file name
if os.path.exists(cache_file): # Check if the cache file exists
print("Loading data...")
df = pd.read_csv(cache_file) # If it exists, load the data from the cache file
else:
print("Fetching data from Kaggle...")
df = kagglehub.load_dataset(
KaggleDatasetAdapter.PANDAS,
"fedesoriano/heart-failure-prediction",
file_path,
)
if not os.path.exists(cache_file):
df.to_csv(cache_file, index=False)
print(f"Data cached to {cache_file}")
Loading data...
Preprocessing the Data¶
In [3]:
# ======================================== Data Preprocessing ======================================== #
print(df, end="\n\n")
print(df.head(), end="\n\n")
print(df.columns, end="\n\n")
print(df.dtypes, end="\n\n")
print("Descriptive Statistics:")
print(df.describe(include='all'))
duplicates = df.duplicated().sum()
print(f"\nNumber of duplicate rows: {duplicates}")
missing_values = df.isnull().sum()
print(f'\nMissing values in each column:\n{missing_values}')
Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease 0 40 M ATA 140 289 0 Normal 172 N 0.0 Up 0 1 49 F NAP 160 180 0 Normal 156 N 1.0 Flat 1 2 37 M ATA 130 283 0 ST 98 N 0.0 Up 0 3 48 F ASY 138 214 0 Normal 108 Y 1.5 Flat 1 4 54 M NAP 150 195 0 Normal 122 N 0.0 Up 0 .. ... .. ... ... ... ... ... ... ... ... ... ... 913 45 M TA 110 264 0 Normal 132 N 1.2 Flat 1 914 68 M ASY 144 193 1 Normal 141 N 3.4 Flat 1 915 57 M ASY 130 131 0 Normal 115 Y 1.2 Flat 1 916 57 F ATA 130 236 0 LVH 174 N 0.0 Flat 1 917 38 M NAP 138 175 0 Normal 173 N 0.0 Up 0 [918 rows x 12 columns] Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease 0 40 M ATA 140 289 0 Normal 172 N 0.0 Up 0 1 49 F NAP 160 180 0 Normal 156 N 1.0 Flat 1 2 37 M ATA 130 283 0 ST 98 N 0.0 Up 0 3 48 F ASY 138 214 0 Normal 108 Y 1.5 Flat 1 4 54 M NAP 150 195 0 Normal 122 N 0.0 Up 0 Index(['Age', 'Sex', 'ChestPainType', 'RestingBP', 'Cholesterol', 'FastingBS', 'RestingECG', 'MaxHR', 'ExerciseAngina', 'Oldpeak', 'ST_Slope', 'HeartDisease'], dtype='object') Age int64 Sex object ChestPainType object RestingBP int64 Cholesterol int64 FastingBS int64 RestingECG object MaxHR int64 ExerciseAngina object Oldpeak float64 ST_Slope object HeartDisease int64 dtype: object Descriptive Statistics: Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease count 918.000000 918 918 918.000000 918.000000 918.000000 918 918.000000 918 918.000000 918 918.000000 unique NaN 2 4 NaN NaN NaN 3 NaN 2 NaN 3 NaN top NaN M ASY NaN NaN NaN Normal NaN N NaN Flat NaN freq NaN 725 496 NaN NaN NaN 552 NaN 547 NaN 460 NaN mean 53.510893 NaN NaN 132.396514 198.799564 0.233115 NaN 136.809368 NaN 0.887364 NaN 0.553377 std 9.432617 NaN NaN 18.514154 109.384145 0.423046 NaN 25.460334 NaN 1.066570 NaN 0.497414 min 28.000000 NaN NaN 0.000000 0.000000 0.000000 NaN 60.000000 NaN -2.600000 NaN 0.000000 25% 47.000000 NaN NaN 120.000000 173.250000 0.000000 NaN 120.000000 NaN 0.000000 NaN 0.000000 50% 54.000000 NaN NaN 130.000000 223.000000 0.000000 NaN 138.000000 NaN 0.600000 NaN 1.000000 75% 60.000000 NaN NaN 140.000000 267.000000 0.000000 NaN 156.000000 NaN 1.500000 NaN 1.000000 max 77.000000 NaN NaN 200.000000 603.000000 1.000000 NaN 202.000000 NaN 6.200000 NaN 1.000000 Number of duplicate rows: 0 Missing values in each column: Age 0 Sex 0 ChestPainType 0 RestingBP 0 Cholesterol 0 FastingBS 0 RestingECG 0 MaxHR 0 ExerciseAngina 0 Oldpeak 0 ST_Slope 0 HeartDisease 0 dtype: int64
Cleaning the Data¶
In [4]:
# ========================================= Data Cleaning ======================================== #
import matplotlib.pyplot as plt
import seaborn as sns
# Visualize distributions before changes
zero_value_cols = ['RestingBP', 'Cholesterol'] # Handle zero values in restingbp and cholesterol
# Visualize distributions before changes
plt.figure(figsize=(12, 5))
for i, col in enumerate(zero_value_cols):
if col in df.columns:
plt.subplot(1, len(zero_value_cols), i + 1)
sns.histplot(df[col], bins=30, kde=True, color='skyblue', edgecolor='black')
plt.title(f"{col} (Before Zero Replacement)")
plt.xlabel(col)
plt.ylabel("Count")
plt.tight_layout()
plt.show()
for col in zero_value_cols:
if col in df.columns:
zero_count = (df[col] == 0).sum()
print(f"Number of zero values in {col}: {zero_count}")
# Replace zeros with the median of the column
if zero_count > 0:
median_value = df.loc[df[col] != 0, col].median()
df[col] = df[col].replace(0, median_value)
print(f"Replaced zero values in {col} with median value: {median_value}")
# Visualize distributions after changes
plt.figure(figsize=(12, 5))
for i, col in enumerate(zero_value_cols):
if col in df.columns:
plt.subplot(1, len(zero_value_cols), i + 1)
sns.histplot(df[col], bins=30, kde=True, color='salmon', edgecolor='black')
plt.title(f"{col} (After Zero Replacement)")
plt.xlabel(col)
plt.ylabel("Count")
plt.tight_layout()
plt.show()
# Summary of changes
print("\nSummary of zero-value handling:")
for col in zero_value_cols:
if col in df.columns:
print(f"{col}: Zero values replaced with median.")
Number of zero values in RestingBP: 1 Replaced zero values in RestingBP with median value: 130.0 Number of zero values in Cholesterol: 172 Replaced zero values in Cholesterol with median value: 237.0
Summary of zero-value handling: RestingBP: Zero values replaced with median. Cholesterol: Zero values replaced with median.
Visualizing the Data¶
In [5]:
# ======================================== Data Visualization ======================================== #
# Identify continuous and binary categorical columns
continuous_cols = df.select_dtypes(include=['float64', 'int64']).columns
# Binary categorical columns: only 0 or 1 values
binary_cols = [col for col in continuous_cols if set(df[col].dropna().unique()).issubset({0, 1})]
# Continuous columns: exclude binary columns
cont_cols = [col for col in continuous_cols if col not in binary_cols]
# Plot continuous columns
num_cont = len(cont_cols)
fig1, axes1 = plt.subplots(nrows=(num_cont + 3) // 3, ncols=3, figsize=(15, 3 * ((num_cont + 3) // 3)))
axes1 = axes1.flatten()
palette = sns.color_palette("mako", num_cont)
for i, col in enumerate(cont_cols):
ax = axes1[i]
hist = sns.histplot(df[col], kde=True, ax=ax, color=palette[i % len(palette)], edgecolor='black')
ax.set_title(col)
# Add count labels
for patch in hist.patches:
height = patch.get_height()
if height > 0:
ax.annotate(f'{int(height)}',
(patch.get_x() + patch.get_width() / 2, height),
ha='center', va='bottom', fontsize=7, color='black', rotation=0)
for j in range(i + 1, len(axes1)):
fig1.delaxes(axes1[j])
plt.tight_layout()
sns.despine()
plt.show()
# Plot binary categorical columns as countplots
num_bin = len(binary_cols)
if num_bin > 0:
fig2, axes2 = plt.subplots(nrows=(num_bin + 2) // 2, ncols=2, figsize=(12, 2 * ((num_bin + 2) // 2)))
axes2 = axes2.flatten()
bin_palette = sns.color_palette("mako", num_bin)
for i, col in enumerate(binary_cols):
cp = sns.countplot(x=col, data=df, ax=axes2[i], hue=col, palette=bin_palette, legend=False)
axes2[i].set_title(col)
axes2[i].set_xticks([0, 1])
# Add count labels inside bars and bold
for p in cp.patches:
height = p.get_height()
if height > 0:
cp.annotate(
f'{int(height)}',
(p.get_x() + p.get_width() / 2, height / 2),
ha='center', va='center', fontsize=8, color='white', fontweight='bold'
)
for j in range(i + 1, len(axes2)):
fig2.delaxes(axes2[j])
plt.tight_layout()
sns.despine()
plt.show()
else:
print("No binary categorical columns found.")
# Plot non-binary categorical columns as countplots
cat_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
# Also include int columns with few unique values (excluding binary)
cat_cols += [col for col in df.select_dtypes(include=['int64', 'float64']).columns
if 2 < df[col].nunique() <= 10 and col not in binary_cols]
cat_cols = list(set(cat_cols)) # Remove duplicates
num_cat = len(cat_cols)
if num_cat > 0:
fig3, axes3 = plt.subplots(nrows=(num_cat + 3) // 3, ncols=3, figsize=(15, 3 * ((num_cat + 3) // 3)))
axes3 = axes3.flatten()
cat_palette = sns.color_palette("mako", num_cat)
for i, col in enumerate(cat_cols):
unique_vals = df[col].nunique()
palette_for_col = sns.color_palette("mako", unique_vals)
cp = sns.countplot(x=col, data=df, ax=axes3[i], hue=col, palette=palette_for_col, legend=False)
axes3[i].set_title(col)
# Add count labels inside bars and bold
for p in cp.patches:
height = p.get_height()
if height > 0:
cp.annotate(
f'{int(height)}',
(p.get_x() + p.get_width() / 2, height / 2),
ha='center', va='center', fontsize=8, color='white', fontweight='bold'
)
for j in range(i + 1, len(axes3)):
fig3.delaxes(axes3[j])
plt.tight_layout()
sns.despine()
plt.show()
else:
print("No non-binary categorical columns found.")
plt.figure(figsize=(12, 8))
corr = df.corr(numeric_only=True)
sns.heatmap(corr, annot=True, fmt=".2f", cmap="mako", square=True, linewidths=0.5)
plt.title("Correlation Matrix Heatmap")
plt.tight_layout()
plt.show()
Feature Importance with Machine Learning¶
In [6]:
# ======================================== Feature Importance with Machine Learning ======================================== #
# Assume the target column is named 'HeartDisease' or similar; adjust as needed
target_col = None
for col in df.columns:
if col.lower() in ['heartdisease']:
target_col = col
break
if target_col is None:
raise ValueError("Could not automatically detect the target column. Please set 'target_col' manually.")
# Encode categorical features if any
df_encoded = df.copy()
for col in df_encoded.select_dtypes(include=['object', 'category']).columns:
df_encoded[col] = LabelEncoder().fit_transform(df_encoded[col])
X = df_encoded.drop(target_col, axis=1)
y = df_encoded[target_col]
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train Random Forest
rf = RandomForestClassifier(n_estimators=200, random_state=42)
rf.fit(X_train, y_train)
# Feature importances
importances = rf.feature_importances_
feat_importance = pd.Series(importances, index=X.columns).sort_values(ascending=False)
print("\nTop contributing factors to heart disease (feature importances):")
for feature, importance in feat_importance.items():
print(f"{feature}: {importance:.4f}")
# Plot feature importances with data labels inside bars
plt.figure(figsize=(10, 6))
ax = sns.barplot(x=feat_importance.values, y=feat_importance.index, hue=feat_importance.index, palette="viridis", legend=False)
plt.title("Feature Importances for Heart Failure Prediction")
plt.xlabel("Importance")
plt.ylabel("Feature")
plt.tight_layout()
# Add data labels inside bars
for p in ax.patches:
width = p.get_width()
if width > 0:
ax.annotate(f'{width:.2f}',
(width / 2, p.get_y() + p.get_height() / 2),
ha='center', va='center', fontsize=8, color='white', fontweight='bold')
plt.show()
print("\nTop 5 contributing factors to heart disease (by feature importance):")
for feature, importance in feat_importance.head(5).items():
print(f"{feature}: {importance:.4f}")
# ======================================== Model Evaluation ======================================== #
# Predict on test set
y_pred = rf.predict(X_test)
# Accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"\nModel Accuracy on Test Set: {accuracy:.4f}")
# Classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred))
# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.tight_layout()
plt.show()
print("\nModel Summary:")
print(f"Model Type: Random Forest Classifier")
print(f"Number of Features: {X.shape[1]}")
print(f"Training Samples: {X_train.shape[0]}")
print(f"Test Samples: {X_test.shape[0]}")
print(f"Test Accuracy: {accuracy:.4f}")
print("Top 5 Features by Importance:")
for feature, importance in feat_importance.head(5).items():
print(f" {feature}: {importance:.4f}")
# Analyze the contribution of the other top 5 features to heart disease
top_features = feat_importance.head(5).index.tolist()
# Exclude 'ST_Slope' if present, since already analyzed
top_features = [f for f in top_features]
for feature in top_features:
unique_vals = df[feature].nunique()
if unique_vals <= 10 or df[feature].dtype == 'object':
rates = df.groupby(feature)[target_col].mean().sort_values(ascending=False)
most_risky = rates.idxmax()
print(f"{feature} value most associated with heart disease: {most_risky} (rate: {rates.max():.2f})")
else:
median_val = df[feature].median()
above = df[df[feature] > median_val][target_col].mean()
below = df[df[feature] <= median_val][target_col].mean()
if above > below:
print(f"{feature} > {median_val:.2f} is most associated with heart disease (rate: {above:.2f})")
else:
print(f"{feature} <= {median_val:.2f} is most associated with heart disease (rate: {below:.2f})")
Top contributing factors to heart disease (feature importances): ST_Slope: 0.2378 Oldpeak: 0.1255 MaxHR: 0.1061 ChestPainType: 0.1048 ExerciseAngina: 0.1002 Age: 0.0887 Cholesterol: 0.0803 RestingBP: 0.0696 Sex: 0.0375 RestingECG: 0.0261 FastingBS: 0.0233
Top 5 contributing factors to heart disease (by feature importance): ST_Slope: 0.2378 Oldpeak: 0.1255 MaxHR: 0.1061 ChestPainType: 0.1048 ExerciseAngina: 0.1002 Model Accuracy on Test Set: 0.8804 Classification Report: precision recall f1-score support 0 0.84 0.88 0.86 77 1 0.91 0.88 0.90 107 accuracy 0.88 184 macro avg 0.88 0.88 0.88 184 weighted avg 0.88 0.88 0.88 184
Model Summary: Model Type: Random Forest Classifier Number of Features: 11 Training Samples: 734 Test Samples: 184 Test Accuracy: 0.8804 Top 5 Features by Importance: ST_Slope: 0.2378 Oldpeak: 0.1255 MaxHR: 0.1061 ChestPainType: 0.1048 ExerciseAngina: 0.1002 ST_Slope value most associated with heart disease: Flat (rate: 0.83) Oldpeak > 0.60 is most associated with heart disease (rate: 0.77) MaxHR <= 138.00 is most associated with heart disease (rate: 0.72) ChestPainType value most associated with heart disease: ASY (rate: 0.79) ExerciseAngina value most associated with heart disease: Y (rate: 0.85)
Optimizing the Model¶
In [7]:
# ======================================== Model Optimization ======================================== #
# Define parameter grid for Random Forest
param_grid = {
'n_estimators': [100, 200, 300],
'max_depth': [None, 5, 10, 20],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4],
'bootstrap': [True, False]
}
# Initialize GridSearchCV
grid_search = GridSearchCV(
estimator=RandomForestClassifier(random_state=42),
param_grid=param_grid,
cv=5,
n_jobs=-1,
scoring='accuracy',
verbose=1
)
# Fit grid search to training data
grid_search.fit(X_train, y_train)
print(f"\nBest parameters found: {grid_search.best_params_}")
print(f"Best cross-validation accuracy: {grid_search.best_score_:.4f}")
# Evaluate the best estimator on the test set
best_rf = grid_search.best_estimator_
y_pred_best = best_rf.predict(X_test)
best_accuracy = accuracy_score(y_test, y_pred_best)
print(f"Optimized Model Accuracy on Test Set: {best_accuracy:.4f}")
print("\nOptimized Model Classification Report:")
print(classification_report(y_test, y_pred_best))
# Confusion matrix for optimized model
cm_best = confusion_matrix(y_test, y_pred_best)
plt.figure(figsize=(5, 4))
sns.heatmap(cm_best, annot=True, fmt='d', cmap='Greens', cbar=False)
plt.title("Optimized Model Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.tight_layout()
plt.show()
Fitting 5 folds for each of 216 candidates, totalling 1080 fits Best parameters found: {'bootstrap': False, 'max_depth': 10, 'min_samples_leaf': 2, 'min_samples_split': 5, 'n_estimators': 100} Best cross-validation accuracy: 0.8651 Optimized Model Accuracy on Test Set: 0.8750 Optimized Model Classification Report: precision recall f1-score support 0 0.85 0.86 0.85 77 1 0.90 0.89 0.89 107 accuracy 0.88 184 macro avg 0.87 0.87 0.87 184 weighted avg 0.88 0.88 0.88 184