support fro multiple litellm auth methods

This commit is contained in:
2026-05-21 18:20:24 -03:00
parent cd8eeaad79
commit 3be9935541
18 changed files with 617 additions and 111 deletions
+4
View File
@@ -146,6 +146,7 @@ package.json
# Development docs # Development docs
connpy_roadmap.md connpy_roadmap.md
testnew/
testall/ testall/
testremote/ testremote/
*.db *.db
@@ -170,3 +171,6 @@ MULTI_USER_IMPLEMENTATION_STEPS.md
#themes #themes
nord.yml nord.yml
theme.py theme.py
#ai auth
auth.json
+49 -16
View File
@@ -108,7 +108,7 @@ class ai:
r'^systemctl\s+status\s+', r'^journalctl\s+' r'^systemctl\s+status\s+', r'^journalctl\s+'
] ]
def __init__(self, config, org=None, api_key=None, engineer_model=None, architect_model=None, engineer_api_key=None, architect_api_key=None, console=None, confirm_handler=None, trust=False): def __init__(self, config, org=None, api_key=None, engineer_model=None, architect_model=None, engineer_api_key=None, architect_api_key=None, console=None, confirm_handler=None, trust=False, engineer_auth=None, architect_auth=None, **kwargs):
self.config = config self.config = config
self.console = console or printer.console self.console = console or printer.console
self.confirm_handler = confirm_handler or self._local_confirm_handler self.confirm_handler = confirm_handler or self._local_confirm_handler
@@ -127,6 +127,29 @@ class ai:
self.engineer_key = engineer_api_key or aiconfig.get("engineer_api_key") self.engineer_key = engineer_api_key or aiconfig.get("engineer_api_key")
self.architect_key = architect_api_key or aiconfig.get("architect_api_key") self.architect_key = architect_api_key or aiconfig.get("architect_api_key")
# Auth configurations (Prioridad: Argumento -> Config)
self.engineer_auth = engineer_auth if engineer_auth is not None else aiconfig.get("engineer_auth")
if self.engineer_auth is None:
self.engineer_auth = {}
elif not isinstance(self.engineer_auth, dict):
self.engineer_auth = {}
self.architect_auth = architect_auth if architect_auth is not None else aiconfig.get("architect_auth")
if self.architect_auth is None:
self.architect_auth = {}
elif not isinstance(self.architect_auth, dict):
self.architect_auth = {}
# Backward compatibility fallbacks: only inject api_key if the auth dict is empty/not configured
if self.engineer_key and not self.engineer_auth:
self.engineer_auth["api_key"] = self.engineer_key
if self.architect_key and not self.architect_auth:
self.architect_auth["api_key"] = self.architect_key
# Strategic Reasoning Engine (Architect) availability
is_architect_keyless = "vertex" in self.architect_model.lower() or "ollama" in self.architect_model.lower() or "local" in self.architect_model.lower()
self.has_architect = bool(self.architect_key or self.architect_auth or is_architect_keyless)
# Custom Trusted Commands Regexes # Custom Trusted Commands Regexes
custom_trusted = aiconfig.get("trusted_commands", []) custom_trusted = aiconfig.get("trusted_commands", [])
if isinstance(custom_trusted, str): if isinstance(custom_trusted, str):
@@ -172,7 +195,7 @@ class ai:
# Prompts base agnósticos # Prompts base agnósticos
architect_instructions = "" architect_instructions = ""
if self.architect_key: if self.has_architect:
architect_instructions = """ architect_instructions = """
CRITICAL - CONSULT vs ESCALATE: CRITICAL - CONSULT vs ESCALATE:
- ALWAYS use 'consult_architect' for: Configuration planning, design decisions, complex troubleshooting. - ALWAYS use 'consult_architect' for: Configuration planning, design decisions, complex troubleshooting.
@@ -188,7 +211,7 @@ class ai:
else: else:
architect_instructions = """ architect_instructions = """
CRITICAL - ARCHITECT UNAVAILABLE: CRITICAL - ARCHITECT UNAVAILABLE:
- The Strategic Reasoning Engine (Architect) is currently UNAVAILABLE because its API key is not configured. - The Strategic Reasoning Engine (Architect) is currently UNAVAILABLE because its API key or authentication is not configured.
- DO NOT attempt to consult or escalate to the architect. - DO NOT attempt to consult or escalate to the architect.
- If the user asks to consult the architect, inform them that the Architect is offline and offer to help them directly to the best of your abilities. - If the user asks to consult the architect, inform them that the Architect is offline and offer to help them directly to the best of your abilities.
""" """
@@ -294,15 +317,19 @@ class ai:
if status_formatter: if status_formatter:
self.tool_status_formatters[name] = status_formatter self.tool_status_formatters[name] = status_formatter
def _stream_completion(self, model, messages, tools, api_key, status=None, label="", debug=False, chunk_callback=None, **kwargs): def _stream_completion(self, model, messages, tools, api_key=None, status=None, label="", debug=False, chunk_callback=None, auth=None, **kwargs):
"""Stream a completion call, rendering styled Markdown in real-time. """Stream a completion call, rendering styled Markdown in real-time.
Returns (response, streamed) where: Returns (response, streamed) where:
- response: reconstructed ModelResponse (same as non-streaming) - response: reconstructed ModelResponse (same as non-streaming)
- streamed: True if text was rendered to console during streaming - streamed: True if text was rendered to console during streaming
""" """
auth_dict = auth if auth is not None else {}
if api_key and "api_key" not in auth_dict:
auth_dict = auth_dict.copy()
auth_dict["api_key"] = api_key
stream_resp = completion(model=model, messages=messages, tools=tools, api_key=api_key, stream=True, **kwargs) stream_resp = completion(model=model, messages=messages, tools=tools, stream=True, **auth_dict, **kwargs)
chunks = [] chunks = []
full_content = "" full_content = ""
@@ -745,7 +772,7 @@ class ai:
try: try:
safe_messages = self._sanitize_messages(messages) safe_messages = self._sanitize_messages(messages)
response = completion(model=self.engineer_model, messages=safe_messages, tools=tools, api_key=self.engineer_key) response = completion(model=self.engineer_model, messages=safe_messages, tools=tools, **self.engineer_auth)
except Exception as e: except Exception as e:
if status: status.stop() if status: status.stop()
raise ValueError(f"Engineer failed to connect: {str(e)}") raise ValueError(f"Engineer failed to connect: {str(e)}")
@@ -981,8 +1008,9 @@ class ai:
@MethodHook @MethodHook
def ask(self, user_input, dryrun=False, chat_history=None, status=None, debug=False, stream=True, session_id=None, chunk_callback=None): def ask(self, user_input, dryrun=False, chat_history=None, status=None, debug=False, stream=True, session_id=None, chunk_callback=None):
if not self.engineer_key: is_engineer_keyless = "vertex" in self.engineer_model.lower() or "ollama" in self.engineer_model.lower() or "local" in self.engineer_model.lower()
raise ValueError("Engineer API key not configured. Use 'connpy config --engineer-api-key <key>' to set it.") if not self.engineer_key and not self.engineer_auth and not is_engineer_keyless:
raise ValueError("Engineer API key or authentication not configured. Use 'connpy config --engineer-auth <auth>' to set it.")
if chat_history is None: chat_history = [] if chat_history is None: chat_history = []
@@ -1031,6 +1059,7 @@ class ai:
tools = self._get_architect_tools() if current_brain == "architect" else self._get_engineer_tools() tools = self._get_architect_tools() if current_brain == "architect" else self._get_engineer_tools()
model = self.architect_model if current_brain == "architect" else self.engineer_model model = self.architect_model if current_brain == "architect" else self.engineer_model
key = self.architect_key if current_brain == "architect" else self.engineer_key key = self.architect_key if current_brain == "architect" else self.engineer_key
current_auth = self.architect_auth if current_brain == "architect" else self.engineer_auth
# Estructura optimizada para Prompt Caching (Solo para Anthropic directo, Vertex tiene reglas distintas) # Estructura optimizada para Prompt Caching (Solo para Anthropic directo, Vertex tiene reglas distintas)
if "claude" in model.lower() and "vertex" not in model.lower(): if "claude" in model.lower() and "vertex" not in model.lower():
@@ -1090,12 +1119,12 @@ class ai:
safe_messages = self._sanitize_messages(messages) safe_messages = self._sanitize_messages(messages)
if stream: if stream:
response, streamed_response = self._stream_completion( response, streamed_response = self._stream_completion(
model=model, messages=safe_messages, tools=tools, api_key=key, model=model, messages=safe_messages, tools=tools, auth=current_auth,
status=status, label=label, debug=debug, num_retries=3, status=status, label=label, debug=debug, num_retries=3,
chunk_callback=chunk_callback chunk_callback=chunk_callback
) )
else: else:
response = completion(model=model, messages=safe_messages, tools=tools, api_key=key, num_retries=3) response = completion(model=model, messages=safe_messages, tools=tools, num_retries=3, **current_auth)
except Exception as e: except Exception as e:
if current_brain == "architect": if current_brain == "architect":
if status: status.update("[unavailable]Architect unavailable! Falling back to Engineer...") if status: status.update("[unavailable]Architect unavailable! Falling back to Engineer...")
@@ -1104,6 +1133,7 @@ class ai:
model = self.engineer_model model = self.engineer_model
tools = self._get_engineer_tools() tools = self._get_engineer_tools()
key = self.engineer_key key = self.engineer_key
current_auth = self.engineer_auth
# Rebuild messages with Engineer system prompt and original user request # Rebuild messages with Engineer system prompt and original user request
messages = [{"role": "system", "content": self.engineer_system_prompt}] messages = [{"role": "system", "content": self.engineer_system_prompt}]
# Add chat history if exists (excluding system prompt) # Add chat history if exists (excluding system prompt)
@@ -1196,6 +1226,7 @@ class ai:
model = self.architect_model model = self.architect_model
tools = self._get_architect_tools() tools = self._get_architect_tools()
key = self.architect_key key = self.architect_key
current_auth = self.architect_auth
messages[0] = {"role": "system", "content": self.architect_system_prompt} messages[0] = {"role": "system", "content": self.architect_system_prompt}
# Prepare handover context to inject AFTER all tool responses # Prepare handover context to inject AFTER all tool responses
handover_msg = f"HANDOVER FROM EXECUTION ENGINE\n\nReason: {args['reason']}\n\nContext: {args['context']}\n\nYou are now in control of this conversation." handover_msg = f"HANDOVER FROM EXECUTION ENGINE\n\nReason: {args['reason']}\n\nContext: {args['context']}\n\nYou are now in control of this conversation."
@@ -1217,6 +1248,7 @@ class ai:
model = self.engineer_model model = self.engineer_model
tools = self._get_engineer_tools() tools = self._get_engineer_tools()
key = self.engineer_key key = self.engineer_key
current_auth = self.engineer_auth
messages[0] = {"role": "system", "content": self.engineer_system_prompt} messages[0] = {"role": "system", "content": self.engineer_system_prompt}
# Prepare handover context to inject AFTER all tool responses # Prepare handover context to inject AFTER all tool responses
handover_msg = f"HANDOVER FROM ARCHITECT\n\nSummary: {args['summary']}\n\nYou are now back in control. Continue handling the user's requests." handover_msg = f"HANDOVER FROM ARCHITECT\n\nSummary: {args['summary']}\n\nYou are now back in control. Continue handling the user's requests."
@@ -1258,7 +1290,7 @@ class ai:
messages.append({"role": "user", "content": "Hard iteration limit reached. Please provide a summary of your findings so far."}) messages.append({"role": "user", "content": "Hard iteration limit reached. Please provide a summary of your findings so far."})
try: try:
safe_messages = self._sanitize_messages(messages) safe_messages = self._sanitize_messages(messages)
response = completion(model=model, messages=safe_messages, tools=[], api_key=key) response = completion(model=model, messages=safe_messages, tools=[], **current_auth)
resp_msg = response.choices[0].message resp_msg = response.choices[0].message
messages.append(resp_msg.model_dump(exclude_none=True)) messages.append(resp_msg.model_dump(exclude_none=True))
except Exception as e: except Exception as e:
@@ -1278,7 +1310,7 @@ class ai:
try: try:
safe_messages = self._sanitize_messages(summary_messages) safe_messages = self._sanitize_messages(summary_messages)
# Use tools=None to force a text summary during interruption # Use tools=None to force a text summary during interruption
response = completion(model=model, messages=safe_messages, tools=None, api_key=key) response = completion(model=model, messages=safe_messages, tools=None, **current_auth)
resp_msg = response.choices[0].message resp_msg = response.choices[0].message
messages.append(resp_msg.model_dump(exclude_none=True)) messages.append(resp_msg.model_dump(exclude_none=True))
@@ -1415,6 +1447,7 @@ Node: {node_name}"""
# Use models based on persona # Use models based on persona
current_model = self.architect_model if persona == "architect" else self.engineer_model current_model = self.architect_model if persona == "architect" else self.engineer_model
current_key = self.architect_key if persona == "architect" else self.engineer_key current_key = self.architect_key if persona == "architect" else self.engineer_key
current_auth = self.architect_auth if persona == "architect" else self.engineer_auth
try: try:
while iteration < max_iterations: while iteration < max_iterations:
@@ -1424,8 +1457,8 @@ Node: {node_name}"""
model=current_model, model=current_model,
messages=messages, messages=messages,
tools=mcp_tools if mcp_tools else None, tools=mcp_tools if mcp_tools else None,
api_key=current_key, stream=True,
stream=True **current_auth
) )
full_content = "" full_content = ""
@@ -1498,8 +1531,8 @@ Node: {node_name}"""
model=self.engineer_model, model=self.engineer_model,
messages=messages, messages=messages,
tools=None, tools=None,
api_key=self.engineer_key, stream=True,
stream=True **self.engineer_auth
) )
full_content = "" full_content = ""
+44 -7
View File
@@ -47,7 +47,7 @@ class AIHandler:
# Determinar session_id para retomar # Determinar session_id para retomar
session_id = None session_id = None
if args.resume: if args.resume:
sessions = self.app.services.ai.list_sessions() sessions, _ = self.app.services.ai.list_sessions()
session_id = sessions[0]["id"] if sessions else None session_id = sessions[0]["id"] if sessions else None
if not session_id: if not session_id:
printer.warning("No previous session found to resume.") printer.warning("No previous session found to resume.")
@@ -66,15 +66,22 @@ class AIHandler:
elif settings.get(key): elif settings.get(key):
arguments[key] = settings.get(key) arguments[key] = settings.get(key)
for key in ["engineer_auth", "architect_auth"]:
cli_val = getattr(args, key, None)
if cli_val:
arguments[key] = self._parse_auth_value(cli_val[0])
elif settings.get(key):
arguments[key] = settings.get(key)
# Check keys only if running in local mode (not remote) # Check keys only if running in local mode (not remote)
if getattr(self.app.services, "mode", "local") == "local": if getattr(self.app.services, "mode", "local") == "local":
if not arguments.get("engineer_api_key"): if not arguments.get("engineer_api_key") and not arguments.get("engineer_auth"):
printer.error("Engineer API key not configured. The chat cannot start.") printer.error("Engineer API key/auth not configured. The chat cannot start.")
printer.info("Use 'connpy config --engineer-api-key <key>' to set it.") printer.info("Use 'connpy config --engineer-api-key <key>' or 'connpy config --engineer-auth <auth>' to set it.")
sys.exit(1) sys.exit(1)
if not arguments.get("architect_api_key"): if not arguments.get("architect_api_key") and not arguments.get("architect_auth"):
printer.warning("Architect API key not configured. Architect will be unavailable.") printer.warning("Architect API key/auth not configured. Architect will be unavailable.")
printer.info("Use 'connpy config --architect-api-key <key>' to enable it.") printer.info("Use 'connpy config --architect-api-key <key>' or 'connpy config --architect-auth <auth>' to enable it.")
# El resto de la interacción el CLI la maneja con el agente subyacente # El resto de la interacción el CLI la maneja con el agente subyacente
self.app.myai = self.app.services.ai self.app.myai = self.app.services.ai
@@ -256,3 +263,33 @@ class AIHandler:
except Exception as e: except Exception as e:
printer.error(str(e)) printer.error(str(e))
def _parse_auth_value(self, value):
if not value or value.lower() in ["none", "clear"]:
return None
import os
import yaml
import json
if os.path.exists(value):
try:
with open(value, "r") as f:
content = f.read()
try:
return json.loads(content)
except ValueError:
return yaml.safe_load(content)
except Exception as e:
printer.error(f"Failed to read/parse auth file '{value}': {e}")
sys.exit(1)
try:
return json.loads(value)
except ValueError:
try:
parsed = yaml.safe_load(value)
if isinstance(parsed, dict):
return parsed
raise ValueError()
except Exception:
printer.error("Auth parameter must be a valid JSON/YAML string, or a path to a JSON/YAML file.")
sys.exit(1)
+52 -2
View File
@@ -19,8 +19,10 @@ class ConfigHandler:
"theme": self.set_theme, "theme": self.set_theme,
"engineer_model": self.set_ai_config, "engineer_model": self.set_ai_config,
"engineer_api_key": self.set_ai_config, "engineer_api_key": self.set_ai_config,
"engineer_auth": self.set_ai_config,
"architect_model": self.set_ai_config, "architect_model": self.set_ai_config,
"architect_api_key": self.set_ai_config, "architect_api_key": self.set_ai_config,
"architect_auth": self.set_ai_config,
"trusted_commands": self.set_ai_config, "trusted_commands": self.set_ai_config,
"service_mode": self.set_service_mode, "service_mode": self.set_service_mode,
"remote_host": self.set_remote_host, "remote_host": self.set_remote_host,
@@ -127,9 +129,57 @@ class ConfigHandler:
try: try:
settings = self.app.services.config_svc.get_settings() settings = self.app.services.config_svc.get_settings()
aiconfig = settings.get("ai", {}) aiconfig = settings.get("ai", {})
aiconfig[args.command] = args.data[0] val = args.data[0]
# Check for unset/clear request
if val.lower() in ["none", "clear", ""]:
if args.command in aiconfig:
del aiconfig[args.command]
else:
# If configuring auth, parse as dictionary (JSON/YAML or file path)
if args.command in ["engineer_auth", "architect_auth"]:
parsed_val = self._parse_auth_value(val)
if parsed_val is not None:
aiconfig[args.command] = parsed_val
else:
if args.command in aiconfig:
del aiconfig[args.command]
else:
aiconfig[args.command] = val
self.app.services.config_svc.update_setting("ai", aiconfig) self.app.services.config_svc.update_setting("ai", aiconfig)
printer.success("Config saved") printer.success("Config saved")
except ConnpyError as e: except (ConnpyError, InvalidConfigurationError) as e:
printer.error(str(e)) printer.error(str(e))
def _parse_auth_value(self, value):
if value.lower() in ["none", "clear", ""]:
return None
# Check if it's a file path
import os
if os.path.exists(value):
try:
with open(value, "r") as f:
content = f.read()
import json
try:
return json.loads(content)
except ValueError:
return yaml.safe_load(content)
except Exception as e:
raise InvalidConfigurationError(f"Failed to read/parse auth file '{value}': {e}")
# Try parsing as inline JSON/YAML
try:
import json
return json.loads(value)
except ValueError:
try:
parsed = yaml.safe_load(value)
if isinstance(parsed, dict):
return parsed
raise ValueError()
except Exception:
raise InvalidConfigurationError("Auth parameter must be a valid JSON/YAML string, or a path to a JSON/YAML file.")
+18 -16
View File
@@ -181,11 +181,28 @@ def _build_tree(nodes, folders, profiles, plugins, configdir):
ai_dict = {"__exclude_used__": True, "--help": None, "-h": None} ai_dict = {"__exclude_used__": True, "--help": None, "-h": None}
for opt in ["--engineer-model", "--engineer-api-key", "--architect-model", "--architect-api-key"]: for opt in ["--engineer-model", "--engineer-api-key", "--architect-model", "--architect-api-key"]:
ai_dict[opt] = {"*": ai_dict} # takes value, loops back ai_dict[opt] = {"*": ai_dict} # takes value, loops back
ai_dict["--engineer-auth"] = {"__extra__": lambda w: get_cwd(w, "--engineer-auth"), "*": ai_dict}
ai_dict["--architect-auth"] = {"__extra__": lambda w: get_cwd(w, "--architect-auth"), "*": ai_dict}
for opt in ["--debug", "--trust", "--list", "--list-sessions", "--session", "--resume", "--delete", "--delete-session", "-y"]: for opt in ["--debug", "--trust", "--list", "--list-sessions", "--session", "--resume", "--delete", "--delete-session", "-y"]:
ai_dict[opt] = ai_dict # takes no value, loops back ai_dict[opt] = ai_dict # takes no value, loops back
ai_dict["--mcp"] = mcp_dict ai_dict["--mcp"] = mcp_dict
ai_dict["*"] = ai_dict ai_dict["*"] = ai_dict
config_dict = {
"--allow-uppercase": ["true", "false"],
"--fzf": ["true", "false"],
"--completion": ["bash", "zsh"],
"--fzf-wrapper": ["bash", "zsh"],
"--service-mode": ["local", "remote"],
"--sync-remote": ["true", "false"],
"--help": None, "-h": None,
}
for opt in ["--keepalive", "--engineer-model", "--engineer-api-key", "--architect-model", "--architect-api-key", "--theme", "--remote", "--trusted-commands"]:
config_dict[opt] = {"*": config_dict}
config_dict["--configfolder"] = {"__extra__": lambda w: get_cwd(w, "--configfolder", True), "*": config_dict}
config_dict["--engineer-auth"] = {"__extra__": lambda w: get_cwd(w, "--engineer-auth"), "*": config_dict}
config_dict["--architect-auth"] = {"__extra__": lambda w: get_cwd(w, "--architect-auth"), "*": config_dict}
mv_state = {"__extra__": _nodes, "--help": None, "-h": None} mv_state = {"__extra__": _nodes, "--help": None, "-h": None}
cp_state = {"__extra__": _nodes, "--help": None, "-h": None} cp_state = {"__extra__": _nodes, "--help": None, "-h": None}
ls_state = { ls_state = {
@@ -280,22 +297,7 @@ def _build_tree(nodes, folders, profiles, plugins, configdir):
"--list": None, "--help": None, "--list": None, "--help": None,
"-h": None, "-h": None,
}, },
"config": { "config": config_dict,
"--allow-uppercase": ["true", "false"],
"--fzf": ["true", "false"],
"--keepalive": None,
"--completion": ["bash", "zsh"],
"--fzf-wrapper": ["bash", "zsh"],
"--configfolder": lambda w: get_cwd(w, "--configfolder", True),
"--engineer-model": None, "--engineer-api-key": None,
"--architect-model": None, "--architect-api-key": None,
"--theme": None,
"--service-mode": ["local", "remote"],
"--remote": None,
"--sync-remote": ["true", "false"],
"--trusted-commands": None,
"--help": None, "-h": None,
},
"sync": { "sync": {
"--login": None, "--logout": None, "--login": None, "--logout": None,
"--status": None, "--list": None, "--status": None, "--list": None,
+4
View File
@@ -276,8 +276,10 @@ class connapp:
aiparser.add_argument("ask", nargs='*', help="Ask connpy AI something") aiparser.add_argument("ask", nargs='*', help="Ask connpy AI something")
aiparser.add_argument("--engineer-model", nargs=1, help="Override engineer model") aiparser.add_argument("--engineer-model", nargs=1, help="Override engineer model")
aiparser.add_argument("--engineer-api-key", nargs=1, help="Override engineer api key") aiparser.add_argument("--engineer-api-key", nargs=1, help="Override engineer api key")
aiparser.add_argument("--engineer-auth", nargs=1, help="Override engineer auth (inline JSON/YAML or file path)")
aiparser.add_argument("--architect-model", nargs=1, help="Override architect model") aiparser.add_argument("--architect-model", nargs=1, help="Override architect model")
aiparser.add_argument("--architect-api-key", nargs=1, help="Override architect api key") aiparser.add_argument("--architect-api-key", nargs=1, help="Override architect api key")
aiparser.add_argument("--architect-auth", nargs=1, help="Override architect auth (inline JSON/YAML or file path)")
aiparser.add_argument("--debug", action="store_true", help="Show AI reasoning and tool calls") aiparser.add_argument("--debug", action="store_true", help="Show AI reasoning and tool calls")
aiparser.add_argument("-y", "--trust", action="store_true", help="Trust AI to execute unsafe commands without confirmation") aiparser.add_argument("-y", "--trust", action="store_true", help="Trust AI to execute unsafe commands without confirmation")
aiparser.add_argument("--list", "--list-sessions", dest="list_sessions", action="store_true", help="List saved AI sessions") aiparser.add_argument("--list", "--list-sessions", dest="list_sessions", action="store_true", help="List saved AI sessions")
@@ -341,11 +343,13 @@ class connapp:
configcrud.add_argument("--configfolder", dest="configfolder", nargs=1, action=self._store_type, help="Set the default location for config file", metavar="FOLDER") configcrud.add_argument("--configfolder", dest="configfolder", nargs=1, action=self._store_type, help="Set the default location for config file", metavar="FOLDER")
configcrud.add_argument("--engineer-model", dest="engineer_model", nargs=1, action=self._store_type, help="Set engineer model", metavar="MODEL") configcrud.add_argument("--engineer-model", dest="engineer_model", nargs=1, action=self._store_type, help="Set engineer model", metavar="MODEL")
configcrud.add_argument("--engineer-api-key", dest="engineer_api_key", nargs=1, action=self._store_type, help="Set engineer api_key", metavar="API_KEY") configcrud.add_argument("--engineer-api-key", dest="engineer_api_key", nargs=1, action=self._store_type, help="Set engineer api_key", metavar="API_KEY")
configcrud.add_argument("--engineer-auth", dest="engineer_auth", nargs=1, action=self._store_type, help="Set engineer auth (inline JSON/YAML or file path)", metavar="AUTH")
configcrud.add_argument("--theme", dest="theme", nargs=1, action=self._store_type, help="Set application theme (dark, light, or YAML file path)", metavar="THEME") configcrud.add_argument("--theme", dest="theme", nargs=1, action=self._store_type, help="Set application theme (dark, light, or YAML file path)", metavar="THEME")
configcrud.add_argument("--service-mode", dest="service_mode", nargs=1, action=self._store_type, help="Set the backend service mode (local or remote)", choices=["local", "remote"]) configcrud.add_argument("--service-mode", dest="service_mode", nargs=1, action=self._store_type, help="Set the backend service mode (local or remote)", choices=["local", "remote"])
configcrud.add_argument("--remote", dest="remote_host", nargs=1, action=self._store_type, help="Connect to a remote connpy service via gRPC", metavar="HOST:PORT") configcrud.add_argument("--remote", dest="remote_host", nargs=1, action=self._store_type, help="Connect to a remote connpy service via gRPC", metavar="HOST:PORT")
configcrud.add_argument("--architect-model", dest="architect_model", nargs=1, action=self._store_type, help="Set architect model", metavar="MODEL") configcrud.add_argument("--architect-model", dest="architect_model", nargs=1, action=self._store_type, help="Set architect model", metavar="MODEL")
configcrud.add_argument("--architect-api-key", dest="architect_api_key", nargs=1, action=self._store_type, help="Set architect api_key", metavar="API_KEY") configcrud.add_argument("--architect-api-key", dest="architect_api_key", nargs=1, action=self._store_type, help="Set architect api_key", metavar="API_KEY")
configcrud.add_argument("--architect-auth", dest="architect_auth", nargs=1, action=self._store_type, help="Set architect auth (inline JSON/YAML or file path)", metavar="AUTH")
configcrud.add_argument("--sync-remote", dest="sync_remote", nargs=1, action=self._store_type, help="Sync remote nodes to Google Drive", choices=["true","false"]) configcrud.add_argument("--sync-remote", dest="sync_remote", nargs=1, action=self._store_type, help="Sync remote nodes to Google Drive", choices=["true","false"])
configparser.add_argument("--trusted-commands", dest="trusted_commands", nargs=1, action=self._store_type, help="Set custom trusted commands regexes (comma separated)", metavar="REGEX,REGEX") configparser.add_argument("--trusted-commands", dest="trusted_commands", nargs=1, action=self._store_type, help="Set custom trusted commands regexes (comma separated)", metavar="REGEX,REGEX")
configparser.set_defaults(func=self._config.dispatch) configparser.set_defaults(func=self._config.dispatch)
+4 -9
View File
@@ -439,21 +439,16 @@ class node:
# Remove any stray \x00 bytes and forward normally # Remove any stray \x00 bytes and forward normally
clean_data = data.replace(b'\x00', b'') clean_data = data.replace(b'\x00', b'')
if clean_data: if clean_data:
# Track command boundaries when user hits Enter # Track command boundaries when user hits Enter or presses Ctrl+C
if hasattr(self, 'mylog') and (b'\r' in clean_data or b'\n' in clean_data): if hasattr(self, 'mylog') and (b'\r' in clean_data or b'\n' in clean_data or b'\x03' in clean_data):
# Introduce a tiny 20ms delay to allow late-arriving tab-completion bytes
# to be written to mylog before finalizing the boundary marker.
async def delayed_marker():
await asyncio.sleep(0.02)
if hasattr(self, 'mylog'):
pos = self.mylog.tell() pos = self.mylog.tell()
self.cmd_byte_positions.append((pos, None)) marker_cmd = "CANCELLED" if b'\x03' in clean_data else None
self.cmd_byte_positions.append((pos, marker_cmd))
if hasattr(self, 'current_local_stream') and self.current_local_stream is not None: if hasattr(self, 'current_local_stream') and self.current_local_stream is not None:
try: try:
await self.current_local_stream.write(f'\x1b]133;B;{pos}\x07'.encode()) await self.current_local_stream.write(f'\x1b]133;B;{pos}\x07'.encode())
except Exception: except Exception:
pass pass
asyncio.create_task(delayed_marker())
try: try:
os.write(child_fd, clean_data) os.write(child_fd, clean_data)
File diff suppressed because one or more lines are too long
+1 -1
View File
@@ -3,7 +3,7 @@
import grpc import grpc
import warnings import warnings
from . import connpy_pb2 as connpy__pb2 import connpy_pb2 as connpy__pb2
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
GRPC_GENERATED_VERSION = '1.80.0' GRPC_GENERATED_VERSION = '1.80.0'
+6 -1
View File
@@ -893,6 +893,10 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
overrides = {} overrides = {}
if req.engineer_model: overrides["engineer_model"] = req.engineer_model if req.engineer_model: overrides["engineer_model"] = req.engineer_model
if req.engineer_api_key: overrides["engineer_api_key"] = req.engineer_api_key if req.engineer_api_key: overrides["engineer_api_key"] = req.engineer_api_key
if req.architect_model: overrides["architect_model"] = req.architect_model
if req.architect_api_key: overrides["architect_api_key"] = req.architect_api_key
if req.HasField("engineer_auth"): overrides["engineer_auth"] = from_struct(req.engineer_auth)
if req.HasField("architect_auth"): overrides["architect_auth"] = from_struct(req.architect_auth)
# Start AI in its own thread so we can keep listening for interrupts # Start AI in its own thread so we can keep listening for interrupts
ai_thread = threading.Thread( ai_thread = threading.Thread(
@@ -967,7 +971,8 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
@handle_errors @handle_errors
def configure_provider(self, request, context): def configure_provider(self, request, context):
self.service.configure_provider(request.provider, request.model, request.api_key) auth_dict = from_struct(request.auth) if request.HasField("auth") else None
self.service.configure_provider(request.provider, request.model, request.api_key, auth=auth_dict)
return Empty() return Empty()
@handle_errors @handle_errors
+7 -1
View File
@@ -745,6 +745,10 @@ class AIStub:
) )
if chat_history is not None: if chat_history is not None:
initial_req.chat_history.CopyFrom(to_value(chat_history)) initial_req.chat_history.CopyFrom(to_value(chat_history))
if "engineer_auth" in overrides and overrides["engineer_auth"]:
initial_req.engineer_auth.CopyFrom(to_struct(overrides["engineer_auth"]))
if "architect_auth" in overrides and overrides["architect_auth"]:
initial_req.architect_auth.CopyFrom(to_struct(overrides["architect_auth"]))
req_queue.put(initial_req) req_queue.put(initial_req)
@@ -926,8 +930,10 @@ class AIStub:
self.stub.delete_session(connpy_pb2.StringRequest(value=session_id)) self.stub.delete_session(connpy_pb2.StringRequest(value=session_id))
@handle_errors @handle_errors
def configure_provider(self, provider, model=None, api_key=None): def configure_provider(self, provider, model=None, api_key=None, auth=None):
req = connpy_pb2.ProviderRequest(provider=provider, model=model or "", api_key=api_key or "") req = connpy_pb2.ProviderRequest(provider=provider, model=model or "", api_key=api_key or "")
if auth:
req.auth.CopyFrom(to_struct(auth))
self.stub.configure_provider(req) self.stub.configure_provider(req)
@handle_errors @handle_errors
+3
View File
@@ -235,6 +235,8 @@ message AskRequest {
bool trust = 10; bool trust = 10;
string confirmation_answer = 11; string confirmation_answer = 11;
bool interrupt = 12; bool interrupt = 12;
google.protobuf.Struct engineer_auth = 13;
google.protobuf.Struct architect_auth = 14;
} }
message AIResponse { message AIResponse {
@@ -255,6 +257,7 @@ message ProviderRequest {
string provider = 1; string provider = 1;
string model = 2; string model = 2;
string api_key = 3; string api_key = 3;
google.protobuf.Struct auth = 4;
} }
message IntRequest { message IntRequest {
+8 -3
View File
@@ -58,6 +58,9 @@ class AIService(BaseService):
prev_pos = cmd_byte_positions[i-1][0] prev_pos = cmd_byte_positions[i-1][0]
if known_cmd: if known_cmd:
if known_cmd == "CANCELLED":
parsed_positions.append({"pos": pos, "type": "CANCELLED", "preview": ""})
else:
prev_chunk = raw_bytes[prev_pos:pos] prev_chunk = raw_bytes[prev_pos:pos]
prev_cleaned = self._clean_cisco_scrolling(prev_chunk.decode(errors='replace')) prev_cleaned = self._clean_cisco_scrolling(prev_chunk.decode(errors='replace'))
prev_lines = [l for l in prev_cleaned.split('\n') if l.strip()] prev_lines = [l for l in prev_cleaned.split('\n') if l.strip()]
@@ -129,11 +132,11 @@ class AIService(BaseService):
start_pos = item["pos"] start_pos = item["pos"]
preview = item["preview"] preview = item["preview"]
# Find the end position: next VALID_CMD or EMPTY_PROMPT # Find the end position: next VALID_CMD or EMPTY_PROMPT or CANCELLED
end_pos = current_prompt_pos end_pos = current_prompt_pos
for j in range(i + 1, len(parsed_positions)): for j in range(i + 1, len(parsed_positions)):
next_item = parsed_positions[j] next_item = parsed_positions[j]
if next_item["type"] in ("VALID_CMD", "EMPTY_PROMPT"): if next_item["type"] in ("VALID_CMD", "EMPTY_PROMPT", "CANCELLED"):
end_pos = next_item["pos"] end_pos = next_item["pos"]
break break
@@ -254,13 +257,15 @@ class AIService(BaseService):
else: else:
raise InvalidConfigurationError(f"Session '{session_id}' not found.") raise InvalidConfigurationError(f"Session '{session_id}' not found.")
def configure_provider(self, provider, model=None, api_key=None): def configure_provider(self, provider, model=None, api_key=None, auth=None):
"""Update AI provider settings in the configuration.""" """Update AI provider settings in the configuration."""
settings = self.config.config.get("ai", {}) settings = self.config.config.get("ai", {})
if model: if model:
settings[f"{provider}_model"] = model settings[f"{provider}_model"] = model
if api_key: if api_key:
settings[f"{provider}_api_key"] = api_key settings[f"{provider}_api_key"] = api_key
if auth is not None:
settings[f"{provider}_auth"] = auth
self.config.config["ai"] = settings self.config.config["ai"] = settings
self.config._saveconfig(self.config.file) self.config._saveconfig(self.config.file)
+76 -3
View File
@@ -23,7 +23,7 @@ class TestAIInit:
myai = ai(config) myai = ai(config)
with pytest.raises(ValueError) as exc: with pytest.raises(ValueError) as exc:
myai.ask("hello") myai.ask("hello")
assert "Engineer API key not configured" in str(exc.value) assert "Engineer API key or authentication not configured" in str(exc.value)
def test_init_missing_architect_key_warns(self, ai_config, capsys, mock_litellm): def test_init_missing_architect_key_warns(self, ai_config, capsys, mock_litellm):
"""Warns if architect key is missing but doesn't crash.""" """Warns if architect key is missing but doesn't crash."""
@@ -58,6 +58,77 @@ class TestAIInit:
pass # May fail on other file opens, that's ok pass # May fail on other file opens, that's ok
# =========================================================================
# AI Auth Dict tests
# =========================================================================
class TestAIAuthDict:
def test_init_with_auth_dict(self, ai_config):
"""Initializes correctly when auth dicts are configured."""
from connpy.ai import ai
ai_config.config["ai"]["engineer_api_key"] = None
ai_config.config["ai"]["architect_api_key"] = None
ai_config.config["ai"]["engineer_auth"] = {"my_key": "my_val"}
ai_config.config["ai"]["architect_auth"] = {"another_key": "another_val"}
myai = ai(ai_config)
assert myai.engineer_auth == {"my_key": "my_val"}
assert myai.architect_auth == {"another_key": "another_val"}
def test_compat_key_injection(self, ai_config):
"""Injects API key into auth dict if auth is empty or doesn't have it."""
from connpy.ai import ai
ai_config.config["ai"]["engineer_api_key"] = "compat-eng-key"
ai_config.config["ai"]["architect_api_key"] = "compat-arch-key"
ai_config.config["ai"]["engineer_auth"] = {}
ai_config.config["ai"]["architect_auth"] = {}
myai = ai(ai_config)
assert myai.engineer_auth == {"api_key": "compat-eng-key"}
assert myai.architect_auth == {"api_key": "compat-arch-key"}
def test_has_architect_keyless(self, ai_config):
"""Evaluates has_architect correctly for keyless models and auth configs."""
from connpy.ai import ai
# 1. Keyless model (Vertex)
ai_config.config["ai"]["architect_api_key"] = None
ai_config.config["ai"]["architect_auth"] = {}
ai_config.config["ai"]["architect_model"] = "vertex/gemini-pro"
myai = ai(ai_config)
assert myai.has_architect is True
# 2. Architect auth dict is set
ai_config.config["ai"]["architect_model"] = "custom-model"
ai_config.config["ai"]["architect_auth"] = {"vertex_project": "proj-1"}
myai = ai(ai_config)
assert myai.has_architect is True
def test_ask_unpacks_auth_dict(self, ai_config, mock_litellm):
"""Verifies that ask unpacks engineer_auth when calling completion."""
from connpy.ai import ai
ai_config.config["ai"]["engineer_api_key"] = None
ai_config.config["ai"]["engineer_auth"] = {"vertex_project": "my-project", "vertex_location": "us-east1"}
myai = ai(ai_config)
myai.ask("test query", stream=False)
# Check mock_litellm completion call
mock_litellm["completion"].assert_called()
kwargs = mock_litellm["completion"].call_args.kwargs
assert kwargs.get("vertex_project") == "my-project"
assert kwargs.get("vertex_location") == "us-east1"
assert "api_key" not in kwargs
def test_auth_precedence_no_api_key_injection(self, ai_config):
"""Verifies that api_key is not injected into the auth dict when auth is already set (non-empty)."""
from connpy.ai import ai
ai_config.config["ai"]["engineer_api_key"] = "legacy-eng-key"
ai_config.config["ai"]["architect_api_key"] = "legacy-arch-key"
ai_config.config["ai"]["engineer_auth"] = {"vertex_project": "proj-eng"}
ai_config.config["ai"]["architect_auth"] = {"vertex_project": "proj-arch"}
myai = ai(ai_config)
assert myai.engineer_auth == {"vertex_project": "proj-eng"}
assert "api_key" not in myai.engineer_auth
assert myai.architect_auth == {"vertex_project": "proj-arch"}
assert "api_key" not in myai.architect_auth
# ========================================================================= # =========================================================================
# register_ai_tool tests # register_ai_tool tests
# ========================================================================= # =========================================================================
@@ -427,12 +498,14 @@ class TestAISessions:
def test_generate_session_id(self, myai): def test_generate_session_id(self, myai):
session_id = myai._generate_session_id("Any query") session_id = myai._generate_session_id("Any query")
# Format: YYYYMMDD-HHMMSS # Format: YYYYMMDD-HHMMSS-suffix
assert len(session_id) == 15 assert len(session_id) == 20
assert "-" in session_id assert "-" in session_id
parts = session_id.split("-") parts = session_id.split("-")
assert len(parts) == 3
assert len(parts[0]) == 8 # YYYYMMDD assert len(parts[0]) == 8 # YYYYMMDD
assert len(parts[1]) == 6 # HHMMSS assert len(parts[1]) == 6 # HHMMSS
assert len(parts[2]) == 4 # suffix
def test_save_and_load_session(self, myai): def test_save_and_load_session(self, myai):
history = [ history = [
+25
View File
@@ -193,3 +193,28 @@ def test_build_context_blocks_horizontal_scrolling_ansi():
assert len(blocks) >= 1 assert len(blocks) >= 1
start, end, preview = blocks[0] start, end, preview = blocks[0]
assert "RP/0/RP0/CPU0:xrd# s show interfaces * | inc" in preview assert "RP/0/RP0/CPU0:xrd# s show interfaces * | inc" in preview
def test_build_context_blocks_cancelled_command():
from connpy.services.ai_service import AIService
svc = AIService(None)
node_info = {"prompt": "router#"}
# Command 1: cancelled with Ctrl+C. Command 2: executed successfully.
raw_bytes = b"router# show plat\x03\r\nrouter# show ver\r\nrouter# "
# 0: initial boundary
# 18: Ctrl+C pressed (ends Command 1, marked CANCELLED)
# 36: Enter pressed (ends Command 2)
cmd_byte_positions = [(0, None), (18, "CANCELLED"), (36, None)]
blocks = svc.build_context_blocks(raw_bytes, cmd_byte_positions, node_info)
# The cancelled command block (0 to 18) should NOT be registered as a VALID_CMD block.
# The block for "show ver" should be registered (starting at 36, ending at current_prompt_pos).
# Plus, the final block for "CURRENT CONTEXT".
valid_blocks = [b for b in blocks if "CURRENT CONTEXT" not in b[2]]
assert len(valid_blocks) == 1
assert "show ver" in valid_blocks[0][2]
assert "show plat" not in valid_blocks[0][2]
+76
View File
@@ -65,4 +65,80 @@ class TestGetCwd:
assert len(dirs_in_result) > 0 assert len(dirs_in_result) > 0
# =========================================================================
# Tree completions tests
# =========================================================================
class TestTreeCompletions:
def test_config_auth_completions(self):
from connpy.completion import _build_tree, resolve_completion
tree = _build_tree([], [], [], {}, "/tmp")
# Test config completions
config_completions = resolve_completion(["config", ""], tree)
assert "--engineer-auth" in config_completions
assert "--architect-auth" in config_completions
# Resolve when --engineer-auth is chosen in config
auth_comp = resolve_completion(["config", "--engineer-auth", ""], tree)
assert isinstance(auth_comp, list)
# Loop back check:
# e.g., connpy config --engineer-auth some_val
# should loop back and resolve to config options
loop_back_comp = resolve_completion(["config", "--engineer-auth", "some_val", ""], tree)
assert "--architect-auth" in loop_back_comp
assert "--engineer-auth" in loop_back_comp
def test_ai_auth_completions(self):
from connpy.completion import _build_tree, resolve_completion
tree = _build_tree([], [], [], {}, "/tmp")
# Test ai completions
ai_completions = resolve_completion(["ai", ""], tree)
assert "--engineer-auth" in ai_completions
assert "--architect-auth" in ai_completions
# Resolve after choosing option
auth_comp = resolve_completion(["ai", "--engineer-auth", ""], tree)
assert isinstance(auth_comp, list)
# Loop back check:
# e.g., connpy ai --engineer-auth some_val
# should loop back and resolve to ai options, excluding --engineer-auth
loop_back_comp = resolve_completion(["ai", "--engineer-auth", "some_val", ""], tree)
assert "--architect-auth" in loop_back_comp
assert "--engineer-auth" not in loop_back_comp
def test_sixwindmcp_plugin_completions(self):
from connpy.completion import resolve_completion, get_cwd
import importlib.util
# Load the testremote/remote_plugins/sixwindmcp.py plugin
plugin_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"testremote", "remote_plugins", "sixwindmcp.py"
)
spec = importlib.util.spec_from_file_location("sixwindmcp", plugin_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module.get_cwd = get_cwd
plugin_node = module._connpy_tree()
assert "--set-path" in plugin_node
assert "--path" in plugin_node
assert "start" in plugin_node
tree = {"sixwindmcp": plugin_node}
# Test resolution when --set-path is chosen
res = resolve_completion(["sixwindmcp", "--set-path", ""], tree)
assert isinstance(res, list)
# Loop back check:
# e.g., connpy sixwindmcp --set-path /tmp start
# should loop back and resolve to plugin options
loop_back_comp = resolve_completion(["sixwindmcp", "--set-path", "/tmp", ""], tree)
assert "start" in loop_back_comp
assert "stop" in loop_back_comp
+53 -1
View File
@@ -246,7 +246,7 @@ def test_plugin_disable(mock_disable, app):
@patch("connpy.services.ai_service.AIService.list_sessions") @patch("connpy.services.ai_service.AIService.list_sessions")
def test_ai_list(mock_list_sessions, app): def test_ai_list(mock_list_sessions, app):
mock_list_sessions.return_value = [{"id": "1", "title": "t", "created_at": "now", "model": "m"}] mock_list_sessions.return_value = ([{"id": "1", "title": "t", "created_at": "now", "model": "m"}], 1)
app.start(["ai", "--list"]) app.start(["ai", "--list"])
mock_list_sessions.assert_called_once() mock_list_sessions.assert_called_once()
@@ -262,3 +262,55 @@ def test_type_node_reserved_word(app):
with pytest.raises(SystemExit) as exc: with pytest.raises(SystemExit) as exc:
app._type_node("bulk") app._type_node("bulk")
assert exc.value.code == 2 assert exc.value.code == 2
@patch("connpy.services.config_service.ConfigService.update_setting")
@patch("connpy.services.config_service.ConfigService.get_settings")
def test_config_auth_inline_json(mock_get_settings, mock_update_setting, app):
mock_get_settings.return_value = {"ai": {}}
app.start(["config", "--engineer-auth", '{"vertex_project": "test-123"}'])
mock_update_setting.assert_called_once()
args, kwargs = mock_update_setting.call_args
assert args[0] == "ai"
assert args[1]["engineer_auth"] == {"vertex_project": "test-123"}
@patch("connpy.services.config_service.ConfigService.update_setting")
@patch("connpy.services.config_service.ConfigService.get_settings")
def test_config_auth_inline_yaml(mock_get_settings, mock_update_setting, app):
mock_get_settings.return_value = {"ai": {}}
app.start(["config", "--architect-auth", 'project: test-yaml'])
mock_update_setting.assert_called_once()
args, kwargs = mock_update_setting.call_args
assert args[0] == "ai"
assert args[1]["architect_auth"] == {"project": "test-yaml"}
@patch("connpy.services.config_service.ConfigService.update_setting")
@patch("connpy.services.config_service.ConfigService.get_settings")
def test_config_clear_auth(mock_get_settings, mock_update_setting, app):
mock_get_settings.return_value = {"ai": {"engineer_auth": {"project": "123"}, "engineer_api_key": "some-key"}}
app.start(["config", "--engineer-auth", "clear"])
args, kwargs = mock_update_setting.call_args
assert "engineer_auth" not in args[1]
app.start(["config", "--engineer-api-key", "none"])
args, kwargs = mock_update_setting.call_args
assert "engineer_api_key" not in args[1]
@patch("os.path.exists")
@patch("builtins.open")
@patch("connpy.services.config_service.ConfigService.update_setting")
@patch("connpy.services.config_service.ConfigService.get_settings")
def test_config_auth_file_path(mock_get_settings, mock_update_setting, mock_open, mock_exists, app):
mock_get_settings.return_value = {"ai": {}}
mock_exists.side_effect = lambda p: True if p == "/path/to/creds.json" else False
mock_file = MagicMock()
mock_file.read.return_value = '{"vertex_project": "file-project"}'
mock_open.return_value.__enter__.return_value = mock_file
app.start(["config", "--engineer-auth", "/path/to/creds.json"])
mock_update_setting.assert_called_once()
args, kwargs = mock_update_setting.call_args
assert args[0] == "ai"
assert args[1]["engineer_auth"] == {"vertex_project": "file-project"}
+136
View File
@@ -0,0 +1,136 @@
"""
Tests for gRPC auth serialization/deserialization (engineer_auth, architect_auth, provider auth).
These tests verify that:
1. to_struct/from_struct round-trips correctly for auth dicts.
2. AIStub.ask() correctly serializes engineer_auth and architect_auth into AskRequest.
3. AIServicer.ask() correctly deserializes them and passes them to the service.
4. AIStub.configure_provider() serializes auth into ProviderRequest.
5. AIServicer.configure_provider() deserializes auth and forwards it to the service.
"""
import pytest
from unittest.mock import MagicMock, patch, call
from connpy.grpc_layer import connpy_pb2
from connpy.grpc_layer.utils import to_struct, from_struct
# --- Unit: Struct round-trip ---
class TestStructRoundTrip:
def test_simple_dict(self):
d = {"api_key": "secret", "region": "us-east-1"}
assert from_struct(to_struct(d)) == d
def test_nested_dict(self):
d = {"vertex_project": "my-project", "vertex_location": "us-central1", "nested": {"key": "val"}}
assert from_struct(to_struct(d)) == d
def test_empty_dict(self):
assert from_struct(to_struct({})) == {}
def test_none_returns_empty(self):
assert from_struct(to_struct(None)) == {}
# --- Unit: AskRequest Struct fields ---
class TestAskRequestStructFields:
def test_engineer_auth_round_trip(self):
auth = {"vertex_project": "proj", "vertex_location": "us-central1"}
req = connpy_pb2.AskRequest(input_text="hi")
req.engineer_auth.CopyFrom(to_struct(auth))
assert from_struct(req.engineer_auth) == auth
def test_architect_auth_round_trip(self):
auth = {"api_key": "sk-abc", "base_url": "https://custom.api/v1"}
req = connpy_pb2.AskRequest(input_text="hi")
req.architect_auth.CopyFrom(to_struct(auth))
assert from_struct(req.architect_auth) == auth
def test_has_field_false_when_unset(self):
req = connpy_pb2.AskRequest(input_text="hi")
assert not req.HasField("engineer_auth")
assert not req.HasField("architect_auth")
def test_has_field_true_when_set(self):
req = connpy_pb2.AskRequest(input_text="hi")
req.engineer_auth.CopyFrom(to_struct({"k": "v"}))
assert req.HasField("engineer_auth")
# --- Unit: ProviderRequest Struct field ---
class TestProviderRequestStructField:
def test_auth_round_trip(self):
auth = {"vertex_project": "proj", "vertex_location": "eu-west1"}
req = connpy_pb2.ProviderRequest(provider="vertex", model="gemini-pro")
req.auth.CopyFrom(to_struct(auth))
assert from_struct(req.auth) == auth
def test_has_field_false_when_unset(self):
req = connpy_pb2.ProviderRequest(provider="openai", model="gpt-4o")
assert not req.HasField("auth")
def test_has_field_true_when_set(self):
req = connpy_pb2.ProviderRequest(provider="vertex")
req.auth.CopyFrom(to_struct({"vertex_project": "p"}))
assert req.HasField("auth")
# --- Integration: Server deserializes auth and passes to service ---
class TestAIServicerAuthDeserialization:
@pytest.fixture
def servicer(self, populated_config):
from connpy.grpc_layer.server import AIServicer
return AIServicer(populated_config)
def test_configure_provider_passes_auth_to_service(self, servicer):
auth = {"vertex_project": "my-proj", "vertex_location": "us-central1"}
req = connpy_pb2.ProviderRequest(provider="vertex", model="gemini/gemini-pro", api_key="")
req.auth.CopyFrom(to_struct(auth))
with patch.object(servicer.service, "configure_provider") as mock_cp:
mock_context = MagicMock()
servicer.configure_provider(req, mock_context)
mock_cp.assert_called_once_with("vertex", "gemini/gemini-pro", "", auth=auth)
def test_configure_provider_no_auth(self, servicer):
req = connpy_pb2.ProviderRequest(provider="openai", model="gpt-4o", api_key="sk-test")
with patch.object(servicer.service, "configure_provider") as mock_cp:
mock_context = MagicMock()
servicer.configure_provider(req, mock_context)
mock_cp.assert_called_once_with("openai", "gpt-4o", "sk-test", auth=None)
# --- Integration: Stub serializes auth into request ---
class TestAIStubAuthSerialization:
@pytest.fixture
def ai_stub(self):
from connpy.grpc_layer.stubs import AIStub
mock_channel = MagicMock()
stub = AIStub(mock_channel, "localhost:8048")
return stub
def test_configure_provider_with_auth_serializes_struct(self, ai_stub):
auth = {"vertex_project": "proj", "vertex_location": "us-central1"}
ai_stub.stub.configure_provider = MagicMock()
ai_stub.configure_provider("vertex", model="gemini/gemini-pro", auth=auth)
ai_stub.stub.configure_provider.assert_called_once()
sent_req = ai_stub.stub.configure_provider.call_args[0][0]
assert sent_req.provider == "vertex"
assert sent_req.model == "gemini/gemini-pro"
assert sent_req.HasField("auth")
assert from_struct(sent_req.auth) == auth
def test_configure_provider_without_auth_no_struct(self, ai_stub):
ai_stub.stub.configure_provider = MagicMock()
ai_stub.configure_provider("openai", model="gpt-4o", api_key="sk-x")
sent_req = ai_stub.stub.configure_provider.call_args[0][0]
assert not sent_req.HasField("auth")