This project is a clean prototype for semantic face parsing with strict constraints:
- train only on provided
data/train - no external data, no pretrained weights, random initialization
- no ensemble
- lightweight model under
1,821,085trainable parameters - default input resolution:
512 x 512 - model: MobileNetV2-style encoder + UNet decoder (FPN decoder also supported via config)
.
├─ train.py
├─ validate.py
├─ infer.py
├─ count_params.py
├─ requirements.txt
├─ README.md
├─ config.yaml
├─ src/
│ ├─ datasets/
│ │ ├─ celebamask_dataset.py
│ │ └─ transforms.py
│ ├─ models/
│ │ ├─ blocks.py
│ │ ├─ bisenet.py
│ │ ├─ pidnet.py
│ │ ├─ lightweight_unet.py
│ │ └─ attention.py
│ ├─ losses/
│ │ ├─ dice.py
│ │ └─ segmentation_loss.py
│ ├─ engine/
│ │ ├─ trainer.py
│ │ ├─ evaluator.py
│ │ └─ inference.py
│ └─ utils/
│ ├─ metrics.py
│ ├─ class_weights.py
│ ├─ class_names.py
│ ├─ checkpoint.py
│ ├─ plotting.py
│ ├─ seed.py
│ └─ param_count.py
└─ experiments/
Expected layout:
data/
├─ train/
│ ├─ images/
│ └─ masks/
└─ val/
└─ images/
Optional labeled validation is supported with:
data/val/masks/
pip install -r requirements.txt- Loss:
CrossEntropy + 0.5 * Dice - Dice defaults to
present-onlyaveraging (only GT-present classes in a batch) - Optional boundary regularization is enabled by default:
- total =
CE + 0.5*Dice + 0.05*Boundary(after warmup) - warmup: first
8epochs without boundary term
- total =
- Weighted CE (class-aware) is supported from
experiments/mask_stats.json - Optimizer:
AdamW(lr=3e-4, weight_decay=5e-4) - Scheduler: cosine annealing
- Epochs:
100 - Batch size:
10 - AMP: enabled
- Augmentations:
- random resized crop
- random horizontal flip (with left/right label swap for paired classes)
- random rotation (+/- 15 deg)
- color jitter
- gaussian blur
- Internal validation split from train set by default (
val_split=0.1) - Training curve is saved to
experiments/<run>/val_fscore_curve.png:- train loss is always plotted
- validation F-score is overlaid when validation is enabled
model.encoder_type supports:
mobilenetv2(default)resnet(light residual backbone)
model.decoder_type supports:
unet(default)fpn
For a submit-ready residual+FPN variant under the parameter cap, use:
python train.py --config experiments/config_residual_fpn.yamlFor a BiSeNetV2-style variant under the parameter cap, use:
python train.py --config experiments/config_bisenet.yamlFor a PIDNet-style variant under the parameter cap, use:
python train.py --config experiments/config_pid.yamlconfig_pid.yaml uses a closer-to-paper PIDNet setup:
pidnet.m / pidnet.n / pidnet.planes / pidnet.ppm_planes / pidnet.head_planespidnet.augment: true(enables auxiliary P and D heads)loss.pid_auxcontrols auxiliary loss weights for training:p_weight: segmentation aux head weightd_weight: boundary aux head weight (BCE on boundary map)
- Count parameters
python count_params.py --config config.yaml- Train
python train.py --config config.yamlIf train.run_val_data: true, after training finishes the script will:
- load
best.pt - run inference on
data/val/images - write predictions to
data/val/masks
- Evaluate checkpoint
python validate.py --config config.yaml --checkpoint experiments/baseline/best.pt --source internalvalidate.py prints:
- overall loss / pixel accuracy / mean F-score
- per-class F-score with class names from
data.class_namesinconfig.yaml
If data/val/masks exists:
python validate.py --config config.yaml --checkpoint experiments/baseline/best.pt --source val- Infer on unlabeled val images
python infer.py --config config.yaml --checkpoint experiments/baseline/best.pt --output-dir outputs --tta-flipPredictions are saved as indexed masks (default .png) in outputs/.
Masks are written in palette (P) mode by default for direct visual inspection.
When --tta-flip is used, left/right class channels are swapped back before averaging logits.
Scale TTA is configurable via inference.tta_scales (for example [0.75, 1.0, 1.25]).
Validation-stage TTA is controlled by validation.use_tta (default: true).
To compute exact label IDs and class proportions from data/train/masks:
python analyze_masks.py --config config.yaml --masks-dir data/train/masksOutputs:
experiments/mask_stats.jsonexperiments/mask_stats.csv
Recommended weights for weighted CE are in:
weights.median_freq_norm_mean1insidemask_stats.json
Current training reads weighted CE settings from config.yaml:
loss.ce_weighting.enabledloss.ce_weighting.stats_jsonloss.dice_present_onlyloss.boundary.enabled / weight / pos_weight / warmup_epochs / pred_scaleloss.ce_weighting.key(for examplerecommended_weighted_ce.weights)inference.tta_enabled / tta_flip / tta_scalesvalidation.use_tta
Two submit-ready scripts are included under experiments/:
experiments/env_setup.batch: loads modules, creates virtualenv, installs requirementsexperiments/train.batch: trains and stores metrics/plot/checkpoints in one output folder
sbatch experiments/env_setup.batchOptional overrides:
sbatch --export=ALL,VENV_DIR=/path/to/venv experiments/env_setup.batchsbatch experiments/train.batch config.yamlOptional overrides:
sbatch --export=ALL,RUN_NAME=my_run,TRAIN_ARGS="--epochs 50 --batch-size 8 --num-workers 4" experiments/train.batch config.yamlAll key outputs are written into the same run directory, default:
experiments/run_<SLURM_JOB_ID>/best.ptexperiments/run_<SLURM_JOB_ID>/last.ptexperiments/run_<SLURM_JOB_ID>/history.jsonexperiments/run_<SLURM_JOB_ID>/val_fscore_curve.png(train loss + optional val F-score)experiments/run_<SLURM_JOB_ID>/val_metrics.txtexperiments/run_<SLURM_JOB_ID>/param_count.txt
PBS versions are also included:
experiments/env_setup.pbsexperiments/train.pbs
Note: these PBS scripts load cluster pytorch modules and skip pip-installing torch/torchvision.
qsub experiments/env_setup.pbsOptional overrides:
qsub -v VENV_DIR=/path/to/venv experiments/env_setup.pbsqsub -v CONFIG_PATH=config.yaml experiments/train.pbsOptional overrides:
qsub -v CONFIG_PATH=config.yaml,RUN_NAME=my_run,TRAIN_ARGS="--epochs 50 --batch-size 8 --num-workers 4" experiments/train.pbs