feat(ai): enhance session management, inquirer theme, and gRPC MCP support
- AI & Session Management:
- Add random suffix to session IDs to ensure uniqueness.
- Implement optional pagination/limit in session listing (default 20).
- Add `--all` flag to `ai` CLI commands to view all sessions.
- Keep active session ID and path synced correctly during clean session startups.
- CLI UI/UX:
- Add a custom `ConnpyTheme` for inquirer prompts that dynamically translates
active hex style colors to terminal ANSI/blessed escapes.
- gRPC & Services:
- Implement remote MCP server listing (`list_mcp_servers` RPC and services).
- Stream responder updates (`__RESPONDER__`) to toggle headers dynamically between
"Network Engineer" and "Network Architect" on remote/web clients.
- Fix remote client deadlock risk by ensuring `final_mark` is sent on exceptions.
- Hydrate client-side chat history correctly on initial streaming request.
- Testing:
- Add integration tests for AI gRPC services and MCP server listing.
This commit is contained in:
+31
-9
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
import secrets
|
||||
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
@@ -165,8 +167,8 @@ class ai:
|
||||
# Session Management
|
||||
self.sessions_dir = os.path.join(self.config.defaultdir, "ai_sessions")
|
||||
os.makedirs(self.sessions_dir, exist_ok=True)
|
||||
self.session_id = None
|
||||
self.session_path = None
|
||||
self.session_id = getattr(self.config, "session_id", None)
|
||||
self.session_path = os.path.join(self.sessions_dir, f"{self.session_id}.json") if self.session_id else None
|
||||
|
||||
# Prompts base agnósticos
|
||||
architect_instructions = ""
|
||||
@@ -877,16 +879,27 @@ class ai:
|
||||
continue
|
||||
return sorted(sessions, key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
def list_sessions(self):
|
||||
def list_sessions(self, limit=20):
|
||||
"""Prints a list of sessions using printer.table."""
|
||||
sessions = self._get_sessions()
|
||||
if not sessions:
|
||||
printer.info("No saved AI sessions found.")
|
||||
return
|
||||
|
||||
total = len(sessions)
|
||||
if limit and total > limit:
|
||||
sessions = sessions[:limit]
|
||||
|
||||
columns = ["ID", "Title", "Created At", "Model"]
|
||||
rows = [[s["id"], s["title"], s["created_at"], s["model"]] for s in sessions]
|
||||
printer.table("AI Persisted Sessions", columns, rows)
|
||||
|
||||
title = "AI Persisted Sessions"
|
||||
if limit and total > limit:
|
||||
title += f" (Showing last {limit} of {total})"
|
||||
|
||||
printer.table(title, columns, rows)
|
||||
if limit and total > limit:
|
||||
printer.info(f"Use '--list --all' (if supported) or check the sessions directory to see all {total} sessions.")
|
||||
|
||||
def load_session_data(self, session_id):
|
||||
"""Loads a session's raw data by ID."""
|
||||
@@ -917,8 +930,10 @@ class ai:
|
||||
return sessions[0]["id"] if sessions else None
|
||||
|
||||
def _generate_session_id(self, query):
|
||||
"""Generates a unique session ID based on timestamp."""
|
||||
return datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
"""Generates a unique session ID based on timestamp and a random suffix."""
|
||||
ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
suffix = secrets.token_hex(2)
|
||||
return f"{ts}-{suffix}"
|
||||
|
||||
def save_session(self, history, title=None, model=None):
|
||||
"""Saves current history to the session file."""
|
||||
@@ -927,6 +942,8 @@ class ai:
|
||||
first_user_msg = next((m["content"] for m in history if m["role"] == "user"), "new-session")
|
||||
self.session_id = self._generate_session_id(first_user_msg)
|
||||
self.session_path = os.path.join(self.sessions_dir, f"{self.session_id}.json")
|
||||
elif not self.session_path:
|
||||
self.session_path = os.path.join(self.sessions_dir, f"{self.session_id}.json")
|
||||
|
||||
# If it's a new file, we might want to set a better title
|
||||
if not os.path.exists(self.session_path) and not title:
|
||||
@@ -970,7 +987,12 @@ class ai:
|
||||
if chat_history is None: chat_history = []
|
||||
|
||||
# Load session if provided and history is empty
|
||||
if session_id and not chat_history:
|
||||
if session_id:
|
||||
# Force the session_id even if it doesn't exist yet
|
||||
self.session_id = session_id
|
||||
self.session_path = os.path.join(self.sessions_dir, f"{session_id}.json")
|
||||
|
||||
if not chat_history:
|
||||
session_data = self.load_session_data(session_id)
|
||||
if session_data:
|
||||
chat_history = session_data.get("history", [])
|
||||
@@ -1058,8 +1080,8 @@ class ai:
|
||||
|
||||
label = "[architect][bold]Architect[/bold][/architect]" if current_brain == "architect" else "[engineer][bold]Engineer[/bold][/engineer]"
|
||||
if status:
|
||||
# Notify responder identity ONLY for web/remote clients (StatusBridge has is_web)
|
||||
if getattr(status, "is_web", False):
|
||||
# Notify responder identity for web/remote clients
|
||||
if getattr(status, "is_web", False) or getattr(status, "is_remote", False):
|
||||
status.update(f"__RESPONDER__:{current_brain}")
|
||||
status.update(f"{label} is thinking... (step {iteration})")
|
||||
|
||||
|
||||
@@ -15,13 +15,22 @@ class AIHandler:
|
||||
|
||||
def dispatch(self, args):
|
||||
if args.list_sessions:
|
||||
sessions = self.app.services.ai.list_sessions()
|
||||
limit = 20 if not getattr(args, "all", False) else None
|
||||
sessions, total = self.app.services.ai.list_sessions(limit=limit)
|
||||
if not sessions:
|
||||
printer.info("No saved AI sessions found.")
|
||||
return
|
||||
|
||||
columns = ["ID", "Title", "Created At", "Model"]
|
||||
rows = [[s["id"], s["title"], s["created_at"], s["model"]] for s in sessions]
|
||||
printer.table("AI Persisted Sessions", columns, rows)
|
||||
|
||||
title = "AI Persisted Sessions"
|
||||
if limit and total > limit:
|
||||
title += f" (Showing last {limit} of {total})"
|
||||
|
||||
printer.table(title, columns, rows)
|
||||
if limit and total > limit:
|
||||
printer.info(f"Use '--list --all' to see all {total} sessions.")
|
||||
return
|
||||
|
||||
if args.delete_session:
|
||||
@@ -102,7 +111,7 @@ class AIHandler:
|
||||
if history:
|
||||
mdprint(f"[debug]Analyzing {len(history)} previous messages...[/debug]\n")
|
||||
else:
|
||||
printer.error(f"Could not load session {session_id}. Starting clean.")
|
||||
printer.info(f"Session '{session_id}' not found. Starting clean.")
|
||||
|
||||
if not history:
|
||||
mdprint(Rule(style="engineer"))
|
||||
@@ -116,7 +125,7 @@ class AIHandler:
|
||||
if user_query.lower() in ['exit', 'quit', 'bye', 'cancel']: break
|
||||
|
||||
with console.status("[ai_status]Agent is thinking...") as status:
|
||||
result = self.app.myai.ask(user_query, chat_history=history, status=status, debug=args.debug, trust=args.trust, **self.ai_overrides)
|
||||
result = self.app.myai.ask(user_query, chat_history=history, status=status, debug=args.debug, trust=args.trust, session_id=session_id, **self.ai_overrides)
|
||||
|
||||
new_history = result.get("chat_history")
|
||||
if new_history is not None:
|
||||
@@ -147,8 +156,7 @@ class AIHandler:
|
||||
action = mcp_args[0].lower()
|
||||
|
||||
if action == "list":
|
||||
settings = self.app.services.config_svc.get_settings()
|
||||
mcp_servers = settings.get("ai", {}).get("mcp_servers", {})
|
||||
mcp_servers = self.app.services.ai.list_mcp_servers()
|
||||
if not mcp_servers:
|
||||
printer.info("No MCP servers configured.")
|
||||
else:
|
||||
@@ -213,8 +221,7 @@ class AIHandler:
|
||||
from .forms import Forms
|
||||
self.app.cli_forms = Forms(self.app)
|
||||
|
||||
settings = self.app.services.config_svc.get_settings()
|
||||
mcp_servers = settings.get("ai", {}).get("mcp_servers", {})
|
||||
mcp_servers = self.app.services.ai.list_mcp_servers()
|
||||
|
||||
result = self.app.cli_forms.mcp_wizard(mcp_servers)
|
||||
if not result:
|
||||
|
||||
+72
-1
@@ -1,10 +1,81 @@
|
||||
import os
|
||||
import inquirer
|
||||
from inquirer.themes import Default, term
|
||||
|
||||
try:
|
||||
from pyfzf.pyfzf import FzfPrompt
|
||||
except ImportError:
|
||||
FzfPrompt = None
|
||||
|
||||
def hex_to_blessed(hex_str):
|
||||
"""Convert hex color string to blessed/ansi format."""
|
||||
if not hex_str or not isinstance(hex_str, str):
|
||||
return term.normal
|
||||
|
||||
# Check for bold prefix
|
||||
prefix = ""
|
||||
if hex_str.startswith('bold '):
|
||||
prefix = term.bold
|
||||
hex_str = hex_str.replace('bold ', '').strip()
|
||||
|
||||
# If it's a standard color name
|
||||
if not hex_str.startswith('#'):
|
||||
return prefix + getattr(term, hex_str, term.normal)
|
||||
|
||||
# Parse hex
|
||||
try:
|
||||
h = hex_str.lstrip('#')
|
||||
if len(h) == 3:
|
||||
h = ''.join([c*2 for c in h])
|
||||
r = int(h[0:2], 16)
|
||||
g = int(h[2:4], 16)
|
||||
b = int(h[4:6], 16)
|
||||
|
||||
# Try RGB, fallback to standard cyan if it fails or returns empty
|
||||
try:
|
||||
c = term.color_rgb(r, g, b)
|
||||
if not c: # Some terms return empty for RGB
|
||||
return prefix + term.cyan
|
||||
return prefix + c
|
||||
except:
|
||||
return prefix + term.cyan
|
||||
except:
|
||||
return prefix + term.normal
|
||||
|
||||
# Custom inquirer theme matching connpy colors
|
||||
class ConnpyTheme(Default):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
from ..printer import _global_active_styles
|
||||
# Use user_prompt as primary accent, fallback to info/cyan
|
||||
accent = _global_active_styles.get("user_prompt", _global_active_styles.get("info", "cyan"))
|
||||
accent_color = hex_to_blessed(accent)
|
||||
|
||||
self.Question.mark_color = accent_color
|
||||
self.List.selection_color = accent_color
|
||||
self.List.selection_cursor = ">"
|
||||
except:
|
||||
# Absolute fallback to standard cyan
|
||||
self.Question.mark_color = term.cyan
|
||||
self.List.selection_color = term.bold_cyan
|
||||
self.List.selection_cursor = ">"
|
||||
|
||||
def get_theme():
|
||||
"""Returns a fresh instance of the theme with current colors."""
|
||||
return ConnpyTheme()
|
||||
|
||||
class ThemeProxy:
|
||||
"""Proxy to ensure theme colors are resolved at runtime."""
|
||||
def __getattr__(self, name):
|
||||
return getattr(get_theme(), name)
|
||||
def __iter__(self):
|
||||
return iter(get_theme())
|
||||
def __getitem__(self, item):
|
||||
return get_theme()[item]
|
||||
|
||||
theme = ThemeProxy()
|
||||
|
||||
def get_config_dir():
|
||||
home = os.path.expanduser("~")
|
||||
defaultdir = os.path.join(home, '.config/conn')
|
||||
@@ -56,7 +127,7 @@ def choose(app, list_, name, action):
|
||||
return answer[0]
|
||||
else:
|
||||
questions = [inquirer.List(name, message="Pick {} to {}:".format(name,action), choices=list_, carousel=True)]
|
||||
answer = inquirer.prompt(questions)
|
||||
answer = inquirer.prompt(questions, theme=theme)
|
||||
if answer == None:
|
||||
return None
|
||||
else:
|
||||
|
||||
@@ -281,6 +281,7 @@ class connapp:
|
||||
aiparser.add_argument("--debug", action="store_true", help="Show AI reasoning and tool calls")
|
||||
aiparser.add_argument("-y", "--trust", action="store_true", help="Trust AI to execute unsafe commands without confirmation")
|
||||
aiparser.add_argument("--list", "--list-sessions", dest="list_sessions", action="store_true", help="List saved AI sessions")
|
||||
aiparser.add_argument("--all", action="store_true", help="Show all sessions without limit")
|
||||
aiparser.add_argument("--session", nargs=1, help="Resume a specific AI session by ID")
|
||||
aiparser.add_argument("--resume", action="store_true", help="Resume the most recent AI session")
|
||||
aiparser.add_argument("--delete", "--delete-session", dest="delete_session", nargs=1, help="Delete an AI session by ID")
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -483,6 +483,7 @@ class ProfileServicer(connpy_pb2_grpc.ProfileServiceServicer):
|
||||
self.service = ProfileService(config)
|
||||
self.node_service = NodeService(config)
|
||||
|
||||
|
||||
@handle_errors
|
||||
def list_profiles(self, request, context):
|
||||
f = request.filter_str if request.filter_str else None
|
||||
@@ -731,6 +732,7 @@ class StatusBridge:
|
||||
self.on_interrupt = self._force_interrupt
|
||||
self.thread = None
|
||||
self.is_web = is_web
|
||||
self.is_remote = True
|
||||
|
||||
def _force_interrupt(self):
|
||||
"""Forcefully raise KeyboardInterrupt in the target thread."""
|
||||
@@ -862,9 +864,11 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||
print(f"AI Task Error: {e}")
|
||||
traceback.print_exc()
|
||||
chunk_queue.put(("status", f"Error: {str(e)}"))
|
||||
# Crucial: always send final_mark to avoid client deadlock
|
||||
chunk_queue.put(("final_mark", {"response": f"Error: {str(e)}", "chat_history": history, "error": True}))
|
||||
|
||||
def request_listener():
|
||||
nonlocal bridge, is_web, ai_thread, agent_instance
|
||||
nonlocal bridge, is_web, ai_thread, agent_instance, history
|
||||
try:
|
||||
for req in request_iterator:
|
||||
if req.interrupt:
|
||||
@@ -878,6 +882,11 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||
|
||||
if req.input_text:
|
||||
is_web = "web" in (req.session_id or "").lower() or (req.session_id or "").lower().startswith("ws-")
|
||||
|
||||
# Hydrate history from client if it's the first interaction in this stream
|
||||
if not history and req.chat_history:
|
||||
from .utils import from_value
|
||||
history = from_value(req.chat_history) or []
|
||||
if not bridge:
|
||||
bridge = StatusBridge(chunk_queue, request_queue=request_queue, is_web=is_web)
|
||||
|
||||
@@ -948,7 +957,8 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||
|
||||
@handle_errors
|
||||
def list_sessions(self, request, context):
|
||||
return connpy_pb2.ValueResponse(data=to_value(self.service.list_sessions()))
|
||||
sessions, total = self.service.list_sessions()
|
||||
return connpy_pb2.ValueResponse(data=to_value(sessions))
|
||||
|
||||
@handle_errors
|
||||
def delete_session(self, request, context):
|
||||
@@ -971,6 +981,11 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||
)
|
||||
return Empty()
|
||||
|
||||
@handle_errors
|
||||
def list_mcp_servers(self, request, context):
|
||||
mcp_servers = self.service.list_mcp_servers()
|
||||
return connpy_pb2.ValueResponse(data=to_value(mcp_servers))
|
||||
|
||||
@handle_errors
|
||||
def load_session_data(self, request, context):
|
||||
return connpy_pb2.StructResponse(data=to_struct(self.service.load_session_data(request.value)))
|
||||
|
||||
@@ -758,6 +758,7 @@ class AIStub:
|
||||
|
||||
full_content = ""
|
||||
header_printed = False
|
||||
current_responder = "engineer"
|
||||
final_result = {"response": "", "chat_history": []}
|
||||
|
||||
# Background thread to pull responses from gRPC into a local queue
|
||||
@@ -802,6 +803,10 @@ class AIStub:
|
||||
break
|
||||
|
||||
if response.status_update:
|
||||
if response.status_update.startswith("__RESPONDER__:"):
|
||||
current_responder = response.status_update.split(":")[1].lower()
|
||||
continue
|
||||
|
||||
if response.requires_confirmation:
|
||||
if status: status.stop()
|
||||
|
||||
@@ -854,7 +859,9 @@ class AIStub:
|
||||
stable_console = RichConsole(theme=connpy_theme, file=get_original_stdout())
|
||||
|
||||
# Print header on first chunk
|
||||
stable_console.print(Rule("[bold engineer]Network Engineer[/bold engineer]", style="engineer"))
|
||||
alias = "architect" if current_responder == "architect" else "engineer"
|
||||
role_label = "Network Architect" if current_responder == "architect" else "Network Engineer"
|
||||
stable_console.print(Rule(f"[bold {alias}]{role_label}[/bold {alias}]", style=alias))
|
||||
header_printed = True
|
||||
|
||||
# Initialize parser
|
||||
@@ -906,8 +913,13 @@ class AIStub:
|
||||
return self.stub.confirm(connpy_pb2.StringRequest(value=input_text)).value
|
||||
|
||||
@handle_errors
|
||||
def list_sessions(self):
|
||||
return from_value(self.stub.list_sessions(Empty()).data)
|
||||
def list_sessions(self, limit=None):
|
||||
from .utils import from_value
|
||||
res = self.stub.list_sessions(Empty())
|
||||
sessions = from_value(res.data) or []
|
||||
if limit and len(sessions) > limit:
|
||||
return sessions[:limit], len(sessions)
|
||||
return sessions, len(sessions)
|
||||
|
||||
@handle_errors
|
||||
def delete_session(self, session_id):
|
||||
@@ -929,6 +941,11 @@ class AIStub:
|
||||
)
|
||||
self.stub.configure_mcp(req)
|
||||
|
||||
@handle_errors
|
||||
def list_mcp_servers(self):
|
||||
res = self.stub.list_mcp_servers(Empty())
|
||||
return from_value(res.data) or {}
|
||||
|
||||
@handle_errors
|
||||
def load_session_data(self, session_id):
|
||||
return from_struct(self.stub.load_session_data(connpy_pb2.StringRequest(value=session_id)).data)
|
||||
|
||||
@@ -70,6 +70,7 @@ service AIService {
|
||||
rpc delete_session (StringRequest) returns (google.protobuf.Empty) {}
|
||||
rpc configure_provider (ProviderRequest) returns (google.protobuf.Empty) {}
|
||||
rpc configure_mcp (MCPRequest) returns (google.protobuf.Empty) {}
|
||||
rpc list_mcp_servers (google.protobuf.Empty) returns (ValueResponse) {}
|
||||
rpc load_session_data (StringRequest) returns (StructResponse) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -167,11 +167,14 @@ class AIService(BaseService):
|
||||
return await asyncio.wrap_future(future)
|
||||
|
||||
|
||||
def list_sessions(self):
|
||||
"""Return a list of all saved AI sessions."""
|
||||
def list_sessions(self, limit=None):
|
||||
"""Return a list of saved AI sessions, optionally limited."""
|
||||
from connpy.ai import ai
|
||||
agent = ai(self.config)
|
||||
return agent._get_sessions()
|
||||
sessions = agent._get_sessions()
|
||||
if limit and len(sessions) > limit:
|
||||
return sessions[:limit], len(sessions)
|
||||
return sessions, len(sessions)
|
||||
|
||||
def delete_session(self, session_id):
|
||||
"""Delete an AI session by ID."""
|
||||
@@ -228,6 +231,11 @@ class AIService(BaseService):
|
||||
self.config.config["ai"] = ai_settings
|
||||
self.config._saveconfig(self.config.file)
|
||||
|
||||
def list_mcp_servers(self) -> dict:
|
||||
"""Get the configured MCP servers."""
|
||||
ai_settings = self.config.config.get("ai", {})
|
||||
return ai_settings.get("mcp_servers", {})
|
||||
|
||||
def load_session_data(self, session_id):
|
||||
"""Load a session's raw data by ID."""
|
||||
from connpy.ai import ai
|
||||
|
||||
@@ -120,6 +120,7 @@ class TestGRPCIntegration:
|
||||
connpy_pb2_grpc.add_ConfigServiceServicer_to_server(server.ConfigServicer(populated_config), srv)
|
||||
connpy_pb2_grpc.add_ExecutionServiceServicer_to_server(server.ExecutionServicer(populated_config), srv)
|
||||
connpy_pb2_grpc.add_ImportExportServiceServicer_to_server(server.ImportExportServicer(populated_config), srv)
|
||||
connpy_pb2_grpc.add_AIServiceServicer_to_server(server.AIServicer(populated_config), srv)
|
||||
|
||||
port = srv.add_insecure_port('127.0.0.1:0')
|
||||
srv.start()
|
||||
@@ -143,6 +144,10 @@ class TestGRPCIntegration:
|
||||
def config_stub(self, channel):
|
||||
return stubs.ConfigStub(channel, "localhost")
|
||||
|
||||
@pytest.fixture
|
||||
def ai_stub(self, channel):
|
||||
return stubs.AIStub(channel, "localhost")
|
||||
|
||||
def test_list_nodes_integration(self, node_stub):
|
||||
nodes = node_stub.list_nodes()
|
||||
assert "router1" in nodes
|
||||
@@ -170,6 +175,12 @@ class TestGRPCIntegration:
|
||||
settings = config_stub.get_settings()
|
||||
assert settings["idletime"] == 99
|
||||
|
||||
def test_list_mcp_servers_integration(self, ai_stub):
|
||||
ai_stub.configure_mcp("test-mcp", url="http://localhost:8080", enabled=True)
|
||||
servers = ai_stub.list_mcp_servers()
|
||||
assert "test-mcp" in servers
|
||||
assert servers["test-mcp"]["url"] == "http://localhost:8080"
|
||||
|
||||
def test_add_delete_node_integration(self, node_stub):
|
||||
node_stub.add_node("integration-test-node", {"host": "9.9.9.9"})
|
||||
assert "integration-test-node" in node_stub.list_nodes()
|
||||
|
||||
Reference in New Issue
Block a user