Predicting Stroke Risk Using Machine Learning in R: A Step-by-Step Guide

Stroke remains one of the leading causes of death and long-term disability worldwide. Early detection of stroke risk can be life-saving and with the power of machine learning (ML), we can now analyze patient data to predict who might be at higher risk.In this article, we’ll learn,  how to use R programming to build predictive models using the Healthcare Stroke Dataset. You’ll learn how to explore the data, prepare it for modeling, and compare different algorithms like Logistic Regression, Decision Tree, and Random Forest.


Task One: Import Data and Preprocessing


Before building any model, it’s essential to load, clean, and explore the data.


Step 1. Load Required Packages


We’ll use several powerful R packages like tidyverse for data wrangling, caret for model training, and randomForest for building ensemble models, among others.


# Load Required Libraries
library(tidyverse)
library(skimr)
library(DataExplorer)
library(GGally)
library(caret)
library(tidymodels)
library(corrplot)
library(ROSE)
library(e1071)
library(rpart)
library(randomForest)
library(pROC)
library(yardstick)
library(janitor)
library(ggthemes)
library(shiny)
library(vetiver)



Step 2. Load the Dataset


After loading the data, all the missing values are removed, ensuring the data is ready for analysis .


# Load Stroke Dataset
healthcare_dataset_stroke_data <- read.csv("healthcare-dataset-stroke-data.csv")
stroke_data <- read.csv("healthcare-dataset-stroke-data.csv", na.strings = c("N/A", "", "NA"))

sum(is.na(stroke_data))         # Total number of missing values
colSums(is.na(stroke_data))     # Missing values per column

stroke_data <- na.omit(stroke_data)
sum(is.na(stroke_data))   # Should return 0

Step 3 : Describe and explore the data           

Before jumping into modeling, let’s understand what the data looks like.

library(summarytools)

# View first few rows
head(stroke_data)

# Check structure of the dataset
str(stroke_data)

# Quick summary of missing values
colSums(is.na(stroke_data))

# Display basic summary statistics
summary(stroke_data)

# Display dimensions (rows and columns)
dim(stroke_data)

# View column names
names(stroke_data)

# Quick statistical summary using summarytools
dfSummary(stroke_data)

# Generate a basic exploratory report
create_report(stroke_data)

Task 2 : Build prediction Model

Once data is ready, its time to build and test our model

Step 1. Data preparation for modelling

Categorical variable are need to be converted into factors, we will split data into training and testing subset.


R Code: Stroke Data Preparation

library(caret)
# Check and convert categorical variables to factors
stroke_data$gender <- as.factor(stroke_data$gender)
stroke_data$ever_married <- as.factor(stroke_data$ever_married)
stroke_data$work_type <- as.factor(stroke_data$work_type)
stroke_data$Residence_type <- as.factor(stroke_data$Residence_type)
stroke_data$smoking_status <- as.factor(stroke_data$smoking_status)
stroke_data$stroke <- as.factor(stroke_data$stroke)

# Split dataset into training (70%) and testing (30%)
set.seed(123)
trainIndex <- createDataPartition(stroke_data$stroke, p = 0.7, list = FALSE)
train_data <- stroke_data[trainIndex, ]
test_data <- stroke_data[-trainIndex, ]
    


Model 1. Logistic Regression

Logistic Regression provides a baseline model that help us to understand which variables significantly influence stroke risk.

R Code: Logistic Regression Model
Model 1: Logistic Regression (Full Code)

# 🧩 Model 1: Logistic Regression
#------------------------------------------
log_model <- glm(stroke ~ ., data = train_data, family = binomial)
summary(log_model)

# Predict on test data
log_pred <- predict(log_model, newdata = test_data, type = "response")
log_pred_class <- ifelse(log_pred > 0.5, 1, 0)

# Confusion matrix
confusionMatrix(as.factor(log_pred_class), test_data$stroke, positive = "1")
      
Simplified Logistic Regression Code

# Logistic regression
log_model <- glm(stroke ~ ., data = train_data, family = binomial)

# Predict
log_pred <- predict(log_model, newdata = test_data, type = "response")
log_pred_class <- ifelse(log_pred > 0.5, 1, 0)

# Confusion matrix
confusionMatrix(as.factor(log_pred_class), test_data$stroke, positive = "1")
      


Model 2. Random Forest

Here we train random forest model: to know the accuracy and robustness.

R Code: Random Forest Model
Model 2: Random Forest Implementation

# 🌳 Model 2: Random Forest
#------------------------------------------
rf_model <- randomForest(stroke ~ ., data = train_data, ntree = 500, importance = TRUE)
rf_pred <- predict(rf_model, newdata = test_data)

# Evaluate performance
confusionMatrix(rf_pred, test_data$stroke, positive = "1")

# Variable importance plot
varImpPlot(rf_model)
    


Model 3. Decision Tree

It helps to visualize how risk factor splits.b

Decision Tree R Code

# 🌿 Model 3: Decision Tree
#------------------------------------------
library(rpart)
library(rpart.plot)
library(caret)

# Build the decision tree
tree_model <- rpart(stroke ~ ., data = train_data, method = "class")

# Visualize the tree
rpart.plot(tree_model)

# Make predictions on test data
tree_pred <- predict(tree_model, newdata = test_data, type = "class")

# Evaluate model performance
confusionMatrix(tree_pred, test_data$stroke, positive = "1")

This specific project demonstrates the full workflow of a health data

1) Data import and preprocessing

2) Exploratory data analysis

3) Building and evaluating multiple prediction models

By Leveraging machine learning model, healthcare organizations can identify high- risk patient early, personalize preventive strategies and ultimately save lives.

Leave a Comment

Your email address will not be published. Required fields are marked *