feat(multiuser): implementar sistema multiusuario gRPC y configuración compartida de IA/MCP

- Servidor gRPC: Agregar interceptores de autenticación y UserRegistry para aislar sesiones por usuario.
- Contexto de Hilos: Corregir propagación de ContextVar _current_user a hilos secundarios en ExecutionServicer.
- Configuración Compartida: Implementar herencia y deep merge de settings de IA ('ai') y servidores MCP en configfile.
- Hot-Reload: Recarga automática en caliente de la configuración compartida global ante cambios en disco.
- CLI: Agregar comandos e interfaces de usuario para autenticación (login) y administración de usuarios.
- Pruebas: Desarrollar tests unitarios completos (test_shared_ai.py) y resolver regresiones en la suite existente.
This commit is contained in:
2026-05-28 09:27:54 -03:00
parent aa542cb6eb
commit 0adaaad971
28 changed files with 2339 additions and 88 deletions
+1
View File
@@ -146,6 +146,7 @@ package.json
# Development docs
connpy_roadmap.md
testfew/
testnew/
testall/
testremote/
+5 -2
View File
@@ -116,8 +116,11 @@ class ai:
self.interrupted = False
# 1. Cargar configuración genérica
aiconfig = self.config.config.get("ai", {})
# 1. Cargar configuración genérica con herencia/merge global
if hasattr(self.config, "get_effective_setting"):
aiconfig = self.config.get_effective_setting("ai", {})
else:
aiconfig = self.config.config.get("ai", {}) if hasattr(self.config, "config") else {}
# Modelos (Prioridad: Argumento -> Config -> Default)
self.engineer_model = engineer_model or aiconfig.get("engineer_model") or "gemini/gemini-3.1-flash-lite"
+92
View File
@@ -0,0 +1,92 @@
import os
import sys
import getpass
from .. import printer
from ..services.exceptions import ConnpyError
class LoginHandler:
def __init__(self, app):
self.app = app
def dispatch(self, args):
action = getattr(args, "action", None)
if action == "login":
return self.login(args)
elif action == "logout":
return self.logout(args)
else:
printer.error(f"Unknown action: {action}")
sys.exit(1)
def login(self, args):
if self.app.services.mode != "remote":
printer.warning("Note: Your current configuration is set to local mode. Logging in will save credentials, but they will only apply when service-mode is set to 'remote'.")
username = getattr(args, "username", None)
if not username:
try:
username = input("Username: ").strip()
if not username:
printer.error("Username cannot be empty.")
sys.exit(1)
except (KeyboardInterrupt, EOFError):
printer.warning("\nOperation cancelled.")
sys.exit(130)
try:
password = getpass.getpass("Password: ")
if not password:
printer.error("Password cannot be empty.")
sys.exit(1)
except (KeyboardInterrupt, EOFError):
printer.warning("\nOperation cancelled.")
sys.exit(130)
# Make the gRPC login call via self.app.services.auth stub
# We need to make sure auth is initialized in remote mode.
# If we are in local mode, self.app.services.auth is not initialized on ServiceProvider.
# Let's instantiate it dynamically if it's not present.
auth_service = getattr(self.app.services, "auth", None)
if not auth_service:
import grpc
from ..grpc_layer.stubs import AuthStub
remote_host = self.app.services.remote_host or self.app.config.config.get("remote_host")
if not remote_host:
printer.error("Remote host is not configured. Run 'connpy config --remote HOST:PORT' first.")
sys.exit(1)
try:
channel = grpc.insecure_channel(remote_host)
auth_service = AuthStub(channel, remote_host=remote_host)
except Exception as e:
printer.error(f"Failed to connect to remote server for login: {e}")
sys.exit(1)
try:
res = auth_service.login(username, password)
token = res["token"]
# Save token to ~/.config/conn/.token
token_path = os.path.join(self.app.config.defaultdir, ".token")
with open(token_path, "w") as f:
f.write(token)
os.chmod(token_path, 0o600)
printer.success(f"Logged in successfully as '{username}'. Session expires in 8 hours.")
except ConnpyError as e:
printer.error(f"Login failed: {e}")
sys.exit(1)
except Exception as e:
printer.error(f"Login failed with unexpected error: {e}")
sys.exit(1)
def logout(self, args):
token_path = os.path.join(self.app.config.defaultdir, ".token")
if os.path.exists(token_path):
try:
os.remove(token_path)
printer.success("Logged out successfully. Local session cleared.")
except Exception as e:
printer.error(f"Failed to clear session: {e}")
sys.exit(1)
else:
printer.info("No active session found (already logged out).")
+35 -2
View File
@@ -20,6 +20,17 @@ class RunHandler:
def node_run(self, args):
nodes_filter = args.data[0]
# Resolve and filter nodes through context-aware list_nodes
try:
matched_nodes = self.app.services.nodes.list_nodes(nodes_filter)
except Exception:
matched_nodes = []
if not matched_nodes:
printer.error(f"No nodes found matching filter: {nodes_filter}")
sys.exit(2)
commands = [" ".join(args.data[1:])]
try:
@@ -36,7 +47,7 @@ class RunHandler:
printer.test_panel(unique, node_output, node_status, node_result)
results = self.app.services.execution.test_commands(
nodes_filter=nodes_filter,
nodes_filter=matched_nodes,
commands=commands,
expected=args.test_expected,
on_node_complete=_on_node_complete
@@ -53,7 +64,7 @@ class RunHandler:
printer.node_panel(unique, node_output, node_status)
results = self.app.services.execution.run_commands(
nodes_filter=nodes_filter,
nodes_filter=matched_nodes,
commands=commands,
on_node_complete=_on_node_complete
)
@@ -103,6 +114,28 @@ class RunHandler:
folder = output_cfg if output_cfg not in [None, "stdout"] else None
prompt = options.get("prompt")
# Resolve and filter nodes through context-aware list_nodes
try:
if isinstance(nodelist, str):
resolved_nodes = self.app.services.nodes.list_nodes(nodelist)
elif isinstance(nodelist, list):
resolved_nodes = []
for item in nodelist:
matches = self.app.services.nodes.list_nodes(item)
for m in matches:
if m not in resolved_nodes:
resolved_nodes.append(m)
else:
resolved_nodes = []
except Exception:
resolved_nodes = []
if not resolved_nodes:
printer.error(f"[{name}] No nodes found matching filter: {nodelist}")
sys.exit(11)
nodelist = resolved_nodes
try:
header_printed = False
if action == "run":
+190
View File
@@ -0,0 +1,190 @@
import sys
import os
import getpass
import yaml
from .. import printer
from ..services.exceptions import ConnpyError
class UserHandler:
def __init__(self, app):
self.app = app
def dispatch(self, args):
if self.app.services.mode == "remote":
printer.error("User management commands are only available in local/server-side mode.")
sys.exit(1)
# Parse actions from argparse mutually exclusive options
if getattr(args, "add", None):
args.action = "add"
args.username = args.add[0]
elif getattr(args, "delete", None):
args.action = "del"
args.username = args.delete[0]
elif getattr(args, "list", False):
args.action = "list"
elif getattr(args, "show", None):
args.action = "show"
args.username = args.show[0]
elif getattr(args, "regen_password", None):
args.action = "regen_password"
args.username = args.regen_password[0]
action = getattr(args, "action", None)
if action == "add":
return self.add_user(args)
elif action == "del":
return self.delete_user(args)
elif action == "list":
return self.list_users(args)
elif action == "show":
return self.show_user(args)
elif action == "regen_password":
return self.regen_password(args)
else:
printer.error(f"Unknown action: {action}")
sys.exit(1)
def add_user(self, args):
username = getattr(args, "username", None)
if not username:
printer.error("Username is required. Usage: connpy user --add <username>")
sys.exit(1)
custom_path = getattr(args, "path", None)
if custom_path:
custom_path = custom_path[0] if isinstance(custom_path, list) else custom_path
try:
password = getpass.getpass("Enter password for new user: ")
if not password:
printer.error("Password cannot be empty.")
sys.exit(1)
confirm = getpass.getpass("Confirm password: ")
if password != confirm:
printer.error("Passwords do not match.")
sys.exit(1)
except (KeyboardInterrupt, EOFError):
printer.warning("\nOperation cancelled.")
sys.exit(130)
try:
self.app.services.users.create_user(username, password, config_path=custom_path)
printer.success(f"User '{username}' created successfully.")
except ConnpyError as e:
printer.error(str(e))
sys.exit(1)
except ValueError as e:
printer.error(str(e))
sys.exit(1)
except Exception as e:
printer.error(f"Failed to create user: {e}")
sys.exit(1)
def delete_user(self, args):
username = getattr(args, "username", None)
if not username:
printer.error("Username is required. Usage: connpy user --del <username>")
sys.exit(1)
try:
self.app.services.users.delete_user(username)
printer.success(f"User '{username}' deleted successfully.")
except ConnpyError as e:
printer.error(str(e))
sys.exit(1)
except ValueError as e:
printer.error(str(e))
sys.exit(1)
except Exception as e:
printer.error(f"Failed to delete user: {e}")
sys.exit(1)
def list_users(self, args):
try:
users = self.app.services.users.list_users()
if not users:
printer.warning("No users registered.")
return
# Format custom config path, falling back to computed default path instead of null/None
formatted_users = []
for u in users:
formatted_u = u.copy()
if not formatted_u.get("config_path"):
formatted_u["config_path"] = os.path.join(self.app.services.users.users_dir, formatted_u["username"])
formatted_users.append(formatted_u)
yaml_str = yaml.dump(formatted_users, sort_keys=False, default_flow_style=False)
printer.data("Registered Users", yaml_str)
except Exception as e:
printer.error(f"Failed to list users: {e}")
sys.exit(1)
def show_user(self, args):
username = getattr(args, "username", None)
if not username:
printer.error("Username is required. Usage: connpy user --show <username>")
sys.exit(1)
try:
user = self.app.services.users.get_user(username)
if not user:
printer.error(f"User '{username}' not found.")
sys.exit(1)
# Hide the password hash from the CLI output for safety
safe_user = {k: v for k, v in user.items() if k != "password_hash"}
if not safe_user.get("config_path"):
safe_user["config_path"] = os.path.join(self.app.services.users.users_dir, username)
yaml_str = yaml.dump(safe_user, sort_keys=False, default_flow_style=False)
printer.data(f"User: {username}", yaml_str)
except ValueError as e:
printer.error(str(e))
sys.exit(1)
except Exception as e:
printer.error(f"Failed to retrieve user details: {e}")
sys.exit(1)
def regen_password(self, args):
username = getattr(args, "username", None)
if not username:
printer.error("Username is required. Usage: connpy user --regen-password <username>")
sys.exit(1)
try:
user = self.app.services.users.get_user(username)
if not user:
printer.error(f"User '{username}' not found.")
sys.exit(1)
except ValueError as e:
printer.error(str(e))
sys.exit(1)
except Exception as e:
printer.error(f"Failed to retrieve user details: {e}")
sys.exit(1)
try:
new_password = getpass.getpass("Enter new password: ")
if not new_password:
printer.error("Password cannot be empty.")
sys.exit(1)
confirm = getpass.getpass("Confirm new password: ")
if new_password != confirm:
printer.error("Passwords do not match.")
sys.exit(1)
except (KeyboardInterrupt, EOFError):
printer.warning("\nOperation cancelled.")
sys.exit(130)
try:
self.app.services.users.admin_change_password(username, new_password)
printer.success(f"Password for user '{username}' regenerated successfully.")
except ValueError as e:
printer.error(str(e))
sys.exit(1)
except Exception as e:
printer.error(f"Failed to regenerate password: {e}")
sys.exit(1)
+31
View File
@@ -105,6 +105,21 @@ def _get_plugins(which, defaultdir):
return final_all_plugins
def _get_users(configdir):
import yaml
registry_file = os.path.join(configdir, "users", "registry.yaml")
if not os.path.exists(registry_file):
return []
try:
with open(registry_file, "r") as f:
data = yaml.safe_load(f) or {}
if isinstance(data, dict) and "users" in data:
return list(data["users"].keys())
except Exception:
pass
return []
def _build_tree(nodes, folders, profiles, plugins, configdir):
"""Build the declarative CLI navigation tree.
@@ -203,6 +218,19 @@ def _build_tree(nodes, folders, profiles, plugins, configdir):
config_dict["--engineer-auth"] = {"__extra__": lambda w: get_cwd(w, "--engineer-auth"), "*": config_dict}
config_dict["--architect-auth"] = {"__extra__": lambda w: get_cwd(w, "--architect-auth"), "*": config_dict}
_users = lambda w=None: _get_users(configdir)
user_dict = {
"--add": {"*": {"--path": {"__extra__": lambda w: get_cwd(w, "--path", True), "*": None}}},
"--del": {"__extra__": _users},
"--rm": {"__extra__": _users},
"--show": {"__extra__": _users},
"--regen-password": {"__extra__": _users},
"--list": None,
"--ls": None,
"--help": None, "-h": None
}
mv_state = {"__extra__": _nodes, "--help": None, "-h": None}
cp_state = {"__extra__": _nodes, "--help": None, "-h": None}
ls_state = {
@@ -297,6 +325,9 @@ def _build_tree(nodes, folders, profiles, plugins, configdir):
"--list": None, "--help": None,
"-h": None,
},
"user": user_dict,
"login": {"--help": None, "-h": None, "*": None},
"logout": {"--help": None, "-h": None},
"config": config_dict,
"sync": {
"--login": None, "--logout": None,
+30 -2
View File
@@ -43,7 +43,8 @@ class configfile:
passwords.
'''
def __init__(self, conf = None, key = None):
def __init__(self, conf = None, key = None, shared_config = None):
self._shared_config = shared_config
'''
### Optional Parameters:
@@ -149,6 +150,32 @@ class configfile:
self._generate_nodes_cache()
def get_effective_setting(self, key, default=None):
"""Get config setting with shared fallback for inheritable keys."""
val = self.config.get(key)
if key == "ai":
if val is not None:
if self._shared_config:
import copy
# Deep merge: shared as base, user overrides
base = copy.deepcopy(self._shared_config.config.get(key, {}))
if isinstance(base, dict) and isinstance(val, dict):
# Recursive update for inner dictionaries (like mcp_servers or model details)
def deep_merge(d1, d2):
for k, v in d2.items():
if isinstance(v, dict) and k in d1 and isinstance(d1[k], dict):
deep_merge(d1[k], v)
else:
d1[k] = copy.deepcopy(v)
deep_merge(base, val)
return base
return val
elif self._shared_config:
return self._shared_config.config.get(key, default)
return val if val is not None else default
def _validate_config(self, data):
"""Verify config data has the required structure."""
if not isinstance(data, dict):
@@ -489,7 +516,8 @@ class configfile:
else:
printer.error("Filter must be a string or a list of strings")
sys.exit(1)
nodes = [item for item in nodes if any(re.search(pattern, item) for pattern in flat_filter)]
flags = re.IGNORECASE if not self.config.get("case", False) else 0
nodes = [item for item in nodes if any(re.search(pattern, item, flags) for pattern in flat_filter)]
return nodes
@MethodHook
+42 -10
View File
@@ -79,15 +79,16 @@ class connapp:
self.debug_api = debug_api
self.ai = ai
# Register context filtering hooks
self.services.context.config._getallnodes.register_post_hook(self.services.context.filter_node_list)
self.services.context.config._getallfolders.register_post_hook(self.services.context.filter_node_list)
self.services.context.config._getallnodesfull.register_post_hook(self.services.context.filter_node_dict)
if hasattr(self.services.nodes, "list_nodes") and hasattr(self.services.nodes.list_nodes, "register_post_hook"):
self.services.nodes.list_nodes.register_post_hook(self.services.context.filter_node_list)
if hasattr(self.services.nodes, "list_folders") and hasattr(self.services.nodes.list_folders, "register_post_hook"):
self.services.nodes.list_folders.register_post_hook(self.services.context.filter_node_list)
# Register context filtering hooks (only on Client CLI, bypass on gRPC Server)
is_api_server = len(sys.argv) > 1 and sys.argv[1] == "api"
if not is_api_server:
self.services.context.config._getallnodes.register_post_hook(self.services.context.filter_node_list)
self.services.context.config._getallfolders.register_post_hook(self.services.context.filter_node_list)
self.services.context.config._getallnodesfull.register_post_hook(self.services.context.filter_node_dict)
if hasattr(self.services.nodes, "list_nodes") and hasattr(self.services.nodes.list_nodes, "register_post_hook"):
self.services.nodes.list_nodes.register_post_hook(self.services.context.filter_node_list)
if hasattr(self.services.nodes, "list_folders") and hasattr(self.services.nodes.list_folders, "register_post_hook"):
self.services.nodes.list_folders.register_post_hook(self.services.context.filter_node_list)
# Apply theme from config if exists before remote connection attempts
user_theme = self.config.config.get("theme", {})
@@ -109,7 +110,10 @@ class connapp:
except ConnpyError as e:
# If in remote mode, connectivity issues should be reported
if mode == "remote":
printer.warning(f"Failed to fetch data from remote server: {e}")
is_auth_cmd = len(sys.argv) > 1 and sys.argv[1] in ["login", "logout", "user"]
is_unauth = "unauthenticated" in str(e).lower() or "token" in str(e).lower()
if not (is_auth_cmd and is_unauth):
printer.warning(f"Failed to fetch data from remote server: {e}")
self.nodes_list = []
self.folders = []
self.profiles = []
@@ -135,6 +139,8 @@ class connapp:
from .cli.context_handler import ContextHandler
from .cli.import_export_handler import ImportExportHandler
from .cli.sync_handler import SyncHandler
from .cli.user_handler import UserHandler
from .cli.login_handler import LoginHandler
# Instantiate Handlers
self._node = NodeHandler(self)
@@ -147,6 +153,8 @@ class connapp:
self._context = ContextHandler(self)
self._import_export = ImportExportHandler(self)
self._sync = SyncHandler(self)
self._user = UserHandler(self)
self._login = LoginHandler(self)
# Register auto-sync hook to trigger after config saves
from .configfile import configfile
@@ -354,6 +362,30 @@ class connapp:
configparser.add_argument("--trusted-commands", dest="trusted_commands", nargs=1, action=self._store_type, help="Set custom trusted commands regexes (comma separated)", metavar="REGEX,REGEX")
configparser.set_defaults(func=self._config.dispatch)
#USERPARSER
userparser = subparsers.add_parser("user", help="Manage server users", description="Manage server users", formatter_class=RichHelpFormatter)
userparser.error = self._custom_error
usercrud = userparser.add_mutually_exclusive_group(required=True)
usercrud.add_argument("--add", nargs=1, dest="add", help="Add new user", metavar="USERNAME")
usercrud.add_argument("--del", "--rm", nargs=1, dest="delete", help="Delete user", metavar="USERNAME")
usercrud.add_argument("--list", "--ls", dest="list", action="store_true", help="List all users")
usercrud.add_argument("--show", nargs=1, dest="show", help="Show user details", metavar="USERNAME")
usercrud.add_argument("--regen-password", nargs=1, dest="regen_password", help="Regenerate user password", metavar="USERNAME")
userparser.add_argument("--path", dest="path", nargs=1, help="Custom configuration path for user configuration (in Mode B)")
userparser.set_defaults(func=self._user.dispatch)
#LOGINPARSER
loginparser = subparsers.add_parser("login", help="Login to remote connpy server", description="Login to remote connpy server", formatter_class=RichHelpFormatter)
loginparser.error = self._custom_error
loginparser.add_argument("username", nargs='?', default=None, help="Username to authenticate")
loginparser.set_defaults(func=self._login.dispatch, action="login")
#LOGOUTPARSER
logoutparser = subparsers.add_parser("logout", help="Logout from remote connpy server", description="Logout from remote connpy server", formatter_class=RichHelpFormatter)
logoutparser.error = self._custom_error
logoutparser.set_defaults(func=self._login.dispatch, action="logout")
#SYNCPARSER
syncparser = subparsers.add_parser("sync", help="Sync config with Google Drive", description="Sync config with Google Drive", formatter_class=RichHelpFormatter)
syncparser.error = self._custom_error
File diff suppressed because one or more lines are too long
+115
View File
@@ -2535,3 +2535,118 @@ class SystemService(object):
timeout,
metadata,
_registered_method=True)
class AuthServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.login = channel.unary_unary(
'/connpy.AuthService/login',
request_serializer=connpy__pb2.LoginRequest.SerializeToString,
response_deserializer=connpy__pb2.LoginResponse.FromString,
_registered_method=True)
self.change_password = channel.unary_unary(
'/connpy.AuthService/change_password',
request_serializer=connpy__pb2.ChangePasswordRequest.SerializeToString,
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
_registered_method=True)
class AuthServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def login(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def change_password(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AuthServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'login': grpc.unary_unary_rpc_method_handler(
servicer.login,
request_deserializer=connpy__pb2.LoginRequest.FromString,
response_serializer=connpy__pb2.LoginResponse.SerializeToString,
),
'change_password': grpc.unary_unary_rpc_method_handler(
servicer.change_password,
request_deserializer=connpy__pb2.ChangePasswordRequest.FromString,
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'connpy.AuthService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('connpy.AuthService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AuthService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def login(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/connpy.AuthService/login',
connpy__pb2.LoginRequest.SerializeToString,
connpy__pb2.LoginResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def change_password(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/connpy.AuthService/change_password',
connpy__pb2.ChangePasswordRequest.SerializeToString,
google_dot_protobuf_dot_empty__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
+337 -50
View File
@@ -4,6 +4,8 @@ from google.protobuf.empty_pb2 import Empty
import os
import ctypes
import threading
import contextvars
import datetime
# Suppress harmless but noisy gRPC fork() warnings from pexpect child processes
os.environ["GRPC_VERBOSITY"] = "NONE"
@@ -14,15 +16,7 @@ from .utils import to_value, from_value, to_struct, from_struct
from ..services.exceptions import ConnpyError
from .. import printer
# Import local services
from ..services.node_service import NodeService
from ..services.profile_service import ProfileService
from ..services.config_service import ConfigService
from ..services.plugin_service import PluginService
from ..services.ai_service import AIService
from ..services.system_service import SystemService
from ..services.execution_service import ExecutionService
from ..services.import_export_service import ImportExportService
_current_user = contextvars.ContextVar("current_user", default=None)
def handle_errors(func):
import inspect
@@ -31,10 +25,16 @@ def handle_errors(func):
try:
for item in func(*args, **kwargs):
yield item
except grpc.RpcError:
raise
except ConnpyError as e:
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.INTERNAL, str(e))
except Exception as e:
if type(e) is Exception and not e.args:
raise e
if e.__class__.__name__ in ("_AbortError", "RpcError"):
raise e
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.UNKNOWN, str(e))
finally:
@@ -44,10 +44,16 @@ def handle_errors(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except grpc.RpcError:
raise
except ConnpyError as e:
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.INTERNAL, str(e))
except Exception as e:
if type(e) is Exception and not e.args:
raise e
if e.__class__.__name__ in ("_AbortError", "RpcError"):
raise e
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.UNKNOWN, str(e))
finally:
@@ -55,25 +61,46 @@ def handle_errors(func):
return wrapper
class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
def __init__(self, config, debug=False):
self.service = NodeService(config)
def __init__(self, provider, registry=None, debug=False):
if not hasattr(provider, "mode"):
from connpy.services.provider import ServiceProvider
provider = ServiceProvider(provider, mode="local")
self._fallback_provider = provider
self._registry = registry
self.server_debug = debug
if debug:
from rich.console import Console
from ..printer import connpy_theme, get_original_stdout
self.server_console = Console(theme=connpy_theme, file=get_original_stdout())
def _get_provider(self):
if self._registry:
username = _current_user.get()
if username:
return self._registry.get_provider(username)
return self._fallback_provider
@property
def service(self):
return self._get_provider().nodes
@handle_errors
def interact_node(self, request_iterator, context):
import sys
import os
import asyncio
from connpy.core import node
from ..services.profile_service import ProfileService
from connpy.tunnels import RemoteStream
import queue
import threading
# Resolve provider once at the start of the RPC stream
provider = self._get_provider()
nodes_service = provider.nodes
profile_service = provider.profiles
ai_service = provider.ai
user_config = provider.config
# Fetch first setup packet
try:
first_req = next(request_iterator)
@@ -100,9 +127,9 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
if base_node_id:
# Look up the base node in config and use its full data
nodes = self.service.config._getallnodes(base_node_id)
nodes = user_config._getallnodes(base_node_id)
if nodes:
device = self.service.config.getitem(nodes[0])
device = user_config.getitem(nodes[0])
# Override device properties with any passed in params
for attr in valid_attrs:
if attr in params:
@@ -116,11 +143,11 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
device["tags"] = device_tags
node_name = params.get("name", base_node_id)
n = node(node_name, **device, config=self.service.config)
n = node(node_name, **device, config=user_config)
else:
# base_node not found, fall back to dynamic
node_name = params.get("name", fallback_id)
n = node(node_name, host=params.get("host", ""), config=self.service.config)
n = node(node_name, host=params.get("host", ""), config=user_config)
for attr in valid_attrs:
if attr in params:
setattr(n, attr, params[attr])
@@ -128,19 +155,18 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
n.tags = params["tags"]
else:
node_name = params.get("name", fallback_id)
n = node(node_name, host=params.get("host", ""), config=self.service.config)
n = node(node_name, host=params.get("host", ""), config=user_config)
for attr in valid_attrs:
if attr in params:
setattr(n, attr, params[attr])
if "tags" in params:
n.tags = params["tags"]
else:
node_data = self.service.config.getitem(unique_id, extract=False)
node_data = user_config.getitem(unique_id, extract=False)
if not node_data:
context.abort(grpc.StatusCode.NOT_FOUND, f"Node {unique_id} not found")
profile_service = ProfileService(self.service.config)
resolved_data = profile_service.resolve_node_data(node_data)
n = node(unique_id, **resolved_data, config=self.service.config)
n = node(unique_id, **resolved_data, config=user_config)
if sftp:
n.protocol = "sftp"
@@ -207,9 +233,8 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
import json
import asyncio
import os
from ..services.ai_service import AIService
service = AIService(self.service.config)
service = ai_service
if node_info is None:
node_info = {}
@@ -479,10 +504,27 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
)
class ProfileServicer(connpy_pb2_grpc.ProfileServiceServicer):
def __init__(self, config):
self.service = ProfileService(config)
self.node_service = NodeService(config)
def __init__(self, provider, registry=None):
if not hasattr(provider, "mode"):
from connpy.services.provider import ServiceProvider
provider = ServiceProvider(provider, mode="local")
self._fallback_provider = provider
self._registry = registry
def _get_provider(self):
if self._registry:
username = _current_user.get()
if username:
return self._registry.get_provider(username)
return self._fallback_provider
@property
def service(self):
return self._get_provider().profiles
@property
def node_service(self):
return self._get_provider().nodes
@handle_errors
def list_profiles(self, request, context):
@@ -516,8 +558,23 @@ class ProfileServicer(connpy_pb2_grpc.ProfileServiceServicer):
return Empty()
class ConfigServicer(connpy_pb2_grpc.ConfigServiceServicer):
def __init__(self, config):
self.service = ConfigService(config)
def __init__(self, provider, registry=None):
if not hasattr(provider, "mode"):
from connpy.services.provider import ServiceProvider
provider = ServiceProvider(provider, mode="local")
self._fallback_provider = provider
self._registry = registry
def _get_provider(self):
if self._registry:
username = _current_user.get()
if username:
return self._registry.get_provider(username)
return self._fallback_provider
@property
def service(self):
return self._get_provider().config_svc
@handle_errors
def get_settings(self, request, context):
@@ -546,8 +603,23 @@ class ConfigServicer(connpy_pb2_grpc.ConfigServiceServicer):
return connpy_pb2.StructResponse(data=to_struct(self.service.apply_theme_from_file(request.value)))
class PluginServicer(connpy_pb2_grpc.PluginServiceServicer, remote_plugin_pb2_grpc.RemotePluginServiceServicer):
def __init__(self, config):
self.service = PluginService(config)
def __init__(self, provider, registry=None):
if not hasattr(provider, "mode"):
from connpy.services.provider import ServiceProvider
provider = ServiceProvider(provider, mode="local")
self._fallback_provider = provider
self._registry = registry
def _get_provider(self):
if self._registry:
username = _current_user.get()
if username:
return self._registry.get_provider(username)
return self._fallback_provider
@property
def service(self):
return self._get_provider().plugins
@handle_errors
def list_plugins(self, request, context):
@@ -589,8 +661,23 @@ class PluginServicer(connpy_pb2_grpc.PluginServiceServicer, remote_plugin_pb2_gr
yield remote_plugin_pb2.OutputChunk(text=chunk)
class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
def __init__(self, config):
self.service = ExecutionService(config)
def __init__(self, provider, registry=None):
if not hasattr(provider, "mode"):
from connpy.services.provider import ServiceProvider
provider = ServiceProvider(provider, mode="local")
self._fallback_provider = provider
self._registry = registry
def _get_provider(self):
if self._registry:
username = _current_user.get()
if username:
return self._registry.get_provider(username)
return self._fallback_provider
@property
def service(self):
return self._get_provider().execution
@handle_errors
def run_commands(self, request, context):
@@ -599,6 +686,11 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
nodes_filter = request.nodes[0] if len(request.nodes) == 1 else list(request.nodes)
# Resolve provider in the main gRPC thread where _current_user ContextVar is set.
# threading.Thread does NOT inherit ContextVars, so self.service inside
# _worker() would fall back to the admin provider.
execution_service = self.service
q = queue.Queue()
def _on_complete(unique, output, status):
@@ -606,7 +698,7 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
def _worker():
try:
self.service.run_commands( nodes_filter=nodes_filter,
execution_service.run_commands( nodes_filter=nodes_filter,
commands=list(request.commands),
folder=request.folder if request.folder else None,
prompt=request.prompt if request.prompt else None,
@@ -645,6 +737,9 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
nodes_filter = request.nodes[0] if len(request.nodes) == 1 else list(request.nodes)
# Resolve provider in the main gRPC thread where _current_user ContextVar is set.
execution_service = self.service
q = queue.Queue()
def _on_complete(unique, node_output, node_status, node_result):
@@ -652,7 +747,7 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
def _worker():
try:
self.service.test_commands(
execution_service.test_commands(
nodes_filter=nodes_filter,
commands=list(request.commands),
expected=list(request.expected),
@@ -698,9 +793,27 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
return connpy_pb2.StructResponse(data=to_struct(res))
class ImportExportServicer(connpy_pb2_grpc.ImportExportServiceServicer):
def __init__(self, config):
self.service = ImportExportService(config)
self.node_service = NodeService(config)
def __init__(self, provider, registry=None):
if not hasattr(provider, "mode"):
from connpy.services.provider import ServiceProvider
provider = ServiceProvider(provider, mode="local")
self._fallback_provider = provider
self._registry = registry
def _get_provider(self):
if self._registry:
username = _current_user.get()
if username:
return self._registry.get_provider(username)
return self._fallback_provider
@property
def service(self):
return self._get_provider().import_export
@property
def node_service(self):
return self._get_provider().nodes
@handle_errors
def export_to_file(self, request, context):
@@ -815,14 +928,30 @@ class StatusBridge:
return default
class AIServicer(connpy_pb2_grpc.AIServiceServicer):
def __init__(self, config):
self.service = AIService(config)
def __init__(self, provider, registry=None):
if not hasattr(provider, "mode"):
from connpy.services.provider import ServiceProvider
provider = ServiceProvider(provider, mode="local")
self._fallback_provider = provider
self._registry = registry
def _get_provider(self):
if self._registry:
username = _current_user.get()
if username:
return self._registry.get_provider(username)
return self._fallback_provider
@property
def service(self):
return self._get_provider().ai
@handle_errors
def ask(self, request_iterator, context):
import queue
import threading
ai_service = self.service
chunk_queue = queue.Queue()
request_queue = queue.Queue()
bridge = None
@@ -840,7 +969,7 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
nonlocal history, bridge, agent_instance
try:
# Run the AI interaction (this blocks this specific thread)
res = self.service.ask(
res = ai_service.ask(
input_text,
chat_history=history if history else None,
session_id=session_id,
@@ -996,8 +1125,23 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
return connpy_pb2.StructResponse(data=to_struct(self.service.load_session_data(request.value)))
class SystemServicer(connpy_pb2_grpc.SystemServiceServicer):
def __init__(self, config):
self.service = SystemService(config)
def __init__(self, provider, registry=None):
if not hasattr(provider, "mode"):
from connpy.services.provider import ServiceProvider
provider = ServiceProvider(provider, mode="local")
self._fallback_provider = provider
self._registry = registry
def _get_provider(self):
if self._registry:
username = _current_user.get()
if username:
return self._registry.get_provider(username)
return self._fallback_provider
@property
def service(self):
return self._get_provider().system
@handle_errors
def start_api(self, request, context):
@@ -1023,6 +1167,138 @@ class SystemServicer(connpy_pb2_grpc.SystemServiceServicer):
def get_api_status(self, request, context):
return connpy_pb2.BoolResponse(value=self.service.get_api_status())
class AuthServicer(connpy_pb2_grpc.AuthServiceServicer):
def __init__(self, registry):
self.registry = registry
@handle_errors
def login(self, request, context):
username = request.username
password = request.password
if not self.registry.user_service.authenticate(username, password):
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid username or password")
token = self.registry.user_service.generate_jwt(username)
expires_at = int((datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8)).timestamp())
return connpy_pb2.LoginResponse(
token=token,
username=username,
expires_at=expires_at
)
@handle_errors
def change_password(self, request, context):
username = _current_user.get()
if not username:
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Authentication required")
try:
self.registry.user_service.change_password(username, request.old_password, request.new_password)
self.registry.evict(username)
except ValueError as e:
context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
return Empty()
class AuthInterceptor(grpc.ServerInterceptor):
OPEN_METHODS = ["/connpy.AuthService/login"]
def __init__(self, registry):
self.registry = registry
def intercept_service(self, continuation, handler_call_details):
method = handler_call_details.method
if method in self.OPEN_METHODS:
return continuation(handler_call_details)
if not self.registry.has_users():
return continuation(handler_call_details)
token = self._extract_token(handler_call_details.invocation_metadata)
if not token:
return self._unauthenticated_handler(handler_call_details, "Authorization token is missing")
username = self.registry.user_service.verify_jwt(token)
if not username:
return self._unauthenticated_handler(handler_call_details, "Invalid or expired token")
handler = continuation(handler_call_details)
if handler is None:
return None
return self._wrap_handler(handler, username)
def _wrap_handler(self, handler, username):
if handler.unary_unary:
original_behavior = handler.unary_unary
def wrapper(request, context):
token = _current_user.set(username)
try:
return original_behavior(request, context)
finally:
_current_user.reset(token)
return grpc.unary_unary_rpc_method_handler(
wrapper,
request_deserializer=handler.request_deserializer,
response_serializer=handler.response_serializer,
)
elif handler.unary_stream:
original_behavior = handler.unary_stream
def wrapper(request, context):
token = _current_user.set(username)
try:
for response in original_behavior(request, context):
yield response
finally:
_current_user.reset(token)
return grpc.unary_stream_rpc_method_handler(
wrapper,
request_deserializer=handler.request_deserializer,
response_serializer=handler.response_serializer,
)
elif handler.stream_unary:
original_behavior = handler.stream_unary
def wrapper(request_iterator, context):
token = _current_user.set(username)
try:
return original_behavior(request_iterator, context)
finally:
_current_user.reset(token)
return grpc.stream_unary_rpc_method_handler(
wrapper,
request_deserializer=handler.request_deserializer,
response_serializer=handler.response_serializer,
)
elif handler.stream_stream:
original_behavior = handler.stream_stream
def wrapper(request_iterator, context):
token = _current_user.set(username)
try:
for response in original_behavior(request_iterator, context):
yield response
finally:
_current_user.reset(token)
return grpc.stream_stream_rpc_method_handler(
wrapper,
request_deserializer=handler.request_deserializer,
response_serializer=handler.response_serializer,
)
return handler
def _extract_token(self, metadata):
for key, value in metadata:
if key.lower() == "authorization":
if value.startswith("Bearer "):
return value[7:]
return None
def _unauthenticated_handler(self, handler_call_details, message):
def abort_call(request_or_iterator, context):
context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
return grpc.unary_unary_rpc_method_handler(abort_call)
class LoggingInterceptor(grpc.ServerInterceptor):
def __init__(self):
from rich.console import Console
@@ -1047,19 +1323,30 @@ class LoggingInterceptor(grpc.ServerInterceptor):
return result
def serve(config, port=8048, debug=False):
interceptors = [LoggingInterceptor()] if debug else []
from connpy.grpc_layer.user_registry import UserRegistry
from connpy.services.provider import ServiceProvider
fallback_provider = ServiceProvider(config, mode="local")
registry = UserRegistry(config.defaultdir)
interceptors = []
if debug:
interceptors.append(LoggingInterceptor())
interceptors.append(AuthInterceptor(registry))
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), interceptors=interceptors)
connpy_pb2_grpc.add_NodeServiceServicer_to_server(NodeServicer(config, debug=debug), server)
connpy_pb2_grpc.add_ProfileServiceServicer_to_server(ProfileServicer(config), server)
connpy_pb2_grpc.add_ConfigServiceServicer_to_server(ConfigServicer(config), server)
plugin_servicer = PluginServicer(config)
connpy_pb2_grpc.add_NodeServiceServicer_to_server(NodeServicer(fallback_provider, registry=registry, debug=debug), server)
connpy_pb2_grpc.add_ProfileServiceServicer_to_server(ProfileServicer(fallback_provider, registry=registry), server)
connpy_pb2_grpc.add_ConfigServiceServicer_to_server(ConfigServicer(fallback_provider, registry=registry), server)
plugin_servicer = PluginServicer(fallback_provider, registry=registry)
connpy_pb2_grpc.add_PluginServiceServicer_to_server(plugin_servicer, server)
remote_plugin_pb2_grpc.add_RemotePluginServiceServicer_to_server(plugin_servicer, server)
connpy_pb2_grpc.add_ExecutionServiceServicer_to_server(ExecutionServicer(config), server)
connpy_pb2_grpc.add_ImportExportServiceServicer_to_server(ImportExportServicer(config), server)
connpy_pb2_grpc.add_AIServiceServicer_to_server(AIServicer(config), server)
connpy_pb2_grpc.add_SystemServiceServicer_to_server(SystemServicer(config), server)
connpy_pb2_grpc.add_ExecutionServiceServicer_to_server(ExecutionServicer(fallback_provider, registry=registry), server)
connpy_pb2_grpc.add_ImportExportServiceServicer_to_server(ImportExportServicer(fallback_provider, registry=registry), server)
connpy_pb2_grpc.add_AIServiceServicer_to_server(AIServicer(fallback_provider, registry=registry), server)
connpy_pb2_grpc.add_SystemServiceServicer_to_server(SystemServicer(fallback_provider, registry=registry), server)
connpy_pb2_grpc.add_AuthServiceServicer_to_server(AuthServicer(registry), server)
server.add_insecure_port(f'[::]:{port}')
server.start()
+75
View File
@@ -980,3 +980,78 @@ class SystemStub:
@handle_errors
def get_api_status(self):
return self.stub.get_api_status(Empty()).value
class _ClientCallDetails(object):
def __init__(self, method, timeout, metadata, credentials, wait_for_ready, compression=None):
self.method = method
self.timeout = timeout
self.metadata = metadata
self.credentials = credentials
self.wait_for_ready = wait_for_ready
self.compression = compression
class AuthClientInterceptor(grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor):
def __init__(self, token_provider):
self.token_provider = token_provider
def _add_metadata(self, client_call_details):
token = self.token_provider()
if not token:
return client_call_details
metadata = []
if client_call_details.metadata:
metadata = list(client_call_details.metadata)
# Check if already present to avoid duplicates
if not any(k.lower() == "authorization" for k, v in metadata):
metadata.append(("authorization", f"Bearer {token}"))
return _ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready,
compression=client_call_details.compression,
)
def intercept_unary_unary(self, continuation, client_call_details, request):
new_details = self._add_metadata(client_call_details)
return continuation(new_details, request)
def intercept_unary_stream(self, continuation, client_call_details, request):
new_details = self._add_metadata(client_call_details)
return continuation(new_details, request)
def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
new_details = self._add_metadata(client_call_details)
return continuation(new_details, request_iterator)
def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
new_details = self._add_metadata(client_call_details)
return continuation(new_details, request_iterator)
class AuthStub:
def __init__(self, channel, remote_host):
self.stub = connpy_pb2_grpc.AuthServiceStub(channel)
self.remote_host = remote_host
@handle_errors
def login(self, username, password):
req = connpy_pb2.LoginRequest(username=username, password=password)
resp = self.stub.login(req)
return {
"token": resp.token,
"username": resp.username,
"expires_at": resp.expires_at
}
@handle_errors
def change_password(self, old_password, new_password):
req = connpy_pb2.ChangePasswordRequest(old_password=old_password, new_password=new_password)
self.stub.change_password(req)
+107
View File
@@ -0,0 +1,107 @@
import os
import threading
from connpy.configfile import configfile
from connpy.services.provider import ServiceProvider
from connpy.services.user_service import UserService
class UserRegistry:
"""Holds per-user ServiceProviders in memory, thread-safe with hot-reloading."""
def __init__(self, server_config_dir):
self.server_config_dir = os.path.abspath(server_config_dir)
self.user_service = UserService(self.server_config_dir)
self._providers = {} # username → ServiceProvider
self._mtimes = {} # username → last loaded mtime (float)
self._lock = threading.Lock()
# Load shared/global config
self._shared_conf_file = os.path.join(self.server_config_dir, "config.yaml")
if os.path.exists(self._shared_conf_file):
self._shared_config = configfile(conf=self._shared_conf_file)
self._shared_mtime = os.path.getmtime(self._shared_conf_file)
else:
self._shared_config = None
self._shared_mtime = 0.0
def _refresh_shared(self):
"""Hot-reload shared config if the file changed on disk."""
if not os.path.exists(self._shared_conf_file):
return
current_mtime = os.path.getmtime(self._shared_conf_file)
if current_mtime > self._shared_mtime:
try:
self._shared_config = configfile(conf=self._shared_conf_file)
self._shared_mtime = current_mtime
# Clear all user providers so they pick up the new shared config
self._providers.clear()
self._mtimes.clear()
except Exception as e:
from connpy import printer
printer.warning(f"Failed to reload shared config: {e}")
def get_provider(self, username) -> ServiceProvider:
"""Get, lazy-load, or hot-reload a user's full ServiceProvider."""
with self._lock:
# Refresh shared/global config if it has changed
self._refresh_shared()
# 1. Resolve physical path of the user's config.yaml file
user_data = self.user_service.get_user(username)
config_path = user_data.get("config_path")
if config_path:
conf_file = os.path.join(config_path, "config.yaml")
else:
conf_file = os.path.join(self.server_config_dir, "users", username, "config.yaml")
# 2. Retrieve actual modification time in disk
current_mtime = os.path.getmtime(conf_file) if os.path.exists(conf_file) else 0.0
# 3. Validate if initial load or hot-reload is required
if username not in self._providers or self._mtimes.get(username, 0.0) < current_mtime:
old_provider = self._providers.get(username)
try:
# Attempt a fresh configuration load
config = configfile(conf=conf_file, shared_config=self._shared_config)
new_provider = ServiceProvider(config, mode="local")
# Successfully loaded, clean up the old provider
if old_provider:
self._providers.pop(username, None)
if hasattr(old_provider, "close"):
try:
old_provider.close()
except Exception:
pass
self._providers[username] = new_provider
self._mtimes[username] = current_mtime
except Exception as e:
# Log warning but fallback to the old stable provider in memory if available
from connpy import printer
printer.warning(f"Failed to hot-reload config for user '{username}' (file may be corrupt/incomplete): {e}")
if old_provider:
# Keep serving with the old cached instance to ensure service continuity
self._mtimes[username] = current_mtime
else:
# No fallback exists, propagate the exception
raise e
return self._providers[username]
def has_users(self) -> bool:
"""Check if any users are registered (enables auth enforcement)."""
return bool(self.user_service.list_users())
def evict(self, username):
"""Remove and cleanly shut down cached provider (after delete or password change)."""
with self._lock:
provider = self._providers.pop(username, None)
self._mtimes.pop(username, None)
if provider:
# Explicit cleanup of user-scoped resources if custom close/cleanup exists
if hasattr(provider, "close"):
try:
provider.close()
except Exception:
pass
+4 -1
View File
@@ -48,7 +48,10 @@ class MCPClientManager:
all_llm_tools = []
try:
mcp_config = self.config.config.get("ai", {}).get("mcp_servers", {})
if hasattr(self.config, "get_effective_setting"):
mcp_config = self.config.get_effective_setting("ai", {}).get("mcp_servers", {})
else:
mcp_config = self.config.config.get("ai", {}).get("mcp_servers", {}) if hasattr(self.config, "config") else {}
except Exception:
return []
+21
View File
@@ -296,3 +296,24 @@ message MCPRequest {
string auto_load_on_os = 4;
bool remove = 5;
}
service AuthService {
rpc login (LoginRequest) returns (LoginResponse) {}
rpc change_password (ChangePasswordRequest) returns (google.protobuf.Empty) {}
}
message LoginRequest {
string username = 1;
string password = 2;
}
message LoginResponse {
string token = 1;
string username = 2;
int64 expires_at = 3;
}
message ChangePasswordRequest {
string old_password = 1;
string new_password = 2;
}
+4 -1
View File
@@ -307,7 +307,10 @@ class AIService(BaseService):
def list_mcp_servers(self) -> dict:
"""Get the configured MCP servers."""
ai_settings = self.config.config.get("ai", {})
if hasattr(self.config, "get_effective_setting"):
ai_settings = self.config.get_effective_setting("ai", {})
else:
ai_settings = self.config.config.get("ai", {}) if hasattr(self.config, "config") else {}
return ai_settings.get("mcp_servers", {})
def load_session_data(self, session_id):
+27 -1
View File
@@ -33,6 +33,7 @@ class ServiceProvider:
from .import_export_service import ImportExportService
from .context_service import ContextService
from .sync_service import SyncService
from .user_service import UserService
self.nodes = NodeService(self.config)
self.profiles = ProfileService(self.config)
@@ -44,6 +45,7 @@ class ServiceProvider:
self.import_export = ImportExportService(self.config)
self.context = ContextService(self.config)
self.sync = SyncService(self.config)
self.users = UserService(self.config.defaultdir)
def _init_remote(self):
# Allow ConfigService to work locally so the user can revert the mode
@@ -53,14 +55,37 @@ class ServiceProvider:
self.config_svc = ConfigService(self.config)
self.context = ContextService(self.config)
self.sync = SyncService(self.config)
self.users = None
if not self.remote_host:
raise InvalidConfigurationError("Remote host must be specified in remote mode")
import grpc
from ..grpc_layer.stubs import NodeStub, ProfileStub, PluginStub, AIStub, ExecutionStub, ImportExportStub, SystemStub
import os
from ..grpc_layer.stubs import (
NodeStub, ProfileStub, PluginStub, AIStub,
ExecutionStub, ImportExportStub, SystemStub,
ConfigStub, AuthClientInterceptor, AuthStub
)
def get_token():
token_path = os.path.join(self.config.defaultdir, ".token")
if os.path.exists(token_path):
try:
with open(token_path, "r") as f:
return f.read().strip()
except Exception:
pass
return None
channel = grpc.insecure_channel(self.remote_host)
interceptor = AuthClientInterceptor(get_token)
channel = grpc.intercept_channel(channel, interceptor)
# Surgical fix: Keep ConfigService local for mode/theme management,
# but delegate encryption to the server stub.
config_remote = ConfigStub(channel, remote_host=self.remote_host)
self.config_svc.encrypt_password = config_remote.encrypt_password
self.nodes = NodeStub(channel, remote_host=self.remote_host, config=self.config)
self.profiles = ProfileStub(channel, remote_host=self.remote_host, node_stub=self.nodes)
@@ -69,3 +94,4 @@ class ServiceProvider:
self.system = SystemStub(channel, remote_host=self.remote_host)
self.execution = ExecutionStub(channel, remote_host=self.remote_host)
self.import_export = ImportExportStub(channel, remote_host=self.remote_host)
self.auth = AuthStub(channel, remote_host=self.remote_host)
+237
View File
@@ -0,0 +1,237 @@
import os
import re
import shutil
import secrets
import datetime
import bcrypt
import jwt
import yaml
from pathlib import Path
from connpy.configfile import configfile
class UserService:
def __init__(self, config_dir):
self.config_dir = os.path.abspath(config_dir)
self.users_dir = os.path.join(self.config_dir, "users")
self.registry_file = os.path.join(self.users_dir, "registry.yaml")
# Ensure users directory exists
os.makedirs(self.users_dir, exist_ok=True)
def _load_registry(self) -> dict:
"""Loads registry from file. If it doesn't exist, initializes it with a new JWT secret."""
if not os.path.exists(self.registry_file):
registry = {
"jwt_secret": secrets.token_hex(32),
"users": {}
}
self._save_registry(registry)
return registry
try:
with open(self.registry_file, "r") as f:
registry = yaml.safe_load(f) or {}
except Exception:
registry = {}
if not isinstance(registry, dict):
registry = {}
if "jwt_secret" not in registry:
registry["jwt_secret"] = secrets.token_hex(32)
if "users" not in registry or not isinstance(registry["users"], dict):
registry["users"] = {}
return registry
def _save_registry(self, data: dict):
"""Safely saves registry structure to registry.yaml."""
tmp_file = self.registry_file + ".tmp"
try:
with open(tmp_file, "w") as f:
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
os.replace(tmp_file, self.registry_file)
os.chmod(self.registry_file, 0o600)
except Exception as e:
if os.path.exists(tmp_file):
try:
os.remove(tmp_file)
except OSError:
pass
raise e
def create_user(self, username, password, config_path=None) -> dict:
"""Creates a new user with bcrypt-hashed credentials.
Mode A: config_path=None (fresh user) -> Generates config.yaml and .osk key.
Mode B: config_path set -> Reuses existing directory after validating its structure.
"""
if not username or not isinstance(username, str):
raise ValueError("Username cannot be empty")
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
raise ValueError("Username must contain only alphanumeric characters, dashes, or underscores")
if not password or not isinstance(password, str):
raise ValueError("Password cannot be empty")
registry = self._load_registry()
if username in registry["users"]:
raise ValueError(f"User '{username}' already exists")
# Resolve path and initialize configuration
if config_path is None:
user_dir = os.path.join(self.users_dir, username)
os.makedirs(user_dir, exist_ok=True)
# Create subdirs for plugins and sessions
os.makedirs(os.path.join(user_dir, "plugins"), exist_ok=True)
os.makedirs(os.path.join(user_dir, "ai_sessions"), exist_ok=True)
# Create default config.yaml & .osk key via configfile
conf_file = os.path.join(user_dir, "config.yaml")
configfile(conf=conf_file)
stored_config_path = None
else:
abs_config_path = os.path.abspath(config_path)
os.makedirs(abs_config_path, exist_ok=True)
# Create subdirs for plugins and sessions in the custom path
os.makedirs(os.path.join(abs_config_path, "plugins"), exist_ok=True)
os.makedirs(os.path.join(abs_config_path, "ai_sessions"), exist_ok=True)
# Create default config.yaml & .osk key via configfile if config.yaml is not present
conf_file = os.path.join(abs_config_path, "config.yaml")
if not os.path.exists(conf_file):
configfile(conf=conf_file)
stored_config_path = abs_config_path
# Hash password securely
password_hash = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
user_entry = {
"password_hash": password_hash,
"config_path": stored_config_path,
"created": datetime.datetime.now(datetime.timezone.utc).isoformat()
}
registry["users"][username] = user_entry
self._save_registry(registry)
return {
"username": username,
"config_path": stored_config_path,
"created": user_entry["created"]
}
def delete_user(self, username):
"""Removes user from the registry and cleans up config directory if server-managed."""
registry = self._load_registry()
if username not in registry["users"]:
raise ValueError(f"User '{username}' not found")
user_data = registry["users"][username]
config_path = user_data.get("config_path")
if config_path is None:
user_dir = os.path.join(self.users_dir, username)
if os.path.exists(user_dir):
shutil.rmtree(user_dir, ignore_errors=True)
del registry["users"][username]
self._save_registry(registry)
def list_users(self) -> list[dict]:
"""Lists all registered users with metadata."""
registry = self._load_registry()
return [
{
"username": name,
"config_path": data.get("config_path"),
"created": data.get("created")
}
for name, data in registry.get("users", {}).items()
]
def get_user(self, username) -> dict:
"""Retrieves raw metadata for a specific user."""
registry = self._load_registry()
if username not in registry["users"]:
raise ValueError(f"User '{username}' not found")
data = registry["users"][username]
return {
"username": username,
"config_path": data.get("config_path"),
"created": data.get("created"),
"password_hash": data.get("password_hash")
}
def change_password(self, username, old_password, new_password):
"""Verifies old password and updates registry with new hashed password."""
if not new_password or not isinstance(new_password, str):
raise ValueError("New password cannot be empty")
registry = self._load_registry()
if username not in registry["users"]:
raise ValueError(f"User '{username}' not found")
user_data = registry["users"][username]
if not bcrypt.checkpw(old_password.encode("utf-8"), user_data["password_hash"].encode("utf-8")):
raise ValueError("Invalid credentials")
# Update hash
user_data["password_hash"] = bcrypt.hashpw(new_password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
self._save_registry(registry)
def admin_change_password(self, username, new_password):
"""Administrative password override (does not require old password)."""
if not new_password or not isinstance(new_password, str):
raise ValueError("New password cannot be empty")
registry = self._load_registry()
if username not in registry["users"]:
raise ValueError(f"User '{username}' not found")
user_data = registry["users"][username]
user_data["password_hash"] = bcrypt.hashpw(new_password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
self._save_registry(registry)
def authenticate(self, username, password) -> bool:
"""Verifies if the credentials are valid using bcrypt."""
registry = self._load_registry()
if username not in registry["users"]:
return False
user_data = registry["users"][username]
return bcrypt.checkpw(password.encode("utf-8"), user_data["password_hash"].encode("utf-8"))
def generate_jwt(self, username) -> str:
"""Generates a secure JSON Web Token for the user expiring in 8 hours."""
registry = self._load_registry()
if username not in registry["users"]:
raise ValueError(f"User '{username}' not found")
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8)
payload = {
"sub": username,
"exp": expiration
}
token = jwt.encode(payload, registry["jwt_secret"], algorithm="HS256")
if isinstance(token, bytes):
token = token.decode("utf-8")
return token
def verify_jwt(self, token) -> str | None:
"""Decodes JWT and returns username if token is valid and unexpired."""
registry = self._load_registry()
try:
payload = jwt.decode(token, registry["jwt_secret"], algorithms=["HS256"])
return payload.get("sub")
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError, KeyError):
return None
+186
View File
@@ -0,0 +1,186 @@
import os
import pytest
import grpc
import argparse
from unittest.mock import MagicMock, patch
from connpy.connapp import connapp
from connpy.services.provider import ServiceProvider
from connpy.cli.user_handler import UserHandler
from connpy.cli.login_handler import LoginHandler
from connpy.grpc_layer.stubs import AuthClientInterceptor, AuthStub
@pytest.fixture
def mock_config():
config = MagicMock()
config.config = {"service_mode": "local", "remote_host": "localhost:8048"}
config.defaultdir = "/mock/default/dir"
return config
@pytest.fixture
def app_instance(mock_config):
with patch("connpy.services.provider.ServiceProvider") as mock_provider_cls:
mock_provider = MagicMock()
mock_provider.context = MagicMock()
mock_provider.nodes = MagicMock()
mock_provider.profiles = MagicMock()
mock_provider.config_svc = MagicMock()
mock_provider.plugins = MagicMock()
mock_provider.sync = MagicMock()
mock_provider.mode = "local"
mock_provider.remote_host = "localhost:8048"
mock_provider_cls.return_value = mock_provider
app = connapp(mock_config)
# Mock UserService on app services
app.services.users = MagicMock()
return app
class TestCLIMultiUserParsing:
def test_parser_contains_user_login_logout(self, app_instance):
parser, _ = app_instance.get_parser()
# Verify subcommands exist by finding the _SubParsersAction
subparsers_action = None
for action in parser._actions:
if isinstance(action, argparse._SubParsersAction):
subparsers_action = action
break
assert subparsers_action is not None
subcommands = subparsers_action.choices.keys()
assert "user" in subcommands
assert "login" in subcommands
assert "logout" in subcommands
def test_user_parser_arguments(self, app_instance):
parser, _ = app_instance.get_parser()
# Parse add user
args = parser.parse_args(["user", "--add", "newguy"])
assert args.add == ["newguy"]
assert args.func == app_instance._user.dispatch
# Parse delete user
args = parser.parse_args(["user", "--del", "oldguy"])
assert args.delete == ["oldguy"]
# Parse list users
args = parser.parse_args(["user", "--list"])
assert args.list is True
# Parse show user
args = parser.parse_args(["user", "--show", "someguy"])
assert args.show == ["someguy"]
# Parse regen-password
args = parser.parse_args(["user", "--regen-password", "someguy"])
assert args.regen_password == ["someguy"]
# Parse path
args = parser.parse_args(["user", "--add", "newguy", "--path", "/some/path"])
assert args.add == ["newguy"]
assert args.path == ["/some/path"]
def test_login_logout_parser_arguments(self, app_instance):
parser, _ = app_instance.get_parser()
args = parser.parse_args(["login", "someuser"])
assert args.username == "someuser"
assert args.func == app_instance._login.dispatch
args = parser.parse_args(["logout"])
assert args.func == app_instance._login.dispatch
class TestUserHandlerDispatch:
def test_user_handler_fails_in_remote_mode(self, app_instance):
app_instance.services.mode = "remote"
handler = UserHandler(app_instance)
args = MagicMock()
args.add = ["testuser"]
with pytest.raises(SystemExit) as excinfo:
handler.dispatch(args)
assert excinfo.value.code == 1
def test_user_handler_routes_add_correctly(self, app_instance):
app_instance.services.mode = "local"
handler = UserHandler(app_instance)
args = MagicMock()
args.add = ["newuser"]
args.delete = None
args.list = False
args.show = None
args.regen_password = None
with patch.object(handler, "add_user") as mock_add:
handler.dispatch(args)
assert args.action == "add"
assert args.username == "newuser"
mock_add.assert_called_once_with(args)
def test_user_handler_routes_list_correctly(self, app_instance):
app_instance.services.mode = "local"
handler = UserHandler(app_instance)
args = MagicMock()
args.add = None
args.delete = None
args.list = True
args.show = None
args.regen_password = None
with patch.object(handler, "list_users") as mock_list:
handler.dispatch(args)
assert args.action == "list"
mock_list.assert_called_once_with(args)
class TestAuthClientInterceptor:
def test_auth_client_interceptor_adds_bearer_token(self):
# Mock token provider
token_provider = MagicMock(return_value="my-super-secret-token")
interceptor = AuthClientInterceptor(token_provider)
# Mock ClientCallDetails using namedtuple
from collections import namedtuple
ClientCallDetails = namedtuple('ClientCallDetails', ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'])
mock_details = ClientCallDetails(
method="/connpy.NodeService/list_nodes",
timeout=10,
metadata=[],
credentials=None,
wait_for_ready=True,
compression=None
)
intercepted_details = interceptor._add_metadata(mock_details)
# Verify metadata was injected
metadata_dict = dict(intercepted_details.metadata)
assert "authorization" in metadata_dict
assert metadata_dict["authorization"] == "Bearer my-super-secret-token"
def test_auth_client_interceptor_no_token(self):
token_provider = MagicMock(return_value=None)
interceptor = AuthClientInterceptor(token_provider)
from collections import namedtuple
ClientCallDetails = namedtuple('ClientCallDetails', ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'])
mock_details = ClientCallDetails(
method="/connpy.NodeService/list_nodes",
timeout=10,
metadata=[],
credentials=None,
wait_for_ready=True,
compression=None
)
intercepted_details = interceptor._add_metadata(mock_details)
# Verify metadata remains empty
assert len(intercepted_details.metadata) == 0
+58
View File
@@ -141,4 +141,62 @@ class TestTreeCompletions:
assert "stop" in loop_back_comp
class TestUserCompletions:
def test_user_command_options(self):
from connpy.completion import _build_tree, resolve_completion
tree = _build_tree([], [], [], {}, "/tmp")
# Test options at the "user" level
user_completions = resolve_completion(["user", ""], tree)
assert "--add" in user_completions
assert "--del" in user_completions
assert "--rm" in user_completions
assert "--show" in user_completions
assert "--regen-password" in user_completions
assert "--list" in user_completions
assert "--ls" in user_completions
def test_user_action_completed_users(self, tmp_path):
from connpy.completion import _build_tree, resolve_completion
import yaml
# Create users directory and mock registry
users_dir = tmp_path / "users"
users_dir.mkdir()
registry_file = users_dir / "registry.yaml"
registry_data = {
"users": {
"fluzzi": {"password_hash": "hash1"},
"john": {"password_hash": "hash2"}
}
}
with open(registry_file, "w") as f:
yaml.dump(registry_data, f)
tree = _build_tree([], [], [], {}, str(tmp_path))
# Resolve after --del, --rm, --show, --regen-password
for action in ["--del", "--rm", "--show", "--regen-password"]:
completions = resolve_completion(["user", action, ""], tree)
assert "fluzzi" in completions
assert "john" in completions
# --add username completed options
add_completions = resolve_completion(["user", "--add", "newguy", ""], tree)
assert "--path" in add_completions
def test_login_logout_completions(self):
from connpy.completion import _build_tree, resolve_completion
tree = _build_tree([], [], [], {}, "/tmp")
# Test login option resolution
login_completions = resolve_completion(["login", ""], tree)
assert "--help" in login_completions
# Test logout option resolution
logout_completions = resolve_completion(["logout", ""], tree)
assert "--help" in logout_completions
+2 -2
View File
@@ -165,9 +165,9 @@ def test_ai(mock_status, mock_ask, app):
@patch("connpy.services.execution_service.ExecutionService.run_commands")
def test_run(mock_run_commands, app):
app.start(["run", "node1", "command1", "command2"])
app.start(["run", "router1", "command1", "command2"])
mock_run_commands.assert_called_once()
assert mock_run_commands.call_args[1]["nodes_filter"] == "node1"
assert mock_run_commands.call_args[1]["nodes_filter"] == ["router1"]
assert mock_run_commands.call_args[1]["commands"] == ["command1 command2"]
@patch("os.path.exists")
+131
View File
@@ -0,0 +1,131 @@
import os
import pytest
import grpc
from concurrent import futures
from google.protobuf.empty_pb2 import Empty
from connpy.grpc_layer import server, connpy_pb2, connpy_pb2_grpc, stubs
from connpy.grpc_layer.user_registry import UserRegistry
from connpy.services.provider import ServiceProvider
from connpy.configfile import configfile
@pytest.fixture
def test_config_dir(tmp_path):
"""Creates a temporary config directory for testing gRPC auth."""
config_dir = tmp_path / "conn_config"
config_dir.mkdir()
# Initialize basic config file inside it
from connpy.configfile import configfile
conf_file = os.path.join(str(config_dir), "config.yaml")
configfile(conf=conf_file)
return config_dir
@pytest.fixture
def registry(test_config_dir):
"""Initializes UserRegistry."""
return UserRegistry(str(test_config_dir))
@pytest.fixture
def auth_grpc_server(test_config_dir, registry):
"""Starts an authenticated local gRPC server for integration testing."""
srv = grpc.server(
futures.ThreadPoolExecutor(max_workers=5),
interceptors=[server.AuthInterceptor(registry)]
)
fallback_provider = ServiceProvider(configfile(conf=os.path.join(str(test_config_dir), "config.yaml")), mode="local")
# Register services
connpy_pb2_grpc.add_NodeServiceServicer_to_server(server.NodeServicer(fallback_provider, registry=registry), srv)
connpy_pb2_grpc.add_AuthServiceServicer_to_server(server.AuthServicer(registry), srv)
port = srv.add_insecure_port('127.0.0.1:0')
srv.start()
yield f"127.0.0.1:{port}"
srv.stop(0)
@pytest.fixture
def channel(auth_grpc_server):
with grpc.insecure_channel(auth_grpc_server) as channel:
yield channel
class TestGRPCAuthentication:
def test_backward_compatibility_no_users(self, channel, registry):
"""Verifies that if no users are registered, gRPC calls proceed without authentication."""
assert registry.has_users() is False
# Calling NodeService list_nodes should succeed without any authorization metadata
stub = connpy_pb2_grpc.NodeServiceStub(channel)
req = connpy_pb2.FilterRequest()
res = stub.list_nodes(req)
assert res is not None
def test_login_and_authenticated_calls(self, channel, registry):
"""Tests user creation, login to retrieve JWT, and using JWT to access protected endpoints."""
username = "alice"
password = "alicepassword"
# 1. Register a user in the registry
registry.user_service.create_user(username, password)
assert registry.has_users() is True
# 2. Try unauthenticated call - must fail with UNAUTHENTICATED
node_stub = connpy_pb2_grpc.NodeServiceStub(channel)
req = connpy_pb2.FilterRequest()
with pytest.raises(grpc.RpcError) as exc:
node_stub.list_nodes(req)
assert exc.value.code() == grpc.StatusCode.UNAUTHENTICATED
assert "Authorization token is missing" in exc.value.details()
# 3. Call login endpoint (open method) - must succeed
auth_stub = connpy_pb2_grpc.AuthServiceStub(channel)
login_req = connpy_pb2.LoginRequest(username=username, password=password)
login_res = auth_stub.login(login_req)
assert login_res.username == username
assert isinstance(login_res.token, str)
assert login_res.expires_at > 0
# 4. Make authenticated call using Bearer token - must succeed
metadata = [("authorization", f"Bearer {login_res.token}")]
res = node_stub.list_nodes(req, metadata=metadata)
assert res is not None
def test_login_invalid_credentials(self, channel, registry):
"""Verifies login fails and returns UNAUTHENTICATED for incorrect credentials."""
registry.user_service.create_user("bob", "bobpass")
auth_stub = connpy_pb2_grpc.AuthServiceStub(channel)
login_req = connpy_pb2.LoginRequest(username="bob", password="wrongpassword")
with pytest.raises(grpc.RpcError) as exc:
auth_stub.login(login_req)
assert exc.value.code() == grpc.StatusCode.UNAUTHENTICATED
assert "Invalid username or password" in exc.value.details()
def test_change_password(self, channel, registry):
"""Tests changing password via gRPC and verifying old password no longer works."""
username = "charlie"
registry.user_service.create_user(username, "oldpass")
auth_stub = connpy_pb2_grpc.AuthServiceStub(channel)
# 1. Login with old password to get token
login_res = auth_stub.login(connpy_pb2.LoginRequest(username=username, password="oldpass"))
token = login_res.token
# 2. Change password via gRPC using the token
metadata = [("authorization", f"Bearer {token}")]
change_req = connpy_pb2.ChangePasswordRequest(old_password="oldpass", new_password="newpass")
auth_stub.change_password(change_req, metadata=metadata)
# 3. Logging in with old password must fail
with pytest.raises(grpc.RpcError) as exc:
auth_stub.login(connpy_pb2.LoginRequest(username=username, password="oldpass"))
assert exc.value.code() == grpc.StatusCode.UNAUTHENTICATED
# 4. Logging in with new password must succeed
login_res_new = auth_stub.login(connpy_pb2.LoginRequest(username=username, password="newpass"))
assert login_res_new.token is not None
+67
View File
@@ -0,0 +1,67 @@
import os
import pytest
from connpy.grpc_layer.server import NodeServicer, _current_user
from connpy.grpc_layer.user_registry import UserRegistry
from connpy.services.provider import ServiceProvider
@pytest.fixture
def test_config_dir(tmp_path):
"""Creates a temporary config directory for testing user registry."""
config_dir = tmp_path / "conn_config"
config_dir.mkdir()
return config_dir
@pytest.fixture
def registry(test_config_dir):
"""Initializes UserRegistry pointing to a temporary directory."""
return UserRegistry(str(test_config_dir))
def test_dynamic_routing_isolation(test_config_dir, registry):
"""Verifies that NodeServicer routes list_nodes to the correct user configuration based on _current_user ContextVar."""
# Setup fallback provider
from connpy.configfile import configfile
conf_file = os.path.join(registry.user_service.config_dir, "config.yaml")
config = configfile(conf=conf_file)
fallback_provider = ServiceProvider(config, mode="local")
# Create servicer with fallback and registry
servicer = NodeServicer(fallback_provider, registry=registry)
# Register two users
u1 = "user1"
u2 = "user2"
registry.user_service.create_user(u1, "pass1")
registry.user_service.create_user(u2, "pass2")
p1 = registry.get_provider(u1)
p2 = registry.get_provider(u2)
# Add nodes to each user's provider
p1.nodes.add_node("node-for-user-1", {"host": "1.1.1.1"})
p2.nodes.add_node("node-for-user-2", {"host": "2.2.2.2"})
# Verify fallback is empty
fallback_res = servicer.list_nodes(type('Request', (), {'filter_str': None, 'format_str': None})(), None)
from connpy.grpc_layer.utils import from_value
assert "node-for-user-1" not in from_value(fallback_res.data)
assert "node-for-user-2" not in from_value(fallback_res.data)
# Set context to User 1
t1 = _current_user.set(u1)
try:
res1 = servicer.list_nodes(type('Request', (), {'filter_str': None, 'format_str': None})(), None)
nodes1 = from_value(res1.data)
assert "node-for-user-1" in nodes1
assert "node-for-user-2" not in nodes1
finally:
_current_user.reset(t1)
# Set context to User 2
t2 = _current_user.set(u2)
try:
res2 = servicer.list_nodes(type('Request', (), {'filter_str': None, 'format_str': None})(), None)
nodes2 = from_value(res2.data)
assert "node-for-user-2" in nodes2
assert "node-for-user-1" not in nodes2
finally:
_current_user.reset(t2)
+162
View File
@@ -0,0 +1,162 @@
import os
import time
import pytest
import yaml
from connpy.configfile import configfile
from connpy.grpc_layer.user_registry import UserRegistry
from connpy.services.provider import ServiceProvider
@pytest.fixture
def temp_config_dir(tmp_path):
"""Creates a temporary config directory for testing."""
config_dir = tmp_path / "conn_shared_test"
config_dir.mkdir()
return config_dir
def test_shared_ai_deep_merge(temp_config_dir):
"""Test get_effective_setting deep merge logic for 'ai' settings."""
shared_dir = os.path.join(temp_config_dir, "shared")
user_dir = os.path.join(temp_config_dir, "user")
os.makedirs(shared_dir, exist_ok=True)
os.makedirs(user_dir, exist_ok=True)
shared_path = os.path.join(shared_dir, "config.yaml")
user_path = os.path.join(user_dir, "config.yaml")
# Write shared configuration
shared_data = {
"config": {
"theme": "dark",
"case": False,
"ai": {
"engineer_model": "shared-eng-model",
"architect_model": "shared-arch-model",
"engineer_api_key": "shared-key",
"mcp_servers": {
"global-server": {
"url": "http://global-server/sse",
"enabled": True
},
"override-server": {
"url": "http://override-shared/sse",
"enabled": True
}
}
}
},
"connections": {},
"profiles": {}
}
with open(shared_path, "w") as f:
yaml.safe_dump(shared_data, f)
# Write user configuration with overrides
user_data = {
"config": {
"case": True,
"ai": {
"engineer_model": "user-custom-eng-model",
"mcp_servers": {
"override-server": {
"enabled": False
},
"user-server": {
"url": "http://user-server/sse",
"enabled": True
}
}
}
},
"connections": {},
"profiles": {}
}
with open(user_path, "w") as f:
yaml.safe_dump(user_data, f)
# Initialize configfile instances
shared_config = configfile(conf=shared_path)
user_config = configfile(conf=user_path, shared_config=shared_config)
# Verify non-inheritable settings (theme, case)
assert user_config.get_effective_setting("case") is True
assert user_config.get_effective_setting("theme") is None # Should NOT inherit "theme"
# Verify AI setting deep merge
effective_ai = user_config.get_effective_setting("ai")
# Model override
assert effective_ai.get("engineer_model") == "user-custom-eng-model"
# Model inheritance
assert effective_ai.get("architect_model") == "shared-arch-model"
# API key inheritance
assert effective_ai.get("engineer_api_key") == "shared-key"
# MCP Servers merge
mcp = effective_ai.get("mcp_servers", {})
# Inherited server
assert "global-server" in mcp
assert mcp["global-server"]["url"] == "http://global-server/sse"
assert mcp["global-server"]["enabled"] is True
# Merged & overridden server
assert "override-server" in mcp
assert mcp["override-server"]["url"] == "http://override-shared/sse" # inherited
assert mcp["override-server"]["enabled"] is False # overridden
# User-only server
assert "user-server" in mcp
assert mcp["user-server"]["url"] == "http://user-server/sse"
def test_registry_injection_and_hot_reload(temp_config_dir):
"""Test that UserRegistry correctly injects shared config and hot-reloads it when it changes on disk."""
registry = UserRegistry(str(temp_config_dir))
# Define paths
shared_path = os.path.join(temp_config_dir, "config.yaml")
# 1. Create a global config file
global_data = {
"config": {
"ai": {
"engineer_api_key": "global-initial-key",
"engineer_model": "global-model"
}
},
"connections": {},
"profiles": {}
}
with open(shared_path, "w") as f:
yaml.safe_dump(global_data, f)
# Re-init registry to pick up the newly created shared config file
registry = UserRegistry(str(temp_config_dir))
# Register user
username = "testuser"
registry.user_service.create_user(username, "testpassword")
# Check initial injection
provider = registry.get_provider(username)
ai_settings = provider.config.get_effective_setting("ai")
assert ai_settings.get("engineer_api_key") == "global-initial-key"
assert ai_settings.get("engineer_model") == "global-model"
# 2. Modify global config on disk
global_data["config"]["ai"]["engineer_api_key"] = "global-updated-key"
# Sleep briefly to ensure mtime change is detectable
time.sleep(0.1)
with open(shared_path, "w") as f:
yaml.safe_dump(global_data, f)
# Set the mtime forward explicitly to avoid filesystem resolution limits
new_mtime = os.path.getmtime(shared_path) + 10.0
os.utime(shared_path, (new_mtime, new_mtime))
# Retrieve provider again - should trigger hot-reload of shared config
provider2 = registry.get_provider(username)
ai_settings_updated = provider2.config.get_effective_setting("ai")
assert ai_settings_updated.get("engineer_api_key") == "global-updated-key"
assert ai_settings_updated.get("engineer_model") == "global-model"
+134
View File
@@ -0,0 +1,134 @@
import os
import pytest
from connpy.grpc_layer.user_registry import UserRegistry
from connpy.services.provider import ServiceProvider
@pytest.fixture
def test_config_dir(tmp_path):
"""Creates a temporary config directory for testing user registry."""
config_dir = tmp_path / "conn_config"
config_dir.mkdir()
return config_dir
@pytest.fixture
def registry(test_config_dir):
"""Initializes UserRegistry pointing to a temporary directory."""
return UserRegistry(str(test_config_dir))
class TestUserRegistry:
def test_has_users_empty(self, registry):
"""Verifies has_users is False when no users exist."""
assert registry.has_users() is False
def test_get_provider_returns_service_provider(self, registry):
"""Tests that get_provider lazy-loads a valid ServiceProvider instance."""
username = "alice"
registry.user_service.create_user(username, "password")
assert registry.has_users() is True
provider = registry.get_provider(username)
assert isinstance(provider, ServiceProvider)
assert provider.mode == "local"
def test_get_provider_cached(self, registry):
"""Verifies that subsequent calls return the cached singleton instance."""
username = "bob"
registry.user_service.create_user(username, "password")
p1 = registry.get_provider(username)
p2 = registry.get_provider(username)
assert p1 is p2 # must be exact same object reference
def test_two_users_isolated(self, registry):
"""Ensures different users get completely separate ServiceProviders and configs."""
u1 = "user1"
u2 = "user2"
registry.user_service.create_user(u1, "pass1")
registry.user_service.create_user(u2, "pass2")
p1 = registry.get_provider(u1)
p2 = registry.get_provider(u2)
assert p1 is not p2
assert p1.config is not p2.config
# Add a node for user1 and verify user2 is unaffected
p1.nodes.add_node("node1", {"host": "1.1.1.1"})
assert "node1" in p1.nodes.list_nodes()
assert "node1" not in p2.nodes.list_nodes()
def test_evict_clears_cache(self, registry):
"""Verifies that eviction deletes the cached provider from memory."""
username = "evictuser"
registry.user_service.create_user(username, "pass")
p1 = registry.get_provider(username)
assert username in registry._providers
registry.evict(username)
assert username not in registry._providers
# Calling get_provider again spawns a new instance
p2 = registry.get_provider(username)
assert p1 is not p2
def test_provider_hot_reload_on_external_change(self, registry):
"""Verifies that UserRegistry hot-reloads the provider if config.yaml is updated externally."""
username = "charlie"
registry.user_service.create_user(username, "password")
# Initial load (no nodes)
p1 = registry.get_provider(username)
assert len(p1.nodes.list_nodes()) == 0
# Resolve config.yaml file path
conf_file = os.path.join(registry.server_config_dir, "users", username, "config.yaml")
# Modify the config file physically on disk by appending a node
from connpy.configfile import configfile
cfg = configfile(conf=conf_file)
cfg._connections_add(id="testnode", host="8.8.8.8")
cfg._saveconfig(cfg.file)
# Artificially increase mtime to force reload
mtime = os.path.getmtime(conf_file)
os.utime(conf_file, (mtime + 5.0, mtime + 5.0))
# Fetch provider again
p2 = registry.get_provider(username)
# Verify it hot-reloaded and the new node is immediately visible
assert p1 is not p2
assert "testnode" in p2.nodes.list_nodes()
def test_provider_hot_reload_fails_on_corrupt_file_keeps_old_provider(self, registry):
"""Verifies that UserRegistry keeps serving the old provider if disk config is corrupt."""
username = "danny"
registry.user_service.create_user(username, "password")
# Initial load
p1 = registry.get_provider(username)
p1.nodes.add_node("nodeA", {"host": "2.2.2.2"})
assert "nodeA" in p1.nodes.list_nodes()
# Resolve config.yaml path
conf_file = os.path.join(registry.server_config_dir, "users", username, "config.yaml")
# Write corrupted content directly to config.yaml
with open(conf_file, "w") as f:
f.write("corrupt yaml content ::: invalid syntax :::")
# Artificially increase mtime to force reload attempt
mtime = os.path.getmtime(conf_file)
os.utime(conf_file, (mtime + 5.0, mtime + 5.0))
# Fetching provider again should fallback to old_provider instead of failing completely
p2 = registry.get_provider(username)
# Verify fallback
assert p1 is p2
assert "nodeA" in p2.nodes.list_nodes()
+217
View File
@@ -0,0 +1,217 @@
import os
import shutil
import pytest
import datetime
import jwt
import yaml
from pathlib import Path
from connpy.services.user_service import UserService
@pytest.fixture
def test_config_dir(tmp_path):
"""Creates a temporary config directory for testing user registry."""
config_dir = tmp_path / "conn_config"
config_dir.mkdir()
return config_dir
@pytest.fixture
def user_service(test_config_dir):
"""Initializes UserService pointing to a temporary directory."""
return UserService(str(test_config_dir))
class TestUserService:
def test_no_users(self, user_service):
"""Verifies that a new registry is empty by default."""
users = user_service.list_users()
assert users == []
def test_create_user_default(self, user_service):
"""Tests Mode A: fresh user config and key creation."""
username = "testuser"
res = user_service.create_user(username, "mypassword")
assert res["username"] == username
assert res["config_path"] is None
assert "created" in res
# Verify folder, config.yaml and .osk key are created
user_dir = os.path.join(user_service.users_dir, username)
assert os.path.isdir(user_dir)
assert os.path.isdir(os.path.join(user_dir, "plugins"))
assert os.path.isdir(os.path.join(user_dir, "ai_sessions"))
assert os.path.isfile(os.path.join(user_dir, "config.yaml"))
assert os.path.isfile(os.path.join(user_dir, ".osk"))
def test_create_user_custom_path(self, user_service, tmp_path):
"""Tests Mode B: using an existing valid config path."""
# Setup existing custom config directory
custom_dir = tmp_path / "custom_user_conn"
custom_dir.mkdir()
config_file = custom_dir / "config.yaml"
# Write basic config.yaml
config_data = {
"config": {"case": False, "idletime": 30, "fzf": False},
"connections": {},
"profiles": {}
}
with open(config_file, "w") as f:
yaml.dump(config_data, f)
res = user_service.create_user("fluzzi", "fluzzipass", config_path=str(custom_dir))
assert res["username"] == "fluzzi"
assert res["config_path"] == str(custom_dir)
# Verify no directory is created under the server's user folder
user_dir = os.path.join(user_service.users_dir, "fluzzi")
assert not os.path.exists(user_dir)
def test_create_user_custom_path_auto_init(self, user_service, tmp_path):
"""Ensures create_user automatically initializes a missing directory and default config.yaml."""
custom_dir = tmp_path / "new_custom_config"
# Test creation where the directory does not exist yet
res = user_service.create_user("john", "pass", config_path=str(custom_dir))
assert res["username"] == "john"
assert res["config_path"] == str(custom_dir)
# Verify custom path and subdirs/configs were created
assert os.path.isdir(custom_dir)
assert os.path.exists(os.path.join(custom_dir, "config.yaml"))
assert os.path.isdir(os.path.join(custom_dir, "plugins"))
assert os.path.isdir(os.path.join(custom_dir, "ai_sessions"))
def test_create_duplicate_user(self, user_service):
"""Ensures duplicate usernames are rejected."""
user_service.create_user("dupuser", "password")
with pytest.raises(ValueError, match="already exists"):
user_service.create_user("dupuser", "anotherpass")
def test_delete_user_default(self, user_service):
"""Tests Mode A: deleting a server-managed user cleans up directories."""
username = "deluser"
user_service.create_user(username, "password")
user_dir = os.path.join(user_service.users_dir, username)
assert os.path.isdir(user_dir)
user_service.delete_user(username)
# Directory should be cleaned up
assert not os.path.exists(user_dir)
# Registry should be updated
assert len(user_service.list_users()) == 0
def test_delete_user_custom_path(self, user_service, tmp_path):
"""Tests Mode B: deleting a custom-path user leaves files untouched."""
custom_dir = tmp_path / "fluzzi_custom"
custom_dir.mkdir()
config_file = custom_dir / "config.yaml"
with open(config_file, "w") as f:
yaml.dump({"config": {}, "connections": {}, "profiles": {}}, f)
username = "fluzzi"
user_service.create_user(username, "pass", config_path=str(custom_dir))
user_service.delete_user(username)
# Registry cleared
assert len(user_service.list_users()) == 0
# Files remain untouched
assert os.path.isdir(str(custom_dir))
assert os.path.isfile(str(config_file))
def test_list_users(self, user_service):
"""Tests listing all registered users with their metadata."""
user_service.create_user("user1", "pass1")
user_service.create_user("user2", "pass2")
users = user_service.list_users()
assert len(users) == 2
usernames = [u["username"] for u in users]
assert "user1" in usernames
assert "user2" in usernames
def test_get_user(self, user_service):
"""Tests retrieving a single user's configuration metadata."""
user_service.create_user("user1", "pass1")
user = user_service.get_user("user1")
assert user["username"] == "user1"
assert user["config_path"] is None
assert "created" in user
with pytest.raises(ValueError, match="not found"):
user_service.get_user("nonexistent")
def test_authenticate_valid(self, user_service):
"""Verifies successful authentication."""
user_service.create_user("john", "my-secure-password")
assert user_service.authenticate("john", "my-secure-password") is True
def test_authenticate_invalid(self, user_service):
"""Verifies unsuccessful authentication on incorrect or missing credentials."""
user_service.create_user("john", "my-secure-password")
assert user_service.authenticate("john", "wrong-password") is False
assert user_service.authenticate("nonexistent", "my-secure-password") is False
def test_jwt_roundtrip(self, user_service):
"""Tests generating a JWT token and verifying it back to the username."""
username = "jwttester"
user_service.create_user(username, "pass")
token = user_service.generate_jwt(username)
assert isinstance(token, str)
verified_user = user_service.verify_jwt(token)
assert verified_user == username
def test_jwt_expired(self, user_service):
"""Tests that expired JWT tokens are rejected and return None."""
username = "jwttester"
user_service.create_user(username, "pass")
# Manually generate an expired token by setting exp to the past
registry = user_service._load_registry()
expired_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(seconds=10)
payload = {
"sub": username,
"exp": expired_time
}
token = jwt.encode(payload, registry["jwt_secret"], algorithm="HS256")
if isinstance(token, bytes):
token = token.decode("utf-8")
verified_user = user_service.verify_jwt(token)
assert verified_user is None
def test_change_password(self, user_service):
"""Tests changing password for a user."""
username = "passchanger"
user_service.create_user(username, "oldpass")
# Old credentials authenticate
assert user_service.authenticate(username, "oldpass") is True
# Change password
user_service.change_password(username, "oldpass", "newpass")
# Old password fails, new password works
assert user_service.authenticate(username, "oldpass") is False
assert user_service.authenticate(username, "newpass") is True
# Change with invalid old password should fail
with pytest.raises(ValueError, match="Invalid credentials"):
user_service.change_password(username, "wrongold", "evennewer")
def test_admin_change_password(self, user_service):
"""Tests administrative password change (no old password required)."""
username = "adminpasschanger"
user_service.create_user(username, "oldpass")
# Admin changes password directly
user_service.admin_change_password(username, "newpass")
# Verify credentials
assert user_service.authenticate(username, "oldpass") is False
assert user_service.authenticate(username, "newpass") is True
+2
View File
@@ -20,3 +20,5 @@ httpx>=0.27.0
requests>=2.31.0
pytest>=8.0.0
pytest-mock>=3.12.0
bcrypt>=4.1.0
PyJWT>=2.8.0
+2
View File
@@ -49,6 +49,8 @@ install_requires =
aiohttp>=3.9.0,<4.0.0
httpx>=0.27.0,<1.0.0
requests>=2.31.0
bcrypt>=4.1.0
PyJWT>=2.8.0
[options.entry_points]
console_scripts =