477 lines
16 KiB
Python
477 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Literal
|
|
|
|
# Import AD adapter for backend selection
|
|
try:
|
|
from ad_adapter import ActiveDirectoryIdentityBackend
|
|
except ImportError:
|
|
ActiveDirectoryIdentityBackend = None # type: ignore
|
|
|
|
|
|
# Type definitions for Phase 2 query contract
|
|
SortByField = Literal["display_name", "username", "last_logon_utc", "when_created_utc", "department"]
|
|
SortDirection = Literal["asc", "desc"]
|
|
GroupByField = Literal["enabled", "ou", "department", "title", "created_month", "last_logon_bucket"]
|
|
|
|
ALLOWED_USER_FIELDS = {
|
|
"username",
|
|
"display_name",
|
|
"first_name",
|
|
"last_name",
|
|
"enabled",
|
|
"ou",
|
|
"description",
|
|
"last_logon_utc",
|
|
"when_created_utc",
|
|
"department",
|
|
"title",
|
|
"email",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class UserRecord:
|
|
username: str
|
|
first_name: str
|
|
last_name: str
|
|
display_name: str
|
|
enabled: bool
|
|
ou: str
|
|
description: str
|
|
last_logon_utc: datetime
|
|
groups: list[str]
|
|
|
|
|
|
@dataclass
|
|
class ComputerRecord:
|
|
computer_name: str
|
|
ou: str
|
|
assigned_username: str | None
|
|
|
|
|
|
class IdentityBackend:
|
|
"""Backend interface for identity data providers.
|
|
|
|
Replace this in-memory implementation with approved AD/Entra integrations.
|
|
"""
|
|
|
|
async def get_user(self, username: str) -> dict[str, Any] | None:
|
|
raise NotImplementedError
|
|
|
|
async def search_users_by_name(
|
|
self, name_query: str, limit: int = 20
|
|
) -> list[dict[str, Any]]:
|
|
raise NotImplementedError
|
|
|
|
async def get_user_groups(self, username: str) -> list[str]:
|
|
raise NotImplementedError
|
|
|
|
async def get_group_members(self, group_name: str) -> list[str]:
|
|
raise NotImplementedError
|
|
|
|
async def find_stale_users(self, days: int) -> list[dict[str, Any]]:
|
|
raise NotImplementedError
|
|
|
|
async def get_computer(self, computer_name: str) -> dict[str, Any] | None:
|
|
raise NotImplementedError
|
|
|
|
async def query_users(
|
|
self,
|
|
filter_params: dict[str, Any] | None = None,
|
|
fields: list[str] | None = None,
|
|
sort_by: str = "display_name",
|
|
sort_direction: str = "asc",
|
|
page_size: int = 50,
|
|
cursor: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Query users with flexible validated filters and pagination.
|
|
|
|
Returns:
|
|
dict with keys: items, next_cursor, page_size, total_estimate, applied_filter, warnings
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
async def count_users(
|
|
self,
|
|
filter_params: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Count users matching filter without returning full records.
|
|
|
|
Returns:
|
|
dict with keys: count, applied_filter
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
async def summarize_users(
|
|
self,
|
|
filter_params: dict[str, Any] | None = None,
|
|
group_by: str = "enabled",
|
|
top: int = 20,
|
|
) -> dict[str, Any]:
|
|
"""Return grouped aggregates for users matching filter.
|
|
|
|
Returns:
|
|
dict with keys: buckets, total, applied_filter
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class InMemoryIdentityBackend(IdentityBackend):
|
|
"""Local-safe backend for initial MCP wiring and tool contract validation."""
|
|
|
|
def __init__(self) -> None:
|
|
now = datetime.now(tz=timezone.utc)
|
|
self._users: dict[str, UserRecord] = {
|
|
"jane.doe": UserRecord(
|
|
username="jane.doe",
|
|
first_name="Jane",
|
|
last_name="Doe",
|
|
display_name="Jane Doe",
|
|
enabled=True,
|
|
ou="OU=Users,DC=example,DC=local",
|
|
description="Service Desk",
|
|
last_logon_utc=now - timedelta(days=2),
|
|
groups=["GG-Global-VPN", "GG-ServiceDesk"],
|
|
),
|
|
"john.smith": UserRecord(
|
|
username="john.smith",
|
|
first_name="John",
|
|
last_name="Smith",
|
|
display_name="John Smith",
|
|
enabled=False,
|
|
ou="OU=DisabledUsers,DC=example,DC=local",
|
|
description="Terminated 2026-02-20",
|
|
last_logon_utc=now - timedelta(days=65),
|
|
groups=["GG-FormerEmployees"],
|
|
), "alice.tech": UserRecord(
|
|
username="alice.tech",
|
|
first_name="Alice",
|
|
last_name="Tech",
|
|
display_name="Alice Tech",
|
|
enabled=True,
|
|
ou="OU=IT,OU=Users,DC=example,DC=local",
|
|
description="IT Infrastructure",
|
|
last_logon_utc=now - timedelta(days=1),
|
|
groups=["GG-IT-Infrastructure", "GG-Global-VPN"],
|
|
),
|
|
"bob.sales": UserRecord(
|
|
username="bob.sales",
|
|
first_name="Bob",
|
|
last_name="Sales",
|
|
display_name="Bob Sales",
|
|
enabled=False,
|
|
ou="OU=DisabledUsers,DC=example,DC=local",
|
|
description="Inactive 2025-12-01",
|
|
last_logon_utc=now - timedelta(days=120),
|
|
groups=["GG-Sales-Disabled"],
|
|
), }
|
|
self._computers: dict[str, ComputerRecord] = {
|
|
"LT-1001": ComputerRecord(
|
|
computer_name="LT-1001",
|
|
ou="OU=Workstations,DC=example,DC=local",
|
|
assigned_username="jane.doe",
|
|
)
|
|
}
|
|
|
|
async def get_user(self, username: str) -> dict[str, Any] | None:
|
|
user = self._users.get(username.lower())
|
|
if user is None:
|
|
return None
|
|
return {
|
|
"username": user.username,
|
|
"first_name": user.first_name,
|
|
"last_name": user.last_name,
|
|
"display_name": user.display_name,
|
|
"enabled": user.enabled,
|
|
"ou": user.ou,
|
|
"description": user.description,
|
|
"last_logon_utc": user.last_logon_utc.isoformat(),
|
|
}
|
|
|
|
async def search_users_by_name(
|
|
self, name_query: str, limit: int = 20
|
|
) -> list[dict[str, Any]]:
|
|
query = name_query.strip().lower()
|
|
if not query:
|
|
return []
|
|
|
|
max_results = max(1, min(limit, 100))
|
|
results: list[dict[str, Any]] = []
|
|
for user in self._users.values():
|
|
searchable = [user.first_name, user.last_name, user.display_name]
|
|
if any(query in value.lower() for value in searchable):
|
|
results.append(
|
|
{
|
|
"username": user.username,
|
|
"first_name": user.first_name,
|
|
"last_name": user.last_name,
|
|
"display_name": user.display_name,
|
|
"enabled": user.enabled,
|
|
"ou": user.ou,
|
|
}
|
|
)
|
|
|
|
results = sorted(results, key=lambda row: row["display_name"].lower())
|
|
return results[:max_results]
|
|
|
|
async def get_user_groups(self, username: str) -> list[str]:
|
|
user = self._users.get(username.lower())
|
|
if user is None:
|
|
return []
|
|
return list(user.groups)
|
|
|
|
async def get_group_members(self, group_name: str) -> list[str]:
|
|
wanted = group_name.lower()
|
|
members: list[str] = []
|
|
for user in self._users.values():
|
|
if any(g.lower() == wanted for g in user.groups):
|
|
members.append(user.username)
|
|
return sorted(members)
|
|
|
|
async def find_stale_users(self, days: int) -> list[dict[str, Any]]:
|
|
if days < 0:
|
|
return []
|
|
cutoff = datetime.now(tz=timezone.utc) - timedelta(days=days)
|
|
results: list[dict[str, Any]] = []
|
|
for user in self._users.values():
|
|
if user.last_logon_utc < cutoff:
|
|
results.append(
|
|
{
|
|
"username": user.username,
|
|
"enabled": user.enabled,
|
|
"last_logon_utc": user.last_logon_utc.isoformat(),
|
|
}
|
|
)
|
|
return sorted(results, key=lambda row: row["username"])
|
|
|
|
async def get_computer(self, computer_name: str) -> dict[str, Any] | None:
|
|
computer = self._computers.get(computer_name.upper())
|
|
if computer is None:
|
|
return None
|
|
return {
|
|
"computer_name": computer.computer_name,
|
|
"ou": computer.ou,
|
|
"assigned_username": computer.assigned_username,
|
|
}
|
|
|
|
async def query_users(
|
|
self,
|
|
filter_params: dict[str, Any] | None = None,
|
|
fields: list[str] | None = None,
|
|
sort_by: str = "display_name",
|
|
sort_direction: str = "asc",
|
|
page_size: int = 50,
|
|
cursor: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Query users with flexible validated filters and pagination."""
|
|
filter_params = filter_params or {}
|
|
fields = fields or list(ALLOWED_USER_FIELDS)
|
|
|
|
# Validate fields
|
|
invalid_fields = set(fields) - ALLOWED_USER_FIELDS
|
|
if invalid_fields:
|
|
return {
|
|
"items": [],
|
|
"next_cursor": None,
|
|
"page_size": page_size,
|
|
"total_estimate": 0,
|
|
"applied_filter": filter_params,
|
|
"warnings": [f"Invalid fields requested: {', '.join(invalid_fields)}"],
|
|
}
|
|
|
|
# Filter users
|
|
matching_users: list[UserRecord] = []
|
|
for user in self._users.values():
|
|
if not self._user_matches_filter(user, filter_params):
|
|
continue
|
|
matching_users.append(user)
|
|
|
|
# Sort
|
|
sort_key_map = {
|
|
"display_name": lambda u: u.display_name.lower(),
|
|
"username": lambda u: u.username.lower(),
|
|
"last_logon_utc": lambda u: u.last_logon_utc,
|
|
}
|
|
sort_func = sort_key_map.get(sort_by, lambda u: u.display_name.lower())
|
|
matching_users = sorted(
|
|
matching_users,
|
|
key=sort_func,
|
|
reverse=(sort_direction == "desc"),
|
|
)
|
|
|
|
# Paginate
|
|
clamped_size = min(max(1, page_size), 200)
|
|
start_index = 0
|
|
if cursor:
|
|
try:
|
|
start_index = int(cursor)
|
|
except ValueError:
|
|
pass
|
|
|
|
page_users = matching_users[start_index:start_index + clamped_size]
|
|
next_cursor = None
|
|
if start_index + clamped_size < len(matching_users):
|
|
next_cursor = str(start_index + clamped_size)
|
|
|
|
# Project fields
|
|
items = []
|
|
for user in page_users:
|
|
item: dict[str, Any] = {}
|
|
if "username" in fields:
|
|
item["username"] = user.username
|
|
if "display_name" in fields:
|
|
item["display_name"] = user.display_name
|
|
if "first_name" in fields:
|
|
item["first_name"] = user.first_name
|
|
if "last_name" in fields:
|
|
item["last_name"] = user.last_name
|
|
if "enabled" in fields:
|
|
item["enabled"] = user.enabled
|
|
if "ou" in fields:
|
|
item["ou"] = user.ou
|
|
if "description" in fields:
|
|
item["description"] = user.description
|
|
if "last_logon_utc" in fields:
|
|
item["last_logon_utc"] = user.last_logon_utc.isoformat()
|
|
if "department" in fields:
|
|
item["department"] = ""
|
|
if "title" in fields:
|
|
item["title"] = ""
|
|
if "email" in fields:
|
|
item["email"] = f"{user.username}@example.local"
|
|
if "when_created_utc" in fields:
|
|
item["when_created_utc"] = ""
|
|
items.append(item)
|
|
|
|
return {
|
|
"items": items,
|
|
"next_cursor": next_cursor,
|
|
"page_size": clamped_size,
|
|
"total_estimate": len(matching_users),
|
|
"applied_filter": filter_params,
|
|
"warnings": [],
|
|
}
|
|
|
|
async def count_users(
|
|
self,
|
|
filter_params: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Count users matching filter without returning full records."""
|
|
filter_params = filter_params or {}
|
|
|
|
count = 0
|
|
for user in self._users.values():
|
|
if self._user_matches_filter(user, filter_params):
|
|
count += 1
|
|
|
|
return {
|
|
"count": count,
|
|
"applied_filter": filter_params,
|
|
}
|
|
|
|
async def summarize_users(
|
|
self,
|
|
filter_params: dict[str, Any] | None = None,
|
|
group_by: str = "enabled",
|
|
top: int = 20,
|
|
) -> dict[str, Any]:
|
|
"""Return grouped aggregates for users matching filter."""
|
|
filter_params = filter_params or {}
|
|
|
|
# Filter users
|
|
matching_users: list[UserRecord] = []
|
|
for user in self._users.values():
|
|
if self._user_matches_filter(user, filter_params):
|
|
matching_users.append(user)
|
|
|
|
# Group
|
|
bucket_counts: dict[str, int] = {}
|
|
for user in matching_users:
|
|
key = self._get_group_key(user, group_by)
|
|
bucket_counts[key] = bucket_counts.get(key, 0) + 1
|
|
|
|
# Sort by count descending and take top N
|
|
sorted_buckets = sorted(
|
|
bucket_counts.items(),
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
)
|
|
clamped_top = min(max(1, top), 50)
|
|
top_buckets = sorted_buckets[:clamped_top]
|
|
|
|
buckets = [{"key": key, "count": count} for key, count in top_buckets]
|
|
|
|
return {
|
|
"buckets": buckets,
|
|
"total": len(matching_users),
|
|
"applied_filter": filter_params,
|
|
}
|
|
|
|
def _user_matches_filter(self, user: UserRecord, filter_params: dict[str, Any]) -> bool:
|
|
"""Check if user matches all filter criteria."""
|
|
if "enabled" in filter_params:
|
|
if user.enabled != filter_params["enabled"]:
|
|
return False
|
|
|
|
if "name_contains" in filter_params:
|
|
query = filter_params["name_contains"].lower()
|
|
if query not in user.display_name.lower():
|
|
return False
|
|
|
|
if "username_prefix" in filter_params:
|
|
prefix = filter_params["username_prefix"].lower()
|
|
if not user.username.lower().startswith(prefix):
|
|
return False
|
|
|
|
if "ou_contains" in filter_params:
|
|
ou_query = filter_params["ou_contains"].lower()
|
|
if ou_query not in user.ou.lower():
|
|
return False
|
|
|
|
if "description_contains" in filter_params:
|
|
desc_query = filter_params["description_contains"].lower()
|
|
if desc_query not in user.description.lower():
|
|
return False
|
|
|
|
if "last_logon_before_days" in filter_params:
|
|
days = filter_params["last_logon_before_days"]
|
|
cutoff = datetime.now(tz=timezone.utc) - timedelta(days=days)
|
|
if user.last_logon_utc >= cutoff:
|
|
return False
|
|
|
|
if "group_any" in filter_params:
|
|
wanted_groups = {g.lower() for g in filter_params["group_any"]}
|
|
user_groups = {g.lower() for g in user.groups}
|
|
if not wanted_groups.intersection(user_groups):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _get_group_key(self, user: UserRecord, group_by: str) -> str:
|
|
"""Get grouping key for a user."""
|
|
if group_by == "enabled":
|
|
return "Enabled" if user.enabled else "Disabled"
|
|
if group_by == "ou":
|
|
return user.ou
|
|
if group_by == "department":
|
|
return "Unknown"
|
|
if group_by == "title":
|
|
return "Unknown"
|
|
if group_by == "created_month":
|
|
return "Unknown"
|
|
if group_by == "last_logon_bucket":
|
|
days_ago = (datetime.now(tz=timezone.utc) - user.last_logon_utc).days
|
|
if days_ago < 7:
|
|
return "Last 7 days"
|
|
elif days_ago < 30:
|
|
return "Last 30 days"
|
|
elif days_ago < 90:
|
|
return "Last 90 days"
|
|
else:
|
|
return "90+ days"
|
|
return "Unknown"
|
|
|