This repository contains the code for our ECE324 project, a framework that improves the reasoning performance of LLMs on math tasks through a multi-stage pipeline involving Chain-of-Thought prompting, Reasoning Distillation, Supervised Fine-Tuning (SFT), and Reinforcement Learning with Human Feedback (RLHF).
The full report is availble in the assets directory at "./assets/reasonix_report.pdf".
While many LLM models excel at next-token prediction, but they often fail to follow logical steps required in math tasks.
The objective our project is to improve LLMs' step-by-step reasoning ability, enabling these models to:
- Generate consistent and structured thought processes
- Reduce logical errors
- Improve performance on GSM8K benchmark dataset
We use GSM8K, a benchmark dataset containing 8,800 grade school math problems.
- 7,500 training samples
- 1,300 test samples
- Each sample includes a word problem and a final numerical answer.
- We augment the training dataset using LLaMA 3.2 3B-Instruct to produce COT-structured step-by-step answer for each question.
For RLHF, we also use:
- PRM800k: a human-annotated dataset labeling individual reasoning steps as good, okay, or bad.
Our pipeline consists of the following key components:
- Few-shot Prompting: Teach the model how to reason through examples.
- Reasoning Distillation: Use a larger model to generate COT-structured training data.
- CoT Supervised Fine-Tuning (SFT): Fine-tune the baseline model using data of step-by-step solutions generated in the reasoning distillation step.
- RLHF with PPO: Further align model outputs with human-preferred reasoning using Proximal Policy Optimization.
We evaluate the models using:
- Accuracy on GSM8K: % of final answers matching ground truth.
- Quality of Reasoning in Four Categories:
- Category 1: Fully correct reasoning and correct final answer
- Category 2: Correct reasoning and arithmetic error
- Category 3: Partial reasoning and wrong answer
- Category 4: Completely failed reasoning
- Baseline: 16.8% accuracy and GSM8K, and 31% responses fell into Category 4, meaning the model completely failed to reason
- After COT SFT: accuracy improved to 53.6% on GSM8K, and the number of responses in Category 4 reduced to 18% — finetuned model are more capable to reason now
- After applying RLHF: accuracy slightly dropped to to 53.05% on GSM8K. The reason is that we are limited by hardware so we couldn't train on the full dataset.
Top: Baseline model output (repeating steps without proceeding)
Bottom: Fintuned model (structured CoT, correct answer)
| .gitignore
| env.yaml
| LICENSE
| README.md
| structure.txt
|
+---assets # Folder containing the images in README
| barchart.png
| improved_example.jpg
| pipeline.png
| reasonix_report.pdf
|
+---CoT # Implementation of COT finetuning pipeline
| config.yaml
| dataset.py
| main.py
| metrics.py
| model.py
| utils.py
|
+---data # Data preprocessing with reasoning distillation
| clean_data.py
| config.yaml
| dataset.py
| gsm8k_train.jsonl
| main.py
| model.py
| train.jsonl
| utils.py
|
+---inference # Evaluation Pipeline for both baseline model and finetuned model
| | config.yaml
| | dataset.py
| | delay_run.sh
| | main.py
| | model.py
| | utils.py
| |
| \---checkpoints
|
\---RLHF # Training and Evaluation code for reward model, and training code for policy model
| config.yaml
| evaluate_reward_model.py
| main.py
| utils.py
|
+---data
| prepare_prm_data.py
| test.jsonl
| train.jsonl
|
+---policy_model
| data.py
| model.py
| train.py
| train_policy.py
| __init__.py
|
\---reward_model
data_loader.py
model.py
reward_model.py
trainer.py
__init__.py
git clone https://github.com/EdwinChacko75/reasonix.git
cd reasonixconda env create -f env.yaml
conda activate reasonixAdjust config files as required.
cd CoT
python3 main.pycd RLHFReward Model training
torchrun --nproc_per_node=2 main.py --task=rewardPolicy Model Training
torchrun --nproc_per_node=2 main.py --task=rlhfcd inference
python3 main.pyMIT License. See LICENSE for details.


