diff --git a/src/nn/nn_train.py b/src/nn/nn_train.py index 233d6c7..e22c9e2 100644 --- a/src/nn/nn_train.py +++ b/src/nn/nn_train.py @@ -1,8 +1,3 @@ -""" - -""" - - import argparse import os import random @@ -41,6 +36,7 @@ class HydrophoneDataset(Dataset): def __init__( self, csv_path: str, + features_list: list[str], dtype: torch.dtype = torch.float32, ): """ @@ -69,25 +65,21 @@ def __init__( if not rows: raise ValueError("CSV is empty.") header = [h.strip() for h in rows[0]] - - try: - idx1 = header.index("Envelope H1") - idx2 = header.index("Envelope H2") - idx3 = header.index("Envelope H3") - idxy = header.index("Truth") - except ValueError as e: - raise ValueError( - "CSV must include header columns: Envelope H1, Envelope H2, Envelope H3, Truth" - ) from e + col_idx = {name: i for i, name in enumerate(header)} + missing = [name for name in features_list if name not in col_idx] + + if missing: + raise ValueError(f"CSV missing requested feature columns: {missing}") + + idxy = col_idx["Truth"] feats, labels = [], [] + feature_indices = [col_idx[name] for name in features_list] # e.g. [idx1, idx3] + for r in rows[1:]: if not r or all(c.strip() == "" for c in r): continue - - feats.append([float(r[idx2])]) - # feats.append([float(r[idx2]), float(r[idx3])]) - # feats.append([float(r[idx1]), float(r[idx2]), float(r[idx3])]) + feats.append([float(r[i]) for i in feature_indices]) labels.append(int(float(r[idxy]))) X = torch.tensor(feats, dtype=dtype) @@ -163,8 +155,7 @@ class using a Softmax activation at the output layer. model : nn.Sequential A sequential container implementing the layer stack. """ - # def __init__(self, in_dim=1, num_classes=4, p_drop=0.2): - # def __init__(self, in_dim=2, num_classes=4, p_drop=0.2): + def __init__(self, in_dim=3, num_classes=4, p_drop=0.2): super().__init__() @@ -233,14 +224,19 @@ def main(): parser.add_argument("--batch", type=int, default=32) parser.add_argument("--dropout", type=float, default=0.2) parser.add_argument("--save_dir", type=str, default="artifacts") + parser.add_argument("--feature_cols", type=str, default="Envelope H1,Envelope H2,Envelope H3") + parser.add_argument("--conf_thresh", type=str, default="0.5,0.8") args = parser.parse_args() set_seed(13) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") + feature_cols = [c.strip() for c in args.feature_cols.split(",")] + thresholds = [float(t.strip()) for t in args.conf_thresh.split(",") if t.strip()] + # Load dataset - ds = HydrophoneDataset(args.csv) + ds = HydrophoneDataset(csv_path=args.csv, features_list=feature_cols) n = len(ds) n_train = int(0.8 * n) n_val = n - n_train @@ -250,10 +246,7 @@ def main(): val_loader = DataLoader(val_ds, batch_size=args.batch) # Build model - model = MLPProb(in_dim=1, num_classes=4, p_drop=args.dropout).to(device) - # model = MLPProb(in_dim=2, num_classes=4, p_drop=args.dropout).to(device) - # model = MLPProb(in_dim=3, num_classes=4, p_drop=args.dropout).to(device) - + model = MLPProb(in_dim=len(feature_cols),num_classes=4, p_drop=args.dropout).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) loss_fn = nn.NLLLoss() # using log(probs) @@ -285,8 +278,9 @@ def main(): # Validation model.eval() val_loss, val_acc = 0, 0 - val_conf_acc_50, val_coverage_50 = 0, 0 - val_conf_acc_80, val_coverage_80 = 0, 0 + val_conf_acc = {t: 0.0 for t in thresholds} + val_coverage = {t: 0.0 for t in thresholds} + with torch.no_grad(): for X, y in val_loader: @@ -298,30 +292,29 @@ def main(): val_acc += accuracy(probs, y) * X.size(0) # Confidence-aware metrics - acc_50, cov_50 = accuracy_with_pass(probs, y, confidence_threshold=0.5) - acc_80, cov_80 = accuracy_with_pass(probs, y, confidence_threshold=0.8) - val_conf_acc_50 += acc_50 * X.size(0) - val_coverage_50 += cov_50 * X.size(0) - val_conf_acc_80 += acc_80 * X.size(0) - val_coverage_80 += cov_80 * X.size(0) + for t in thresholds: + acc_t, cov_t = accuracy_with_pass(probs, y, confidence_threshold=t) + val_conf_acc[t] += acc_t * X.size(0) + val_coverage[t] += cov_t * X.size(0) val_loss /= len(val_loader.dataset) val_acc /= len(val_loader.dataset) - val_conf_acc_50 /= len(val_loader.dataset) - val_coverage_50 /= len(val_loader.dataset) - val_conf_acc_80 /= len(val_loader.dataset) - val_coverage_80 /= len(val_loader.dataset) + for t in thresholds: + val_conf_acc[t] /= len(val_loader.dataset) + val_coverage[t] /= len(val_loader.dataset) if val_loss < best_val_loss: best_val_loss = val_loss best_acc = val_acc best_state = {k: v.cpu() for k, v in model.state_dict().items()} + conf_parts = " | ".join( + [f"{int(t*100)}%: acc={val_conf_acc[t]:.3f} cov={val_coverage[t]:.3f}" for t in thresholds] + ) print(f"Epoch {epoch+1:03d}: " - f"train_loss={train_loss:.4f} acc={train_acc:.3f} | " - f"val_loss={val_loss:.4f} acc={val_acc:.3f} | " - f"50%: acc={val_conf_acc_50:.3f} cov={val_coverage_50:.3f} | " - f"80%: acc={val_conf_acc_80:.3f} cov={val_coverage_80:.3f}") + f"train_loss={train_loss:.4f} acc={train_acc:.3f} | " + f"val_loss={val_loss:.4f} acc={val_acc:.3f} | " + f"{conf_parts}") # Save model + normalization stats os.makedirs(args.save_dir, exist_ok=True) @@ -330,28 +323,26 @@ def main(): # Calculate final confidence metrics for best model model.load_state_dict(best_state) model.eval() - final_conf_acc_50, final_coverage_50 = 0, 0 - final_conf_acc_80, final_coverage_80 = 0, 0 + final_conf_acc = {t: 0.0 for t in thresholds} + final_coverage = {t: 0.0 for t in thresholds} with torch.no_grad(): for X, y in val_loader: X, y = X.to(device), y.to(device) probs = model(X) - acc_50, cov_50 = accuracy_with_pass(probs, y, confidence_threshold=0.5) - acc_80, cov_80 = accuracy_with_pass(probs, y, confidence_threshold=0.8) - final_conf_acc_50 += acc_50 * X.size(0) - final_coverage_50 += cov_50 * X.size(0) - final_conf_acc_80 += acc_80 * X.size(0) - final_coverage_80 += cov_80 * X.size(0) - - final_conf_acc_50 /= len(val_loader.dataset) - final_coverage_50 /= len(val_loader.dataset) - final_conf_acc_80 /= len(val_loader.dataset) - final_coverage_80 /= len(val_loader.dataset) - - print(f"BEST: val_loss={best_val_loss:.4f} acc={best_acc:.3f} | " - f"50%: acc={final_conf_acc_50:.3f} cov={final_coverage_50:.3f} | " - f"80%: acc={final_conf_acc_80:.3f} cov={final_coverage_80:.3f}") + for t in thresholds: + acc_t, cov_t = accuracy_with_pass(probs, y, confidence_threshold=t) + final_conf_acc[t] += acc_t * X.size(0) + final_coverage[t] += cov_t * X.size(0) + + for t in thresholds: + final_conf_acc[t] /= len(val_loader.dataset) + final_coverage[t] /= len(val_loader.dataset) + + final_parts = " | ".join( + [f"{int(t*100)}%: acc={final_conf_acc[t]:.3f} cov={final_coverage[t]:.3f}" for t in thresholds] + ) + print(f"BEST: val_loss={best_val_loss:.4f} acc={best_acc:.3f} | {final_parts}") torch.save({ "model_state_dict": best_state,