feat(ai): enhance session management, inquirer theme, and gRPC MCP support

- AI & Session Management:
  - Add random suffix to session IDs to ensure uniqueness.
  - Implement optional pagination/limit in session listing (default 20).
  - Add `--all` flag to `ai` CLI commands to view all sessions.
  - Keep active session ID and path synced correctly during clean session startups.

- CLI UI/UX:
  - Add a custom `ConnpyTheme` for inquirer prompts that dynamically translates
    active hex style colors to terminal ANSI/blessed escapes.

- gRPC & Services:
  - Implement remote MCP server listing (`list_mcp_servers` RPC and services).
  - Stream responder updates (`__RESPONDER__`) to toggle headers dynamically between
    "Network Engineer" and "Network Architect" on remote/web clients.
  - Fix remote client deadlock risk by ensuring `final_mark` is sent on exceptions.
  - Hydrate client-side chat history correctly on initial streaming request.

- Testing:
  - Add integration tests for AI gRPC services and MCP server listing.
This commit is contained in:
2026-05-20 12:27:02 -03:00
parent 468868ac18
commit dce9982454
11 changed files with 515 additions and 319 deletions
+34 -12
View File
@@ -1,4 +1,6 @@
import os import os
import secrets
import sys import sys
import json import json
import re import re
@@ -165,8 +167,8 @@ class ai:
# Session Management # Session Management
self.sessions_dir = os.path.join(self.config.defaultdir, "ai_sessions") self.sessions_dir = os.path.join(self.config.defaultdir, "ai_sessions")
os.makedirs(self.sessions_dir, exist_ok=True) os.makedirs(self.sessions_dir, exist_ok=True)
self.session_id = None self.session_id = getattr(self.config, "session_id", None)
self.session_path = None self.session_path = os.path.join(self.sessions_dir, f"{self.session_id}.json") if self.session_id else None
# Prompts base agnósticos # Prompts base agnósticos
architect_instructions = "" architect_instructions = ""
@@ -877,16 +879,27 @@ class ai:
continue continue
return sorted(sessions, key=lambda x: x["created_at"], reverse=True) return sorted(sessions, key=lambda x: x["created_at"], reverse=True)
def list_sessions(self): def list_sessions(self, limit=20):
"""Prints a list of sessions using printer.table.""" """Prints a list of sessions using printer.table."""
sessions = self._get_sessions() sessions = self._get_sessions()
if not sessions: if not sessions:
printer.info("No saved AI sessions found.") printer.info("No saved AI sessions found.")
return return
total = len(sessions)
if limit and total > limit:
sessions = sessions[:limit]
columns = ["ID", "Title", "Created At", "Model"] columns = ["ID", "Title", "Created At", "Model"]
rows = [[s["id"], s["title"], s["created_at"], s["model"]] for s in sessions] rows = [[s["id"], s["title"], s["created_at"], s["model"]] for s in sessions]
printer.table("AI Persisted Sessions", columns, rows)
title = "AI Persisted Sessions"
if limit and total > limit:
title += f" (Showing last {limit} of {total})"
printer.table(title, columns, rows)
if limit and total > limit:
printer.info(f"Use '--list --all' (if supported) or check the sessions directory to see all {total} sessions.")
def load_session_data(self, session_id): def load_session_data(self, session_id):
"""Loads a session's raw data by ID.""" """Loads a session's raw data by ID."""
@@ -917,8 +930,10 @@ class ai:
return sessions[0]["id"] if sessions else None return sessions[0]["id"] if sessions else None
def _generate_session_id(self, query): def _generate_session_id(self, query):
"""Generates a unique session ID based on timestamp.""" """Generates a unique session ID based on timestamp and a random suffix."""
return datetime.datetime.now().strftime("%Y%m%d-%H%M%S") ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
suffix = secrets.token_hex(2)
return f"{ts}-{suffix}"
def save_session(self, history, title=None, model=None): def save_session(self, history, title=None, model=None):
"""Saves current history to the session file.""" """Saves current history to the session file."""
@@ -927,6 +942,8 @@ class ai:
first_user_msg = next((m["content"] for m in history if m["role"] == "user"), "new-session") first_user_msg = next((m["content"] for m in history if m["role"] == "user"), "new-session")
self.session_id = self._generate_session_id(first_user_msg) self.session_id = self._generate_session_id(first_user_msg)
self.session_path = os.path.join(self.sessions_dir, f"{self.session_id}.json") self.session_path = os.path.join(self.sessions_dir, f"{self.session_id}.json")
elif not self.session_path:
self.session_path = os.path.join(self.sessions_dir, f"{self.session_id}.json")
# If it's a new file, we might want to set a better title # If it's a new file, we might want to set a better title
if not os.path.exists(self.session_path) and not title: if not os.path.exists(self.session_path) and not title:
@@ -970,10 +987,15 @@ class ai:
if chat_history is None: chat_history = [] if chat_history is None: chat_history = []
# Load session if provided and history is empty # Load session if provided and history is empty
if session_id and not chat_history: if session_id:
session_data = self.load_session_data(session_id) # Force the session_id even if it doesn't exist yet
if session_data: self.session_id = session_id
chat_history = session_data.get("history", []) self.session_path = os.path.join(self.sessions_dir, f"{session_id}.json")
if not chat_history:
session_data = self.load_session_data(session_id)
if session_data:
chat_history = session_data.get("history", [])
# If we loaded history, the caller might need it back # If we loaded history, the caller might need it back
# But typically ask() is called in a loop with an external history object # But typically ask() is called in a loop with an external history object
@@ -1058,8 +1080,8 @@ class ai:
label = "[architect][bold]Architect[/bold][/architect]" if current_brain == "architect" else "[engineer][bold]Engineer[/bold][/engineer]" label = "[architect][bold]Architect[/bold][/architect]" if current_brain == "architect" else "[engineer][bold]Engineer[/bold][/engineer]"
if status: if status:
# Notify responder identity ONLY for web/remote clients (StatusBridge has is_web) # Notify responder identity for web/remote clients
if getattr(status, "is_web", False): if getattr(status, "is_web", False) or getattr(status, "is_remote", False):
status.update(f"__RESPONDER__:{current_brain}") status.update(f"__RESPONDER__:{current_brain}")
status.update(f"{label} is thinking... (step {iteration})") status.update(f"{label} is thinking... (step {iteration})")
+15 -8
View File
@@ -15,13 +15,22 @@ class AIHandler:
def dispatch(self, args): def dispatch(self, args):
if args.list_sessions: if args.list_sessions:
sessions = self.app.services.ai.list_sessions() limit = 20 if not getattr(args, "all", False) else None
sessions, total = self.app.services.ai.list_sessions(limit=limit)
if not sessions: if not sessions:
printer.info("No saved AI sessions found.") printer.info("No saved AI sessions found.")
return return
columns = ["ID", "Title", "Created At", "Model"] columns = ["ID", "Title", "Created At", "Model"]
rows = [[s["id"], s["title"], s["created_at"], s["model"]] for s in sessions] rows = [[s["id"], s["title"], s["created_at"], s["model"]] for s in sessions]
printer.table("AI Persisted Sessions", columns, rows)
title = "AI Persisted Sessions"
if limit and total > limit:
title += f" (Showing last {limit} of {total})"
printer.table(title, columns, rows)
if limit and total > limit:
printer.info(f"Use '--list --all' to see all {total} sessions.")
return return
if args.delete_session: if args.delete_session:
@@ -102,7 +111,7 @@ class AIHandler:
if history: if history:
mdprint(f"[debug]Analyzing {len(history)} previous messages...[/debug]\n") mdprint(f"[debug]Analyzing {len(history)} previous messages...[/debug]\n")
else: else:
printer.error(f"Could not load session {session_id}. Starting clean.") printer.info(f"Session '{session_id}' not found. Starting clean.")
if not history: if not history:
mdprint(Rule(style="engineer")) mdprint(Rule(style="engineer"))
@@ -116,7 +125,7 @@ class AIHandler:
if user_query.lower() in ['exit', 'quit', 'bye', 'cancel']: break if user_query.lower() in ['exit', 'quit', 'bye', 'cancel']: break
with console.status("[ai_status]Agent is thinking...") as status: with console.status("[ai_status]Agent is thinking...") as status:
result = self.app.myai.ask(user_query, chat_history=history, status=status, debug=args.debug, trust=args.trust, **self.ai_overrides) result = self.app.myai.ask(user_query, chat_history=history, status=status, debug=args.debug, trust=args.trust, session_id=session_id, **self.ai_overrides)
new_history = result.get("chat_history") new_history = result.get("chat_history")
if new_history is not None: if new_history is not None:
@@ -147,8 +156,7 @@ class AIHandler:
action = mcp_args[0].lower() action = mcp_args[0].lower()
if action == "list": if action == "list":
settings = self.app.services.config_svc.get_settings() mcp_servers = self.app.services.ai.list_mcp_servers()
mcp_servers = settings.get("ai", {}).get("mcp_servers", {})
if not mcp_servers: if not mcp_servers:
printer.info("No MCP servers configured.") printer.info("No MCP servers configured.")
else: else:
@@ -213,8 +221,7 @@ class AIHandler:
from .forms import Forms from .forms import Forms
self.app.cli_forms = Forms(self.app) self.app.cli_forms = Forms(self.app)
settings = self.app.services.config_svc.get_settings() mcp_servers = self.app.services.ai.list_mcp_servers()
mcp_servers = settings.get("ai", {}).get("mcp_servers", {})
result = self.app.cli_forms.mcp_wizard(mcp_servers) result = self.app.cli_forms.mcp_wizard(mcp_servers)
if not result: if not result:
+72 -1
View File
@@ -1,10 +1,81 @@
import os import os
import inquirer import inquirer
from inquirer.themes import Default, term
try: try:
from pyfzf.pyfzf import FzfPrompt from pyfzf.pyfzf import FzfPrompt
except ImportError: except ImportError:
FzfPrompt = None FzfPrompt = None
def hex_to_blessed(hex_str):
"""Convert hex color string to blessed/ansi format."""
if not hex_str or not isinstance(hex_str, str):
return term.normal
# Check for bold prefix
prefix = ""
if hex_str.startswith('bold '):
prefix = term.bold
hex_str = hex_str.replace('bold ', '').strip()
# If it's a standard color name
if not hex_str.startswith('#'):
return prefix + getattr(term, hex_str, term.normal)
# Parse hex
try:
h = hex_str.lstrip('#')
if len(h) == 3:
h = ''.join([c*2 for c in h])
r = int(h[0:2], 16)
g = int(h[2:4], 16)
b = int(h[4:6], 16)
# Try RGB, fallback to standard cyan if it fails or returns empty
try:
c = term.color_rgb(r, g, b)
if not c: # Some terms return empty for RGB
return prefix + term.cyan
return prefix + c
except:
return prefix + term.cyan
except:
return prefix + term.normal
# Custom inquirer theme matching connpy colors
class ConnpyTheme(Default):
def __init__(self):
super().__init__()
try:
from ..printer import _global_active_styles
# Use user_prompt as primary accent, fallback to info/cyan
accent = _global_active_styles.get("user_prompt", _global_active_styles.get("info", "cyan"))
accent_color = hex_to_blessed(accent)
self.Question.mark_color = accent_color
self.List.selection_color = accent_color
self.List.selection_cursor = ">"
except:
# Absolute fallback to standard cyan
self.Question.mark_color = term.cyan
self.List.selection_color = term.bold_cyan
self.List.selection_cursor = ">"
def get_theme():
"""Returns a fresh instance of the theme with current colors."""
return ConnpyTheme()
class ThemeProxy:
"""Proxy to ensure theme colors are resolved at runtime."""
def __getattr__(self, name):
return getattr(get_theme(), name)
def __iter__(self):
return iter(get_theme())
def __getitem__(self, item):
return get_theme()[item]
theme = ThemeProxy()
def get_config_dir(): def get_config_dir():
home = os.path.expanduser("~") home = os.path.expanduser("~")
defaultdir = os.path.join(home, '.config/conn') defaultdir = os.path.join(home, '.config/conn')
@@ -56,7 +127,7 @@ def choose(app, list_, name, action):
return answer[0] return answer[0]
else: else:
questions = [inquirer.List(name, message="Pick {} to {}:".format(name,action), choices=list_, carousel=True)] questions = [inquirer.List(name, message="Pick {} to {}:".format(name,action), choices=list_, carousel=True)]
answer = inquirer.prompt(questions) answer = inquirer.prompt(questions, theme=theme)
if answer == None: if answer == None:
return None return None
else: else:
+1
View File
@@ -281,6 +281,7 @@ class connapp:
aiparser.add_argument("--debug", action="store_true", help="Show AI reasoning and tool calls") aiparser.add_argument("--debug", action="store_true", help="Show AI reasoning and tool calls")
aiparser.add_argument("-y", "--trust", action="store_true", help="Trust AI to execute unsafe commands without confirmation") aiparser.add_argument("-y", "--trust", action="store_true", help="Trust AI to execute unsafe commands without confirmation")
aiparser.add_argument("--list", "--list-sessions", dest="list_sessions", action="store_true", help="List saved AI sessions") aiparser.add_argument("--list", "--list-sessions", dest="list_sessions", action="store_true", help="List saved AI sessions")
aiparser.add_argument("--all", action="store_true", help="Show all sessions without limit")
aiparser.add_argument("--session", nargs=1, help="Resume a specific AI session by ID") aiparser.add_argument("--session", nargs=1, help="Resume a specific AI session by ID")
aiparser.add_argument("--resume", action="store_true", help="Resume the most recent AI session") aiparser.add_argument("--resume", action="store_true", help="Resume the most recent AI session")
aiparser.add_argument("--delete", "--delete-session", dest="delete_session", nargs=1, help="Delete an AI session by ID") aiparser.add_argument("--delete", "--delete-session", dest="delete_session", nargs=1, help="Delete an AI session by ID")
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large Load Diff
+17 -2
View File
@@ -483,6 +483,7 @@ class ProfileServicer(connpy_pb2_grpc.ProfileServiceServicer):
self.service = ProfileService(config) self.service = ProfileService(config)
self.node_service = NodeService(config) self.node_service = NodeService(config)
@handle_errors @handle_errors
def list_profiles(self, request, context): def list_profiles(self, request, context):
f = request.filter_str if request.filter_str else None f = request.filter_str if request.filter_str else None
@@ -731,6 +732,7 @@ class StatusBridge:
self.on_interrupt = self._force_interrupt self.on_interrupt = self._force_interrupt
self.thread = None self.thread = None
self.is_web = is_web self.is_web = is_web
self.is_remote = True
def _force_interrupt(self): def _force_interrupt(self):
"""Forcefully raise KeyboardInterrupt in the target thread.""" """Forcefully raise KeyboardInterrupt in the target thread."""
@@ -862,9 +864,11 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
print(f"AI Task Error: {e}") print(f"AI Task Error: {e}")
traceback.print_exc() traceback.print_exc()
chunk_queue.put(("status", f"Error: {str(e)}")) chunk_queue.put(("status", f"Error: {str(e)}"))
# Crucial: always send final_mark to avoid client deadlock
chunk_queue.put(("final_mark", {"response": f"Error: {str(e)}", "chat_history": history, "error": True}))
def request_listener(): def request_listener():
nonlocal bridge, is_web, ai_thread, agent_instance nonlocal bridge, is_web, ai_thread, agent_instance, history
try: try:
for req in request_iterator: for req in request_iterator:
if req.interrupt: if req.interrupt:
@@ -878,6 +882,11 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
if req.input_text: if req.input_text:
is_web = "web" in (req.session_id or "").lower() or (req.session_id or "").lower().startswith("ws-") is_web = "web" in (req.session_id or "").lower() or (req.session_id or "").lower().startswith("ws-")
# Hydrate history from client if it's the first interaction in this stream
if not history and req.chat_history:
from .utils import from_value
history = from_value(req.chat_history) or []
if not bridge: if not bridge:
bridge = StatusBridge(chunk_queue, request_queue=request_queue, is_web=is_web) bridge = StatusBridge(chunk_queue, request_queue=request_queue, is_web=is_web)
@@ -948,7 +957,8 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
@handle_errors @handle_errors
def list_sessions(self, request, context): def list_sessions(self, request, context):
return connpy_pb2.ValueResponse(data=to_value(self.service.list_sessions())) sessions, total = self.service.list_sessions()
return connpy_pb2.ValueResponse(data=to_value(sessions))
@handle_errors @handle_errors
def delete_session(self, request, context): def delete_session(self, request, context):
@@ -971,6 +981,11 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
) )
return Empty() return Empty()
@handle_errors
def list_mcp_servers(self, request, context):
mcp_servers = self.service.list_mcp_servers()
return connpy_pb2.ValueResponse(data=to_value(mcp_servers))
@handle_errors @handle_errors
def load_session_data(self, request, context): def load_session_data(self, request, context):
return connpy_pb2.StructResponse(data=to_struct(self.service.load_session_data(request.value))) return connpy_pb2.StructResponse(data=to_struct(self.service.load_session_data(request.value)))
+20 -3
View File
@@ -758,6 +758,7 @@ class AIStub:
full_content = "" full_content = ""
header_printed = False header_printed = False
current_responder = "engineer"
final_result = {"response": "", "chat_history": []} final_result = {"response": "", "chat_history": []}
# Background thread to pull responses from gRPC into a local queue # Background thread to pull responses from gRPC into a local queue
@@ -802,6 +803,10 @@ class AIStub:
break break
if response.status_update: if response.status_update:
if response.status_update.startswith("__RESPONDER__:"):
current_responder = response.status_update.split(":")[1].lower()
continue
if response.requires_confirmation: if response.requires_confirmation:
if status: status.stop() if status: status.stop()
@@ -854,7 +859,9 @@ class AIStub:
stable_console = RichConsole(theme=connpy_theme, file=get_original_stdout()) stable_console = RichConsole(theme=connpy_theme, file=get_original_stdout())
# Print header on first chunk # Print header on first chunk
stable_console.print(Rule("[bold engineer]Network Engineer[/bold engineer]", style="engineer")) alias = "architect" if current_responder == "architect" else "engineer"
role_label = "Network Architect" if current_responder == "architect" else "Network Engineer"
stable_console.print(Rule(f"[bold {alias}]{role_label}[/bold {alias}]", style=alias))
header_printed = True header_printed = True
# Initialize parser # Initialize parser
@@ -906,8 +913,13 @@ class AIStub:
return self.stub.confirm(connpy_pb2.StringRequest(value=input_text)).value return self.stub.confirm(connpy_pb2.StringRequest(value=input_text)).value
@handle_errors @handle_errors
def list_sessions(self): def list_sessions(self, limit=None):
return from_value(self.stub.list_sessions(Empty()).data) from .utils import from_value
res = self.stub.list_sessions(Empty())
sessions = from_value(res.data) or []
if limit and len(sessions) > limit:
return sessions[:limit], len(sessions)
return sessions, len(sessions)
@handle_errors @handle_errors
def delete_session(self, session_id): def delete_session(self, session_id):
@@ -929,6 +941,11 @@ class AIStub:
) )
self.stub.configure_mcp(req) self.stub.configure_mcp(req)
@handle_errors
def list_mcp_servers(self):
res = self.stub.list_mcp_servers(Empty())
return from_value(res.data) or {}
@handle_errors @handle_errors
def load_session_data(self, session_id): def load_session_data(self, session_id):
return from_struct(self.stub.load_session_data(connpy_pb2.StringRequest(value=session_id)).data) return from_struct(self.stub.load_session_data(connpy_pb2.StringRequest(value=session_id)).data)
+1
View File
@@ -70,6 +70,7 @@ service AIService {
rpc delete_session (StringRequest) returns (google.protobuf.Empty) {} rpc delete_session (StringRequest) returns (google.protobuf.Empty) {}
rpc configure_provider (ProviderRequest) returns (google.protobuf.Empty) {} rpc configure_provider (ProviderRequest) returns (google.protobuf.Empty) {}
rpc configure_mcp (MCPRequest) returns (google.protobuf.Empty) {} rpc configure_mcp (MCPRequest) returns (google.protobuf.Empty) {}
rpc list_mcp_servers (google.protobuf.Empty) returns (ValueResponse) {}
rpc load_session_data (StringRequest) returns (StructResponse) {} rpc load_session_data (StringRequest) returns (StructResponse) {}
} }
+11 -3
View File
@@ -167,11 +167,14 @@ class AIService(BaseService):
return await asyncio.wrap_future(future) return await asyncio.wrap_future(future)
def list_sessions(self): def list_sessions(self, limit=None):
"""Return a list of all saved AI sessions.""" """Return a list of saved AI sessions, optionally limited."""
from connpy.ai import ai from connpy.ai import ai
agent = ai(self.config) agent = ai(self.config)
return agent._get_sessions() sessions = agent._get_sessions()
if limit and len(sessions) > limit:
return sessions[:limit], len(sessions)
return sessions, len(sessions)
def delete_session(self, session_id): def delete_session(self, session_id):
"""Delete an AI session by ID.""" """Delete an AI session by ID."""
@@ -228,6 +231,11 @@ class AIService(BaseService):
self.config.config["ai"] = ai_settings self.config.config["ai"] = ai_settings
self.config._saveconfig(self.config.file) self.config._saveconfig(self.config.file)
def list_mcp_servers(self) -> dict:
"""Get the configured MCP servers."""
ai_settings = self.config.config.get("ai", {})
return ai_settings.get("mcp_servers", {})
def load_session_data(self, session_id): def load_session_data(self, session_id):
"""Load a session's raw data by ID.""" """Load a session's raw data by ID."""
from connpy.ai import ai from connpy.ai import ai
+11
View File
@@ -120,6 +120,7 @@ class TestGRPCIntegration:
connpy_pb2_grpc.add_ConfigServiceServicer_to_server(server.ConfigServicer(populated_config), srv) connpy_pb2_grpc.add_ConfigServiceServicer_to_server(server.ConfigServicer(populated_config), srv)
connpy_pb2_grpc.add_ExecutionServiceServicer_to_server(server.ExecutionServicer(populated_config), srv) connpy_pb2_grpc.add_ExecutionServiceServicer_to_server(server.ExecutionServicer(populated_config), srv)
connpy_pb2_grpc.add_ImportExportServiceServicer_to_server(server.ImportExportServicer(populated_config), srv) connpy_pb2_grpc.add_ImportExportServiceServicer_to_server(server.ImportExportServicer(populated_config), srv)
connpy_pb2_grpc.add_AIServiceServicer_to_server(server.AIServicer(populated_config), srv)
port = srv.add_insecure_port('127.0.0.1:0') port = srv.add_insecure_port('127.0.0.1:0')
srv.start() srv.start()
@@ -143,6 +144,10 @@ class TestGRPCIntegration:
def config_stub(self, channel): def config_stub(self, channel):
return stubs.ConfigStub(channel, "localhost") return stubs.ConfigStub(channel, "localhost")
@pytest.fixture
def ai_stub(self, channel):
return stubs.AIStub(channel, "localhost")
def test_list_nodes_integration(self, node_stub): def test_list_nodes_integration(self, node_stub):
nodes = node_stub.list_nodes() nodes = node_stub.list_nodes()
assert "router1" in nodes assert "router1" in nodes
@@ -170,6 +175,12 @@ class TestGRPCIntegration:
settings = config_stub.get_settings() settings = config_stub.get_settings()
assert settings["idletime"] == 99 assert settings["idletime"] == 99
def test_list_mcp_servers_integration(self, ai_stub):
ai_stub.configure_mcp("test-mcp", url="http://localhost:8080", enabled=True)
servers = ai_stub.list_mcp_servers()
assert "test-mcp" in servers
assert servers["test-mcp"]["url"] == "http://localhost:8080"
def test_add_delete_node_integration(self, node_stub): def test_add_delete_node_integration(self, node_stub):
node_stub.add_node("integration-test-node", {"host": "9.9.9.9"}) node_stub.add_node("integration-test-node", {"host": "9.9.9.9"})
assert "integration-test-node" in node_stub.list_nodes() assert "integration-test-node" in node_stub.list_nodes()