From 5ea97b8fba1c586e63e269dbc514cad468c6ee91 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Mon, 6 Apr 2026 18:18:36 -0400 Subject: [PATCH 1/3] Add DINOv3 vision foundation models contrib for Neuron Onboard Meta DINOv3 ViT and ConvNeXt backbones (21M-6.7B params) to Neuron. Two compilation paths: torch_neuronx.trace() for models up to 840M, and neuronx-distributed TP=4 for ViT-7B (first encoder-only vision TP on Neuron). ViT cosine sim 1.000000, ConvNeXt 0.999989, peak DP=4: 722.8 img/s (ViT-S), ViT-7B TP=4: 38.8 img/s at 25.77ms latency. --- contrib/models/DINOv3/README.md | 178 ++++ contrib/models/DINOv3/src/__init__.py | 0 contrib/models/DINOv3/src/modeling_dinov3.py | 801 ++++++++++++++++++ contrib/models/DINOv3/test/__init__.py | 0 .../DINOv3/test/integration/__init__.py | 0 .../DINOv3/test/integration/test_model.py | 361 ++++++++ contrib/models/DINOv3/test/unit/__init__.py | 0 7 files changed, 1340 insertions(+) create mode 100644 contrib/models/DINOv3/README.md create mode 100644 contrib/models/DINOv3/src/__init__.py create mode 100644 contrib/models/DINOv3/src/modeling_dinov3.py create mode 100644 contrib/models/DINOv3/test/__init__.py create mode 100644 contrib/models/DINOv3/test/integration/__init__.py create mode 100644 contrib/models/DINOv3/test/integration/test_model.py create mode 100644 contrib/models/DINOv3/test/unit/__init__.py diff --git a/contrib/models/DINOv3/README.md b/contrib/models/DINOv3/README.md new file mode 100644 index 00000000..8c006b26 --- /dev/null +++ b/contrib/models/DINOv3/README.md @@ -0,0 +1,178 @@ +# DINOv3 on AWS Neuron + +Compile and run [Meta DINOv3](https://github.com/facebookresearch/dinov3) self-supervised vision foundation models on AWS Neuron (Trainium2 and Inferentia2). + +## Model + +| Property | Value | +|----------|-------| +| **Models** | DINOv3 ViT-S/B/L/H+ (21M-840M), ConvNeXt-T/B (28M-88M), ViT-7B (6.7B) | +| **Architecture** | Encoder-only vision transformer / ConvNeXt backbone | +| **Parameters** | 21M to 6.7B (FP32 by default) | +| **Input** | 224x224 RGB images | +| **Output** | Dense feature embeddings (CLS token for ViT, pooled features for ConvNeXt) | +| **Source** | https://github.com/facebookresearch/dinov3 | +| **License** | DINOv3 License (not Apache/MIT -- check distribution terms) | + +## Compilation Approaches + +DINOv3 models are compiled using two different approaches depending on model size: + +| Model | Params | Approach | Instance | +|-------|--------|----------|----------| +| ViT-S/16 | 21.6M | `torch_neuronx.trace()` | inf2.xlarge | +| ViT-B/16 | 85.7M | `torch_neuronx.trace()` | inf2.xlarge / trn2.3xlarge | +| ViT-L/16 | 303.2M | `torch_neuronx.trace()` | trn2.3xlarge | +| ViT-H+/16 | 840.6M | `torch_neuronx.trace()` | trn2.3xlarge | +| ViT-7B/16 | 6,716M | `neuronx-distributed` ModelBuilder TP=4 | trn2.3xlarge | +| ConvNeXt-T | 27.8M | `torch_neuronx.trace()` | inf2.xlarge | +| ConvNeXt-B | 87.6M | `torch_neuronx.trace()` | inf2.xlarge / trn2.3xlarge | + +**Key insight**: ViT-7B is the first encoder-only vision model to use tensor parallelism on Neuron. The 20.1 GB NEFF does not fit in single-core HBM, so TP=4 via `neuronx-distributed` is required. + +## Results + +### Accuracy (matmult bf16 vs CPU FP32) + +| Model | Cosine Similarity | Max Abs Diff | +|-------|------------------:|-------------:| +| ViT-S/16 | 1.000000 | < 0.001 | +| ViT-B/16 | 1.000000 | < 0.001 | +| ViT-L/16 | 1.000000 | < 0.001 | +| ViT-H+/16 | 1.000000 | < 0.001 | +| ViT-7B/16 | Deterministic (random weights) | -- | +| ConvNeXt-T | 0.999989 | < 0.001 | +| ConvNeXt-B | 0.999989 | < 0.001 | + +### Benchmark (trn2.3xlarge, LNC=2, DP=4) + +| Model | NEFF Size | Compile Time | 1-Core (img/s) | DP=4 Peak (img/s) | +|-------|----------:|-------------:|----------------:|-------------------:| +| ViT-S/16 | 68 MB | 84s | 367 | **722.8** | +| ViT-B/16 | 264 MB | 45s | 222 | **438.7** | +| ViT-L/16 | 931 MB | 123s | 87.6 | **174.7** | +| ViT-H+/16 | 2,595 MB | 688s | 5.2 | 10.5 | +| ViT-7B/16 | TP=4 NEFF | 5.9s | OOM | **38.8** (TP=4) | +| ConvNeXt-T | 90 MB | 44s | 183 | **363.3** | +| ConvNeXt-B | 275 MB | 63s | 130 | **257.8** | + +### Key Findings + +1. **`--auto-cast=matmult` is critical**: FP32 models get 50-60% speedup with matmult bf16 autocast, consistent with SigLIP and MoLFormer results +2. **ViT 1.7x faster than ConvNeXt**: At comparable parameter counts, ViT models are significantly faster on Neuron (transformer ops are heavily optimized) +3. **DataParallel scales near-perfectly**: DP=4 achieves ~1.95-2.0x over single-core across all models +4. **ViT-H+ is HBM-bandwidth limited**: 2.5 GB NEFF saturates single-core HBM bandwidth, resulting in only 10.5 img/s DP=4 (16.6x slower than ViT-L) +5. **ViT-7B requires TP=4**: 20.1 GB NEFF exceeds single-core HBM. Tensor parallelism via `neuronx-distributed` ModelBuilder achieves 38.8 img/s at 25.77ms latency + +## Compatibility + +| Component | Version | +|-----------|---------| +| **Neuron SDK** | 2.28 | +| **torch-neuronx** | 2.9.0.2.11 | +| **neuronx-cc** | 2.22.12471 | +| **neuronx-distributed** | 0.16.25997 (ViT-7B TP only) | +| **Instance (small/medium)** | inf2.xlarge, trn2.3xlarge | +| **Instance (ViT-7B)** | trn2.3xlarge (TP=4, LNC=2) | +| **DLAMI** | Deep Learning AMI Neuron (Ubuntu 24.04) 20260227 | + +## Usage + +### Setup + +```bash +# Activate Neuron environment +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate + +# Clone DINOv3 repository +git clone https://github.com/facebookresearch/dinov3.git /mnt/models/dinov3 +``` + +### Trace a ViT Model + +```python +import sys +sys.path.insert(0, "contrib/models/DINOv3/src") +from modeling_dinov3 import load_dinov3_model, trace_dinov3, validate_accuracy, benchmark_model + +# Load model +model = load_dinov3_model("dinov3_vitb16", repo_dir="/mnt/models/dinov3") + +# Compile for Neuron +model_neuron = trace_dinov3(model, is_convnext=False, save_path="/tmp/dinov3_vit_b.pt") + +# Validate accuracy +metrics = validate_accuracy(model, model_neuron) +print(f"Cosine similarity: {metrics['cosine_sim']:.6f}") + +# Benchmark +perf = benchmark_model(model_neuron) +print(f"Throughput: {perf['throughput_img_s']:.1f} img/s") +``` + +### Trace a ConvNeXt Model + +```python +model = load_dinov3_model("dinov3_convnext_tiny", repo_dir="/mnt/models/dinov3") +model_neuron = trace_dinov3(model, is_convnext=True, save_path="/tmp/dinov3_convnext_t.pt") +``` + +### Compile ViT-7B with Tensor Parallelism + +```python +from modeling_dinov3 import compile_vit7b_tp, benchmark_model + +# Requires NEURON_RT_NUM_CORES=4 and trn2.3xlarge +nxd_model = compile_vit7b_tp(tp_degree=4) +perf = benchmark_model(nxd_model) +print(f"TP=4 throughput: {perf['throughput_img_s']:.1f} img/s") +``` + +### DataParallel Benchmark + +```python +from modeling_dinov3 import benchmark_dataparallel + +# DP=4 across all NeuronCores on trn2.3xlarge (LNC=2) +dp_results = benchmark_dataparallel(model_neuron, num_cores=4) +for bs, r in dp_results.items(): + print(f"BS={bs}: {r['throughput_img_s']:.1f} img/s") +``` + +## Running Tests + +```bash +# Activate Neuron environment +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate + +# Run integration tests +python -m pytest contrib/models/DINOv3/test/integration/test_model.py -v + +# Or standalone (with detailed output) +python contrib/models/DINOv3/test/integration/test_model.py + +# Set custom paths +DINOV3_REPO_DIR=/path/to/dinov3 python -m pytest contrib/models/DINOv3/test/ -v +``` + +## Dependencies + +Pre-installed in DLAMI PyTorch inference venv: +- torch-neuronx +- neuronx-distributed (for ViT-7B TP) +- numpy + +Required (clone separately): +- DINOv3 repository: `git clone https://github.com/facebookresearch/dinov3.git` + +## Notes + +- All models use `pretrained=False` (random weights) for architecture validation. Replace with pretrained weights for production use. +- `--model-type=transformer` compiler flag is used for ViT models only (not ConvNeXt). +- ConvNeXt models exercise different Neuron ops (Conv2d, depthwise conv, GroupNorm) -- good diversity test for Neuron compiler. +- ViT-H+ traces successfully but is HBM-bandwidth-limited (2.5 GB NEFF). Consider TP for production use of models > 500M params. +- DINOv3 License is not Apache/MIT -- review before redistribution. + +## Maintainer + +Jim Burtoft (`jimburtoft`) diff --git a/contrib/models/DINOv3/src/__init__.py b/contrib/models/DINOv3/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/DINOv3/src/modeling_dinov3.py b/contrib/models/DINOv3/src/modeling_dinov3.py new file mode 100644 index 00000000..6c70a8ef --- /dev/null +++ b/contrib/models/DINOv3/src/modeling_dinov3.py @@ -0,0 +1,801 @@ +""" +DINOv3 vision foundation models on AWS Neuron. + +Supports two compilation paths: +1. torch_neuronx.trace() -- for ViT-S/B/L/H+ and ConvNeXt-T/S/B/L (single NeuronCore) +2. neuronx-distributed ModelBuilder with TP -- for ViT-7B (6.7B params, requires TP=4) + +All models are encoder-only with static input shapes -- ideal for torch_neuronx.trace(). +FP32 by default; --auto-cast=matmult is critical for performance. + +Architecture: + - ViT: patch size 16, CLS + register/storage tokens, 2D axial RoPE, SwiGLU FFN, LayerScale + - ConvNeXt: hierarchical conv backbone (Conv2d, GroupNorm, GELU, LayerScale) + - Input: 224x224 RGB images + - Output: dense feature embeddings (CLS token for ViT, pooled features for ConvNeXt) + +Reference: https://github.com/facebookresearch/dinov3 +License: DINOv3 License (not Apache/MIT) +""" + +import math +import os +import sys +import time +from functools import partial +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import torch_neuronx +from torch import Tensor, nn + +# Default compiler args for trace-based models +COMPILER_ARGS_VIT = [ + "--auto-cast", + "matmult", + "--auto-cast-type", + "bf16", + "--model-type", + "transformer", +] + +COMPILER_ARGS_CONVNEXT = [ + "--auto-cast", + "matmult", + "--auto-cast-type", + "bf16", +] + +# Model registry: hub function name -> configuration +MODEL_REGISTRY = { + "vit_s": { + "hub_name": "dinov3_vits16", + "arch": "vit", + "embed_dim": 384, + "params_M": 21.6, + }, + "vit_b": { + "hub_name": "dinov3_vitb16", + "arch": "vit", + "embed_dim": 768, + "params_M": 85.7, + }, + "vit_l": { + "hub_name": "dinov3_vitl16", + "arch": "vit", + "embed_dim": 1024, + "params_M": 303.2, + }, + "vit_h_plus": { + "hub_name": "dinov3_vith16plus", + "arch": "vit", + "embed_dim": 1280, + "params_M": 840.6, + }, + "convnext_tiny": { + "hub_name": "dinov3_convnext_tiny", + "arch": "convnext", + "embed_dim": 768, + "params_M": 27.8, + }, + "convnext_base": { + "hub_name": "dinov3_convnext_base", + "arch": "convnext", + "embed_dim": 1024, + "params_M": 87.6, + }, +} + +IMG_SIZE = 224 +BATCH_SIZE = 1 + + +def load_dinov3_model(hub_name: str, repo_dir: str = "/mnt/models/dinov3"): + """Load a DINOv3 model from the cloned repository. + + Args: + hub_name: Model function name (e.g., 'dinov3_vits16', 'dinov3_convnext_tiny') + repo_dir: Path to cloned dinov3 repository + + Returns: + PyTorch model in eval mode with random weights (pretrained=False) + """ + if repo_dir not in sys.path: + sys.path.insert(0, repo_dir) + + from dinov3.hub.backbones import ( + dinov3_vits16, + dinov3_vitb16, + dinov3_vitl16, + dinov3_vith16plus, + dinov3_vit7b16, + dinov3_convnext_tiny, + dinov3_convnext_base, + ) + + hub_fns = { + "dinov3_vits16": dinov3_vits16, + "dinov3_vitb16": dinov3_vitb16, + "dinov3_vitl16": dinov3_vitl16, + "dinov3_vith16plus": dinov3_vith16plus, + "dinov3_vit7b16": dinov3_vit7b16, + "dinov3_convnext_tiny": dinov3_convnext_tiny, + "dinov3_convnext_base": dinov3_convnext_base, + } + + if hub_name not in hub_fns: + raise ValueError( + f"Unknown model: {hub_name}. Available: {list(hub_fns.keys())}" + ) + + model = hub_fns[hub_name](pretrained=False) + model.eval() + return model + + +def trace_dinov3( + model: nn.Module, + is_convnext: bool = False, + img_size: int = IMG_SIZE, + batch_size: int = BATCH_SIZE, + save_path: Optional[str] = None, + inline_weights: bool = True, +) -> torch.jit.ScriptModule: + """Trace a DINOv3 model for Neuron via torch_neuronx.trace(). + + Args: + model: DINOv3 model (ViT or ConvNeXt) in eval mode + is_convnext: True for ConvNeXt models (uses different compiler args) + img_size: Input image size (default: 224) + batch_size: Batch size for tracing (default: 1) + save_path: Optional path to save compiled model + inline_weights: Whether to inline weights into NEFF (default: True) + + Returns: + Compiled Neuron model + """ + compiler_args = COMPILER_ARGS_CONVNEXT if is_convnext else COMPILER_ARGS_VIT + example_input = torch.randn(batch_size, 3, img_size, img_size) + + print(f" Tracing with compiler_args={compiler_args}") + t0 = time.time() + model_neuron = torch_neuronx.trace( + model, + example_input, + compiler_args=compiler_args, + inline_weights_to_neff=inline_weights, + ) + compile_time = time.time() - t0 + print(f" Compilation time: {compile_time:.1f}s") + + if save_path: + torch.jit.save(model_neuron, save_path) + neff_size_mb = os.path.getsize(save_path) / (1024 * 1024) + print(f" Saved to {save_path} ({neff_size_mb:.1f} MB)") + + return model_neuron + + +def validate_accuracy( + model_cpu: nn.Module, + model_neuron: torch.jit.ScriptModule, + img_size: int = IMG_SIZE, + batch_size: int = BATCH_SIZE, +) -> Dict[str, float]: + """Compare CPU vs Neuron outputs for accuracy validation. + + Args: + model_cpu: CPU reference model + model_neuron: Compiled Neuron model + img_size: Input image size + batch_size: Batch size + + Returns: + Dict with cosine_sim, max_diff, l2_rel_error + """ + example_input = torch.randn(batch_size, 3, img_size, img_size) + + with torch.no_grad(): + cpu_out = model_cpu(example_input) + neuron_out = model_neuron(example_input) + + cpu_flat = cpu_out.flatten().float() + neuron_flat = neuron_out.flatten().float() + + cosine_sim = F.cosine_similarity( + cpu_flat.unsqueeze(0), neuron_flat.unsqueeze(0) + ).item() + + max_diff = (cpu_flat - neuron_flat).abs().max().item() + l2_rel = (torch.norm(cpu_flat - neuron_flat) / torch.norm(cpu_flat)).item() + + return { + "cosine_sim": cosine_sim, + "max_diff": max_diff, + "l2_rel_error": l2_rel, + } + + +def benchmark_model( + model_neuron, + img_size: int = IMG_SIZE, + batch_size: int = BATCH_SIZE, + warmup_iters: int = 5, + bench_iters: int = 50, +) -> Dict[str, float]: + """Benchmark latency and throughput of a compiled Neuron model. + + Args: + model_neuron: Compiled Neuron model (trace or TP) + img_size: Input image size + batch_size: Batch size per inference call + warmup_iters: Number of warmup iterations + bench_iters: Number of timed iterations + + Returns: + Dict with mean_latency_ms, median_latency_ms, p99_latency_ms, throughput_img_s + """ + example_input = torch.randn(batch_size, 3, img_size, img_size) + + # For TP models that expect bfloat16 input + if hasattr(model_neuron, "_is_tp_model"): + example_input = example_input.bfloat16() + + for _ in range(warmup_iters): + model_neuron(example_input) + + latencies = [] + for _ in range(bench_iters): + t0 = time.time() + model_neuron(example_input) + latencies.append((time.time() - t0) * 1000) + + latencies = np.array(latencies) + return { + "mean_latency_ms": latencies.mean(), + "median_latency_ms": float(np.median(latencies)), + "p99_latency_ms": float(np.percentile(latencies, 99)), + "throughput_img_s": 1000.0 / latencies.mean() * batch_size, + } + + +def benchmark_dataparallel( + model_neuron, + num_cores: int = 4, + img_size: int = IMG_SIZE, + batch_sizes: Optional[List[int]] = None, + warmup_iters: int = 10, + bench_iters: int = 100, +) -> Dict[int, Dict[str, float]]: + """Benchmark DataParallel throughput across multiple batch sizes. + + Args: + model_neuron: Compiled Neuron model (single-core) + num_cores: Number of NeuronCores for DataParallel + img_size: Input image size + batch_sizes: List of batch sizes to test (default: [cores, 2*cores, 4*cores]) + warmup_iters: Number of warmup iterations + bench_iters: Number of timed iterations + + Returns: + Dict mapping batch_size -> benchmark metrics + """ + if batch_sizes is None: + batch_sizes = [num_cores, num_cores * 2, num_cores * 4] + + model_dp = torch_neuronx.DataParallel( + model_neuron, + device_ids=list(range(num_cores)), + dim=0, + ) + + results = {} + for bs in batch_sizes: + dp_input = torch.randn(bs, 3, img_size, img_size) + + for _ in range(warmup_iters): + model_dp(dp_input) + + latencies = [] + for _ in range(bench_iters): + t0 = time.time() + model_dp(dp_input) + latencies.append((time.time() - t0) * 1000) + + latencies = np.array(latencies) + results[bs] = { + "mean_latency_ms": latencies.mean(), + "median_latency_ms": float(np.median(latencies)), + "throughput_img_s": 1000.0 / latencies.mean() * bs, + } + + return results + + +# ============================================================================ +# TP ViT-7B Model Definition +# ============================================================================ +# The following classes implement a standalone tensor-parallel ViT-7B/16 model +# using neuronx-distributed parallel layers. This is required because the 6.7B +# parameter model produces a 20.1 GB NEFF that does not fit in single-core HBM. +# +# This is the first encoder-only vision model to use TP on Neuron. +# ============================================================================ + + +def _rope_rotate_half(x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat([-x2, x1], dim=-1) + + +def _rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: + return (x * cos) + (_rope_rotate_half(x) * sin) + + +class RopePositionEmbedding(nn.Module): + """2D axial RoPE for vision patches. Replicated across TP ranks.""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + base: float = 100.0, + normalize_coords: str = "separate", + rescale_coords: Optional[float] = None, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + D_head = embed_dim // num_heads + self.D_head = D_head + self.normalize_coords = normalize_coords + self.rescale_coords = rescale_coords + self.dtype = dtype + + periods = base ** (2 * torch.arange(D_head // 4, dtype=dtype) / (D_head // 2)) + self.register_buffer("periods", periods, persistent=True) + + def forward(self, H: int, W: int) -> Tuple[Tensor, Tensor]: + device = self.periods.device + dtype = self.dtype + dd = {"device": device, "dtype": dtype} + + if self.normalize_coords == "separate": + coords_h = torch.arange(0.5, H, **dd) / H + coords_w = torch.arange(0.5, W, **dd) / W + else: + raise ValueError(f"Unsupported normalize_coords: {self.normalize_coords}") + + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) # [HW, 2] + coords = 2.0 * coords - 1.0 + + angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] + angles = angles.flatten(1, 2) # [HW, D//2] + angles = angles.tile(2) # [HW, D] + cos = torch.cos(angles) + sin = torch.sin(angles) + return (sin, cos) + + +class PatchEmbed(nn.Module): + """Patch embedding via Conv2d. Replicated across TP ranks (small layer).""" + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.patch_size = (patch_size, patch_size) + self.num_patches = (img_size // patch_size) ** 2 + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, dtype=dtype + ) + + def forward(self, x: Tensor) -> Tensor: + x = self.proj(x) # [B, C, H, W] + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # [B, HW, C] + x = x.reshape(-1, H, W, x.shape[-1]) # [B, H, W, C] + return x + + def reset_parameters(self): + k = 1 / (self.proj.in_channels * (self.patch_size[0] ** 2)) + nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k)) + if self.proj.bias is not None: + nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k)) + + +class LayerScale(nn.Module): + """Per-dimension scaling. Replicated across TP ranks.""" + + def __init__( + self, dim: int, init_values: float = 1e-5, dtype: torch.dtype = torch.bfloat16 + ): + super().__init__() + self.gamma = nn.Parameter(init_values * torch.ones(dim, dtype=dtype)) + + def forward(self, x: Tensor) -> Tensor: + return x * self.gamma + + def reset_parameters(self): + nn.init.ones_(self.gamma) + self.gamma.data *= 1e-5 + + +class TPSelfAttention(nn.Module): + """Self-attention with tensor-parallel Q/K/V/O projections. + + TP strategy: + - Q, K, V: ColumnParallelLinear (gather_output=True) -- each rank gets full tensor + - O: RowParallelLinear (input_is_parallel=False) -- handles all-reduce + """ + + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = False, + proj_bias: bool = True, + mask_k_bias: bool = False, + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, + ) + + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.dim = dim + + self.q_proj = ColumnParallelLinear( + dim, dim, bias=qkv_bias, gather_output=True, dtype=dtype, pad=True + ) + self.k_proj = ColumnParallelLinear( + dim, dim, bias=qkv_bias, gather_output=True, dtype=dtype, pad=True + ) + self.v_proj = ColumnParallelLinear( + dim, dim, bias=qkv_bias, gather_output=True, dtype=dtype, pad=True + ) + self.proj = RowParallelLinear( + dim, dim, bias=proj_bias, input_is_parallel=False, dtype=dtype + ) + self.mask_k_bias = mask_k_bias + + def apply_rope(self, q, k, rope): + q_dtype, k_dtype = q.dtype, k.dtype + sin, cos = rope + rope_dtype = sin.dtype + q, k = q.to(dtype=rope_dtype), k.to(dtype=rope_dtype) + N = q.shape[-2] + prefix = N - sin.shape[-2] + assert prefix >= 0 + q_prefix = q[:, :, :prefix, :] + q = _rope_apply(q[:, :, prefix:, :], sin, cos) + q = torch.cat((q_prefix, q), dim=-2) + k_prefix = k[:, :, :prefix, :] + k = _rope_apply(k[:, :, prefix:, :], sin, cos) + k = torch.cat((k_prefix, k), dim=-2) + return q.to(dtype=q_dtype), k.to(dtype=k_dtype) + + def forward(self, x: Tensor, rope=None) -> Tensor: + B, N, _ = x.shape + q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + + if rope is not None: + q, k = self.apply_rope(q, k, rope) + + attn_out = F.scaled_dot_product_attention(q, k, v) + attn_out = attn_out.transpose(1, 2).reshape(B, N, self.dim) + return self.proj(attn_out) + + +class TPSwiGLUFFN(nn.Module): + """SwiGLU FFN with tensor-parallel w1/w2 (gate/up) and w3 (down). + + TP strategy: + - w1, w2: ColumnParallelLinear (gather_output=False) -- keep sharded + - w3: RowParallelLinear (input_is_parallel=True) -- reduce + """ + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + bias: bool = True, + align_to: int = 64, + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, + ) + + out_features = out_features or in_features + hidden_features = hidden_features or in_features + d = int(hidden_features * 2 / 3) + swiglu_hidden_features = d + (-d % align_to) + + self.w1 = ColumnParallelLinear( + in_features, + swiglu_hidden_features, + bias=bias, + gather_output=False, + dtype=dtype, + pad=True, + ) + self.w2 = ColumnParallelLinear( + in_features, + swiglu_hidden_features, + bias=bias, + gather_output=False, + dtype=dtype, + pad=True, + ) + self.w3 = RowParallelLinear( + swiglu_hidden_features, + out_features, + bias=bias, + input_is_parallel=True, + dtype=dtype, + ) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.w1(x) + x2 = self.w2(x) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class TPSelfAttentionBlock(nn.Module): + """Transformer block: LayerNorm -> Attention -> LayerScale -> LayerNorm -> FFN -> LayerScale.""" + + def __init__( + self, + dim: int, + num_heads: int, + ffn_ratio: float = 3.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + layerscale_init: Optional[float] = 1e-5, + mask_k_bias: bool = False, + align_to: int = 64, + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=1e-5, dtype=dtype) + self.attn = TPSelfAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + mask_k_bias=mask_k_bias, + dtype=dtype, + ) + self.ls1 = ( + LayerScale(dim, init_values=layerscale_init, dtype=dtype) + if layerscale_init + else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(dim, eps=1e-5, dtype=dtype) + mlp_hidden_dim = int(dim * ffn_ratio) + self.mlp = TPSwiGLUFFN( + in_features=dim, + hidden_features=mlp_hidden_dim, + bias=ffn_bias, + align_to=align_to, + dtype=dtype, + ) + self.ls2 = ( + LayerScale(dim, init_values=layerscale_init, dtype=dtype) + if layerscale_init + else nn.Identity() + ) + + def forward(self, x: Tensor, rope=None) -> Tensor: + x = x + self.ls1(self.attn(self.norm1(x), rope=rope)) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +class TPDinoViT7B(nn.Module): + """DINOv3 ViT-7B/16 with tensor parallelism via neuronx-distributed. + + Config (from dinov3 hub/backbones.py dinov3_vit7b16): + embed_dim=4096, depth=40, num_heads=32, head_dim=128 + ffn_layer=swiglu64 (SwiGLU with align_to=64), ffn_ratio=3 + qkv_bias=False, proj_bias=True, ffn_bias=True + n_storage_tokens=4, layerscale_init=1e-5 + RoPE: base=100, normalize_coords=separate, rescale_coords=2 + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + embed_dim: int = 4096, + depth: int = 40, + num_heads: int = 32, + ffn_ratio: float = 3.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + n_storage_tokens: int = 4, + layerscale_init: float = 1e-5, + mask_k_bias: bool = True, + align_to: int = 64, + rope_base: float = 100.0, + rope_normalize_coords: str = "separate", + rope_rescale_coords: Optional[float] = 2.0, + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.embed_dim = embed_dim + self.n_storage_tokens = n_storage_tokens + self.patch_size = patch_size + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=3, + embed_dim=embed_dim, + dtype=dtype, + ) + self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, dtype=dtype)) + self.storage_tokens = nn.Parameter( + torch.empty(1, n_storage_tokens, embed_dim, dtype=dtype) + ) + self.mask_token = nn.Parameter(torch.empty(1, embed_dim, dtype=dtype)) + self.rope_embed = RopePositionEmbedding( + embed_dim=embed_dim, + num_heads=num_heads, + base=rope_base, + normalize_coords=rope_normalize_coords, + rescale_coords=rope_rescale_coords, + dtype=torch.float32, + ) + + self.blocks = nn.ModuleList( + [ + TPSelfAttentionBlock( + dim=embed_dim, + num_heads=num_heads, + ffn_ratio=ffn_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + layerscale_init=layerscale_init, + mask_k_bias=mask_k_bias, + align_to=align_to, + dtype=dtype, + ) + for _ in range(depth) + ] + ) + + self.norm = nn.LayerNorm(embed_dim, eps=1e-5, dtype=dtype) + self.head = nn.Identity() + + def init_weights(self): + nn.init.normal_(self.cls_token, std=0.02) + nn.init.normal_(self.storage_tokens, std=0.02) + nn.init.zeros_(self.mask_token) + self.patch_embed.reset_parameters() + for block in self.blocks: + if hasattr(block, "ls1") and isinstance(block.ls1, LayerScale): + block.ls1.reset_parameters() + if hasattr(block, "ls2") and isinstance(block.ls2, LayerScale): + block.ls2.reset_parameters() + + def forward(self, x: Tensor) -> Tensor: + x = self.patch_embed(x) + B, H, W, _ = x.shape + x = x.flatten(1, 2) # [B, HW, C] + + cls_token = self.cls_token + 0 * self.mask_token # tie mask_token gradient + x = torch.cat( + [cls_token.expand(B, -1, -1), self.storage_tokens.expand(B, -1, -1), x], + dim=1, + ) + + rope = self.rope_embed(H=H, W=W) + + for blk in self.blocks: + x = blk(x, rope=rope) + + x = self.norm(x) + return self.head(x[:, 0]) # [B, embed_dim] + + +def create_vit7b_tp(dtype: torch.dtype = torch.bfloat16) -> TPDinoViT7B: + """Create TP ViT-7B model with initialized weights. + + Must be called inside a NxDParallelState context for TP to work. + + Returns: + TPDinoViT7B model in eval mode + """ + model = TPDinoViT7B(dtype=dtype) + model.init_weights() + model.eval() + return model + + +def compile_vit7b_tp( + tp_degree: int = 4, + img_size: int = IMG_SIZE, + batch_size: int = BATCH_SIZE, + save_dir: str = "/mnt/models/compiled/dinov3_vit7b_tp4", +) -> "NxDModel": + """Compile ViT-7B with tensor parallelism via ModelBuilder. + + Must be run with NEURON_RT_NUM_CORES >= tp_degree. + + Args: + tp_degree: Number of NeuronCores for TP (default: 4) + img_size: Input image size (default: 224) + batch_size: Batch size (default: 1) + save_dir: Directory for compiler workdir and artifacts + + Returns: + Compiled NxD model ready for inference + """ + from neuronx_distributed.trace.parallel_context import NxDParallelState + from neuronx_distributed import ModelBuilder + + os.makedirs(save_dir, exist_ok=True) + os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) + + example_input = torch.randn(batch_size, 3, img_size, img_size).bfloat16() + + compiler_args = ( + "--auto-cast=matmult " + "--auto-cast-type=bf16 " + "--model-type=transformer " + "--enable-saturate-infinity " + "-O1" + ) + + with NxDParallelState(world_size=tp_degree, tensor_model_parallel_size=tp_degree): + model = create_vit7b_tp(dtype=torch.bfloat16) + model_state = model.state_dict() + n_params = sum(p.numel() for p in model.parameters()) / 1e6 + print(f" TP ViT-7B parameters (per-rank): {n_params:.1f}M") + + builder = ModelBuilder(model) + + t0 = time.time() + builder.trace(args=example_input, tag="vit7b") + trace_time = time.time() - t0 + print(f" Trace time: {trace_time:.1f}s") + + t0 = time.time() + nxd_model = builder.compile( + compiler_workdir=os.path.join(save_dir, "compiler_workdir"), + compiler_args=compiler_args, + ) + compile_time = time.time() - t0 + print(f" Compile time: {compile_time:.1f}s") + + # Outside context: set weights and load on Neuron + sharded_checkpoint = [model_state for _ in range(tp_degree)] + nxd_model.set_weights(sharded_checkpoint) + nxd_model.to_neuron() + + # Tag for benchmark function + nxd_model._is_tp_model = True + + return nxd_model diff --git a/contrib/models/DINOv3/test/__init__.py b/contrib/models/DINOv3/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/DINOv3/test/integration/__init__.py b/contrib/models/DINOv3/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/DINOv3/test/integration/test_model.py b/contrib/models/DINOv3/test/integration/test_model.py new file mode 100644 index 00000000..40e4c602 --- /dev/null +++ b/contrib/models/DINOv3/test/integration/test_model.py @@ -0,0 +1,361 @@ +""" +Integration tests for DINOv3 vision models on Neuron. + +Tests accuracy (cosine similarity), compilation, inference, DataParallel, +and performance for DINOv3 ViT and ConvNeXt models compiled with torch_neuronx.trace(). + +Requires: + - Neuron instance (trn2.3xlarge recommended, inf2.xlarge for small models) + - Cloned dinov3 repository at /mnt/models/dinov3 + +Run: + python -m pytest contrib/models/DINOv3/test/integration/test_model.py -v + python contrib/models/DINOv3/test/integration/test_model.py # standalone +""" + +import json +import os +import subprocess +import sys +import time + +import numpy as np +import pytest +import torch +import torch_neuronx + +# Add src to path for modeling imports +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_SRC_DIR = os.path.normpath(os.path.join(_THIS_DIR, "..", "..", "src")) +sys.path.insert(0, _SRC_DIR) + +from modeling_dinov3 import ( + COMPILER_ARGS_CONVNEXT, + COMPILER_ARGS_VIT, + MODEL_REGISTRY, + benchmark_dataparallel, + benchmark_model, + load_dinov3_model, + trace_dinov3, + validate_accuracy, +) + +# --- Constants --- + +REPO_DIR = os.environ.get("DINOV3_REPO_DIR", "/mnt/models/dinov3") +SAVED_DIR = os.environ.get( + "DINOV3_SAVED_DIR", + os.path.join(os.path.normpath(os.path.join(_THIS_DIR, "..", "..")), "saved_models"), +) +IMG_SIZE = 224 + +# Models to test (subset for fast CI; full set for manual runs) +# ViT-B and ConvNeXt-Tiny are representative of both architecture families +TEST_MODELS = { + "vit_b": {"hub_name": "dinov3_vitb16", "is_convnext": False}, + "convnext_tiny": {"hub_name": "dinov3_convnext_tiny", "is_convnext": True}, +} + +# Extended model set for thorough testing +ALL_TRACE_MODELS = { + "vit_s": {"hub_name": "dinov3_vits16", "is_convnext": False}, + "vit_b": {"hub_name": "dinov3_vitb16", "is_convnext": False}, + "vit_l": {"hub_name": "dinov3_vitl16", "is_convnext": False}, + "convnext_tiny": {"hub_name": "dinov3_convnext_tiny", "is_convnext": True}, + "convnext_base": {"hub_name": "dinov3_convnext_base", "is_convnext": True}, +} + +# Accuracy thresholds (validated: ViT achieves 1.000000, ConvNeXt achieves 0.999989) +COSINE_SIM_THRESHOLD_VIT = 0.9999 +COSINE_SIM_THRESHOLD_CONVNEXT = 0.9998 +MAX_DIFF_THRESHOLD = 0.05 + +# Performance thresholds (conservative, for trn2.3xlarge single core) +MIN_THROUGHPUT_VIT_B = 150 # img/s (achieved: 222) +MIN_THROUGHPUT_CONVNEXT_T = 100 # img/s (achieved: 183) +MIN_THROUGHPUT_DP4 = 300 # img/s, DP=4 for ViT-B (achieved: 439) +MAX_P50_LATENCY_MS = 15.0 # ms, single core + + +# --- Helpers --- + + +def get_neuron_core_count(): + """Detect number of NeuronCores available.""" + try: + result = subprocess.run( + ["neuron-ls", "--json-output"], capture_output=True, text=True, timeout=10 + ) + info = json.loads(result.stdout) + return sum(d["nc_count"] for d in info) + except Exception: + return 0 + + +def compile_and_cache(model_key, config): + """Compile a DINOv3 model, caching the NEFF on disk.""" + os.makedirs(SAVED_DIR, exist_ok=True) + save_path = os.path.join(SAVED_DIR, f"dinov3_{model_key}_bs1.pt") + + if os.path.isfile(save_path): + return torch.jit.load(save_path) + + model = load_dinov3_model(config["hub_name"], repo_dir=REPO_DIR) + model_neuron = trace_dinov3( + model, + is_convnext=config["is_convnext"], + save_path=save_path, + ) + return model_neuron + + +# --- Fixtures --- + + +@pytest.fixture(scope="module") +def n_cores(): + return get_neuron_core_count() + + +@pytest.fixture(scope="module") +def vit_b_cpu(): + return load_dinov3_model("dinov3_vitb16", repo_dir=REPO_DIR) + + +@pytest.fixture(scope="module") +def vit_b_neuron(vit_b_cpu): + return compile_and_cache("vit_b", TEST_MODELS["vit_b"]) + + +@pytest.fixture(scope="module") +def convnext_tiny_cpu(): + return load_dinov3_model("dinov3_convnext_tiny", repo_dir=REPO_DIR) + + +@pytest.fixture(scope="module") +def convnext_tiny_neuron(convnext_tiny_cpu): + return compile_and_cache("convnext_tiny", TEST_MODELS["convnext_tiny"]) + + +# --- Test Classes --- + + +class TestModelLoads: + """Smoke tests: model loads, traces, and produces correct output shape.""" + + def test_vit_b_loads(self, vit_b_neuron): + example = torch.randn(1, 3, IMG_SIZE, IMG_SIZE) + out = vit_b_neuron(example) + # ViT-B output: CLS token embedding [1, 768] + assert out.shape == (1, 768), f"Expected (1, 768), got {out.shape}" + + def test_convnext_tiny_loads(self, convnext_tiny_neuron): + example = torch.randn(1, 3, IMG_SIZE, IMG_SIZE) + out = convnext_tiny_neuron(example) + # ConvNeXt-Tiny output shape depends on head configuration + assert out.ndim >= 1, f"Expected at least 1D output, got {out.ndim}D" + + def test_output_is_finite(self, vit_b_neuron): + example = torch.randn(1, 3, IMG_SIZE, IMG_SIZE) + out = vit_b_neuron(example) + assert torch.isfinite(out).all(), "Output contains NaN or Inf values" + + def test_deterministic_output(self, vit_b_neuron): + """Same input should produce identical output (no stochastic ops in eval).""" + example = torch.randn(1, 3, IMG_SIZE, IMG_SIZE) + out1 = vit_b_neuron(example) + out2 = vit_b_neuron(example) + assert torch.equal(out1, out2), "Model is not deterministic" + + +class TestAccuracy: + """Accuracy tests: Neuron vs CPU cosine similarity.""" + + def test_vit_b_cosine_similarity(self, vit_b_cpu, vit_b_neuron): + metrics = validate_accuracy(vit_b_cpu, vit_b_neuron) + assert metrics["cosine_sim"] >= COSINE_SIM_THRESHOLD_VIT, ( + f"ViT-B cosine sim {metrics['cosine_sim']:.6f} < {COSINE_SIM_THRESHOLD_VIT}" + ) + + def test_vit_b_max_diff(self, vit_b_cpu, vit_b_neuron): + metrics = validate_accuracy(vit_b_cpu, vit_b_neuron) + assert metrics["max_diff"] <= MAX_DIFF_THRESHOLD, ( + f"ViT-B max diff {metrics['max_diff']:.6f} > {MAX_DIFF_THRESHOLD}" + ) + + def test_convnext_tiny_cosine_similarity( + self, convnext_tiny_cpu, convnext_tiny_neuron + ): + metrics = validate_accuracy(convnext_tiny_cpu, convnext_tiny_neuron) + assert metrics["cosine_sim"] >= COSINE_SIM_THRESHOLD_CONVNEXT, ( + f"ConvNeXt-Tiny cosine sim {metrics['cosine_sim']:.6f} < {COSINE_SIM_THRESHOLD_CONVNEXT}" + ) + + def test_convnext_tiny_max_diff(self, convnext_tiny_cpu, convnext_tiny_neuron): + metrics = validate_accuracy(convnext_tiny_cpu, convnext_tiny_neuron) + assert metrics["max_diff"] <= MAX_DIFF_THRESHOLD, ( + f"ConvNeXt-Tiny max diff {metrics['max_diff']:.6f} > {MAX_DIFF_THRESHOLD}" + ) + + def test_vit_b_multiple_inputs(self, vit_b_cpu, vit_b_neuron): + """Accuracy holds across multiple random inputs.""" + for i in range(5): + torch.manual_seed(i * 42) + metrics = validate_accuracy(vit_b_cpu, vit_b_neuron) + assert metrics["cosine_sim"] >= COSINE_SIM_THRESHOLD_VIT, ( + f"ViT-B input {i}: cosine sim {metrics['cosine_sim']:.6f} < {COSINE_SIM_THRESHOLD_VIT}" + ) + + +class TestDataParallel: + """DataParallel tests: verify multi-core scaling.""" + + def test_dp_runs(self, vit_b_neuron, n_cores): + if n_cores < 2: + pytest.skip("Need >= 2 NeuronCores for DataParallel test") + + model_dp = torch_neuronx.DataParallel( + vit_b_neuron, + device_ids=list(range(min(n_cores, 4))), + dim=0, + ) + dp_cores = min(n_cores, 4) + dp_input = torch.randn(dp_cores, 3, IMG_SIZE, IMG_SIZE) + out = model_dp(dp_input) + assert out.shape[0] == dp_cores, ( + f"Expected batch={dp_cores}, got {out.shape[0]}" + ) + + def test_dp_speedup(self, vit_b_neuron, n_cores): + if n_cores < 4: + pytest.skip("Need >= 4 NeuronCores for DP speedup test") + + # Single-core benchmark + single_metrics = benchmark_model(vit_b_neuron, bench_iters=30) + + # DP=4 benchmark + dp_results = benchmark_dataparallel( + vit_b_neuron, num_cores=4, batch_sizes=[4], bench_iters=30 + ) + + dp_throughput = dp_results[4]["throughput_img_s"] + single_throughput = single_metrics["throughput_img_s"] + speedup = dp_throughput / single_throughput + + assert speedup > 1.5, ( + f"DP=4 speedup {speedup:.2f}x < 1.5x " + f"(single: {single_throughput:.1f}, DP: {dp_throughput:.1f} img/s)" + ) + + +class TestPerformance: + """Performance tests: throughput and latency thresholds.""" + + def test_vit_b_throughput(self, vit_b_neuron): + metrics = benchmark_model(vit_b_neuron, bench_iters=30) + assert metrics["throughput_img_s"] >= MIN_THROUGHPUT_VIT_B, ( + f"ViT-B throughput {metrics['throughput_img_s']:.1f} img/s < {MIN_THROUGHPUT_VIT_B}" + ) + + def test_convnext_tiny_throughput(self, convnext_tiny_neuron): + metrics = benchmark_model(convnext_tiny_neuron, bench_iters=30) + assert metrics["throughput_img_s"] >= MIN_THROUGHPUT_CONVNEXT_T, ( + f"ConvNeXt-Tiny throughput {metrics['throughput_img_s']:.1f} img/s < {MIN_THROUGHPUT_CONVNEXT_T}" + ) + + def test_vit_b_latency(self, vit_b_neuron): + metrics = benchmark_model(vit_b_neuron, bench_iters=30) + assert metrics["median_latency_ms"] <= MAX_P50_LATENCY_MS, ( + f"ViT-B P50 latency {metrics['median_latency_ms']:.2f}ms > {MAX_P50_LATENCY_MS}ms" + ) + + def test_dp4_throughput(self, vit_b_neuron, n_cores): + if n_cores < 4: + pytest.skip("Need >= 4 NeuronCores for DP throughput test") + + dp_results = benchmark_dataparallel( + vit_b_neuron, num_cores=4, batch_sizes=[8], bench_iters=30 + ) + throughput = dp_results[8]["throughput_img_s"] + assert throughput >= MIN_THROUGHPUT_DP4, ( + f"DP=4 throughput {throughput:.1f} img/s < {MIN_THROUGHPUT_DP4}" + ) + + +# --- Standalone Runner --- + +if __name__ == "__main__": + os.environ.setdefault("TORCHDYNAMO_DISABLE", "1") + + print("=" * 60) + print("DINOv3 Neuron Integration Tests (standalone)") + print("=" * 60) + + n_cores = get_neuron_core_count() + print(f"\nNeuronCores detected: {n_cores}") + print(f"DINOv3 repo: {REPO_DIR}") + print(f"Save directory: {SAVED_DIR}") + + all_pass = True + + for model_key, config in TEST_MODELS.items(): + is_convnext = config["is_convnext"] + arch = "ConvNeXt" if is_convnext else "ViT" + threshold = ( + COSINE_SIM_THRESHOLD_CONVNEXT if is_convnext else COSINE_SIM_THRESHOLD_VIT + ) + + print(f"\n--- {model_key} ({arch}) ---") + + # 1. Load CPU model + print(f"[1] Loading CPU model...") + cpu_model = load_dinov3_model(config["hub_name"], repo_dir=REPO_DIR) + n_params = sum(p.numel() for p in cpu_model.parameters()) / 1e6 + print(f" Parameters: {n_params:.1f}M") + + # 2. CPU reference + example = torch.randn(1, 3, IMG_SIZE, IMG_SIZE) + with torch.no_grad(): + cpu_out = cpu_model(example) + print(f"[2] CPU output shape: {cpu_out.shape}") + + # 3. Compile + print(f"[3] Compiling for Neuron...") + neuron_model = compile_and_cache(model_key, config) + + # 4. Smoke test + neuron_out = neuron_model(example) + print(f"[4] Neuron output shape: {neuron_out.shape}") + + # 5. Accuracy + metrics = validate_accuracy(cpu_model, neuron_model) + status = "PASS" if metrics["cosine_sim"] >= threshold else "FAIL" + if status == "FAIL": + all_pass = False + print( + f"[5] Accuracy [{status}]: cosine={metrics['cosine_sim']:.6f}, " + f"max_diff={metrics['max_diff']:.6f}, l2_rel={metrics['l2_rel_error']:.6f}" + ) + + # 6. Performance (single core) + perf = benchmark_model(neuron_model, bench_iters=50) + print( + f"[6] Performance: {perf['throughput_img_s']:.1f} img/s, " + f"P50={perf['median_latency_ms']:.2f}ms, P99={perf['p99_latency_ms']:.2f}ms" + ) + + # 7. DataParallel (if enough cores) + if n_cores >= 4: + dp_results = benchmark_dataparallel(neuron_model, num_cores=4) + print(f"[7] DataParallel (DP=4):") + for bs, r in dp_results.items(): + print( + f" BS={bs}: {r['throughput_img_s']:.1f} img/s, P50={r['median_latency_ms']:.2f}ms" + ) + else: + print(f"[7] DataParallel: SKIPPED (need >= 4 cores, have {n_cores})") + + # Summary + print(f"\n{'=' * 60}") + print(f"RESULT: {'ALL PASS' if all_pass else 'SOME FAILED'}") + print(f"{'=' * 60}") diff --git a/contrib/models/DINOv3/test/unit/__init__.py b/contrib/models/DINOv3/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b From cbe7ebe73196410c75fbf0c8302fbfb676808232 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Mon, 6 Apr 2026 18:33:09 -0400 Subject: [PATCH 2/3] Fix accuracy validation: trace from same CPU model instance pretrained=False gives different random weights on each load_dinov3_model call. compile_and_cache must receive the CPU model used for accuracy comparison rather than loading a new one internally. --- .../DINOv3/test/integration/test_model.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/contrib/models/DINOv3/test/integration/test_model.py b/contrib/models/DINOv3/test/integration/test_model.py index 40e4c602..ee53323e 100644 --- a/contrib/models/DINOv3/test/integration/test_model.py +++ b/contrib/models/DINOv3/test/integration/test_model.py @@ -92,17 +92,20 @@ def get_neuron_core_count(): return 0 -def compile_and_cache(model_key, config): - """Compile a DINOv3 model, caching the NEFF on disk.""" +def compile_and_cache(cpu_model, model_key, config): + """Compile a DINOv3 model for Neuron, tracing from the given CPU model. + + IMPORTANT: The cpu_model must be the same instance used for accuracy + validation. Since pretrained=False gives different random weights on + each call, we must trace the exact model we compare against. + """ os.makedirs(SAVED_DIR, exist_ok=True) save_path = os.path.join(SAVED_DIR, f"dinov3_{model_key}_bs1.pt") - if os.path.isfile(save_path): - return torch.jit.load(save_path) - - model = load_dinov3_model(config["hub_name"], repo_dir=REPO_DIR) + # Do NOT use cached NEFFs -- they were traced from a different model instance. + # Always re-trace from the provided cpu_model to ensure weight consistency. model_neuron = trace_dinov3( - model, + cpu_model, is_convnext=config["is_convnext"], save_path=save_path, ) @@ -124,7 +127,7 @@ def vit_b_cpu(): @pytest.fixture(scope="module") def vit_b_neuron(vit_b_cpu): - return compile_and_cache("vit_b", TEST_MODELS["vit_b"]) + return compile_and_cache(vit_b_cpu, "vit_b", TEST_MODELS["vit_b"]) @pytest.fixture(scope="module") @@ -134,7 +137,9 @@ def convnext_tiny_cpu(): @pytest.fixture(scope="module") def convnext_tiny_neuron(convnext_tiny_cpu): - return compile_and_cache("convnext_tiny", TEST_MODELS["convnext_tiny"]) + return compile_and_cache( + convnext_tiny_cpu, "convnext_tiny", TEST_MODELS["convnext_tiny"] + ) # --- Test Classes --- @@ -319,9 +324,9 @@ def test_dp4_throughput(self, vit_b_neuron, n_cores): cpu_out = cpu_model(example) print(f"[2] CPU output shape: {cpu_out.shape}") - # 3. Compile + # 3. Compile (from the same cpu_model for weight consistency) print(f"[3] Compiling for Neuron...") - neuron_model = compile_and_cache(model_key, config) + neuron_model = compile_and_cache(cpu_model, model_key, config) # 4. Smoke test neuron_out = neuron_model(example) From 4158a3d72ed50f5537f0105b20c974ed591ddd76 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Tue, 7 Apr 2026 23:08:08 -0400 Subject: [PATCH 3/3] Add GPU comparison benchmarks to DINOv3 README (A10G vs trn2) --- contrib/models/DINOv3/README.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/contrib/models/DINOv3/README.md b/contrib/models/DINOv3/README.md index 8c006b26..46dc91b8 100644 --- a/contrib/models/DINOv3/README.md +++ b/contrib/models/DINOv3/README.md @@ -64,6 +64,38 @@ DINOv3 models are compiled using two different approaches depending on model siz 4. **ViT-H+ is HBM-bandwidth limited**: 2.5 GB NEFF saturates single-core HBM bandwidth, resulting in only 10.5 img/s DP=4 (16.6x slower than ViT-L) 5. **ViT-7B requires TP=4**: 20.1 GB NEFF exceeds single-core HBM. Tensor parallelism via `neuronx-distributed` ModelBuilder achieves 38.8 img/s at 25.77ms latency +### GPU Comparison (A10G g5.xlarge) + +| Model | Neuron Best (trn2 DP=4) | GPU Best (A10G torch.compile BS=16) | Winner | +|-------|------------------------:|------------------------------------:|--------| +| ViT-B/16 | **440.6 img/s** | 380.0 img/s | Neuron 1.16x | +| ConvNeXt-Tiny | 364.5 img/s | **1,156.2 img/s** | GPU 3.2x | + +Neuron excels on ViT (transformer ops), GPU excels on ConvNeXt (conv ops). The hardware advantage depends on model architecture. + +
+Full GPU results (A10G, PyTorch 2.6) + +**ViT-B/16:** + +| Batch Size | Eager (img/s) | torch.compile (img/s) | +|-----------:|--------------:|----------------------:| +| 1 | 92.1 | 213.2 | +| 4 | 240.2 | 302.5 | +| 8 | 290.1 | 341.0 | +| 16 | 330.5 | 380.0 | + +**ConvNeXt-Tiny:** + +| Batch Size | Eager (img/s) | torch.compile (img/s) | +|-----------:|--------------:|----------------------:| +| 1 | 213.0 | 571.4 | +| 4 | 551.7 | 893.5 | +| 8 | 665.2 | 1,020.3 | +| 16 | 800.1 | 1,156.2 | + +
+ ## Compatibility | Component | Version |