homelab/ansible/ansible-old/scripts/ansible_mcp_server.py

750 lines
24 KiB
Python

#!/usr/bin/env python3
"""Ansible MCP server with path guardrails and auditable run records.
This server is intentionally conservative:
- Playbook execution is restricted to allowlisted directories.
- Write operations require explicit confirmation.
- Background jobs are tracked in a local state directory.
"""
from __future__ import annotations
import argparse
import hmac
import json
import os
import shlex
import signal
import subprocess
import tempfile
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from mcp.server.fastmcp import FastMCP
def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat()
@dataclass(frozen=True)
class ServerConfig:
repo_root: Path
inventory_file: Path
allowed_dirs: tuple[Path, ...]
allowed_playbooks: tuple[str, ...]
api_token: str | None
allow_write: bool
require_confirm_for_write: bool
default_timeout_seconds: int
max_timeout_seconds: int
max_extra_vars_bytes: int
blocked_extra_vars_keys: tuple[str, ...]
state_dir: Path
audit_log_file: Path
class JobStore:
def __init__(self, state_dir: Path) -> None:
self.state_dir = state_dir
self.jobs_dir = self.state_dir / "jobs"
self.logs_dir = self.state_dir / "logs"
self.wrap_dir = self.state_dir / "wrappers"
self.state_dir.mkdir(parents=True, exist_ok=True)
self.jobs_dir.mkdir(parents=True, exist_ok=True)
self.logs_dir.mkdir(parents=True, exist_ok=True)
self.wrap_dir.mkdir(parents=True, exist_ok=True)
def _job_path(self, run_id: str) -> Path:
return self.jobs_dir / f"{run_id}.json"
def save_job(self, run_id: str, payload: dict[str, Any]) -> None:
self._job_path(run_id).write_text(json.dumps(payload, indent=2), encoding="utf-8")
def load_job(self, run_id: str) -> dict[str, Any] | None:
path = self._job_path(run_id)
if not path.exists():
return None
return json.loads(path.read_text(encoding="utf-8"))
def _load_config() -> ServerConfig:
repo_root = Path(os.getenv("ANSIBLE_MCP_REPO_ROOT", "/home/chester/homelab/ansible")).resolve()
inventory_env = os.getenv("ANSIBLE_MCP_INVENTORY", "inventory/hosts.ini")
inventory_file = (repo_root / inventory_env).resolve()
allowed_raw = os.getenv("ANSIBLE_MCP_ALLOWED_PLAYBOOK_DIRS", "playbooks")
allowed_dirs: list[Path] = []
for item in [p.strip() for p in allowed_raw.split(",") if p.strip()]:
allowed_dirs.append((repo_root / item).resolve())
allowlisted_playbooks_raw = os.getenv("ANSIBLE_MCP_ALLOWED_PLAYBOOKS", "")
allowed_playbooks = tuple(
p.strip() for p in allowlisted_playbooks_raw.split(",") if p.strip()
)
api_token_raw = os.getenv("ANSIBLE_MCP_API_TOKEN", "").strip()
api_token = api_token_raw if api_token_raw else None
allow_write = os.getenv("ANSIBLE_MCP_ALLOW_WRITE", "false").lower() == "true"
require_confirm_for_write = os.getenv("ANSIBLE_MCP_REQUIRE_CONFIRM", "true").lower() == "true"
default_timeout_seconds = int(os.getenv("ANSIBLE_MCP_DEFAULT_TIMEOUT", "900"))
max_timeout_seconds = int(os.getenv("ANSIBLE_MCP_MAX_TIMEOUT", "3600"))
max_extra_vars_bytes = int(os.getenv("ANSIBLE_MCP_MAX_EXTRA_VARS_BYTES", "16384"))
blocked_extra_vars_raw = os.getenv("ANSIBLE_MCP_BLOCKED_EXTRA_VARS_KEYS", "")
blocked_extra_vars_keys = tuple(
p.strip().lower() for p in blocked_extra_vars_raw.split(",") if p.strip()
)
state_dir = Path(os.getenv("ANSIBLE_MCP_STATE_DIR", "/var/lib/ansible-mcp")).resolve()
audit_log_file = (
Path(os.getenv("ANSIBLE_MCP_AUDIT_LOG_FILE", "") or state_dir / "audit" / "events.jsonl")
.resolve()
)
return ServerConfig(
repo_root=repo_root,
inventory_file=inventory_file,
allowed_dirs=tuple(allowed_dirs),
allowed_playbooks=allowed_playbooks,
api_token=api_token,
allow_write=allow_write,
require_confirm_for_write=require_confirm_for_write,
default_timeout_seconds=default_timeout_seconds,
max_timeout_seconds=max_timeout_seconds,
max_extra_vars_bytes=max_extra_vars_bytes,
blocked_extra_vars_keys=blocked_extra_vars_keys,
state_dir=state_dir,
audit_log_file=audit_log_file,
)
def _is_relative_to(candidate: Path, parent: Path) -> bool:
try:
candidate.relative_to(parent)
return True
except ValueError:
return False
def _resolve_allowed_playbook(config: ServerConfig, playbook: str) -> Path:
candidate = (config.repo_root / playbook).resolve()
if not candidate.exists():
raise ValueError(f"Playbook does not exist: {playbook}")
if not candidate.is_file():
raise ValueError(f"Playbook path is not a file: {playbook}")
relative_playbook = str(candidate.relative_to(config.repo_root).as_posix())
if config.allowed_playbooks:
if relative_playbook not in config.allowed_playbooks:
allow_text = ", ".join(config.allowed_playbooks)
raise ValueError(
f"Playbook is not in explicit allowlist: {relative_playbook}. "
f"Allowed playbooks: {allow_text}"
)
return candidate
if not any(_is_relative_to(candidate, allowed) for allowed in config.allowed_dirs):
allowed_text = ", ".join(str(p) for p in config.allowed_dirs)
raise ValueError(
f"Playbook path is outside allowed directories: {playbook}. "
f"Allowed roots: {allowed_text}"
)
return candidate
def _sanitize_timeout(config: ServerConfig, timeout_seconds: int | None) -> int:
value = timeout_seconds if timeout_seconds is not None else config.default_timeout_seconds
if value <= 0:
raise ValueError("timeout_seconds must be greater than 0")
if value > config.max_timeout_seconds:
raise ValueError(
f"timeout_seconds exceeds maximum allowed ({config.max_timeout_seconds})"
)
return value
def _redact_payload(value: Any) -> Any:
if isinstance(value, dict):
redacted: dict[str, Any] = {}
for key, item in value.items():
key_l = str(key).lower()
if any(marker in key_l for marker in ("token", "secret", "password", "key")):
redacted[key] = "[REDACTED]"
else:
redacted[key] = _redact_payload(item)
return redacted
if isinstance(value, list):
return [_redact_payload(v) for v in value]
return value
def _audit_event(event: str, payload: dict[str, Any]) -> None:
record = {
"timestamp": _utc_now(),
"event": event,
"payload": _redact_payload(payload),
}
CONFIG.audit_log_file.parent.mkdir(parents=True, exist_ok=True)
with CONFIG.audit_log_file.open("a", encoding="utf-8") as fh:
fh.write(json.dumps(record) + "\n")
def _require_auth(auth_token: str | None) -> None:
if not CONFIG.api_token:
return
provided = (auth_token or "").strip()
if not provided:
raise ValueError("Authentication required: provide auth_token")
if not hmac.compare_digest(provided, CONFIG.api_token):
raise ValueError("Authentication failed: invalid auth_token")
def _collect_keys(node: Any, sink: set[str]) -> None:
if isinstance(node, dict):
for key, value in node.items():
sink.add(str(key).lower())
_collect_keys(value, sink)
elif isinstance(node, list):
for value in node:
_collect_keys(value, sink)
def _validate_extra_vars(extra_vars: dict[str, Any] | None) -> None:
if not extra_vars:
return
encoded = json.dumps(extra_vars)
if len(encoded.encode("utf-8")) > CONFIG.max_extra_vars_bytes:
raise ValueError(
f"extra_vars payload exceeds max size ({CONFIG.max_extra_vars_bytes} bytes)"
)
if CONFIG.blocked_extra_vars_keys:
keys: set[str] = set()
_collect_keys(extra_vars, keys)
blocked = sorted(k for k in keys if k in CONFIG.blocked_extra_vars_keys)
if blocked:
blocked_text = ", ".join(blocked)
raise ValueError(f"extra_vars contains blocked keys: {blocked_text}")
def _build_command(
config: ServerConfig,
playbook_path: Path,
limit: str | None,
tags: str | None,
skip_tags: str | None,
check_mode: bool,
extra_vars_file: Path | None,
) -> list[str]:
cmd = [
"ansible-playbook",
"-i",
str(config.inventory_file),
str(playbook_path),
]
if limit:
cmd.extend(["--limit", limit])
if tags:
cmd.extend(["--tags", tags])
if skip_tags:
cmd.extend(["--skip-tags", skip_tags])
if check_mode:
cmd.append("--check")
if extra_vars_file is not None:
cmd.extend(["--extra-vars", f"@{extra_vars_file}"])
return cmd
CONFIG = _load_config()
STORE = JobStore(CONFIG.state_dir)
mcp = FastMCP(
"ansible-mcp",
host=os.getenv("ANSIBLE_MCP_HOST", "127.0.0.1"),
port=int(os.getenv("ANSIBLE_MCP_PORT", "8449")),
streamable_http_path="/mcp",
)
@mcp.tool()
def health() -> dict[str, Any]:
"""Return server health and effective runtime configuration."""
return {
"ok": True,
"server": "ansible-mcp",
"timestamp": _utc_now(),
"repo_root": str(CONFIG.repo_root),
"inventory_file": str(CONFIG.inventory_file),
"allowed_dirs": [str(p) for p in CONFIG.allowed_dirs],
"allowed_playbooks": list(CONFIG.allowed_playbooks),
"allow_write": CONFIG.allow_write,
"require_confirm_for_write": CONFIG.require_confirm_for_write,
"auth_enabled": CONFIG.api_token is not None,
"max_extra_vars_bytes": CONFIG.max_extra_vars_bytes,
"blocked_extra_vars_keys": list(CONFIG.blocked_extra_vars_keys),
"state_dir": str(CONFIG.state_dir),
}
@mcp.tool()
def list_inventory(limit: str | None = None, auth_token: str | None = None) -> dict[str, Any]:
"""Return inventory graph information from ansible-inventory --list."""
_require_auth(auth_token)
cmd = ["ansible-inventory", "-i", str(CONFIG.inventory_file), "--list"]
if limit:
cmd.extend(["--limit", limit])
result = subprocess.run(
cmd,
cwd=CONFIG.repo_root,
capture_output=True,
text=True,
timeout=60,
check=False,
)
payload: dict[str, Any] = {
"ok": result.returncode == 0,
"returncode": result.returncode,
"stderr": result.stderr,
"command": " ".join(shlex.quote(c) for c in cmd),
}
if result.returncode == 0:
try:
payload["inventory"] = json.loads(result.stdout)
except json.JSONDecodeError:
payload["ok"] = False
payload["error"] = "ansible-inventory returned non-JSON output"
payload["raw_stdout"] = result.stdout
else:
payload["stdout"] = result.stdout
_audit_event(
"list_inventory",
{"limit": limit, "returncode": result.returncode, "ok": payload["ok"]},
)
return payload
@mcp.tool()
def validate_syntax(
playbook: str,
limit: str | None = None,
extra_vars: dict[str, Any] | None = None,
auth_token: str | None = None,
) -> dict[str, Any]:
"""Run ansible-playbook --syntax-check on an allowlisted playbook."""
_require_auth(auth_token)
playbook_path = _resolve_allowed_playbook(CONFIG, playbook)
_validate_extra_vars(extra_vars)
extra_vars_file: Path | None = None
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tf:
if extra_vars:
json.dump(extra_vars, tf)
tf.flush()
extra_vars_file = Path(tf.name)
try:
cmd = _build_command(
config=CONFIG,
playbook_path=playbook_path,
limit=limit,
tags=None,
skip_tags=None,
check_mode=False,
extra_vars_file=extra_vars_file,
)
cmd.append("--syntax-check")
result = subprocess.run(
cmd,
cwd=CONFIG.repo_root,
capture_output=True,
text=True,
timeout=120,
check=False,
)
payload = {
"ok": result.returncode == 0,
"returncode": result.returncode,
"stdout": result.stdout,
"stderr": result.stderr,
"command": " ".join(shlex.quote(c) for c in cmd),
"playbook": str(playbook_path.relative_to(CONFIG.repo_root)),
}
_audit_event(
"validate_syntax",
{
"playbook": payload["playbook"],
"limit": limit,
"returncode": payload["returncode"],
"ok": payload["ok"],
},
)
return payload
finally:
if extra_vars_file and extra_vars_file.exists():
extra_vars_file.unlink(missing_ok=True)
@mcp.tool()
def run_playbook(
playbook: str,
limit: str | None = None,
extra_vars: dict[str, Any] | None = None,
tags: str | None = None,
skip_tags: str | None = None,
check_mode: bool = True,
confirm: bool = False,
timeout_seconds: int | None = None,
background: bool = False,
auth_token: str | None = None,
) -> dict[str, Any]:
"""Run an allowlisted playbook with guardrails and run tracking.
Safety model:
- check_mode defaults to true.
- write operations require allow_write=true and confirm=true.
"""
_require_auth(auth_token)
playbook_path = _resolve_allowed_playbook(CONFIG, playbook)
_validate_extra_vars(extra_vars)
safe_timeout = _sanitize_timeout(CONFIG, timeout_seconds)
is_write = not check_mode
if is_write and not CONFIG.allow_write:
payload = {
"ok": False,
"error": "Write operations are disabled (ANSIBLE_MCP_ALLOW_WRITE=false)",
"hint": "Set check_mode=true or enable ANSIBLE_MCP_ALLOW_WRITE",
}
_audit_event(
"run_playbook_denied",
{
"playbook": str(playbook_path.relative_to(CONFIG.repo_root)),
"reason": "write_disabled",
"check_mode": check_mode,
},
)
return payload
if is_write and CONFIG.require_confirm_for_write and not confirm:
payload = {
"ok": False,
"error": "Write operation requires explicit confirm=true",
"hint": "Retry with confirm=true after review",
}
_audit_event(
"run_playbook_denied",
{
"playbook": str(playbook_path.relative_to(CONFIG.repo_root)),
"reason": "missing_confirm",
"check_mode": check_mode,
},
)
return payload
run_id = str(uuid.uuid4())
started_at = _utc_now()
extra_vars_file: Path | None = None
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tf:
if extra_vars:
json.dump(extra_vars, tf)
tf.flush()
extra_vars_file = Path(tf.name)
cmd = _build_command(
config=CONFIG,
playbook_path=playbook_path,
limit=limit,
tags=tags,
skip_tags=skip_tags,
check_mode=check_mode,
extra_vars_file=extra_vars_file,
)
base_job = {
"run_id": run_id,
"playbook": str(playbook_path.relative_to(CONFIG.repo_root)),
"check_mode": check_mode,
"confirm": confirm,
"limit": limit,
"tags": tags,
"skip_tags": skip_tags,
"timeout_seconds": safe_timeout,
"command": " ".join(shlex.quote(c) for c in cmd),
"started_at": started_at,
}
try:
if background:
log_file = STORE.logs_dir / f"{run_id}.log"
done_file = STORE.jobs_dir / f"{run_id}.done.json"
wrapper = STORE.wrap_dir / f"{run_id}.sh"
script = "\n".join(
[
"#!/usr/bin/env bash",
"set -o pipefail",
f"cd {shlex.quote(str(CONFIG.repo_root))}",
"{ " + " ".join(shlex.quote(c) for c in cmd) + "; }",
"rc=$?",
"python3 - <<'PY'",
"import json",
"from datetime import datetime, timezone",
f"done_file = {str(done_file)!r}",
"payload = {",
" 'completed_at': datetime.now(timezone.utc).isoformat(),",
" 'returncode': rc,",
"}",
"with open(done_file, 'w', encoding='utf-8') as f:",
" json.dump(payload, f)",
"PY",
"exit $rc",
]
)
wrapper.write_text(script, encoding="utf-8")
wrapper.chmod(0o750)
log_handle = log_file.open("w", encoding="utf-8")
proc = subprocess.Popen(
[str(wrapper)],
cwd=CONFIG.repo_root,
stdout=log_handle,
stderr=subprocess.STDOUT,
start_new_session=True,
)
log_handle.close()
payload = {
**base_job,
"background": True,
"status": "running",
"pid": proc.pid,
"log_file": str(log_file),
"done_file": str(done_file),
}
STORE.save_job(run_id, payload)
_audit_event(
"run_playbook_background_started",
{
"run_id": run_id,
"playbook": payload["playbook"],
"check_mode": check_mode,
"pid": proc.pid,
},
)
return {
"ok": True,
"run_id": run_id,
"status": "running",
"pid": proc.pid,
"log_file": str(log_file),
"message": "Playbook started in background",
}
result = subprocess.run(
cmd,
cwd=CONFIG.repo_root,
capture_output=True,
text=True,
timeout=safe_timeout,
check=False,
)
completed_payload = {
**base_job,
"background": False,
"status": "succeeded" if result.returncode == 0 else "failed",
"returncode": result.returncode,
"completed_at": _utc_now(),
"stdout": result.stdout,
"stderr": result.stderr,
}
STORE.save_job(run_id, completed_payload)
_audit_event(
"run_playbook_completed",
{
"run_id": run_id,
"playbook": completed_payload["playbook"],
"status": completed_payload["status"],
"returncode": result.returncode,
},
)
return {
"ok": result.returncode == 0,
"run_id": run_id,
"status": completed_payload["status"],
"returncode": result.returncode,
"stdout": result.stdout,
"stderr": result.stderr,
"playbook": completed_payload["playbook"],
"command": completed_payload["command"],
}
except subprocess.TimeoutExpired as err:
timed_out_payload = {
**base_job,
"background": False,
"status": "timed_out",
"completed_at": _utc_now(),
"stdout": err.stdout,
"stderr": err.stderr,
}
STORE.save_job(run_id, timed_out_payload)
_audit_event(
"run_playbook_timed_out",
{
"run_id": run_id,
"playbook": timed_out_payload["playbook"],
"timeout_seconds": safe_timeout,
},
)
return {
"ok": False,
"run_id": run_id,
"status": "timed_out",
"timeout_seconds": safe_timeout,
"stdout": err.stdout,
"stderr": err.stderr,
"message": "Playbook exceeded timeout",
}
finally:
if extra_vars_file and extra_vars_file.exists():
extra_vars_file.unlink(missing_ok=True)
@mcp.tool()
def get_job_status(
run_id: str,
tail_lines: int = 80,
auth_token: str | None = None,
) -> dict[str, Any]:
"""Get status and recent logs for a tracked run_id."""
_require_auth(auth_token)
if tail_lines <= 0:
raise ValueError("tail_lines must be greater than 0")
job = STORE.load_job(run_id)
if not job:
return {"ok": False, "error": f"Unknown run_id: {run_id}"}
if job.get("background"):
done_file = Path(job["done_file"])
if done_file.exists():
done_payload = json.loads(done_file.read_text(encoding="utf-8"))
job["status"] = "succeeded" if done_payload["returncode"] == 0 else "failed"
job["returncode"] = done_payload["returncode"]
job["completed_at"] = done_payload["completed_at"]
STORE.save_job(run_id, job)
else:
pid = int(job.get("pid", 0))
if pid > 0:
try:
os.kill(pid, 0)
job["status"] = "running"
except OSError:
job["status"] = "unknown"
else:
job["status"] = "unknown"
response = {"ok": True, **job}
log_file = job.get("log_file")
if log_file and Path(log_file).exists():
lines = Path(log_file).read_text(encoding="utf-8", errors="replace").splitlines()
response["log_tail"] = lines[-tail_lines:]
_audit_event(
"get_job_status",
{"run_id": run_id, "status": response.get("status")},
)
return response
@mcp.tool()
def cancel_job(run_id: str, auth_token: str | None = None) -> dict[str, Any]:
"""Cancel a running background job."""
_require_auth(auth_token)
job = STORE.load_job(run_id)
if not job:
return {"ok": False, "error": f"Unknown run_id: {run_id}"}
if not job.get("background"):
return {"ok": False, "error": "cancel_job is only valid for background jobs"}
if job.get("status") not in {"running", "unknown"}:
return {"ok": False, "error": f"Job is not running (status={job.get('status')})"}
pid = int(job.get("pid", 0))
if pid <= 0:
return {"ok": False, "error": "Job PID is invalid"}
try:
os.killpg(pid, signal.SIGTERM)
except ProcessLookupError:
return {"ok": False, "error": "Process does not exist"}
except PermissionError as err:
return {"ok": False, "error": f"Permission denied terminating process group: {err}"}
job["status"] = "cancelled"
job["completed_at"] = _utc_now()
STORE.save_job(run_id, job)
payload = {"ok": True, "run_id": run_id, "status": "cancelled"}
_audit_event("cancel_job", payload)
return payload
def main() -> None:
parser = argparse.ArgumentParser(description="Run the Ansible MCP server")
parser.add_argument(
"--transport",
choices=["stdio", "streamable-http"],
default=os.getenv("ANSIBLE_MCP_TRANSPORT", "stdio"),
help="MCP transport to use",
)
parser.add_argument(
"--host",
default=os.getenv("ANSIBLE_MCP_HOST", "0.0.0.0"),
help="Host for streamable-http transport",
)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("ANSIBLE_MCP_PORT", "8449")),
help="Port for streamable-http transport",
)
args = parser.parse_args()
# FastMCP transport settings are configured on the server object in this SDK version.
mcp.settings.host = args.host
mcp.settings.port = args.port
if args.transport == "stdio":
mcp.run(transport="stdio")
else:
mcp.run(transport="streamable-http")
if __name__ == "__main__":
main()