From 63371822264ca92fe1a00753c8a5c353adfdfd5c Mon Sep 17 00:00:00 2001 From: nathan Date: Mon, 13 Apr 2026 10:54:06 -0400 Subject: [PATCH] feat: Add enterprise system resilience and graceful degradation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves CRITICAL #1 from code-health-report-2026-04-13.md Changes: - Add tenacity dependency for retry logic - Create lib/resilience.py with: - resilient_http_call decorator (3 retries, exponential backoff 2s→4s→8s) - CircuitBreaker class (opens after 5 consecutive failures) - handle_404_gracefully decorator for safe resource lookups - Apply retry decorators to all HTTP clients: - workday_client.py: get(), raas() - entra_client.py: get(), get_all_pages() - helix_client.py: get(), post() - intune_client.py: get() - lansweeper_client.py: gql() - fedex_client.py: post() - Add graceful degradation to audit tools: - audit_user_drift(): Wrap Workday, AD, Entra calls separately - audit_device_drift(): Wrap Lansweeper, Intune, Helix calls separately - Both now return systems_available and systems_failed fields - Create check_system_health() tool for proactive monitoring - Add comprehensive unit tests for resilience module Benefits: - HTTP clients now automatically retry transient failures (5xx, timeouts) - Circuit breaker prevents hammering failing services (fast-fail after threshold) - Audit tools continue with partial data if some systems unavailable - Health check tool enables proactive system monitoring before bulk audits --- .github/prompts/feature-add.prompt.md | 1 + nexus-mcp/lib/entra_client.py | 3 + nexus-mcp/lib/fedex_client.py | 2 + nexus-mcp/lib/helix_client.py | 3 + nexus-mcp/lib/intune_client.py | 2 + nexus-mcp/lib/lansweeper_client.py | 2 + nexus-mcp/lib/resilience.py | 248 ++++++++++++++ nexus-mcp/lib/workday_client.py | 3 + nexus-mcp/pyproject.toml | 1 + nexus-mcp/src/shards/audit.py | 464 ++++++++++++++++++-------- nexus-mcp/tests/test_resilience.py | 269 +++++++++++++++ 11 files changed, 859 insertions(+), 139 deletions(-) create mode 100644 nexus-mcp/lib/resilience.py create mode 100644 nexus-mcp/tests/test_resilience.py diff --git a/.github/prompts/feature-add.prompt.md b/.github/prompts/feature-add.prompt.md index f5db307..b34562c 100644 --- a/.github/prompts/feature-add.prompt.md +++ b/.github/prompts/feature-add.prompt.md @@ -1,4 +1,5 @@ --- +agent: Plan name: feature-add description: This prompt helps you add a new feature to your existing MCP server by guiding you through branch creation, code drafting, and deployment steps. model: Claude Opus 4.6 diff --git a/nexus-mcp/lib/entra_client.py b/nexus-mcp/lib/entra_client.py index 9db59f5..1f2d2a3 100644 --- a/nexus-mcp/lib/entra_client.py +++ b/nexus-mcp/lib/entra_client.py @@ -4,6 +4,7 @@ from typing import Any import httpx import msal from config import EntraConfig +from resilience import resilient_http_call, handle_404_gracefully GRAPH_BASE = "https://graph.microsoft.com/v1.0" GRAPH_BETA = "https://graph.microsoft.com/beta" @@ -42,6 +43,7 @@ class EntraClient: raise RuntimeError(f"MSAL token error: {result.get('error_description')}") return result["access_token"] + @resilient_http_call(service_name="Entra") async def get(self, path: str, params: dict | None = None, beta: bool = False) -> Any: token = await self.get_token() base = GRAPH_BETA if beta else GRAPH_BASE @@ -53,6 +55,7 @@ class EntraClient: resp.raise_for_status() return resp.json() + @resilient_http_call(service_name="Entra") async def get_all_pages(self, path: str, params: dict | None = None) -> list[dict]: results: list[dict] = [] data = await self.get(path, params) diff --git a/nexus-mcp/lib/fedex_client.py b/nexus-mcp/lib/fedex_client.py index c0f765d..0e39750 100644 --- a/nexus-mcp/lib/fedex_client.py +++ b/nexus-mcp/lib/fedex_client.py @@ -3,6 +3,7 @@ from typing import Any import httpx from config import FedExConfig +from resilience import resilient_http_call class FedExClient: @@ -29,6 +30,7 @@ class FedExClient: self._token = resp.json()["access_token"] return self._token + @resilient_http_call(service_name="FedEx") async def post(self, path: str, body: dict) -> Any: token = await self.get_token() resp = await self._http.post( diff --git a/nexus-mcp/lib/helix_client.py b/nexus-mcp/lib/helix_client.py index 48d1c2b..13a21b3 100644 --- a/nexus-mcp/lib/helix_client.py +++ b/nexus-mcp/lib/helix_client.py @@ -3,6 +3,7 @@ from typing import Any import httpx from config import HelixConfig +from resilience import resilient_http_call class HelixClient: @@ -28,6 +29,7 @@ class HelixClient: self._token = resp.text.strip() return self._token + @resilient_http_call(service_name="Helix") async def get(self, path: str, params: dict | None = None) -> Any: token = await self.get_token() resp = await self._http.get( @@ -38,6 +40,7 @@ class HelixClient: resp.raise_for_status() return resp.json() + @resilient_http_call(service_name="Helix") async def post(self, path: str, body: dict) -> Any: token = await self.get_token() resp = await self._http.post( diff --git a/nexus-mcp/lib/intune_client.py b/nexus-mcp/lib/intune_client.py index 8568299..322598f 100644 --- a/nexus-mcp/lib/intune_client.py +++ b/nexus-mcp/lib/intune_client.py @@ -4,6 +4,7 @@ from typing import Any import httpx import msal from config import IntuneConfig +from resilience import resilient_http_call GRAPH_BASE = "https://graph.microsoft.com/v1.0" GRAPH_BETA = "https://graph.microsoft.com/beta" @@ -39,6 +40,7 @@ class IntuneClient: raise RuntimeError(f"MSAL token error: {result.get('error_description')}") return result["access_token"] + @resilient_http_call(service_name="Intune") async def get(self, path: str, params: dict | None = None, beta: bool = False) -> Any: token = await self.get_token() base = GRAPH_BETA if beta else GRAPH_BASE diff --git a/nexus-mcp/lib/lansweeper_client.py b/nexus-mcp/lib/lansweeper_client.py index dae0c74..08cc628 100644 --- a/nexus-mcp/lib/lansweeper_client.py +++ b/nexus-mcp/lib/lansweeper_client.py @@ -3,6 +3,7 @@ from typing import Any import httpx from config import LansweeperConfig +from resilience import resilient_http_call class LansweeperClient: @@ -28,6 +29,7 @@ class LansweeperClient: self._token = resp.json()["access_token"] return self._token + @resilient_http_call(service_name="Lansweeper") async def gql(self, query: str, variables: dict | None = None) -> Any: token = await self.get_token() resp = await self._http.post( diff --git a/nexus-mcp/lib/resilience.py b/nexus-mcp/lib/resilience.py new file mode 100644 index 0000000..e16054c --- /dev/null +++ b/nexus-mcp/lib/resilience.py @@ -0,0 +1,248 @@ +"""Resilience utilities for enterprise HTTP clients. + +Provides retry logic with exponential backoff and circuit breaker pattern +to prevent cascade failures when enterprise systems are unavailable. + +Usage: + from resilience import resilient_http_call + + @resilient_http_call(service_name="Workday") + async def get(self, path: str) -> Any: + # Your HTTP call here + pass +""" + +from __future__ import annotations +import asyncio +import logging +from datetime import datetime, timedelta +from enum import Enum +from typing import Any, Callable, TypeVar +from functools import wraps + +import httpx +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + before_sleep_log, + RetryError, +) + +logger = logging.getLogger(__name__) + +# Type variable for async functions +T = TypeVar("T") + + +class CircuitState(str, Enum): + """Circuit breaker states following the classic pattern.""" + CLOSED = "closed" # Normal operation + OPEN = "open" # Failures exceeded threshold, rejecting calls + HALF_OPEN = "half_open" # Testing if service recovered + + +class CircuitBreaker: + """Circuit breaker for a single service to prevent hammering failing systems. + + Pattern: + - CLOSED: Normal operation. Count consecutive failures. + - OPEN: After threshold failures, open circuit and reject calls for timeout period. + - HALF_OPEN: After timeout, allow one test request. If succeeds → CLOSED, if fails → OPEN. + + Args: + service_name: Identifier for the protected service (e.g., "Workday", "Entra"). + failure_threshold: Number of consecutive failures before opening circuit. + timeout_seconds: How long to wait before testing recovery (half-open state). + """ + + def __init__( + self, + service_name: str, + failure_threshold: int = 5, + timeout_seconds: int = 60, + ): + self.service_name = service_name + self.failure_threshold = failure_threshold + self.timeout_seconds = timeout_seconds + + self.state = CircuitState.CLOSED + self.consecutive_failures = 0 + self.last_failure_time: datetime | None = None + self._lock = asyncio.Lock() + + async def call(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Execute function with circuit breaker protection. + + Raises: + CircuitBreakerOpenError: If circuit is open and not ready for retry. + """ + async with self._lock: + # Check if we should transition from OPEN → HALF_OPEN + if self.state == CircuitState.OPEN: + if self.last_failure_time and datetime.utcnow() - self.last_failure_time > timedelta(seconds=self.timeout_seconds): + logger.info(f"[{self.service_name}] Circuit transitioning OPEN → HALF_OPEN (testing recovery)") + self.state = CircuitState.HALF_OPEN + else: + raise CircuitBreakerOpenError( + f"{self.service_name} circuit breaker is OPEN (fast-fail mode). " + f"Retry after {self.timeout_seconds}s timeout." + ) + + # Execute the function + try: + result = await func(*args, **kwargs) + await self._on_success() + return result + except Exception as e: + await self._on_failure() + raise + + async def _on_success(self) -> None: + """Handle successful call — reset failure count and close circuit.""" + async with self._lock: + if self.state == CircuitState.HALF_OPEN: + logger.info(f"[{self.service_name}] Circuit HALF_OPEN → CLOSED (service recovered)") + self.state = CircuitState.CLOSED + self.consecutive_failures = 0 + self.last_failure_time = None + + async def _on_failure(self) -> None: + """Handle failed call — increment failures and potentially open circuit.""" + async with self._lock: + self.consecutive_failures += 1 + self.last_failure_time = datetime.utcnow() + + if self.state == CircuitState.HALF_OPEN: + # Half-open test failed → back to OPEN + logger.warning(f"[{self.service_name}] Circuit HALF_OPEN → OPEN (recovery test failed)") + self.state = CircuitState.OPEN + elif self.consecutive_failures >= self.failure_threshold: + # Threshold exceeded → open circuit + logger.error( + f"[{self.service_name}] Circuit CLOSED → OPEN " + f"({self.consecutive_failures} consecutive failures, threshold={self.failure_threshold})" + ) + self.state = CircuitState.OPEN + + +class CircuitBreakerOpenError(Exception): + """Raised when circuit breaker is open and rejecting calls.""" + pass + + +# Global registry of circuit breakers (one per service) +_circuit_breakers: dict[str, CircuitBreaker] = {} + + +def get_circuit_breaker(service_name: str) -> CircuitBreaker: + """Get or create circuit breaker for a service.""" + if service_name not in _circuit_breakers: + _circuit_breakers[service_name] = CircuitBreaker(service_name) + return _circuit_breakers[service_name] + + +def resilient_http_call( + service_name: str, + max_attempts: int = 3, + enable_circuit_breaker: bool = True, +) -> Callable: + """Decorator for HTTP calls that adds retry logic and circuit breaker. + + Retry strategy: + - 3 attempts by default with exponential backoff (2s → 4s → 8s) + - Retries on: httpx.TimeoutException, httpx.HTTPStatusError (5xx only) + - No retry on: 4xx errors (client errors like 404, 401) + + Circuit breaker: + - Opens after 5 consecutive failures + - Fast-fails subsequent requests for 60 seconds + - Automatically tests recovery after timeout + + Args: + service_name: Name of the service (for logging and circuit breaker tracking). + max_attempts: Maximum number of retry attempts (default: 3). + enable_circuit_breaker: Whether to use circuit breaker (default: True). + + Example: + @resilient_http_call(service_name="Workday") + async def get(self, path: str) -> dict: + resp = await self._http.get(url) + resp.raise_for_status() + return resp.json() + """ + def decorator(func: Callable[..., T]) -> Callable[..., T]: + # Determine if we should retry based on exception type and status code + def should_retry_exception(exception: Exception) -> bool: + """Only retry on transient failures (timeouts, 5xx errors).""" + if isinstance(exception, httpx.TimeoutException): + return True + if isinstance(exception, httpx.HTTPStatusError): + # Retry on 5xx (server errors), not on 4xx (client errors) + return exception.response.status_code >= 500 + if isinstance(exception, (httpx.ConnectError, httpx.RemoteProtocolError)): + return True + return False + + # Apply tenacity retry decorator + retrying_func = retry( + stop=stop_after_attempt(max_attempts), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type(( + httpx.TimeoutException, + httpx.HTTPStatusError, + httpx.ConnectError, + httpx.RemoteProtocolError, + )), + before_sleep=before_sleep_log(logger, logging.INFO), + reraise=True, + )(func) + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + if enable_circuit_breaker: + breaker = get_circuit_breaker(service_name) + try: + return await breaker.call(retrying_func, *args, **kwargs) + except RetryError as e: + # Extract original exception from tenacity wrapper + logger.error(f"[{service_name}] All retry attempts exhausted: {e}") + raise e.last_attempt.exception() if e.last_attempt.exception() else e + else: + # No circuit breaker, just retry logic + try: + return await retrying_func(*args, **kwargs) + except RetryError as e: + logger.error(f"[{service_name}] All retry attempts exhausted: {e}") + raise e.last_attempt.exception() if e.last_attempt.exception() else e + + return wrapper + + return decorator + + +def handle_404_gracefully(func: Callable[..., T]) -> Callable[..., T | None]: + """Decorator to convert 404 errors to None instead of raising. + + Useful for "get user/device by ID" operations where 404 = "not found" is a valid state. + + Example: + @handle_404_gracefully + @resilient_http_call(service_name="Entra") + async def get_user(self, user_id: str) -> dict | None: + resp = await self._http.get(f"/users/{user_id}") + resp.raise_for_status() # Will be caught if 404 + return resp.json() + """ + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T | None: + try: + return await func(*args, **kwargs) + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + logger.debug(f"Resource not found (404): {e.request.url}") + return None + raise + + return wrapper diff --git a/nexus-mcp/lib/workday_client.py b/nexus-mcp/lib/workday_client.py index 66681d2..773c53c 100644 --- a/nexus-mcp/lib/workday_client.py +++ b/nexus-mcp/lib/workday_client.py @@ -3,6 +3,7 @@ from typing import Any import httpx from config import WorkdayConfig +from resilience import resilient_http_call, handle_404_gracefully class WorkdayClient: @@ -33,6 +34,7 @@ class WorkdayClient: self._token = resp.json()["access_token"] return self._token + @resilient_http_call(service_name="Workday") async def get(self, path: str, params: dict | None = None) -> Any: token = await self.get_token() url = f"{self.cfg.base_url}/{self.cfg.tenant}{path}" @@ -44,6 +46,7 @@ class WorkdayClient: resp.raise_for_status() return resp.json() + @resilient_http_call(service_name="Workday") async def raas(self, report_path: str, params: dict | None = None) -> list[dict]: token = await self.get_token() url = ( diff --git a/nexus-mcp/pyproject.toml b/nexus-mcp/pyproject.toml index f33c586..dd53a79 100644 --- a/nexus-mcp/pyproject.toml +++ b/nexus-mcp/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "tabulate>=0.9.0", "python-dateutil>=2.9.0", "aiofiles>=24.1.0", + "tenacity>=8.2.0", ] [project.scripts] diff --git a/nexus-mcp/src/shards/audit.py b/nexus-mcp/src/shards/audit.py index 85c39a3..762449d 100644 --- a/nexus-mcp/src/shards/audit.py +++ b/nexus-mcp/src/shards/audit.py @@ -161,11 +161,88 @@ def _compare_users(wd_user: CanonicalUser | None, ad_user: CanonicalUser | None, return { "email": email, "systems_checked": ["Workday", "ActiveDirectory", "Entra"], + "systems_available": ["Workday", "ActiveDirectory", "Entra"], + "systems_failed": [], "workday_found": wd_user is not None, "ad_found": ad_user is not None, "entra_found": entra_user is not None, "discrepancy_count": len(drifts), - "discrepancies": [d.model_dump(mode='json') for d in drifts] + "discrepancies": [d.model_dump(mode='json') for d in drifts], + } + + # Live mode with graceful degradation — each system call wrapped separately + import logging + logger = logging.getLogger(__name__) + + systems_available: list[str] = [] + systems_failed: list[str] = [] + wd_dict: dict | None = None + ad_dict: dict | None = None + entra_dict: dict | None = None + + # Try Workday + try: + wd_data = await _get_wd().get("/staffing/v6/workers", params={"limit": 500}) + wd_dict = next( + (w for w in wd_data.get("data", []) + if (w.get("primaryWorkEmail") or "").lower() == email.lower()), + None, + ) + systems_available.append("Workday") + logger.info(f"[audit_user_drift] Workday: {'found' if wd_dict else 'not found'}") + except Exception as e: + systems_failed.append("Workday") + logger.warning(f"[audit_user_drift] Workday unavailable: {e}") + + # Try Active Directory + try: + ad_dict = await asyncio.to_thread(_get_ad().get_user_by_email, email) + systems_available.append("ActiveDirectory") + logger.info(f"[audit_user_drift] AD: {'found' if ad_dict else 'not found'}") + except Exception as e: + systems_failed.append("ActiveDirectory") + logger.warning(f"[audit_user_drift] AD unavailable: {e}") + + # Try Entra ID + try: + entra_data = await _get_entra().get( + "/users", + params={ + "$select": "id,displayName,userPrincipalName,mail,jobTitle,department,accountEnabled", + "$top": 999, + }, + ) + entra_dict = next( + (u for u in entra_data.get("value", []) + if _norm(u.get("mail")) == _norm(email) + or _norm(u.get("userPrincipalName")) == _norm(email)), + None, + ) + systems_available.append("Entra") + logger.info(f"[audit_user_drift] Entra: {'found' if entra_dict else 'not found'}") + except Exception as e: + systems_failed.append("Entra") + logger.warning(f"[audit_user_drift] Entra unavailable: {e}") + + # Transform to canonical models + wd_user = WorkdayWorkerAdapter.to_canonical(wd_dict) if wd_dict else None + ad_user = ADUserAdapter.to_canonical(ad_dict) if ad_dict else None + entra_user = EntraUserAdapter.to_canonical(entra_dict) if entra_dict else None + + # Compare using canonical models + drifts = _compare_users(wd_user, ad_user, entra_user) + + return { + "email": email, + "systems_checked": ["Workday", "ActiveDirectory", "Entra"], + "systems_available": systems_available, + "systems_failed": systems_failed, + "workday_found": wd_user is not None, + "ad_found": ad_user is not None, + "entra_found": entra_user is not None, + "discrepancy_count": len(drifts), + "discrepancies": [d.model_dump(mode='json') for d in drifts], + } # ── Shard registration ──────────────────────────────────────────────────────── def register(mcp: FastMCP) -> None: @@ -178,107 +255,34 @@ def register(mcp: FastMCP) -> None: """Audit a single user across Workday, Active Directory, and Entra ID for field drift. Compares displayName, jobTitle, and department across all three systems. + Uses graceful degradation — continues audit with available systems if some fail. Args: email: Primary work email of the user to audit. + + Returns: + dict with keys: + - email: The queried email + - systems_checked: List of all systems that were attempted + - systems_available: List of systems that responded successfully + - systems_failed: List of systems that were unavailable + - workday_found/ad_found/entra_found: Whether user exists in each system + - discrepancy_count: Number of field mismatches found + - discrepancies: List of FieldDrift objects showing differences """ if _USE_MOCK: - wd_worker = M.WORKDAY_WORKERS_BY_EMAIL.get(email.lower()) - ad_user = M.AD_USERS_BY_EMAIL.get(email.lower()) - entra_user = M.ENTRA_USERS_BY_MAIL.get(email.lower()) - diffs: list[dict] = [] - comparisons = [ - ("displayName", "Workday", _pick(wd_worker, "descriptor"), - "ActiveDirectory", _pick(ad_user, "displayName")), - ("displayName", "Workday", _pick(wd_worker, "descriptor"), - "Entra", _pick(entra_user, "displayName")), - ("displayName", "ActiveDirectory", _pick(ad_user, "displayName"), - "Entra", _pick(entra_user, "displayName")), - ("jobTitle", "Workday", _pick(wd_worker, "primaryJob", "jobProfile", "descriptor"), - "ActiveDirectory", _pick(ad_user, "title")), - ("jobTitle", "Workday", _pick(wd_worker, "primaryJob", "jobProfile", "descriptor"), - "Entra", _pick(entra_user, "jobTitle")), - ("department", "Workday", _pick(wd_worker, "primaryJob", "businessUnit", "descriptor"), - "ActiveDirectory", _pick(ad_user, "department")), - ("department", "Workday", _pick(wd_worker, "primaryJob", "businessUnit", "descriptor"), - "Entra", _pick(entra_user, "department")), - ] - for field, sa, va, sb, vb in comparisons: - d = _drift(sa, sb, field, va, vb) - if d: - diffs.append(d) - return { - "email": email, - "systems_checked": ["Workday", "ActiveDirectory", "Entra"], - "workday_found": wd_worker is not None, - "ad_found": ad_user is not None, - "entra_found": entra_user is not None, - "discrepancy_count": len(diffs), - "discrepancies": diffs, - } - wd_task = asyncio.create_task(_get_wd().get( - "/staffing/v6/workers", params={"limit": 500} - )) - entra_task = asyncio.create_task(_get_entra().get( - "/users", - params={ - "$select": "id,displayName,userPrincipalName,mail,jobTitle,department,accountEnabled", - "$top": 999, - }, - )) - wd_data, entra_data = await asyncio.gather(wd_task, entra_task) - ad_user = await asyncio.to_thread(_get_ad().get_user_by_email, email) - - wd_worker = next( - (w for w in wd_data.get("data", []) - if (w.get("primaryWorkEmail") or "").lower() == email.lower()), - None, - ) - entra_user = next( - (u for u in entra_data.get("value", []) - if _norm(u.get("mail")) == _norm(email) - or _norm(u.get("userPrincipalName")) == _norm(email)), - None, - ) - - diffs: list[dict] = [] - comparisons = [ - ("displayName", - "Workday", _pick(wd_worker, "descriptor"), - "ActiveDirectory", _pick(ad_user, "displayName")), - ("displayName", - "Workday", _pick(wd_worker, "descriptor"), - "Entra", _pick(entra_user, "displayName")), - ("displayName", - "ActiveDirectory", _pick(ad_user, "displayName"), - "Entra", _pick(entra_user, "displayName")), - ("jobTitle", - "Workday", _pick(wd_worker, "primaryJob", "jobProfile", "descriptor"), - "ActiveDirectory", _pick(ad_user, "title")), - ("jobTitle", - "Workday", _pick(wd_worker, "primaryJob", "jobProfile", "descriptor"), - "Entra", _pick(entra_user, "jobTitle")), - ("department", - "Workday", _pick(wd_worker, "primaryJob", "businessUnit", "descriptor"), - "ActiveDirectory", _pick(ad_user, "department")), - ("department", - "Workday", _pick(wd_worker, "primaryJob", "businessUnit", "descriptor"), - "Entra", _pick(entra_user, "department")), - ] - for field, sa, va, sb, vb in comparisons: - d = _drift(sa, sb, field, va, vb) - if d: - diffs.append(d) - - return { - "email": email, - "systems_checked": ["Workday", "ActiveDirectory", "Entra"], - "workday_found": wd_worker is not None, - "ad_found": ad_user is not None, - "entra_found": entra_user is not None, - "discrepancy_count": len(diffs), - "discrepancies": diffs, - } + # Get raw dicts from mock data + wd_dict = M.WORKDAY_WORKERS_BY_EMAIL.get(email.lower()) + ad_dict = M.AD_USERS_BY_EMAIL.get(email.lower()) + entra_dict = M.ENTRA_USERS_BY_MAIL.get(email.lower()) + + # Transform to canonical models + wd_user = WorkdayWorkerAdapter.to_canonical(wd_dict) if wd_dict else None + ad_user = ADUserAdapter.to_canonical(ad_dict) if ad_dict else None + entra_user = EntraUserAdapter.to_canonical(entra_dict) if entra_dict else None + + # Compare using canonical models + drifts = _compare_users(wd_user, ad_user, entra_user) @mcp.tool() async def audit_bulk_user_drift(emails: list[str]) -> list[dict]: @@ -295,15 +299,52 @@ def register(mcp: FastMCP) -> None: return await asyncio.gather(*[_one(e) for e in emails[:50]]) + @mcp.tool() @mcp.tool() async def audit_device_drift(device_name: str) -> dict: """Audit a device across Lansweeper, Intune, and BMC Helix CMDB for field drift. Compares manufacturer and serial number across all three asset systems. + Uses graceful degradation — continues audit with available systems if some fail. Args: device_name: The computer/device name to look up. + + Returns: + dict with keys: + - device_name: The queried device name + - systems_checked: List of all systems that were attempted + - systems_available: List of systems that responded successfully + - systems_failed: List of systems that were unavailable + - lansweeper_found/intune_found/helix_found: Whether device exists in each system + - discrepancy_count: Number of field mismatches found + - discrepancies: List of drift objects showing differences """ + def _safe_get(obj: dict | None, *keys: str) -> Any: + """Safely navigate nested dict keys.""" + if obj is None: + return None + current = obj + for key in keys: + if isinstance(current, dict): + current = current.get(key) + else: + return None + return current + + def _compare_device_field(field: str, sys_a: str, val_a: Any, sys_b: str, val_b: Any) -> dict | None: + """Compare device field between two systems.""" + if _norm(val_a) != _norm(val_b): + return { + "field": field, + "system_a": sys_a, + "value_a": str(val_a) if val_a else None, + "system_b": sys_b, + "value_b": str(val_b) if val_b else None, + "severity": "medium", + } + return None + if _USE_MOCK: dn = device_name.lower() ls_asset = M.LANSWEEPER_ASSETS_BY_NAME.get(dn) @@ -312,88 +353,233 @@ def register(mcp: FastMCP) -> None: diffs: list[dict] = [] for field, ik, hk in [("manufacturer", "manufacturer", "Manufacturer"), ("serialNumber", "serialNumber", "Serial Number")]: - ls_val = _pick(ls_asset, field) - intune_val = _pick(intune_device, ik) - helix_val = _pick(helix_ci, "values", hk) + ls_val = _safe_get(ls_asset, field) + intune_val = _safe_get(intune_device, ik) + helix_val = _safe_get(helix_ci, "values", hk) for sa, sb, va, vb in [ ("Lansweeper", "Intune", ls_val, intune_val), ("Lansweeper", "Helix", ls_val, helix_val), ]: - d = _drift(sa, sb, field, va, vb) + d = _compare_device_field(field, sa, sb, va, vb) if d: diffs.append(d) return { "device_name": device_name, "systems_checked": ["Lansweeper", "Intune", "HelixCMDB"], + "systems_available": ["Lansweeper", "Intune", "HelixCMDB"], + "systems_failed": [], "lansweeper_found": ls_asset is not None, "intune_found": intune_device is not None, "helix_found": helix_ci is not None, "discrepancy_count": len(diffs), "discrepancies": diffs, } - from config import LansweeperConfig - site_id = LansweeperConfig().site_id - ls_query = """ - query S($siteId: String!, $q: String!) { - site(id: $siteId) { - assetResources( - pagination: { limit: 5, page: 1 } - assetBasicFilters: { assetName: $q } - ) { - items { assetId assetName operatingSystem manufacturer serialNumber } + + # Live mode with graceful degradation + import logging + logger = logging.getLogger(__name__) + + systems_available: list[str] = [] + systems_failed: list[str] = [] + ls_asset: dict | None = None + intune_device: dict | None = None + helix_ci: dict | None = None + + # Try Lansweeper + try: + from config import LansweeperConfig + site_id = LansweeperConfig().site_id + ls_query = """ + query S($siteId: String!, $q: String!) { + site(id: $siteId) { + assetResources( + pagination: { limit: 5, page: 1 } + assetBasicFilters: { assetName: $q } + ) { + items { assetId assetName operatingSystem manufacturer serialNumber } + } + } } - } - } - """ - ls_task = asyncio.create_task( - _get_ls().gql(ls_query, {"siteId": site_id, "q": device_name}) - ) - intune_task = asyncio.create_task( - _get_intune().get("/deviceManagement/managedDevices", params={"$top": 500}) - ) - helix_task = asyncio.create_task( - _get_helix().get( + """ + ls_data = await _get_ls().gql(ls_query, {"siteId": site_id, "q": device_name}) + ls_results = ls_data["site"]["assetResources"]["items"] + ls_asset = ls_results[0] if ls_results else None + systems_available.append("Lansweeper") + logger.info(f"[audit_device_drift] Lansweeper: {'found' if ls_asset else 'not found'}") + except Exception as e: + systems_failed.append("Lansweeper") + logger.warning(f"[audit_device_drift] Lansweeper unavailable: {e}") + + # Try Intune + try: + intune_data = await _get_intune().get("/deviceManagement/managedDevices", params={"$top": 500}) + intune_device = next( + (d for d in intune_data.get("value", []) + if _norm(d.get("deviceName")) == _norm(device_name)), + None, + ) + systems_available.append("Intune") + logger.info(f"[audit_device_drift] Intune: {'found' if intune_device else 'not found'}") + except Exception as e: + systems_failed.append("Intune") + logger.warning(f"[audit_device_drift] Intune unavailable: {e}") + + # Try Helix CMDB + try: + helix_data = await _get_helix().get( "/api/arsys/v1/entry/BMC.CORE:BMC_ComputerSystem", params={"q": f"'Name' LIKE \"%{device_name}%\"", "limit": 5}, ) - ) - ls_data, intune_data, helix_data = await asyncio.gather(ls_task, intune_task, helix_task) - - ls_results = ls_data["site"]["assetResources"]["items"] - ls_asset = ls_results[0] if ls_results else None - intune_device = next( - (d for d in intune_data.get("value", []) - if _norm(d.get("deviceName")) == _norm(device_name)), - None, - ) - helix_entries = helix_data.get("entries", []) - helix_ci = helix_entries[0] if helix_entries else None - + helix_entries = helix_data.get("entries", []) + helix_ci = helix_entries[0] if helix_entries else None + systems_available.append("HelixCMDB") + logger.info(f"[audit_device_drift] Helix: {'found' if helix_ci else 'not found'}") + except Exception as e: + systems_failed.append("HelixCMDB") + logger.warning(f"[audit_device_drift] Helix unavailable: {e}") + + # Compare fields across available systems diffs: list[dict] = [] for field, lk, ik, hk in [ ("manufacturer", "manufacturer", "manufacturer", "Manufacturer"), ("serialNumber", "serialNumber", "serialNumber", "Serial Number"), ]: - ls_val = _pick(ls_asset, field) - intune_val = _pick(intune_device, ik) - helix_val = _pick(helix_ci, "values", hk) + ls_val = _safe_get(ls_asset, field) + intune_val = _safe_get(intune_device, ik) + helix_val = _safe_get(helix_ci, "values", hk) + for sa, sb, va, vb in [ ("Lansweeper", "Intune", ls_val, intune_val), ("Lansweeper", "Helix", ls_val, helix_val), ]: - d = _drift(sa, sb, field, va, vb) + d = _compare_device_field(field, sa, sb, va, vb) if d: diffs.append(d) return { "device_name": device_name, "systems_checked": ["Lansweeper", "Intune", "HelixCMDB"], + "systems_available": systems_available, + "systems_failed": systems_failed, "lansweeper_found": ls_asset is not None, "intune_found": intune_device is not None, "helix_found": helix_ci is not None, "discrepancy_count": len(diffs), "discrepancies": diffs, } + + # ── Health check tools ──────────────────────────────────────────────────── + + @mcp.tool() + async def check_system_health() -> dict: + """Check availability and response time of all enterprise systems. + + Useful for proactive monitoring before running bulk audits. + Uses resilient HTTP calls with retry logic. + + Returns: + dict with system names as keys, each containing: + - available: bool indicating if system is reachable + - response_time_ms: int response time in milliseconds (if available) + - error: str error message (if unavailable) + """ + import time + import logging + logger = logging.getLogger(__name__) + + results = {} + + # Check Workday + start = time.time() + try: + await _get_wd().get("/staffing/v6/workers", params={"limit": 1}) + elapsed = int((time.time() - start) * 1000) + results["Workday"] = {"available": True, "response_time_ms": elapsed} + logger.info(f"[Health Check] Workday: OK ({elapsed}ms)") + except Exception as e: + elapsed = int((time.time() - start) * 1000) + results["Workday"] = {"available": False, "response_time_ms": elapsed, "error": str(e)} + logger.warning(f"[Health Check] Workday: FAILED - {e}") + + # Check Active Directory + start = time.time() + try: + # AD adapter uses blocking PowerShell, run in thread + await asyncio.to_thread(_get_ad().get_user, "testuser") + elapsed = int((time.time() - start) * 1000) + results["ActiveDirectory"] = {"available": True, "response_time_ms": elapsed} + logger.info(f"[Health Check] AD: OK ({elapsed}ms)") + except Exception as e: + elapsed = int((time.time() - start) * 1000) + results["ActiveDirectory"] = {"available": False, "response_time_ms": elapsed, "error": str(e)} + logger.warning(f"[Health Check] AD: FAILED - {e}") + + # Check Entra ID + start = time.time() + try: + await _get_entra().get("/users", params={"$top": 1}) + elapsed = int((time.time() - start) * 1000) + results["Entra"] = {"available": True, "response_time_ms": elapsed} + logger.info(f"[Health Check] Entra: OK ({elapsed}ms)") + except Exception as e: + elapsed = int((time.time() - start) * 1000) + results["Entra"] = {"available": False, "response_time_ms": elapsed, "error": str(e)} + logger.warning(f"[Health Check] Entra: FAILED - {e}") + + # Check Lansweeper + start = time.time() + try: + from config import LansweeperConfig + site_id = LansweeperConfig().site_id + query = "query { sites { total } }" + await _get_ls().gql(query, {}) + elapsed = int((time.time() - start) * 1000) + results["Lansweeper"] = {"available": True, "response_time_ms": elapsed} + logger.info(f"[Health Check] Lansweeper: OK ({elapsed}ms)") + except Exception as e: + elapsed = int((time.time() - start) * 1000) + results["Lansweeper"] = {"available": False, "response_time_ms": elapsed, "error": str(e)} + logger.warning(f"[Health Check] Lansweeper: FAILED - {e}") + + # Check Intune + start = time.time() + try: + await _get_intune().get("/deviceManagement/managedDevices", params={"$top": 1}) + elapsed = int((time.time() - start) * 1000) + results["Intune"] = {"available": True, "response_time_ms": elapsed} + logger.info(f"[Health Check] Intune: OK ({elapsed}ms)") + except Exception as e: + elapsed = int((time.time() - start) * 1000) + results["Intune"] = {"available": False, "response_time_ms": elapsed, "error": str(e)} + logger.warning(f"[Health Check] Intune: FAILED - {e}") + + # Check Helix + start = time.time() + try: + await _get_helix().get("/api/arsys/v1/entry/BMC.CORE:BMC_ComputerSystem", params={"limit": 1}) + elapsed = int((time.time() - start) * 1000) + results["Helix"] = {"available": True, "response_time_ms": elapsed} + logger.info(f"[Health Check] Helix: OK ({elapsed}ms)") + except Exception as e: + elapsed = int((time.time() - start) * 1000) + results["Helix"] = {"available": False, "response_time_ms": elapsed, "error": str(e)} + logger.warning(f"[Health Check] Helix: FAILED - {e}") + + # Calculate summary statistics + total_systems = len(results) + available_systems = sum(1 for r in results.values() if r["available"]) + availability_percentage = int((available_systems / total_systems) * 100) + + return { + "timestamp": datetime.datetime.utcnow().isoformat(), + "systems": results, + "summary": { + "total_systems": total_systems, + "available_systems": available_systems, + "unavailable_systems": total_systems - available_systems, + "availability_percentage": availability_percentage, + } + } @mcp.tool() async def audit_entra_ad_sync_drift(limit: int = 200) -> dict: diff --git a/nexus-mcp/tests/test_resilience.py b/nexus-mcp/tests/test_resilience.py new file mode 100644 index 0000000..024bf92 --- /dev/null +++ b/nexus-mcp/tests/test_resilience.py @@ -0,0 +1,269 @@ +"""Unit tests for resilience module (retry logic and circuit breaker).""" + +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +import httpx + +from resilience import ( + resilient_http_call, + handle_404_gracefully, + CircuitBreaker, + CircuitState, + CircuitBreakerOpenError, + get_circuit_breaker, +) + + +class TestCircuitBreaker: + """Test circuit breaker state machine.""" + + @pytest.mark.asyncio + async def test_circuit_closed_to_open_after_threshold_failures(self): + """Circuit should open after consecutive failures exceed threshold.""" + breaker = CircuitBreaker("TestService", failure_threshold=3, timeout_seconds=60) + + async def failing_func(): + raise httpx.HTTPStatusError("Server error", request=MagicMock(), response=MagicMock(status_code=503)) + + # Execute 3 failures + for _ in range(3): + with pytest.raises(httpx.HTTPStatusError): + await breaker.call(failing_func) + + # Circuit should now be OPEN + assert breaker.state == CircuitState.OPEN + + # Next call should fail fast with CircuitBreakerOpenError + with pytest.raises(CircuitBreakerOpenError): + await breaker.call(failing_func) + + @pytest.mark.asyncio + async def test_circuit_half_open_to_closed_on_success(self): + """Circuit should close after successful test in half-open state.""" + breaker = CircuitBreaker("TestService", failure_threshold=2, timeout_seconds=1) + + async def failing_func(): + raise httpx.TimeoutException("Timeout") + + async def success_func(): + return "OK" + + # Trigger failures to open circuit + for _ in range(2): + with pytest.raises(httpx.TimeoutException): + await breaker.call(failing_func) + + assert breaker.state == CircuitState.OPEN + + # Wait for timeout to transition to half-open + await asyncio.sleep(1.1) + + # Successful call should close circuit + result = await breaker.call(success_func) + assert result == "OK" + assert breaker.state == CircuitState.CLOSED + assert breaker.consecutive_failures == 0 + + @pytest.mark.asyncio + async def test_circuit_half_open_to_open_on_failure(self): + """Circuit should reopen if test fails in half-open state.""" + breaker = CircuitBreaker("TestService", failure_threshold=2, timeout_seconds=1) + + async def failing_func(): + raise httpx.ConnectError("Connection refused") + + # Trigger failures to open circuit + for _ in range(2): + with pytest.raises(httpx.ConnectError): + await breaker.call(failing_func) + + assert breaker.state == CircuitState.OPEN + + # Wait for timeout to transition to half-open + await asyncio.sleep(1.1) + + # Failed test should reopen circuit + with pytest.raises(httpx.ConnectError): + await breaker.call(failing_func) + + assert breaker.state == CircuitState.OPEN + + @pytest.mark.asyncio + async def test_circuit_resets_on_success(self): + """Successful calls should reset failure counter.""" + breaker = CircuitBreaker("TestService", failure_threshold=3) + + async def failing_func(): + raise httpx.TimeoutException("Timeout") + + async def success_func(): + return "OK" + + # Execute 2 failures (below threshold) + for _ in range(2): + with pytest.raises(httpx.TimeoutException): + await breaker.call(failing_func) + + assert breaker.consecutive_failures == 2 + assert breaker.state == CircuitState.CLOSED + + # Successful call resets counter + await breaker.call(success_func) + assert breaker.consecutive_failures == 0 + assert breaker.state == CircuitState.CLOSED + + +class TestResilientHttpCall: + """Test resilient_http_call decorator with retry logic.""" + + @pytest.mark.asyncio + async def test_retries_on_timeout_exception(self): + """Decorator should retry on timeout exceptions.""" + call_count = 0 + + @resilient_http_call(service_name="TestService", max_attempts=3, enable_circuit_breaker=False) + async def flaky_function(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise httpx.TimeoutException("Timeout") + return "Success" + + result = await flaky_function() + assert result == "Success" + assert call_count == 3 # 1 initial + 2 retries + + @pytest.mark.asyncio + async def test_retries_on_5xx_errors(self): + """Decorator should retry on 5xx HTTP errors.""" + call_count = 0 + + @resilient_http_call(service_name="TestService", max_attempts=3, enable_circuit_breaker=False) + async def server_error_function(): + nonlocal call_count + call_count += 1 + if call_count < 2: + response = MagicMock() + response.status_code = 503 + raise httpx.HTTPStatusError("Service Unavailable", request=MagicMock(), response=response) + return "Success" + + result = await server_error_function() + assert result == "Success" + assert call_count == 2 # 1 initial + 1 retry + + @pytest.mark.asyncio + async def test_no_retry_on_4xx_errors(self): + """Decorator should NOT retry on 4xx client errors.""" + call_count = 0 + + @resilient_http_call(service_name="TestService", max_attempts=3, enable_circuit_breaker=False) + async def client_error_function(): + nonlocal call_count + call_count += 1 + response = MagicMock() + response.status_code = 404 + raise httpx.HTTPStatusError("Not Found", request=MagicMock(), response=response) + + with pytest.raises(httpx.HTTPStatusError): + await client_error_function() + + assert call_count == 1 # No retries on 4xx + + @pytest.mark.asyncio + async def test_exhausts_retries_and_raises(self): + """Decorator should raise original exception after exhausting retries.""" + call_count = 0 + + @resilient_http_call(service_name="TestService", max_attempts=3, enable_circuit_breaker=False) + async def always_fails(): + nonlocal call_count + call_count += 1 + raise httpx.TimeoutException("Persistent timeout") + + with pytest.raises(httpx.TimeoutException) as exc_info: + await always_fails() + + assert "Persistent timeout" in str(exc_info.value) + assert call_count == 3 # 1 initial + 2 retries + + +class TestHandle404Gracefully: + """Test handle_404_gracefully decorator.""" + + @pytest.mark.asyncio + async def test_converts_404_to_none(self): + """Decorator should convert 404 errors to None.""" + @handle_404_gracefully + async def get_user(): + response = MagicMock() + response.status_code = 404 + response.request.url = "https://api.example.com/users/123" + raise httpx.HTTPStatusError("Not Found", request=MagicMock(), response=response) + + result = await get_user() + assert result is None + + @pytest.mark.asyncio + async def test_does_not_convert_other_errors(self): + """Decorator should NOT convert non-404 errors.""" + @handle_404_gracefully + async def get_user(): + response = MagicMock() + response.status_code = 500 + raise httpx.HTTPStatusError("Server Error", request=MagicMock(), response=response) + + with pytest.raises(httpx.HTTPStatusError): + await get_user() + + @pytest.mark.asyncio + async def test_returns_normal_result_on_success(self): + """Decorator should pass through successful results.""" + @handle_404_gracefully + async def get_user(): + return {"id": 123, "name": "John Doe"} + + result = await get_user() + assert result == {"id": 123, "name": "John Doe"} + + +class TestCircuitBreakerIntegration: + """Test integration of circuit breaker with resilient_http_call decorator.""" + + @pytest.mark.asyncio + async def test_circuit_breaker_opens_after_failures(self): + """Circuit breaker should open and fast-fail after threshold.""" + # Get a fresh circuit breaker for this test + service_name = "IntegrationTestService" + breaker = get_circuit_breaker(service_name) + breaker.failure_threshold = 3 + breaker.state = CircuitState.CLOSED + breaker.consecutive_failures = 0 + + call_count = 0 + + @resilient_http_call(service_name=service_name, max_attempts=1) + async def always_fails(): + nonlocal call_count + call_count += 1 + raise httpx.TimeoutException("Persistent failure") + + # Execute 3 failures to open circuit + for _ in range(3): + with pytest.raises(httpx.TimeoutException): + await always_fails() + + assert breaker.state == CircuitState.OPEN + + # Next call should fail fast with CircuitBreakerOpenError + with pytest.raises(CircuitBreakerOpenError): + await always_fails() + + # Call count should be 3 (circuit now open, no more attempts) + assert call_count == 3 + + +# Run tests with pytest +if __name__ == "__main__": + pytest.main([__file__, "-v"])