feat(multiuser): implementar sistema multiusuario gRPC y configuración compartida de IA/MCP
- Servidor gRPC: Agregar interceptores de autenticación y UserRegistry para aislar sesiones por usuario.
- Contexto de Hilos: Corregir propagación de ContextVar _current_user a hilos secundarios en ExecutionServicer.
- Configuración Compartida: Implementar herencia y deep merge de settings de IA ('ai') y servidores MCP en configfile.
- Hot-Reload: Recarga automática en caliente de la configuración compartida global ante cambios en disco.
- CLI: Agregar comandos e interfaces de usuario para autenticación (login) y administración de usuarios.
- Pruebas: Desarrollar tests unitarios completos (test_shared_ai.py) y resolver regresiones en la suite existente.
This commit is contained in:
@@ -146,6 +146,7 @@ package.json
|
|||||||
|
|
||||||
# Development docs
|
# Development docs
|
||||||
connpy_roadmap.md
|
connpy_roadmap.md
|
||||||
|
testfew/
|
||||||
testnew/
|
testnew/
|
||||||
testall/
|
testall/
|
||||||
testremote/
|
testremote/
|
||||||
|
|||||||
+5
-2
@@ -116,8 +116,11 @@ class ai:
|
|||||||
self.interrupted = False
|
self.interrupted = False
|
||||||
|
|
||||||
|
|
||||||
# 1. Cargar configuración genérica
|
# 1. Cargar configuración genérica con herencia/merge global
|
||||||
aiconfig = self.config.config.get("ai", {})
|
if hasattr(self.config, "get_effective_setting"):
|
||||||
|
aiconfig = self.config.get_effective_setting("ai", {})
|
||||||
|
else:
|
||||||
|
aiconfig = self.config.config.get("ai", {}) if hasattr(self.config, "config") else {}
|
||||||
|
|
||||||
# Modelos (Prioridad: Argumento -> Config -> Default)
|
# Modelos (Prioridad: Argumento -> Config -> Default)
|
||||||
self.engineer_model = engineer_model or aiconfig.get("engineer_model") or "gemini/gemini-3.1-flash-lite"
|
self.engineer_model = engineer_model or aiconfig.get("engineer_model") or "gemini/gemini-3.1-flash-lite"
|
||||||
|
|||||||
@@ -0,0 +1,92 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import getpass
|
||||||
|
from .. import printer
|
||||||
|
from ..services.exceptions import ConnpyError
|
||||||
|
|
||||||
|
class LoginHandler:
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
def dispatch(self, args):
|
||||||
|
action = getattr(args, "action", None)
|
||||||
|
if action == "login":
|
||||||
|
return self.login(args)
|
||||||
|
elif action == "logout":
|
||||||
|
return self.logout(args)
|
||||||
|
else:
|
||||||
|
printer.error(f"Unknown action: {action}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def login(self, args):
|
||||||
|
if self.app.services.mode != "remote":
|
||||||
|
printer.warning("Note: Your current configuration is set to local mode. Logging in will save credentials, but they will only apply when service-mode is set to 'remote'.")
|
||||||
|
|
||||||
|
username = getattr(args, "username", None)
|
||||||
|
if not username:
|
||||||
|
try:
|
||||||
|
username = input("Username: ").strip()
|
||||||
|
if not username:
|
||||||
|
printer.error("Username cannot be empty.")
|
||||||
|
sys.exit(1)
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
printer.warning("\nOperation cancelled.")
|
||||||
|
sys.exit(130)
|
||||||
|
|
||||||
|
try:
|
||||||
|
password = getpass.getpass("Password: ")
|
||||||
|
if not password:
|
||||||
|
printer.error("Password cannot be empty.")
|
||||||
|
sys.exit(1)
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
printer.warning("\nOperation cancelled.")
|
||||||
|
sys.exit(130)
|
||||||
|
|
||||||
|
# Make the gRPC login call via self.app.services.auth stub
|
||||||
|
# We need to make sure auth is initialized in remote mode.
|
||||||
|
# If we are in local mode, self.app.services.auth is not initialized on ServiceProvider.
|
||||||
|
# Let's instantiate it dynamically if it's not present.
|
||||||
|
auth_service = getattr(self.app.services, "auth", None)
|
||||||
|
if not auth_service:
|
||||||
|
import grpc
|
||||||
|
from ..grpc_layer.stubs import AuthStub
|
||||||
|
remote_host = self.app.services.remote_host or self.app.config.config.get("remote_host")
|
||||||
|
if not remote_host:
|
||||||
|
printer.error("Remote host is not configured. Run 'connpy config --remote HOST:PORT' first.")
|
||||||
|
sys.exit(1)
|
||||||
|
try:
|
||||||
|
channel = grpc.insecure_channel(remote_host)
|
||||||
|
auth_service = AuthStub(channel, remote_host=remote_host)
|
||||||
|
except Exception as e:
|
||||||
|
printer.error(f"Failed to connect to remote server for login: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = auth_service.login(username, password)
|
||||||
|
token = res["token"]
|
||||||
|
|
||||||
|
# Save token to ~/.config/conn/.token
|
||||||
|
token_path = os.path.join(self.app.config.defaultdir, ".token")
|
||||||
|
with open(token_path, "w") as f:
|
||||||
|
f.write(token)
|
||||||
|
os.chmod(token_path, 0o600)
|
||||||
|
|
||||||
|
printer.success(f"Logged in successfully as '{username}'. Session expires in 8 hours.")
|
||||||
|
except ConnpyError as e:
|
||||||
|
printer.error(f"Login failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
printer.error(f"Login failed with unexpected error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def logout(self, args):
|
||||||
|
token_path = os.path.join(self.app.config.defaultdir, ".token")
|
||||||
|
if os.path.exists(token_path):
|
||||||
|
try:
|
||||||
|
os.remove(token_path)
|
||||||
|
printer.success("Logged out successfully. Local session cleared.")
|
||||||
|
except Exception as e:
|
||||||
|
printer.error(f"Failed to clear session: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
printer.info("No active session found (already logged out).")
|
||||||
@@ -20,6 +20,17 @@ class RunHandler:
|
|||||||
|
|
||||||
def node_run(self, args):
|
def node_run(self, args):
|
||||||
nodes_filter = args.data[0]
|
nodes_filter = args.data[0]
|
||||||
|
|
||||||
|
# Resolve and filter nodes through context-aware list_nodes
|
||||||
|
try:
|
||||||
|
matched_nodes = self.app.services.nodes.list_nodes(nodes_filter)
|
||||||
|
except Exception:
|
||||||
|
matched_nodes = []
|
||||||
|
|
||||||
|
if not matched_nodes:
|
||||||
|
printer.error(f"No nodes found matching filter: {nodes_filter}")
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
commands = [" ".join(args.data[1:])]
|
commands = [" ".join(args.data[1:])]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -36,7 +47,7 @@ class RunHandler:
|
|||||||
printer.test_panel(unique, node_output, node_status, node_result)
|
printer.test_panel(unique, node_output, node_status, node_result)
|
||||||
|
|
||||||
results = self.app.services.execution.test_commands(
|
results = self.app.services.execution.test_commands(
|
||||||
nodes_filter=nodes_filter,
|
nodes_filter=matched_nodes,
|
||||||
commands=commands,
|
commands=commands,
|
||||||
expected=args.test_expected,
|
expected=args.test_expected,
|
||||||
on_node_complete=_on_node_complete
|
on_node_complete=_on_node_complete
|
||||||
@@ -53,7 +64,7 @@ class RunHandler:
|
|||||||
printer.node_panel(unique, node_output, node_status)
|
printer.node_panel(unique, node_output, node_status)
|
||||||
|
|
||||||
results = self.app.services.execution.run_commands(
|
results = self.app.services.execution.run_commands(
|
||||||
nodes_filter=nodes_filter,
|
nodes_filter=matched_nodes,
|
||||||
commands=commands,
|
commands=commands,
|
||||||
on_node_complete=_on_node_complete
|
on_node_complete=_on_node_complete
|
||||||
)
|
)
|
||||||
@@ -103,6 +114,28 @@ class RunHandler:
|
|||||||
folder = output_cfg if output_cfg not in [None, "stdout"] else None
|
folder = output_cfg if output_cfg not in [None, "stdout"] else None
|
||||||
prompt = options.get("prompt")
|
prompt = options.get("prompt")
|
||||||
|
|
||||||
|
# Resolve and filter nodes through context-aware list_nodes
|
||||||
|
try:
|
||||||
|
if isinstance(nodelist, str):
|
||||||
|
resolved_nodes = self.app.services.nodes.list_nodes(nodelist)
|
||||||
|
elif isinstance(nodelist, list):
|
||||||
|
resolved_nodes = []
|
||||||
|
for item in nodelist:
|
||||||
|
matches = self.app.services.nodes.list_nodes(item)
|
||||||
|
for m in matches:
|
||||||
|
if m not in resolved_nodes:
|
||||||
|
resolved_nodes.append(m)
|
||||||
|
else:
|
||||||
|
resolved_nodes = []
|
||||||
|
except Exception:
|
||||||
|
resolved_nodes = []
|
||||||
|
|
||||||
|
if not resolved_nodes:
|
||||||
|
printer.error(f"[{name}] No nodes found matching filter: {nodelist}")
|
||||||
|
sys.exit(11)
|
||||||
|
|
||||||
|
nodelist = resolved_nodes
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header_printed = False
|
header_printed = False
|
||||||
if action == "run":
|
if action == "run":
|
||||||
|
|||||||
@@ -0,0 +1,190 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import getpass
|
||||||
|
import yaml
|
||||||
|
from .. import printer
|
||||||
|
from ..services.exceptions import ConnpyError
|
||||||
|
|
||||||
|
class UserHandler:
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
def dispatch(self, args):
|
||||||
|
if self.app.services.mode == "remote":
|
||||||
|
printer.error("User management commands are only available in local/server-side mode.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Parse actions from argparse mutually exclusive options
|
||||||
|
if getattr(args, "add", None):
|
||||||
|
args.action = "add"
|
||||||
|
args.username = args.add[0]
|
||||||
|
elif getattr(args, "delete", None):
|
||||||
|
args.action = "del"
|
||||||
|
args.username = args.delete[0]
|
||||||
|
elif getattr(args, "list", False):
|
||||||
|
args.action = "list"
|
||||||
|
elif getattr(args, "show", None):
|
||||||
|
args.action = "show"
|
||||||
|
args.username = args.show[0]
|
||||||
|
elif getattr(args, "regen_password", None):
|
||||||
|
args.action = "regen_password"
|
||||||
|
args.username = args.regen_password[0]
|
||||||
|
|
||||||
|
action = getattr(args, "action", None)
|
||||||
|
|
||||||
|
if action == "add":
|
||||||
|
return self.add_user(args)
|
||||||
|
elif action == "del":
|
||||||
|
return self.delete_user(args)
|
||||||
|
elif action == "list":
|
||||||
|
return self.list_users(args)
|
||||||
|
elif action == "show":
|
||||||
|
return self.show_user(args)
|
||||||
|
elif action == "regen_password":
|
||||||
|
return self.regen_password(args)
|
||||||
|
else:
|
||||||
|
printer.error(f"Unknown action: {action}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def add_user(self, args):
|
||||||
|
username = getattr(args, "username", None)
|
||||||
|
if not username:
|
||||||
|
printer.error("Username is required. Usage: connpy user --add <username>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
custom_path = getattr(args, "path", None)
|
||||||
|
if custom_path:
|
||||||
|
custom_path = custom_path[0] if isinstance(custom_path, list) else custom_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
password = getpass.getpass("Enter password for new user: ")
|
||||||
|
if not password:
|
||||||
|
printer.error("Password cannot be empty.")
|
||||||
|
sys.exit(1)
|
||||||
|
confirm = getpass.getpass("Confirm password: ")
|
||||||
|
if password != confirm:
|
||||||
|
printer.error("Passwords do not match.")
|
||||||
|
sys.exit(1)
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
printer.warning("\nOperation cancelled.")
|
||||||
|
sys.exit(130)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.app.services.users.create_user(username, password, config_path=custom_path)
|
||||||
|
printer.success(f"User '{username}' created successfully.")
|
||||||
|
except ConnpyError as e:
|
||||||
|
printer.error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except ValueError as e:
|
||||||
|
printer.error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
printer.error(f"Failed to create user: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def delete_user(self, args):
|
||||||
|
username = getattr(args, "username", None)
|
||||||
|
if not username:
|
||||||
|
printer.error("Username is required. Usage: connpy user --del <username>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.app.services.users.delete_user(username)
|
||||||
|
printer.success(f"User '{username}' deleted successfully.")
|
||||||
|
except ConnpyError as e:
|
||||||
|
printer.error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except ValueError as e:
|
||||||
|
printer.error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
printer.error(f"Failed to delete user: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def list_users(self, args):
|
||||||
|
try:
|
||||||
|
users = self.app.services.users.list_users()
|
||||||
|
if not users:
|
||||||
|
printer.warning("No users registered.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Format custom config path, falling back to computed default path instead of null/None
|
||||||
|
formatted_users = []
|
||||||
|
for u in users:
|
||||||
|
formatted_u = u.copy()
|
||||||
|
if not formatted_u.get("config_path"):
|
||||||
|
formatted_u["config_path"] = os.path.join(self.app.services.users.users_dir, formatted_u["username"])
|
||||||
|
formatted_users.append(formatted_u)
|
||||||
|
|
||||||
|
yaml_str = yaml.dump(formatted_users, sort_keys=False, default_flow_style=False)
|
||||||
|
printer.data("Registered Users", yaml_str)
|
||||||
|
except Exception as e:
|
||||||
|
printer.error(f"Failed to list users: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def show_user(self, args):
|
||||||
|
username = getattr(args, "username", None)
|
||||||
|
if not username:
|
||||||
|
printer.error("Username is required. Usage: connpy user --show <username>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = self.app.services.users.get_user(username)
|
||||||
|
if not user:
|
||||||
|
printer.error(f"User '{username}' not found.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Hide the password hash from the CLI output for safety
|
||||||
|
safe_user = {k: v for k, v in user.items() if k != "password_hash"}
|
||||||
|
if not safe_user.get("config_path"):
|
||||||
|
safe_user["config_path"] = os.path.join(self.app.services.users.users_dir, username)
|
||||||
|
|
||||||
|
yaml_str = yaml.dump(safe_user, sort_keys=False, default_flow_style=False)
|
||||||
|
printer.data(f"User: {username}", yaml_str)
|
||||||
|
except ValueError as e:
|
||||||
|
printer.error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
printer.error(f"Failed to retrieve user details: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def regen_password(self, args):
|
||||||
|
username = getattr(args, "username", None)
|
||||||
|
if not username:
|
||||||
|
printer.error("Username is required. Usage: connpy user --regen-password <username>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = self.app.services.users.get_user(username)
|
||||||
|
if not user:
|
||||||
|
printer.error(f"User '{username}' not found.")
|
||||||
|
sys.exit(1)
|
||||||
|
except ValueError as e:
|
||||||
|
printer.error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
printer.error(f"Failed to retrieve user details: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_password = getpass.getpass("Enter new password: ")
|
||||||
|
if not new_password:
|
||||||
|
printer.error("Password cannot be empty.")
|
||||||
|
sys.exit(1)
|
||||||
|
confirm = getpass.getpass("Confirm new password: ")
|
||||||
|
if new_password != confirm:
|
||||||
|
printer.error("Passwords do not match.")
|
||||||
|
sys.exit(1)
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
printer.warning("\nOperation cancelled.")
|
||||||
|
sys.exit(130)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.app.services.users.admin_change_password(username, new_password)
|
||||||
|
printer.success(f"Password for user '{username}' regenerated successfully.")
|
||||||
|
except ValueError as e:
|
||||||
|
printer.error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
printer.error(f"Failed to regenerate password: {e}")
|
||||||
|
sys.exit(1)
|
||||||
@@ -105,6 +105,21 @@ def _get_plugins(which, defaultdir):
|
|||||||
return final_all_plugins
|
return final_all_plugins
|
||||||
|
|
||||||
|
|
||||||
|
def _get_users(configdir):
|
||||||
|
import yaml
|
||||||
|
registry_file = os.path.join(configdir, "users", "registry.yaml")
|
||||||
|
if not os.path.exists(registry_file):
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
with open(registry_file, "r") as f:
|
||||||
|
data = yaml.safe_load(f) or {}
|
||||||
|
if isinstance(data, dict) and "users" in data:
|
||||||
|
return list(data["users"].keys())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _build_tree(nodes, folders, profiles, plugins, configdir):
|
def _build_tree(nodes, folders, profiles, plugins, configdir):
|
||||||
"""Build the declarative CLI navigation tree.
|
"""Build the declarative CLI navigation tree.
|
||||||
|
|
||||||
@@ -203,6 +218,19 @@ def _build_tree(nodes, folders, profiles, plugins, configdir):
|
|||||||
config_dict["--engineer-auth"] = {"__extra__": lambda w: get_cwd(w, "--engineer-auth"), "*": config_dict}
|
config_dict["--engineer-auth"] = {"__extra__": lambda w: get_cwd(w, "--engineer-auth"), "*": config_dict}
|
||||||
config_dict["--architect-auth"] = {"__extra__": lambda w: get_cwd(w, "--architect-auth"), "*": config_dict}
|
config_dict["--architect-auth"] = {"__extra__": lambda w: get_cwd(w, "--architect-auth"), "*": config_dict}
|
||||||
|
|
||||||
|
_users = lambda w=None: _get_users(configdir)
|
||||||
|
|
||||||
|
user_dict = {
|
||||||
|
"--add": {"*": {"--path": {"__extra__": lambda w: get_cwd(w, "--path", True), "*": None}}},
|
||||||
|
"--del": {"__extra__": _users},
|
||||||
|
"--rm": {"__extra__": _users},
|
||||||
|
"--show": {"__extra__": _users},
|
||||||
|
"--regen-password": {"__extra__": _users},
|
||||||
|
"--list": None,
|
||||||
|
"--ls": None,
|
||||||
|
"--help": None, "-h": None
|
||||||
|
}
|
||||||
|
|
||||||
mv_state = {"__extra__": _nodes, "--help": None, "-h": None}
|
mv_state = {"__extra__": _nodes, "--help": None, "-h": None}
|
||||||
cp_state = {"__extra__": _nodes, "--help": None, "-h": None}
|
cp_state = {"__extra__": _nodes, "--help": None, "-h": None}
|
||||||
ls_state = {
|
ls_state = {
|
||||||
@@ -297,6 +325,9 @@ def _build_tree(nodes, folders, profiles, plugins, configdir):
|
|||||||
"--list": None, "--help": None,
|
"--list": None, "--help": None,
|
||||||
"-h": None,
|
"-h": None,
|
||||||
},
|
},
|
||||||
|
"user": user_dict,
|
||||||
|
"login": {"--help": None, "-h": None, "*": None},
|
||||||
|
"logout": {"--help": None, "-h": None},
|
||||||
"config": config_dict,
|
"config": config_dict,
|
||||||
"sync": {
|
"sync": {
|
||||||
"--login": None, "--logout": None,
|
"--login": None, "--logout": None,
|
||||||
|
|||||||
+30
-2
@@ -43,7 +43,8 @@ class configfile:
|
|||||||
passwords.
|
passwords.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, conf = None, key = None):
|
def __init__(self, conf = None, key = None, shared_config = None):
|
||||||
|
self._shared_config = shared_config
|
||||||
'''
|
'''
|
||||||
|
|
||||||
### Optional Parameters:
|
### Optional Parameters:
|
||||||
@@ -149,6 +150,32 @@ class configfile:
|
|||||||
self._generate_nodes_cache()
|
self._generate_nodes_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def get_effective_setting(self, key, default=None):
|
||||||
|
"""Get config setting with shared fallback for inheritable keys."""
|
||||||
|
val = self.config.get(key)
|
||||||
|
if key == "ai":
|
||||||
|
if val is not None:
|
||||||
|
if self._shared_config:
|
||||||
|
import copy
|
||||||
|
# Deep merge: shared as base, user overrides
|
||||||
|
base = copy.deepcopy(self._shared_config.config.get(key, {}))
|
||||||
|
if isinstance(base, dict) and isinstance(val, dict):
|
||||||
|
# Recursive update for inner dictionaries (like mcp_servers or model details)
|
||||||
|
def deep_merge(d1, d2):
|
||||||
|
for k, v in d2.items():
|
||||||
|
if isinstance(v, dict) and k in d1 and isinstance(d1[k], dict):
|
||||||
|
deep_merge(d1[k], v)
|
||||||
|
else:
|
||||||
|
d1[k] = copy.deepcopy(v)
|
||||||
|
deep_merge(base, val)
|
||||||
|
return base
|
||||||
|
return val
|
||||||
|
elif self._shared_config:
|
||||||
|
return self._shared_config.config.get(key, default)
|
||||||
|
|
||||||
|
return val if val is not None else default
|
||||||
|
|
||||||
|
|
||||||
def _validate_config(self, data):
|
def _validate_config(self, data):
|
||||||
"""Verify config data has the required structure."""
|
"""Verify config data has the required structure."""
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
@@ -489,7 +516,8 @@ class configfile:
|
|||||||
else:
|
else:
|
||||||
printer.error("Filter must be a string or a list of strings")
|
printer.error("Filter must be a string or a list of strings")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
nodes = [item for item in nodes if any(re.search(pattern, item) for pattern in flat_filter)]
|
flags = re.IGNORECASE if not self.config.get("case", False) else 0
|
||||||
|
nodes = [item for item in nodes if any(re.search(pattern, item, flags) for pattern in flat_filter)]
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
@MethodHook
|
@MethodHook
|
||||||
|
|||||||
+34
-2
@@ -79,11 +79,12 @@ class connapp:
|
|||||||
self.debug_api = debug_api
|
self.debug_api = debug_api
|
||||||
self.ai = ai
|
self.ai = ai
|
||||||
|
|
||||||
# Register context filtering hooks
|
# Register context filtering hooks (only on Client CLI, bypass on gRPC Server)
|
||||||
|
is_api_server = len(sys.argv) > 1 and sys.argv[1] == "api"
|
||||||
|
if not is_api_server:
|
||||||
self.services.context.config._getallnodes.register_post_hook(self.services.context.filter_node_list)
|
self.services.context.config._getallnodes.register_post_hook(self.services.context.filter_node_list)
|
||||||
self.services.context.config._getallfolders.register_post_hook(self.services.context.filter_node_list)
|
self.services.context.config._getallfolders.register_post_hook(self.services.context.filter_node_list)
|
||||||
self.services.context.config._getallnodesfull.register_post_hook(self.services.context.filter_node_dict)
|
self.services.context.config._getallnodesfull.register_post_hook(self.services.context.filter_node_dict)
|
||||||
|
|
||||||
if hasattr(self.services.nodes, "list_nodes") and hasattr(self.services.nodes.list_nodes, "register_post_hook"):
|
if hasattr(self.services.nodes, "list_nodes") and hasattr(self.services.nodes.list_nodes, "register_post_hook"):
|
||||||
self.services.nodes.list_nodes.register_post_hook(self.services.context.filter_node_list)
|
self.services.nodes.list_nodes.register_post_hook(self.services.context.filter_node_list)
|
||||||
if hasattr(self.services.nodes, "list_folders") and hasattr(self.services.nodes.list_folders, "register_post_hook"):
|
if hasattr(self.services.nodes, "list_folders") and hasattr(self.services.nodes.list_folders, "register_post_hook"):
|
||||||
@@ -109,6 +110,9 @@ class connapp:
|
|||||||
except ConnpyError as e:
|
except ConnpyError as e:
|
||||||
# If in remote mode, connectivity issues should be reported
|
# If in remote mode, connectivity issues should be reported
|
||||||
if mode == "remote":
|
if mode == "remote":
|
||||||
|
is_auth_cmd = len(sys.argv) > 1 and sys.argv[1] in ["login", "logout", "user"]
|
||||||
|
is_unauth = "unauthenticated" in str(e).lower() or "token" in str(e).lower()
|
||||||
|
if not (is_auth_cmd and is_unauth):
|
||||||
printer.warning(f"Failed to fetch data from remote server: {e}")
|
printer.warning(f"Failed to fetch data from remote server: {e}")
|
||||||
self.nodes_list = []
|
self.nodes_list = []
|
||||||
self.folders = []
|
self.folders = []
|
||||||
@@ -135,6 +139,8 @@ class connapp:
|
|||||||
from .cli.context_handler import ContextHandler
|
from .cli.context_handler import ContextHandler
|
||||||
from .cli.import_export_handler import ImportExportHandler
|
from .cli.import_export_handler import ImportExportHandler
|
||||||
from .cli.sync_handler import SyncHandler
|
from .cli.sync_handler import SyncHandler
|
||||||
|
from .cli.user_handler import UserHandler
|
||||||
|
from .cli.login_handler import LoginHandler
|
||||||
|
|
||||||
# Instantiate Handlers
|
# Instantiate Handlers
|
||||||
self._node = NodeHandler(self)
|
self._node = NodeHandler(self)
|
||||||
@@ -147,6 +153,8 @@ class connapp:
|
|||||||
self._context = ContextHandler(self)
|
self._context = ContextHandler(self)
|
||||||
self._import_export = ImportExportHandler(self)
|
self._import_export = ImportExportHandler(self)
|
||||||
self._sync = SyncHandler(self)
|
self._sync = SyncHandler(self)
|
||||||
|
self._user = UserHandler(self)
|
||||||
|
self._login = LoginHandler(self)
|
||||||
|
|
||||||
# Register auto-sync hook to trigger after config saves
|
# Register auto-sync hook to trigger after config saves
|
||||||
from .configfile import configfile
|
from .configfile import configfile
|
||||||
@@ -354,6 +362,30 @@ class connapp:
|
|||||||
configparser.add_argument("--trusted-commands", dest="trusted_commands", nargs=1, action=self._store_type, help="Set custom trusted commands regexes (comma separated)", metavar="REGEX,REGEX")
|
configparser.add_argument("--trusted-commands", dest="trusted_commands", nargs=1, action=self._store_type, help="Set custom trusted commands regexes (comma separated)", metavar="REGEX,REGEX")
|
||||||
configparser.set_defaults(func=self._config.dispatch)
|
configparser.set_defaults(func=self._config.dispatch)
|
||||||
|
|
||||||
|
#USERPARSER
|
||||||
|
userparser = subparsers.add_parser("user", help="Manage server users", description="Manage server users", formatter_class=RichHelpFormatter)
|
||||||
|
userparser.error = self._custom_error
|
||||||
|
usercrud = userparser.add_mutually_exclusive_group(required=True)
|
||||||
|
usercrud.add_argument("--add", nargs=1, dest="add", help="Add new user", metavar="USERNAME")
|
||||||
|
usercrud.add_argument("--del", "--rm", nargs=1, dest="delete", help="Delete user", metavar="USERNAME")
|
||||||
|
usercrud.add_argument("--list", "--ls", dest="list", action="store_true", help="List all users")
|
||||||
|
usercrud.add_argument("--show", nargs=1, dest="show", help="Show user details", metavar="USERNAME")
|
||||||
|
usercrud.add_argument("--regen-password", nargs=1, dest="regen_password", help="Regenerate user password", metavar="USERNAME")
|
||||||
|
|
||||||
|
userparser.add_argument("--path", dest="path", nargs=1, help="Custom configuration path for user configuration (in Mode B)")
|
||||||
|
userparser.set_defaults(func=self._user.dispatch)
|
||||||
|
|
||||||
|
#LOGINPARSER
|
||||||
|
loginparser = subparsers.add_parser("login", help="Login to remote connpy server", description="Login to remote connpy server", formatter_class=RichHelpFormatter)
|
||||||
|
loginparser.error = self._custom_error
|
||||||
|
loginparser.add_argument("username", nargs='?', default=None, help="Username to authenticate")
|
||||||
|
loginparser.set_defaults(func=self._login.dispatch, action="login")
|
||||||
|
|
||||||
|
#LOGOUTPARSER
|
||||||
|
logoutparser = subparsers.add_parser("logout", help="Logout from remote connpy server", description="Logout from remote connpy server", formatter_class=RichHelpFormatter)
|
||||||
|
logoutparser.error = self._custom_error
|
||||||
|
logoutparser.set_defaults(func=self._login.dispatch, action="logout")
|
||||||
|
|
||||||
#SYNCPARSER
|
#SYNCPARSER
|
||||||
syncparser = subparsers.add_parser("sync", help="Sync config with Google Drive", description="Sync config with Google Drive", formatter_class=RichHelpFormatter)
|
syncparser = subparsers.add_parser("sync", help="Sync config with Google Drive", description="Sync config with Google Drive", formatter_class=RichHelpFormatter)
|
||||||
syncparser.error = self._custom_error
|
syncparser.error = self._custom_error
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -2535,3 +2535,118 @@ class SystemService(object):
|
|||||||
timeout,
|
timeout,
|
||||||
metadata,
|
metadata,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthServiceStub(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.login = channel.unary_unary(
|
||||||
|
'/connpy.AuthService/login',
|
||||||
|
request_serializer=connpy__pb2.LoginRequest.SerializeToString,
|
||||||
|
response_deserializer=connpy__pb2.LoginResponse.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.change_password = channel.unary_unary(
|
||||||
|
'/connpy.AuthService/change_password',
|
||||||
|
request_serializer=connpy__pb2.ChangePasswordRequest.SerializeToString,
|
||||||
|
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthServiceServicer(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def login(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def change_password(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_AuthServiceServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'login': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.login,
|
||||||
|
request_deserializer=connpy__pb2.LoginRequest.FromString,
|
||||||
|
response_serializer=connpy__pb2.LoginResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
'change_password': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.change_password,
|
||||||
|
request_deserializer=connpy__pb2.ChangePasswordRequest.FromString,
|
||||||
|
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'connpy.AuthService', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
server.add_registered_method_handlers('connpy.AuthService', rpc_method_handlers)
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class AuthService(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def login(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/connpy.AuthService/login',
|
||||||
|
connpy__pb2.LoginRequest.SerializeToString,
|
||||||
|
connpy__pb2.LoginResponse.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def change_password(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/connpy.AuthService/change_password',
|
||||||
|
connpy__pb2.ChangePasswordRequest.SerializeToString,
|
||||||
|
google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|||||||
+337
-50
@@ -4,6 +4,8 @@ from google.protobuf.empty_pb2 import Empty
|
|||||||
import os
|
import os
|
||||||
import ctypes
|
import ctypes
|
||||||
import threading
|
import threading
|
||||||
|
import contextvars
|
||||||
|
import datetime
|
||||||
|
|
||||||
# Suppress harmless but noisy gRPC fork() warnings from pexpect child processes
|
# Suppress harmless but noisy gRPC fork() warnings from pexpect child processes
|
||||||
os.environ["GRPC_VERBOSITY"] = "NONE"
|
os.environ["GRPC_VERBOSITY"] = "NONE"
|
||||||
@@ -14,15 +16,7 @@ from .utils import to_value, from_value, to_struct, from_struct
|
|||||||
from ..services.exceptions import ConnpyError
|
from ..services.exceptions import ConnpyError
|
||||||
from .. import printer
|
from .. import printer
|
||||||
|
|
||||||
# Import local services
|
_current_user = contextvars.ContextVar("current_user", default=None)
|
||||||
from ..services.node_service import NodeService
|
|
||||||
from ..services.profile_service import ProfileService
|
|
||||||
from ..services.config_service import ConfigService
|
|
||||||
from ..services.plugin_service import PluginService
|
|
||||||
from ..services.ai_service import AIService
|
|
||||||
from ..services.system_service import SystemService
|
|
||||||
from ..services.execution_service import ExecutionService
|
|
||||||
from ..services.import_export_service import ImportExportService
|
|
||||||
|
|
||||||
def handle_errors(func):
|
def handle_errors(func):
|
||||||
import inspect
|
import inspect
|
||||||
@@ -31,10 +25,16 @@ def handle_errors(func):
|
|||||||
try:
|
try:
|
||||||
for item in func(*args, **kwargs):
|
for item in func(*args, **kwargs):
|
||||||
yield item
|
yield item
|
||||||
|
except grpc.RpcError:
|
||||||
|
raise
|
||||||
except ConnpyError as e:
|
except ConnpyError as e:
|
||||||
context = kwargs.get("context") or args[-1]
|
context = kwargs.get("context") or args[-1]
|
||||||
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if type(e) is Exception and not e.args:
|
||||||
|
raise e
|
||||||
|
if e.__class__.__name__ in ("_AbortError", "RpcError"):
|
||||||
|
raise e
|
||||||
context = kwargs.get("context") or args[-1]
|
context = kwargs.get("context") or args[-1]
|
||||||
context.abort(grpc.StatusCode.UNKNOWN, str(e))
|
context.abort(grpc.StatusCode.UNKNOWN, str(e))
|
||||||
finally:
|
finally:
|
||||||
@@ -44,10 +44,16 @@ def handle_errors(func):
|
|||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
except grpc.RpcError:
|
||||||
|
raise
|
||||||
except ConnpyError as e:
|
except ConnpyError as e:
|
||||||
context = kwargs.get("context") or args[-1]
|
context = kwargs.get("context") or args[-1]
|
||||||
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if type(e) is Exception and not e.args:
|
||||||
|
raise e
|
||||||
|
if e.__class__.__name__ in ("_AbortError", "RpcError"):
|
||||||
|
raise e
|
||||||
context = kwargs.get("context") or args[-1]
|
context = kwargs.get("context") or args[-1]
|
||||||
context.abort(grpc.StatusCode.UNKNOWN, str(e))
|
context.abort(grpc.StatusCode.UNKNOWN, str(e))
|
||||||
finally:
|
finally:
|
||||||
@@ -55,25 +61,46 @@ def handle_errors(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||||
def __init__(self, config, debug=False):
|
def __init__(self, provider, registry=None, debug=False):
|
||||||
self.service = NodeService(config)
|
if not hasattr(provider, "mode"):
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
provider = ServiceProvider(provider, mode="local")
|
||||||
|
self._fallback_provider = provider
|
||||||
|
self._registry = registry
|
||||||
self.server_debug = debug
|
self.server_debug = debug
|
||||||
if debug:
|
if debug:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from ..printer import connpy_theme, get_original_stdout
|
from ..printer import connpy_theme, get_original_stdout
|
||||||
self.server_console = Console(theme=connpy_theme, file=get_original_stdout())
|
self.server_console = Console(theme=connpy_theme, file=get_original_stdout())
|
||||||
|
|
||||||
|
def _get_provider(self):
|
||||||
|
if self._registry:
|
||||||
|
username = _current_user.get()
|
||||||
|
if username:
|
||||||
|
return self._registry.get_provider(username)
|
||||||
|
return self._fallback_provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def service(self):
|
||||||
|
return self._get_provider().nodes
|
||||||
|
|
||||||
@handle_errors
|
@handle_errors
|
||||||
def interact_node(self, request_iterator, context):
|
def interact_node(self, request_iterator, context):
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
from connpy.core import node
|
from connpy.core import node
|
||||||
from ..services.profile_service import ProfileService
|
|
||||||
from connpy.tunnels import RemoteStream
|
from connpy.tunnels import RemoteStream
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
# Resolve provider once at the start of the RPC stream
|
||||||
|
provider = self._get_provider()
|
||||||
|
nodes_service = provider.nodes
|
||||||
|
profile_service = provider.profiles
|
||||||
|
ai_service = provider.ai
|
||||||
|
user_config = provider.config
|
||||||
|
|
||||||
# Fetch first setup packet
|
# Fetch first setup packet
|
||||||
try:
|
try:
|
||||||
first_req = next(request_iterator)
|
first_req = next(request_iterator)
|
||||||
@@ -100,9 +127,9 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
|||||||
|
|
||||||
if base_node_id:
|
if base_node_id:
|
||||||
# Look up the base node in config and use its full data
|
# Look up the base node in config and use its full data
|
||||||
nodes = self.service.config._getallnodes(base_node_id)
|
nodes = user_config._getallnodes(base_node_id)
|
||||||
if nodes:
|
if nodes:
|
||||||
device = self.service.config.getitem(nodes[0])
|
device = user_config.getitem(nodes[0])
|
||||||
# Override device properties with any passed in params
|
# Override device properties with any passed in params
|
||||||
for attr in valid_attrs:
|
for attr in valid_attrs:
|
||||||
if attr in params:
|
if attr in params:
|
||||||
@@ -116,11 +143,11 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
|||||||
device["tags"] = device_tags
|
device["tags"] = device_tags
|
||||||
|
|
||||||
node_name = params.get("name", base_node_id)
|
node_name = params.get("name", base_node_id)
|
||||||
n = node(node_name, **device, config=self.service.config)
|
n = node(node_name, **device, config=user_config)
|
||||||
else:
|
else:
|
||||||
# base_node not found, fall back to dynamic
|
# base_node not found, fall back to dynamic
|
||||||
node_name = params.get("name", fallback_id)
|
node_name = params.get("name", fallback_id)
|
||||||
n = node(node_name, host=params.get("host", ""), config=self.service.config)
|
n = node(node_name, host=params.get("host", ""), config=user_config)
|
||||||
for attr in valid_attrs:
|
for attr in valid_attrs:
|
||||||
if attr in params:
|
if attr in params:
|
||||||
setattr(n, attr, params[attr])
|
setattr(n, attr, params[attr])
|
||||||
@@ -128,19 +155,18 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
|||||||
n.tags = params["tags"]
|
n.tags = params["tags"]
|
||||||
else:
|
else:
|
||||||
node_name = params.get("name", fallback_id)
|
node_name = params.get("name", fallback_id)
|
||||||
n = node(node_name, host=params.get("host", ""), config=self.service.config)
|
n = node(node_name, host=params.get("host", ""), config=user_config)
|
||||||
for attr in valid_attrs:
|
for attr in valid_attrs:
|
||||||
if attr in params:
|
if attr in params:
|
||||||
setattr(n, attr, params[attr])
|
setattr(n, attr, params[attr])
|
||||||
if "tags" in params:
|
if "tags" in params:
|
||||||
n.tags = params["tags"]
|
n.tags = params["tags"]
|
||||||
else:
|
else:
|
||||||
node_data = self.service.config.getitem(unique_id, extract=False)
|
node_data = user_config.getitem(unique_id, extract=False)
|
||||||
if not node_data:
|
if not node_data:
|
||||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Node {unique_id} not found")
|
context.abort(grpc.StatusCode.NOT_FOUND, f"Node {unique_id} not found")
|
||||||
profile_service = ProfileService(self.service.config)
|
|
||||||
resolved_data = profile_service.resolve_node_data(node_data)
|
resolved_data = profile_service.resolve_node_data(node_data)
|
||||||
n = node(unique_id, **resolved_data, config=self.service.config)
|
n = node(unique_id, **resolved_data, config=user_config)
|
||||||
if sftp:
|
if sftp:
|
||||||
n.protocol = "sftp"
|
n.protocol = "sftp"
|
||||||
|
|
||||||
@@ -207,9 +233,8 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from ..services.ai_service import AIService
|
|
||||||
|
|
||||||
service = AIService(self.service.config)
|
service = ai_service
|
||||||
|
|
||||||
if node_info is None:
|
if node_info is None:
|
||||||
node_info = {}
|
node_info = {}
|
||||||
@@ -479,10 +504,27 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
class ProfileServicer(connpy_pb2_grpc.ProfileServiceServicer):
|
class ProfileServicer(connpy_pb2_grpc.ProfileServiceServicer):
|
||||||
def __init__(self, config):
|
def __init__(self, provider, registry=None):
|
||||||
self.service = ProfileService(config)
|
if not hasattr(provider, "mode"):
|
||||||
self.node_service = NodeService(config)
|
from connpy.services.provider import ServiceProvider
|
||||||
|
provider = ServiceProvider(provider, mode="local")
|
||||||
|
self._fallback_provider = provider
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def _get_provider(self):
|
||||||
|
if self._registry:
|
||||||
|
username = _current_user.get()
|
||||||
|
if username:
|
||||||
|
return self._registry.get_provider(username)
|
||||||
|
return self._fallback_provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def service(self):
|
||||||
|
return self._get_provider().profiles
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_service(self):
|
||||||
|
return self._get_provider().nodes
|
||||||
|
|
||||||
@handle_errors
|
@handle_errors
|
||||||
def list_profiles(self, request, context):
|
def list_profiles(self, request, context):
|
||||||
@@ -516,8 +558,23 @@ class ProfileServicer(connpy_pb2_grpc.ProfileServiceServicer):
|
|||||||
return Empty()
|
return Empty()
|
||||||
|
|
||||||
class ConfigServicer(connpy_pb2_grpc.ConfigServiceServicer):
|
class ConfigServicer(connpy_pb2_grpc.ConfigServiceServicer):
|
||||||
def __init__(self, config):
|
def __init__(self, provider, registry=None):
|
||||||
self.service = ConfigService(config)
|
if not hasattr(provider, "mode"):
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
provider = ServiceProvider(provider, mode="local")
|
||||||
|
self._fallback_provider = provider
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def _get_provider(self):
|
||||||
|
if self._registry:
|
||||||
|
username = _current_user.get()
|
||||||
|
if username:
|
||||||
|
return self._registry.get_provider(username)
|
||||||
|
return self._fallback_provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def service(self):
|
||||||
|
return self._get_provider().config_svc
|
||||||
|
|
||||||
@handle_errors
|
@handle_errors
|
||||||
def get_settings(self, request, context):
|
def get_settings(self, request, context):
|
||||||
@@ -546,8 +603,23 @@ class ConfigServicer(connpy_pb2_grpc.ConfigServiceServicer):
|
|||||||
return connpy_pb2.StructResponse(data=to_struct(self.service.apply_theme_from_file(request.value)))
|
return connpy_pb2.StructResponse(data=to_struct(self.service.apply_theme_from_file(request.value)))
|
||||||
|
|
||||||
class PluginServicer(connpy_pb2_grpc.PluginServiceServicer, remote_plugin_pb2_grpc.RemotePluginServiceServicer):
|
class PluginServicer(connpy_pb2_grpc.PluginServiceServicer, remote_plugin_pb2_grpc.RemotePluginServiceServicer):
|
||||||
def __init__(self, config):
|
def __init__(self, provider, registry=None):
|
||||||
self.service = PluginService(config)
|
if not hasattr(provider, "mode"):
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
provider = ServiceProvider(provider, mode="local")
|
||||||
|
self._fallback_provider = provider
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def _get_provider(self):
|
||||||
|
if self._registry:
|
||||||
|
username = _current_user.get()
|
||||||
|
if username:
|
||||||
|
return self._registry.get_provider(username)
|
||||||
|
return self._fallback_provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def service(self):
|
||||||
|
return self._get_provider().plugins
|
||||||
|
|
||||||
@handle_errors
|
@handle_errors
|
||||||
def list_plugins(self, request, context):
|
def list_plugins(self, request, context):
|
||||||
@@ -589,8 +661,23 @@ class PluginServicer(connpy_pb2_grpc.PluginServiceServicer, remote_plugin_pb2_gr
|
|||||||
yield remote_plugin_pb2.OutputChunk(text=chunk)
|
yield remote_plugin_pb2.OutputChunk(text=chunk)
|
||||||
|
|
||||||
class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
|
class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
|
||||||
def __init__(self, config):
|
def __init__(self, provider, registry=None):
|
||||||
self.service = ExecutionService(config)
|
if not hasattr(provider, "mode"):
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
provider = ServiceProvider(provider, mode="local")
|
||||||
|
self._fallback_provider = provider
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def _get_provider(self):
|
||||||
|
if self._registry:
|
||||||
|
username = _current_user.get()
|
||||||
|
if username:
|
||||||
|
return self._registry.get_provider(username)
|
||||||
|
return self._fallback_provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def service(self):
|
||||||
|
return self._get_provider().execution
|
||||||
|
|
||||||
@handle_errors
|
@handle_errors
|
||||||
def run_commands(self, request, context):
|
def run_commands(self, request, context):
|
||||||
@@ -599,6 +686,11 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
|
|||||||
|
|
||||||
nodes_filter = request.nodes[0] if len(request.nodes) == 1 else list(request.nodes)
|
nodes_filter = request.nodes[0] if len(request.nodes) == 1 else list(request.nodes)
|
||||||
|
|
||||||
|
# Resolve provider in the main gRPC thread where _current_user ContextVar is set.
|
||||||
|
# threading.Thread does NOT inherit ContextVars, so self.service inside
|
||||||
|
# _worker() would fall back to the admin provider.
|
||||||
|
execution_service = self.service
|
||||||
|
|
||||||
q = queue.Queue()
|
q = queue.Queue()
|
||||||
|
|
||||||
def _on_complete(unique, output, status):
|
def _on_complete(unique, output, status):
|
||||||
@@ -606,7 +698,7 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
|
|||||||
|
|
||||||
def _worker():
|
def _worker():
|
||||||
try:
|
try:
|
||||||
self.service.run_commands( nodes_filter=nodes_filter,
|
execution_service.run_commands( nodes_filter=nodes_filter,
|
||||||
commands=list(request.commands),
|
commands=list(request.commands),
|
||||||
folder=request.folder if request.folder else None,
|
folder=request.folder if request.folder else None,
|
||||||
prompt=request.prompt if request.prompt else None,
|
prompt=request.prompt if request.prompt else None,
|
||||||
@@ -645,6 +737,9 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
|
|||||||
|
|
||||||
nodes_filter = request.nodes[0] if len(request.nodes) == 1 else list(request.nodes)
|
nodes_filter = request.nodes[0] if len(request.nodes) == 1 else list(request.nodes)
|
||||||
|
|
||||||
|
# Resolve provider in the main gRPC thread where _current_user ContextVar is set.
|
||||||
|
execution_service = self.service
|
||||||
|
|
||||||
q = queue.Queue()
|
q = queue.Queue()
|
||||||
|
|
||||||
def _on_complete(unique, node_output, node_status, node_result):
|
def _on_complete(unique, node_output, node_status, node_result):
|
||||||
@@ -652,7 +747,7 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
|
|||||||
|
|
||||||
def _worker():
|
def _worker():
|
||||||
try:
|
try:
|
||||||
self.service.test_commands(
|
execution_service.test_commands(
|
||||||
nodes_filter=nodes_filter,
|
nodes_filter=nodes_filter,
|
||||||
commands=list(request.commands),
|
commands=list(request.commands),
|
||||||
expected=list(request.expected),
|
expected=list(request.expected),
|
||||||
@@ -698,9 +793,27 @@ class ExecutionServicer(connpy_pb2_grpc.ExecutionServiceServicer):
|
|||||||
return connpy_pb2.StructResponse(data=to_struct(res))
|
return connpy_pb2.StructResponse(data=to_struct(res))
|
||||||
|
|
||||||
class ImportExportServicer(connpy_pb2_grpc.ImportExportServiceServicer):
|
class ImportExportServicer(connpy_pb2_grpc.ImportExportServiceServicer):
|
||||||
def __init__(self, config):
|
def __init__(self, provider, registry=None):
|
||||||
self.service = ImportExportService(config)
|
if not hasattr(provider, "mode"):
|
||||||
self.node_service = NodeService(config)
|
from connpy.services.provider import ServiceProvider
|
||||||
|
provider = ServiceProvider(provider, mode="local")
|
||||||
|
self._fallback_provider = provider
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def _get_provider(self):
|
||||||
|
if self._registry:
|
||||||
|
username = _current_user.get()
|
||||||
|
if username:
|
||||||
|
return self._registry.get_provider(username)
|
||||||
|
return self._fallback_provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def service(self):
|
||||||
|
return self._get_provider().import_export
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_service(self):
|
||||||
|
return self._get_provider().nodes
|
||||||
|
|
||||||
@handle_errors
|
@handle_errors
|
||||||
def export_to_file(self, request, context):
|
def export_to_file(self, request, context):
|
||||||
@@ -815,14 +928,30 @@ class StatusBridge:
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||||
def __init__(self, config):
|
def __init__(self, provider, registry=None):
|
||||||
self.service = AIService(config)
|
if not hasattr(provider, "mode"):
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
provider = ServiceProvider(provider, mode="local")
|
||||||
|
self._fallback_provider = provider
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def _get_provider(self):
|
||||||
|
if self._registry:
|
||||||
|
username = _current_user.get()
|
||||||
|
if username:
|
||||||
|
return self._registry.get_provider(username)
|
||||||
|
return self._fallback_provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def service(self):
|
||||||
|
return self._get_provider().ai
|
||||||
|
|
||||||
@handle_errors
|
@handle_errors
|
||||||
def ask(self, request_iterator, context):
|
def ask(self, request_iterator, context):
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
ai_service = self.service
|
||||||
chunk_queue = queue.Queue()
|
chunk_queue = queue.Queue()
|
||||||
request_queue = queue.Queue()
|
request_queue = queue.Queue()
|
||||||
bridge = None
|
bridge = None
|
||||||
@@ -840,7 +969,7 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
|||||||
nonlocal history, bridge, agent_instance
|
nonlocal history, bridge, agent_instance
|
||||||
try:
|
try:
|
||||||
# Run the AI interaction (this blocks this specific thread)
|
# Run the AI interaction (this blocks this specific thread)
|
||||||
res = self.service.ask(
|
res = ai_service.ask(
|
||||||
input_text,
|
input_text,
|
||||||
chat_history=history if history else None,
|
chat_history=history if history else None,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -996,8 +1125,23 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
|||||||
return connpy_pb2.StructResponse(data=to_struct(self.service.load_session_data(request.value)))
|
return connpy_pb2.StructResponse(data=to_struct(self.service.load_session_data(request.value)))
|
||||||
|
|
||||||
class SystemServicer(connpy_pb2_grpc.SystemServiceServicer):
|
class SystemServicer(connpy_pb2_grpc.SystemServiceServicer):
|
||||||
def __init__(self, config):
|
def __init__(self, provider, registry=None):
|
||||||
self.service = SystemService(config)
|
if not hasattr(provider, "mode"):
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
provider = ServiceProvider(provider, mode="local")
|
||||||
|
self._fallback_provider = provider
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
def _get_provider(self):
|
||||||
|
if self._registry:
|
||||||
|
username = _current_user.get()
|
||||||
|
if username:
|
||||||
|
return self._registry.get_provider(username)
|
||||||
|
return self._fallback_provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def service(self):
|
||||||
|
return self._get_provider().system
|
||||||
|
|
||||||
@handle_errors
|
@handle_errors
|
||||||
def start_api(self, request, context):
|
def start_api(self, request, context):
|
||||||
@@ -1023,6 +1167,138 @@ class SystemServicer(connpy_pb2_grpc.SystemServiceServicer):
|
|||||||
def get_api_status(self, request, context):
|
def get_api_status(self, request, context):
|
||||||
return connpy_pb2.BoolResponse(value=self.service.get_api_status())
|
return connpy_pb2.BoolResponse(value=self.service.get_api_status())
|
||||||
|
|
||||||
|
class AuthServicer(connpy_pb2_grpc.AuthServiceServicer):
|
||||||
|
def __init__(self, registry):
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
|
@handle_errors
|
||||||
|
def login(self, request, context):
|
||||||
|
username = request.username
|
||||||
|
password = request.password
|
||||||
|
|
||||||
|
if not self.registry.user_service.authenticate(username, password):
|
||||||
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid username or password")
|
||||||
|
|
||||||
|
token = self.registry.user_service.generate_jwt(username)
|
||||||
|
expires_at = int((datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8)).timestamp())
|
||||||
|
|
||||||
|
return connpy_pb2.LoginResponse(
|
||||||
|
token=token,
|
||||||
|
username=username,
|
||||||
|
expires_at=expires_at
|
||||||
|
)
|
||||||
|
|
||||||
|
@handle_errors
|
||||||
|
def change_password(self, request, context):
|
||||||
|
username = _current_user.get()
|
||||||
|
if not username:
|
||||||
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Authentication required")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.registry.user_service.change_password(username, request.old_password, request.new_password)
|
||||||
|
self.registry.evict(username)
|
||||||
|
except ValueError as e:
|
||||||
|
context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||||
|
|
||||||
|
return Empty()
|
||||||
|
|
||||||
|
class AuthInterceptor(grpc.ServerInterceptor):
|
||||||
|
OPEN_METHODS = ["/connpy.AuthService/login"]
|
||||||
|
|
||||||
|
def __init__(self, registry):
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
|
def intercept_service(self, continuation, handler_call_details):
|
||||||
|
method = handler_call_details.method
|
||||||
|
if method in self.OPEN_METHODS:
|
||||||
|
return continuation(handler_call_details)
|
||||||
|
|
||||||
|
if not self.registry.has_users():
|
||||||
|
return continuation(handler_call_details)
|
||||||
|
|
||||||
|
token = self._extract_token(handler_call_details.invocation_metadata)
|
||||||
|
if not token:
|
||||||
|
return self._unauthenticated_handler(handler_call_details, "Authorization token is missing")
|
||||||
|
|
||||||
|
username = self.registry.user_service.verify_jwt(token)
|
||||||
|
if not username:
|
||||||
|
return self._unauthenticated_handler(handler_call_details, "Invalid or expired token")
|
||||||
|
|
||||||
|
handler = continuation(handler_call_details)
|
||||||
|
if handler is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._wrap_handler(handler, username)
|
||||||
|
|
||||||
|
def _wrap_handler(self, handler, username):
|
||||||
|
if handler.unary_unary:
|
||||||
|
original_behavior = handler.unary_unary
|
||||||
|
def wrapper(request, context):
|
||||||
|
token = _current_user.set(username)
|
||||||
|
try:
|
||||||
|
return original_behavior(request, context)
|
||||||
|
finally:
|
||||||
|
_current_user.reset(token)
|
||||||
|
return grpc.unary_unary_rpc_method_handler(
|
||||||
|
wrapper,
|
||||||
|
request_deserializer=handler.request_deserializer,
|
||||||
|
response_serializer=handler.response_serializer,
|
||||||
|
)
|
||||||
|
elif handler.unary_stream:
|
||||||
|
original_behavior = handler.unary_stream
|
||||||
|
def wrapper(request, context):
|
||||||
|
token = _current_user.set(username)
|
||||||
|
try:
|
||||||
|
for response in original_behavior(request, context):
|
||||||
|
yield response
|
||||||
|
finally:
|
||||||
|
_current_user.reset(token)
|
||||||
|
return grpc.unary_stream_rpc_method_handler(
|
||||||
|
wrapper,
|
||||||
|
request_deserializer=handler.request_deserializer,
|
||||||
|
response_serializer=handler.response_serializer,
|
||||||
|
)
|
||||||
|
elif handler.stream_unary:
|
||||||
|
original_behavior = handler.stream_unary
|
||||||
|
def wrapper(request_iterator, context):
|
||||||
|
token = _current_user.set(username)
|
||||||
|
try:
|
||||||
|
return original_behavior(request_iterator, context)
|
||||||
|
finally:
|
||||||
|
_current_user.reset(token)
|
||||||
|
return grpc.stream_unary_rpc_method_handler(
|
||||||
|
wrapper,
|
||||||
|
request_deserializer=handler.request_deserializer,
|
||||||
|
response_serializer=handler.response_serializer,
|
||||||
|
)
|
||||||
|
elif handler.stream_stream:
|
||||||
|
original_behavior = handler.stream_stream
|
||||||
|
def wrapper(request_iterator, context):
|
||||||
|
token = _current_user.set(username)
|
||||||
|
try:
|
||||||
|
for response in original_behavior(request_iterator, context):
|
||||||
|
yield response
|
||||||
|
finally:
|
||||||
|
_current_user.reset(token)
|
||||||
|
return grpc.stream_stream_rpc_method_handler(
|
||||||
|
wrapper,
|
||||||
|
request_deserializer=handler.request_deserializer,
|
||||||
|
response_serializer=handler.response_serializer,
|
||||||
|
)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
def _extract_token(self, metadata):
|
||||||
|
for key, value in metadata:
|
||||||
|
if key.lower() == "authorization":
|
||||||
|
if value.startswith("Bearer "):
|
||||||
|
return value[7:]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _unauthenticated_handler(self, handler_call_details, message):
|
||||||
|
def abort_call(request_or_iterator, context):
|
||||||
|
context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
|
||||||
|
return grpc.unary_unary_rpc_method_handler(abort_call)
|
||||||
|
|
||||||
class LoggingInterceptor(grpc.ServerInterceptor):
|
class LoggingInterceptor(grpc.ServerInterceptor):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@@ -1047,19 +1323,30 @@ class LoggingInterceptor(grpc.ServerInterceptor):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def serve(config, port=8048, debug=False):
|
def serve(config, port=8048, debug=False):
|
||||||
interceptors = [LoggingInterceptor()] if debug else []
|
from connpy.grpc_layer.user_registry import UserRegistry
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
|
||||||
|
fallback_provider = ServiceProvider(config, mode="local")
|
||||||
|
registry = UserRegistry(config.defaultdir)
|
||||||
|
|
||||||
|
interceptors = []
|
||||||
|
if debug:
|
||||||
|
interceptors.append(LoggingInterceptor())
|
||||||
|
interceptors.append(AuthInterceptor(registry))
|
||||||
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), interceptors=interceptors)
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), interceptors=interceptors)
|
||||||
|
|
||||||
connpy_pb2_grpc.add_NodeServiceServicer_to_server(NodeServicer(config, debug=debug), server)
|
connpy_pb2_grpc.add_NodeServiceServicer_to_server(NodeServicer(fallback_provider, registry=registry, debug=debug), server)
|
||||||
connpy_pb2_grpc.add_ProfileServiceServicer_to_server(ProfileServicer(config), server)
|
connpy_pb2_grpc.add_ProfileServiceServicer_to_server(ProfileServicer(fallback_provider, registry=registry), server)
|
||||||
connpy_pb2_grpc.add_ConfigServiceServicer_to_server(ConfigServicer(config), server)
|
connpy_pb2_grpc.add_ConfigServiceServicer_to_server(ConfigServicer(fallback_provider, registry=registry), server)
|
||||||
plugin_servicer = PluginServicer(config)
|
plugin_servicer = PluginServicer(fallback_provider, registry=registry)
|
||||||
connpy_pb2_grpc.add_PluginServiceServicer_to_server(plugin_servicer, server)
|
connpy_pb2_grpc.add_PluginServiceServicer_to_server(plugin_servicer, server)
|
||||||
remote_plugin_pb2_grpc.add_RemotePluginServiceServicer_to_server(plugin_servicer, server)
|
remote_plugin_pb2_grpc.add_RemotePluginServiceServicer_to_server(plugin_servicer, server)
|
||||||
connpy_pb2_grpc.add_ExecutionServiceServicer_to_server(ExecutionServicer(config), server)
|
connpy_pb2_grpc.add_ExecutionServiceServicer_to_server(ExecutionServicer(fallback_provider, registry=registry), server)
|
||||||
connpy_pb2_grpc.add_ImportExportServiceServicer_to_server(ImportExportServicer(config), server)
|
connpy_pb2_grpc.add_ImportExportServiceServicer_to_server(ImportExportServicer(fallback_provider, registry=registry), server)
|
||||||
connpy_pb2_grpc.add_AIServiceServicer_to_server(AIServicer(config), server)
|
connpy_pb2_grpc.add_AIServiceServicer_to_server(AIServicer(fallback_provider, registry=registry), server)
|
||||||
connpy_pb2_grpc.add_SystemServiceServicer_to_server(SystemServicer(config), server)
|
connpy_pb2_grpc.add_SystemServiceServicer_to_server(SystemServicer(fallback_provider, registry=registry), server)
|
||||||
|
connpy_pb2_grpc.add_AuthServiceServicer_to_server(AuthServicer(registry), server)
|
||||||
|
|
||||||
server.add_insecure_port(f'[::]:{port}')
|
server.add_insecure_port(f'[::]:{port}')
|
||||||
server.start()
|
server.start()
|
||||||
|
|||||||
@@ -980,3 +980,78 @@ class SystemStub:
|
|||||||
@handle_errors
|
@handle_errors
|
||||||
def get_api_status(self):
|
def get_api_status(self):
|
||||||
return self.stub.get_api_status(Empty()).value
|
return self.stub.get_api_status(Empty()).value
|
||||||
|
|
||||||
|
class _ClientCallDetails(object):
|
||||||
|
def __init__(self, method, timeout, metadata, credentials, wait_for_ready, compression=None):
|
||||||
|
self.method = method
|
||||||
|
self.timeout = timeout
|
||||||
|
self.metadata = metadata
|
||||||
|
self.credentials = credentials
|
||||||
|
self.wait_for_ready = wait_for_ready
|
||||||
|
self.compression = compression
|
||||||
|
|
||||||
|
class AuthClientInterceptor(grpc.UnaryUnaryClientInterceptor,
|
||||||
|
grpc.UnaryStreamClientInterceptor,
|
||||||
|
grpc.StreamUnaryClientInterceptor,
|
||||||
|
grpc.StreamStreamClientInterceptor):
|
||||||
|
def __init__(self, token_provider):
|
||||||
|
self.token_provider = token_provider
|
||||||
|
|
||||||
|
def _add_metadata(self, client_call_details):
|
||||||
|
token = self.token_provider()
|
||||||
|
if not token:
|
||||||
|
return client_call_details
|
||||||
|
|
||||||
|
metadata = []
|
||||||
|
if client_call_details.metadata:
|
||||||
|
metadata = list(client_call_details.metadata)
|
||||||
|
|
||||||
|
# Check if already present to avoid duplicates
|
||||||
|
if not any(k.lower() == "authorization" for k, v in metadata):
|
||||||
|
metadata.append(("authorization", f"Bearer {token}"))
|
||||||
|
|
||||||
|
return _ClientCallDetails(
|
||||||
|
method=client_call_details.method,
|
||||||
|
timeout=client_call_details.timeout,
|
||||||
|
metadata=metadata,
|
||||||
|
credentials=client_call_details.credentials,
|
||||||
|
wait_for_ready=client_call_details.wait_for_ready,
|
||||||
|
compression=client_call_details.compression,
|
||||||
|
)
|
||||||
|
|
||||||
|
def intercept_unary_unary(self, continuation, client_call_details, request):
|
||||||
|
new_details = self._add_metadata(client_call_details)
|
||||||
|
return continuation(new_details, request)
|
||||||
|
|
||||||
|
def intercept_unary_stream(self, continuation, client_call_details, request):
|
||||||
|
new_details = self._add_metadata(client_call_details)
|
||||||
|
return continuation(new_details, request)
|
||||||
|
|
||||||
|
def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
|
||||||
|
new_details = self._add_metadata(client_call_details)
|
||||||
|
return continuation(new_details, request_iterator)
|
||||||
|
|
||||||
|
def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
|
||||||
|
new_details = self._add_metadata(client_call_details)
|
||||||
|
return continuation(new_details, request_iterator)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthStub:
|
||||||
|
def __init__(self, channel, remote_host):
|
||||||
|
self.stub = connpy_pb2_grpc.AuthServiceStub(channel)
|
||||||
|
self.remote_host = remote_host
|
||||||
|
|
||||||
|
@handle_errors
|
||||||
|
def login(self, username, password):
|
||||||
|
req = connpy_pb2.LoginRequest(username=username, password=password)
|
||||||
|
resp = self.stub.login(req)
|
||||||
|
return {
|
||||||
|
"token": resp.token,
|
||||||
|
"username": resp.username,
|
||||||
|
"expires_at": resp.expires_at
|
||||||
|
}
|
||||||
|
|
||||||
|
@handle_errors
|
||||||
|
def change_password(self, old_password, new_password):
|
||||||
|
req = connpy_pb2.ChangePasswordRequest(old_password=old_password, new_password=new_password)
|
||||||
|
self.stub.change_password(req)
|
||||||
|
|||||||
@@ -0,0 +1,107 @@
|
|||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from connpy.configfile import configfile
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
from connpy.services.user_service import UserService
|
||||||
|
|
||||||
|
class UserRegistry:
|
||||||
|
"""Holds per-user ServiceProviders in memory, thread-safe with hot-reloading."""
|
||||||
|
def __init__(self, server_config_dir):
|
||||||
|
self.server_config_dir = os.path.abspath(server_config_dir)
|
||||||
|
self.user_service = UserService(self.server_config_dir)
|
||||||
|
self._providers = {} # username → ServiceProvider
|
||||||
|
self._mtimes = {} # username → last loaded mtime (float)
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# Load shared/global config
|
||||||
|
self._shared_conf_file = os.path.join(self.server_config_dir, "config.yaml")
|
||||||
|
if os.path.exists(self._shared_conf_file):
|
||||||
|
self._shared_config = configfile(conf=self._shared_conf_file)
|
||||||
|
self._shared_mtime = os.path.getmtime(self._shared_conf_file)
|
||||||
|
else:
|
||||||
|
self._shared_config = None
|
||||||
|
self._shared_mtime = 0.0
|
||||||
|
|
||||||
|
def _refresh_shared(self):
|
||||||
|
"""Hot-reload shared config if the file changed on disk."""
|
||||||
|
if not os.path.exists(self._shared_conf_file):
|
||||||
|
return
|
||||||
|
current_mtime = os.path.getmtime(self._shared_conf_file)
|
||||||
|
if current_mtime > self._shared_mtime:
|
||||||
|
try:
|
||||||
|
self._shared_config = configfile(conf=self._shared_conf_file)
|
||||||
|
self._shared_mtime = current_mtime
|
||||||
|
# Clear all user providers so they pick up the new shared config
|
||||||
|
self._providers.clear()
|
||||||
|
self._mtimes.clear()
|
||||||
|
except Exception as e:
|
||||||
|
from connpy import printer
|
||||||
|
printer.warning(f"Failed to reload shared config: {e}")
|
||||||
|
|
||||||
|
def get_provider(self, username) -> ServiceProvider:
|
||||||
|
"""Get, lazy-load, or hot-reload a user's full ServiceProvider."""
|
||||||
|
with self._lock:
|
||||||
|
# Refresh shared/global config if it has changed
|
||||||
|
self._refresh_shared()
|
||||||
|
|
||||||
|
# 1. Resolve physical path of the user's config.yaml file
|
||||||
|
user_data = self.user_service.get_user(username)
|
||||||
|
config_path = user_data.get("config_path")
|
||||||
|
if config_path:
|
||||||
|
conf_file = os.path.join(config_path, "config.yaml")
|
||||||
|
else:
|
||||||
|
conf_file = os.path.join(self.server_config_dir, "users", username, "config.yaml")
|
||||||
|
|
||||||
|
# 2. Retrieve actual modification time in disk
|
||||||
|
current_mtime = os.path.getmtime(conf_file) if os.path.exists(conf_file) else 0.0
|
||||||
|
|
||||||
|
# 3. Validate if initial load or hot-reload is required
|
||||||
|
if username not in self._providers or self._mtimes.get(username, 0.0) < current_mtime:
|
||||||
|
old_provider = self._providers.get(username)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Attempt a fresh configuration load
|
||||||
|
config = configfile(conf=conf_file, shared_config=self._shared_config)
|
||||||
|
new_provider = ServiceProvider(config, mode="local")
|
||||||
|
|
||||||
|
# Successfully loaded, clean up the old provider
|
||||||
|
if old_provider:
|
||||||
|
self._providers.pop(username, None)
|
||||||
|
if hasattr(old_provider, "close"):
|
||||||
|
try:
|
||||||
|
old_provider.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._providers[username] = new_provider
|
||||||
|
self._mtimes[username] = current_mtime
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log warning but fallback to the old stable provider in memory if available
|
||||||
|
from connpy import printer
|
||||||
|
printer.warning(f"Failed to hot-reload config for user '{username}' (file may be corrupt/incomplete): {e}")
|
||||||
|
if old_provider:
|
||||||
|
# Keep serving with the old cached instance to ensure service continuity
|
||||||
|
self._mtimes[username] = current_mtime
|
||||||
|
else:
|
||||||
|
# No fallback exists, propagate the exception
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return self._providers[username]
|
||||||
|
|
||||||
|
def has_users(self) -> bool:
|
||||||
|
"""Check if any users are registered (enables auth enforcement)."""
|
||||||
|
return bool(self.user_service.list_users())
|
||||||
|
|
||||||
|
def evict(self, username):
|
||||||
|
"""Remove and cleanly shut down cached provider (after delete or password change)."""
|
||||||
|
with self._lock:
|
||||||
|
provider = self._providers.pop(username, None)
|
||||||
|
self._mtimes.pop(username, None)
|
||||||
|
if provider:
|
||||||
|
# Explicit cleanup of user-scoped resources if custom close/cleanup exists
|
||||||
|
if hasattr(provider, "close"):
|
||||||
|
try:
|
||||||
|
provider.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
@@ -48,7 +48,10 @@ class MCPClientManager:
|
|||||||
|
|
||||||
all_llm_tools = []
|
all_llm_tools = []
|
||||||
try:
|
try:
|
||||||
mcp_config = self.config.config.get("ai", {}).get("mcp_servers", {})
|
if hasattr(self.config, "get_effective_setting"):
|
||||||
|
mcp_config = self.config.get_effective_setting("ai", {}).get("mcp_servers", {})
|
||||||
|
else:
|
||||||
|
mcp_config = self.config.config.get("ai", {}).get("mcp_servers", {}) if hasattr(self.config, "config") else {}
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|||||||
@@ -296,3 +296,24 @@ message MCPRequest {
|
|||||||
string auto_load_on_os = 4;
|
string auto_load_on_os = 4;
|
||||||
bool remove = 5;
|
bool remove = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
service AuthService {
|
||||||
|
rpc login (LoginRequest) returns (LoginResponse) {}
|
||||||
|
rpc change_password (ChangePasswordRequest) returns (google.protobuf.Empty) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
message LoginRequest {
|
||||||
|
string username = 1;
|
||||||
|
string password = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message LoginResponse {
|
||||||
|
string token = 1;
|
||||||
|
string username = 2;
|
||||||
|
int64 expires_at = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ChangePasswordRequest {
|
||||||
|
string old_password = 1;
|
||||||
|
string new_password = 2;
|
||||||
|
}
|
||||||
|
|||||||
@@ -307,7 +307,10 @@ class AIService(BaseService):
|
|||||||
|
|
||||||
def list_mcp_servers(self) -> dict:
|
def list_mcp_servers(self) -> dict:
|
||||||
"""Get the configured MCP servers."""
|
"""Get the configured MCP servers."""
|
||||||
ai_settings = self.config.config.get("ai", {})
|
if hasattr(self.config, "get_effective_setting"):
|
||||||
|
ai_settings = self.config.get_effective_setting("ai", {})
|
||||||
|
else:
|
||||||
|
ai_settings = self.config.config.get("ai", {}) if hasattr(self.config, "config") else {}
|
||||||
return ai_settings.get("mcp_servers", {})
|
return ai_settings.get("mcp_servers", {})
|
||||||
|
|
||||||
def load_session_data(self, session_id):
|
def load_session_data(self, session_id):
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class ServiceProvider:
|
|||||||
from .import_export_service import ImportExportService
|
from .import_export_service import ImportExportService
|
||||||
from .context_service import ContextService
|
from .context_service import ContextService
|
||||||
from .sync_service import SyncService
|
from .sync_service import SyncService
|
||||||
|
from .user_service import UserService
|
||||||
|
|
||||||
self.nodes = NodeService(self.config)
|
self.nodes = NodeService(self.config)
|
||||||
self.profiles = ProfileService(self.config)
|
self.profiles = ProfileService(self.config)
|
||||||
@@ -44,6 +45,7 @@ class ServiceProvider:
|
|||||||
self.import_export = ImportExportService(self.config)
|
self.import_export = ImportExportService(self.config)
|
||||||
self.context = ContextService(self.config)
|
self.context = ContextService(self.config)
|
||||||
self.sync = SyncService(self.config)
|
self.sync = SyncService(self.config)
|
||||||
|
self.users = UserService(self.config.defaultdir)
|
||||||
|
|
||||||
def _init_remote(self):
|
def _init_remote(self):
|
||||||
# Allow ConfigService to work locally so the user can revert the mode
|
# Allow ConfigService to work locally so the user can revert the mode
|
||||||
@@ -53,14 +55,37 @@ class ServiceProvider:
|
|||||||
self.config_svc = ConfigService(self.config)
|
self.config_svc = ConfigService(self.config)
|
||||||
self.context = ContextService(self.config)
|
self.context = ContextService(self.config)
|
||||||
self.sync = SyncService(self.config)
|
self.sync = SyncService(self.config)
|
||||||
|
self.users = None
|
||||||
|
|
||||||
if not self.remote_host:
|
if not self.remote_host:
|
||||||
raise InvalidConfigurationError("Remote host must be specified in remote mode")
|
raise InvalidConfigurationError("Remote host must be specified in remote mode")
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
from ..grpc_layer.stubs import NodeStub, ProfileStub, PluginStub, AIStub, ExecutionStub, ImportExportStub, SystemStub
|
import os
|
||||||
|
from ..grpc_layer.stubs import (
|
||||||
|
NodeStub, ProfileStub, PluginStub, AIStub,
|
||||||
|
ExecutionStub, ImportExportStub, SystemStub,
|
||||||
|
ConfigStub, AuthClientInterceptor, AuthStub
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_token():
|
||||||
|
token_path = os.path.join(self.config.defaultdir, ".token")
|
||||||
|
if os.path.exists(token_path):
|
||||||
|
try:
|
||||||
|
with open(token_path, "r") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
channel = grpc.insecure_channel(self.remote_host)
|
channel = grpc.insecure_channel(self.remote_host)
|
||||||
|
interceptor = AuthClientInterceptor(get_token)
|
||||||
|
channel = grpc.intercept_channel(channel, interceptor)
|
||||||
|
|
||||||
|
# Surgical fix: Keep ConfigService local for mode/theme management,
|
||||||
|
# but delegate encryption to the server stub.
|
||||||
|
config_remote = ConfigStub(channel, remote_host=self.remote_host)
|
||||||
|
self.config_svc.encrypt_password = config_remote.encrypt_password
|
||||||
|
|
||||||
self.nodes = NodeStub(channel, remote_host=self.remote_host, config=self.config)
|
self.nodes = NodeStub(channel, remote_host=self.remote_host, config=self.config)
|
||||||
self.profiles = ProfileStub(channel, remote_host=self.remote_host, node_stub=self.nodes)
|
self.profiles = ProfileStub(channel, remote_host=self.remote_host, node_stub=self.nodes)
|
||||||
@@ -69,3 +94,4 @@ class ServiceProvider:
|
|||||||
self.system = SystemStub(channel, remote_host=self.remote_host)
|
self.system = SystemStub(channel, remote_host=self.remote_host)
|
||||||
self.execution = ExecutionStub(channel, remote_host=self.remote_host)
|
self.execution = ExecutionStub(channel, remote_host=self.remote_host)
|
||||||
self.import_export = ImportExportStub(channel, remote_host=self.remote_host)
|
self.import_export = ImportExportStub(channel, remote_host=self.remote_host)
|
||||||
|
self.auth = AuthStub(channel, remote_host=self.remote_host)
|
||||||
|
|||||||
@@ -0,0 +1,237 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import secrets
|
||||||
|
import datetime
|
||||||
|
import bcrypt
|
||||||
|
import jwt
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from connpy.configfile import configfile
|
||||||
|
|
||||||
|
class UserService:
|
||||||
|
def __init__(self, config_dir):
|
||||||
|
self.config_dir = os.path.abspath(config_dir)
|
||||||
|
self.users_dir = os.path.join(self.config_dir, "users")
|
||||||
|
self.registry_file = os.path.join(self.users_dir, "registry.yaml")
|
||||||
|
|
||||||
|
# Ensure users directory exists
|
||||||
|
os.makedirs(self.users_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def _load_registry(self) -> dict:
|
||||||
|
"""Loads registry from file. If it doesn't exist, initializes it with a new JWT secret."""
|
||||||
|
if not os.path.exists(self.registry_file):
|
||||||
|
registry = {
|
||||||
|
"jwt_secret": secrets.token_hex(32),
|
||||||
|
"users": {}
|
||||||
|
}
|
||||||
|
self._save_registry(registry)
|
||||||
|
return registry
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.registry_file, "r") as f:
|
||||||
|
registry = yaml.safe_load(f) or {}
|
||||||
|
except Exception:
|
||||||
|
registry = {}
|
||||||
|
|
||||||
|
if not isinstance(registry, dict):
|
||||||
|
registry = {}
|
||||||
|
|
||||||
|
if "jwt_secret" not in registry:
|
||||||
|
registry["jwt_secret"] = secrets.token_hex(32)
|
||||||
|
|
||||||
|
if "users" not in registry or not isinstance(registry["users"], dict):
|
||||||
|
registry["users"] = {}
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
def _save_registry(self, data: dict):
|
||||||
|
"""Safely saves registry structure to registry.yaml."""
|
||||||
|
tmp_file = self.registry_file + ".tmp"
|
||||||
|
try:
|
||||||
|
with open(tmp_file, "w") as f:
|
||||||
|
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
|
||||||
|
os.replace(tmp_file, self.registry_file)
|
||||||
|
os.chmod(self.registry_file, 0o600)
|
||||||
|
except Exception as e:
|
||||||
|
if os.path.exists(tmp_file):
|
||||||
|
try:
|
||||||
|
os.remove(tmp_file)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def create_user(self, username, password, config_path=None) -> dict:
|
||||||
|
"""Creates a new user with bcrypt-hashed credentials.
|
||||||
|
|
||||||
|
Mode A: config_path=None (fresh user) -> Generates config.yaml and .osk key.
|
||||||
|
Mode B: config_path set -> Reuses existing directory after validating its structure.
|
||||||
|
"""
|
||||||
|
if not username or not isinstance(username, str):
|
||||||
|
raise ValueError("Username cannot be empty")
|
||||||
|
|
||||||
|
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
|
||||||
|
raise ValueError("Username must contain only alphanumeric characters, dashes, or underscores")
|
||||||
|
|
||||||
|
if not password or not isinstance(password, str):
|
||||||
|
raise ValueError("Password cannot be empty")
|
||||||
|
|
||||||
|
registry = self._load_registry()
|
||||||
|
if username in registry["users"]:
|
||||||
|
raise ValueError(f"User '{username}' already exists")
|
||||||
|
|
||||||
|
# Resolve path and initialize configuration
|
||||||
|
if config_path is None:
|
||||||
|
user_dir = os.path.join(self.users_dir, username)
|
||||||
|
os.makedirs(user_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create subdirs for plugins and sessions
|
||||||
|
os.makedirs(os.path.join(user_dir, "plugins"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(user_dir, "ai_sessions"), exist_ok=True)
|
||||||
|
|
||||||
|
# Create default config.yaml & .osk key via configfile
|
||||||
|
conf_file = os.path.join(user_dir, "config.yaml")
|
||||||
|
configfile(conf=conf_file)
|
||||||
|
|
||||||
|
stored_config_path = None
|
||||||
|
else:
|
||||||
|
abs_config_path = os.path.abspath(config_path)
|
||||||
|
os.makedirs(abs_config_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Create subdirs for plugins and sessions in the custom path
|
||||||
|
os.makedirs(os.path.join(abs_config_path, "plugins"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(abs_config_path, "ai_sessions"), exist_ok=True)
|
||||||
|
|
||||||
|
# Create default config.yaml & .osk key via configfile if config.yaml is not present
|
||||||
|
conf_file = os.path.join(abs_config_path, "config.yaml")
|
||||||
|
if not os.path.exists(conf_file):
|
||||||
|
configfile(conf=conf_file)
|
||||||
|
|
||||||
|
stored_config_path = abs_config_path
|
||||||
|
|
||||||
|
# Hash password securely
|
||||||
|
password_hash = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
|
||||||
|
user_entry = {
|
||||||
|
"password_hash": password_hash,
|
||||||
|
"config_path": stored_config_path,
|
||||||
|
"created": datetime.datetime.now(datetime.timezone.utc).isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
registry["users"][username] = user_entry
|
||||||
|
self._save_registry(registry)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"username": username,
|
||||||
|
"config_path": stored_config_path,
|
||||||
|
"created": user_entry["created"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def delete_user(self, username):
|
||||||
|
"""Removes user from the registry and cleans up config directory if server-managed."""
|
||||||
|
registry = self._load_registry()
|
||||||
|
if username not in registry["users"]:
|
||||||
|
raise ValueError(f"User '{username}' not found")
|
||||||
|
|
||||||
|
user_data = registry["users"][username]
|
||||||
|
config_path = user_data.get("config_path")
|
||||||
|
|
||||||
|
if config_path is None:
|
||||||
|
user_dir = os.path.join(self.users_dir, username)
|
||||||
|
if os.path.exists(user_dir):
|
||||||
|
shutil.rmtree(user_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
del registry["users"][username]
|
||||||
|
self._save_registry(registry)
|
||||||
|
|
||||||
|
def list_users(self) -> list[dict]:
|
||||||
|
"""Lists all registered users with metadata."""
|
||||||
|
registry = self._load_registry()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"username": name,
|
||||||
|
"config_path": data.get("config_path"),
|
||||||
|
"created": data.get("created")
|
||||||
|
}
|
||||||
|
for name, data in registry.get("users", {}).items()
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_user(self, username) -> dict:
|
||||||
|
"""Retrieves raw metadata for a specific user."""
|
||||||
|
registry = self._load_registry()
|
||||||
|
if username not in registry["users"]:
|
||||||
|
raise ValueError(f"User '{username}' not found")
|
||||||
|
|
||||||
|
data = registry["users"][username]
|
||||||
|
return {
|
||||||
|
"username": username,
|
||||||
|
"config_path": data.get("config_path"),
|
||||||
|
"created": data.get("created"),
|
||||||
|
"password_hash": data.get("password_hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
def change_password(self, username, old_password, new_password):
|
||||||
|
"""Verifies old password and updates registry with new hashed password."""
|
||||||
|
if not new_password or not isinstance(new_password, str):
|
||||||
|
raise ValueError("New password cannot be empty")
|
||||||
|
|
||||||
|
registry = self._load_registry()
|
||||||
|
if username not in registry["users"]:
|
||||||
|
raise ValueError(f"User '{username}' not found")
|
||||||
|
|
||||||
|
user_data = registry["users"][username]
|
||||||
|
if not bcrypt.checkpw(old_password.encode("utf-8"), user_data["password_hash"].encode("utf-8")):
|
||||||
|
raise ValueError("Invalid credentials")
|
||||||
|
|
||||||
|
# Update hash
|
||||||
|
user_data["password_hash"] = bcrypt.hashpw(new_password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
self._save_registry(registry)
|
||||||
|
|
||||||
|
def admin_change_password(self, username, new_password):
|
||||||
|
"""Administrative password override (does not require old password)."""
|
||||||
|
if not new_password or not isinstance(new_password, str):
|
||||||
|
raise ValueError("New password cannot be empty")
|
||||||
|
|
||||||
|
registry = self._load_registry()
|
||||||
|
if username not in registry["users"]:
|
||||||
|
raise ValueError(f"User '{username}' not found")
|
||||||
|
|
||||||
|
user_data = registry["users"][username]
|
||||||
|
user_data["password_hash"] = bcrypt.hashpw(new_password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
self._save_registry(registry)
|
||||||
|
|
||||||
|
def authenticate(self, username, password) -> bool:
|
||||||
|
"""Verifies if the credentials are valid using bcrypt."""
|
||||||
|
registry = self._load_registry()
|
||||||
|
if username not in registry["users"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
user_data = registry["users"][username]
|
||||||
|
return bcrypt.checkpw(password.encode("utf-8"), user_data["password_hash"].encode("utf-8"))
|
||||||
|
|
||||||
|
def generate_jwt(self, username) -> str:
|
||||||
|
"""Generates a secure JSON Web Token for the user expiring in 8 hours."""
|
||||||
|
registry = self._load_registry()
|
||||||
|
if username not in registry["users"]:
|
||||||
|
raise ValueError(f"User '{username}' not found")
|
||||||
|
|
||||||
|
expiration = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=8)
|
||||||
|
payload = {
|
||||||
|
"sub": username,
|
||||||
|
"exp": expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
token = jwt.encode(payload, registry["jwt_secret"], algorithm="HS256")
|
||||||
|
if isinstance(token, bytes):
|
||||||
|
token = token.decode("utf-8")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def verify_jwt(self, token) -> str | None:
|
||||||
|
"""Decodes JWT and returns username if token is valid and unexpired."""
|
||||||
|
registry = self._load_registry()
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, registry["jwt_secret"], algorithms=["HS256"])
|
||||||
|
return payload.get("sub")
|
||||||
|
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError, KeyError):
|
||||||
|
return None
|
||||||
@@ -0,0 +1,186 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import grpc
|
||||||
|
import argparse
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from connpy.connapp import connapp
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
from connpy.cli.user_handler import UserHandler
|
||||||
|
from connpy.cli.login_handler import LoginHandler
|
||||||
|
from connpy.grpc_layer.stubs import AuthClientInterceptor, AuthStub
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config():
|
||||||
|
config = MagicMock()
|
||||||
|
config.config = {"service_mode": "local", "remote_host": "localhost:8048"}
|
||||||
|
config.defaultdir = "/mock/default/dir"
|
||||||
|
return config
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_instance(mock_config):
|
||||||
|
with patch("connpy.services.provider.ServiceProvider") as mock_provider_cls:
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.context = MagicMock()
|
||||||
|
mock_provider.nodes = MagicMock()
|
||||||
|
mock_provider.profiles = MagicMock()
|
||||||
|
mock_provider.config_svc = MagicMock()
|
||||||
|
mock_provider.plugins = MagicMock()
|
||||||
|
mock_provider.sync = MagicMock()
|
||||||
|
mock_provider.mode = "local"
|
||||||
|
mock_provider.remote_host = "localhost:8048"
|
||||||
|
mock_provider_cls.return_value = mock_provider
|
||||||
|
|
||||||
|
app = connapp(mock_config)
|
||||||
|
# Mock UserService on app services
|
||||||
|
app.services.users = MagicMock()
|
||||||
|
return app
|
||||||
|
|
||||||
|
class TestCLIMultiUserParsing:
|
||||||
|
def test_parser_contains_user_login_logout(self, app_instance):
|
||||||
|
parser, _ = app_instance.get_parser()
|
||||||
|
|
||||||
|
# Verify subcommands exist by finding the _SubParsersAction
|
||||||
|
subparsers_action = None
|
||||||
|
for action in parser._actions:
|
||||||
|
if isinstance(action, argparse._SubParsersAction):
|
||||||
|
subparsers_action = action
|
||||||
|
break
|
||||||
|
|
||||||
|
assert subparsers_action is not None
|
||||||
|
subcommands = subparsers_action.choices.keys()
|
||||||
|
assert "user" in subcommands
|
||||||
|
assert "login" in subcommands
|
||||||
|
assert "logout" in subcommands
|
||||||
|
|
||||||
|
def test_user_parser_arguments(self, app_instance):
|
||||||
|
parser, _ = app_instance.get_parser()
|
||||||
|
|
||||||
|
# Parse add user
|
||||||
|
args = parser.parse_args(["user", "--add", "newguy"])
|
||||||
|
assert args.add == ["newguy"]
|
||||||
|
assert args.func == app_instance._user.dispatch
|
||||||
|
|
||||||
|
# Parse delete user
|
||||||
|
args = parser.parse_args(["user", "--del", "oldguy"])
|
||||||
|
assert args.delete == ["oldguy"]
|
||||||
|
|
||||||
|
# Parse list users
|
||||||
|
args = parser.parse_args(["user", "--list"])
|
||||||
|
assert args.list is True
|
||||||
|
|
||||||
|
# Parse show user
|
||||||
|
args = parser.parse_args(["user", "--show", "someguy"])
|
||||||
|
assert args.show == ["someguy"]
|
||||||
|
|
||||||
|
# Parse regen-password
|
||||||
|
args = parser.parse_args(["user", "--regen-password", "someguy"])
|
||||||
|
assert args.regen_password == ["someguy"]
|
||||||
|
|
||||||
|
# Parse path
|
||||||
|
args = parser.parse_args(["user", "--add", "newguy", "--path", "/some/path"])
|
||||||
|
assert args.add == ["newguy"]
|
||||||
|
assert args.path == ["/some/path"]
|
||||||
|
|
||||||
|
def test_login_logout_parser_arguments(self, app_instance):
|
||||||
|
parser, _ = app_instance.get_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args(["login", "someuser"])
|
||||||
|
assert args.username == "someuser"
|
||||||
|
assert args.func == app_instance._login.dispatch
|
||||||
|
|
||||||
|
args = parser.parse_args(["logout"])
|
||||||
|
assert args.func == app_instance._login.dispatch
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserHandlerDispatch:
|
||||||
|
def test_user_handler_fails_in_remote_mode(self, app_instance):
|
||||||
|
app_instance.services.mode = "remote"
|
||||||
|
handler = UserHandler(app_instance)
|
||||||
|
|
||||||
|
args = MagicMock()
|
||||||
|
args.add = ["testuser"]
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit) as excinfo:
|
||||||
|
handler.dispatch(args)
|
||||||
|
assert excinfo.value.code == 1
|
||||||
|
|
||||||
|
def test_user_handler_routes_add_correctly(self, app_instance):
|
||||||
|
app_instance.services.mode = "local"
|
||||||
|
handler = UserHandler(app_instance)
|
||||||
|
|
||||||
|
args = MagicMock()
|
||||||
|
args.add = ["newuser"]
|
||||||
|
args.delete = None
|
||||||
|
args.list = False
|
||||||
|
args.show = None
|
||||||
|
args.regen_password = None
|
||||||
|
|
||||||
|
with patch.object(handler, "add_user") as mock_add:
|
||||||
|
handler.dispatch(args)
|
||||||
|
assert args.action == "add"
|
||||||
|
assert args.username == "newuser"
|
||||||
|
mock_add.assert_called_once_with(args)
|
||||||
|
|
||||||
|
def test_user_handler_routes_list_correctly(self, app_instance):
|
||||||
|
app_instance.services.mode = "local"
|
||||||
|
handler = UserHandler(app_instance)
|
||||||
|
|
||||||
|
args = MagicMock()
|
||||||
|
args.add = None
|
||||||
|
args.delete = None
|
||||||
|
args.list = True
|
||||||
|
args.show = None
|
||||||
|
args.regen_password = None
|
||||||
|
|
||||||
|
with patch.object(handler, "list_users") as mock_list:
|
||||||
|
handler.dispatch(args)
|
||||||
|
assert args.action == "list"
|
||||||
|
mock_list.assert_called_once_with(args)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthClientInterceptor:
|
||||||
|
def test_auth_client_interceptor_adds_bearer_token(self):
|
||||||
|
# Mock token provider
|
||||||
|
token_provider = MagicMock(return_value="my-super-secret-token")
|
||||||
|
interceptor = AuthClientInterceptor(token_provider)
|
||||||
|
|
||||||
|
# Mock ClientCallDetails using namedtuple
|
||||||
|
from collections import namedtuple
|
||||||
|
ClientCallDetails = namedtuple('ClientCallDetails', ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'])
|
||||||
|
|
||||||
|
mock_details = ClientCallDetails(
|
||||||
|
method="/connpy.NodeService/list_nodes",
|
||||||
|
timeout=10,
|
||||||
|
metadata=[],
|
||||||
|
credentials=None,
|
||||||
|
wait_for_ready=True,
|
||||||
|
compression=None
|
||||||
|
)
|
||||||
|
|
||||||
|
intercepted_details = interceptor._add_metadata(mock_details)
|
||||||
|
|
||||||
|
# Verify metadata was injected
|
||||||
|
metadata_dict = dict(intercepted_details.metadata)
|
||||||
|
assert "authorization" in metadata_dict
|
||||||
|
assert metadata_dict["authorization"] == "Bearer my-super-secret-token"
|
||||||
|
|
||||||
|
def test_auth_client_interceptor_no_token(self):
|
||||||
|
token_provider = MagicMock(return_value=None)
|
||||||
|
interceptor = AuthClientInterceptor(token_provider)
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
ClientCallDetails = namedtuple('ClientCallDetails', ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'])
|
||||||
|
|
||||||
|
mock_details = ClientCallDetails(
|
||||||
|
method="/connpy.NodeService/list_nodes",
|
||||||
|
timeout=10,
|
||||||
|
metadata=[],
|
||||||
|
credentials=None,
|
||||||
|
wait_for_ready=True,
|
||||||
|
compression=None
|
||||||
|
)
|
||||||
|
|
||||||
|
intercepted_details = interceptor._add_metadata(mock_details)
|
||||||
|
|
||||||
|
# Verify metadata remains empty
|
||||||
|
assert len(intercepted_details.metadata) == 0
|
||||||
@@ -141,4 +141,62 @@ class TestTreeCompletions:
|
|||||||
assert "stop" in loop_back_comp
|
assert "stop" in loop_back_comp
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserCompletions:
|
||||||
|
def test_user_command_options(self):
|
||||||
|
from connpy.completion import _build_tree, resolve_completion
|
||||||
|
tree = _build_tree([], [], [], {}, "/tmp")
|
||||||
|
|
||||||
|
# Test options at the "user" level
|
||||||
|
user_completions = resolve_completion(["user", ""], tree)
|
||||||
|
assert "--add" in user_completions
|
||||||
|
assert "--del" in user_completions
|
||||||
|
assert "--rm" in user_completions
|
||||||
|
assert "--show" in user_completions
|
||||||
|
assert "--regen-password" in user_completions
|
||||||
|
assert "--list" in user_completions
|
||||||
|
assert "--ls" in user_completions
|
||||||
|
|
||||||
|
def test_user_action_completed_users(self, tmp_path):
|
||||||
|
from connpy.completion import _build_tree, resolve_completion
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# Create users directory and mock registry
|
||||||
|
users_dir = tmp_path / "users"
|
||||||
|
users_dir.mkdir()
|
||||||
|
registry_file = users_dir / "registry.yaml"
|
||||||
|
|
||||||
|
registry_data = {
|
||||||
|
"users": {
|
||||||
|
"fluzzi": {"password_hash": "hash1"},
|
||||||
|
"john": {"password_hash": "hash2"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with open(registry_file, "w") as f:
|
||||||
|
yaml.dump(registry_data, f)
|
||||||
|
|
||||||
|
tree = _build_tree([], [], [], {}, str(tmp_path))
|
||||||
|
|
||||||
|
# Resolve after --del, --rm, --show, --regen-password
|
||||||
|
for action in ["--del", "--rm", "--show", "--regen-password"]:
|
||||||
|
completions = resolve_completion(["user", action, ""], tree)
|
||||||
|
assert "fluzzi" in completions
|
||||||
|
assert "john" in completions
|
||||||
|
|
||||||
|
# --add username completed options
|
||||||
|
add_completions = resolve_completion(["user", "--add", "newguy", ""], tree)
|
||||||
|
assert "--path" in add_completions
|
||||||
|
|
||||||
|
def test_login_logout_completions(self):
|
||||||
|
from connpy.completion import _build_tree, resolve_completion
|
||||||
|
tree = _build_tree([], [], [], {}, "/tmp")
|
||||||
|
|
||||||
|
# Test login option resolution
|
||||||
|
login_completions = resolve_completion(["login", ""], tree)
|
||||||
|
assert "--help" in login_completions
|
||||||
|
|
||||||
|
# Test logout option resolution
|
||||||
|
logout_completions = resolve_completion(["logout", ""], tree)
|
||||||
|
assert "--help" in logout_completions
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -165,9 +165,9 @@ def test_ai(mock_status, mock_ask, app):
|
|||||||
|
|
||||||
@patch("connpy.services.execution_service.ExecutionService.run_commands")
|
@patch("connpy.services.execution_service.ExecutionService.run_commands")
|
||||||
def test_run(mock_run_commands, app):
|
def test_run(mock_run_commands, app):
|
||||||
app.start(["run", "node1", "command1", "command2"])
|
app.start(["run", "router1", "command1", "command2"])
|
||||||
mock_run_commands.assert_called_once()
|
mock_run_commands.assert_called_once()
|
||||||
assert mock_run_commands.call_args[1]["nodes_filter"] == "node1"
|
assert mock_run_commands.call_args[1]["nodes_filter"] == ["router1"]
|
||||||
assert mock_run_commands.call_args[1]["commands"] == ["command1 command2"]
|
assert mock_run_commands.call_args[1]["commands"] == ["command1 command2"]
|
||||||
|
|
||||||
@patch("os.path.exists")
|
@patch("os.path.exists")
|
||||||
|
|||||||
@@ -0,0 +1,131 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import grpc
|
||||||
|
from concurrent import futures
|
||||||
|
from google.protobuf.empty_pb2 import Empty
|
||||||
|
|
||||||
|
from connpy.grpc_layer import server, connpy_pb2, connpy_pb2_grpc, stubs
|
||||||
|
from connpy.grpc_layer.user_registry import UserRegistry
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
from connpy.configfile import configfile
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_config_dir(tmp_path):
|
||||||
|
"""Creates a temporary config directory for testing gRPC auth."""
|
||||||
|
config_dir = tmp_path / "conn_config"
|
||||||
|
config_dir.mkdir()
|
||||||
|
|
||||||
|
# Initialize basic config file inside it
|
||||||
|
from connpy.configfile import configfile
|
||||||
|
conf_file = os.path.join(str(config_dir), "config.yaml")
|
||||||
|
configfile(conf=conf_file)
|
||||||
|
return config_dir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry(test_config_dir):
|
||||||
|
"""Initializes UserRegistry."""
|
||||||
|
return UserRegistry(str(test_config_dir))
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_grpc_server(test_config_dir, registry):
|
||||||
|
"""Starts an authenticated local gRPC server for integration testing."""
|
||||||
|
srv = grpc.server(
|
||||||
|
futures.ThreadPoolExecutor(max_workers=5),
|
||||||
|
interceptors=[server.AuthInterceptor(registry)]
|
||||||
|
)
|
||||||
|
|
||||||
|
fallback_provider = ServiceProvider(configfile(conf=os.path.join(str(test_config_dir), "config.yaml")), mode="local")
|
||||||
|
|
||||||
|
# Register services
|
||||||
|
connpy_pb2_grpc.add_NodeServiceServicer_to_server(server.NodeServicer(fallback_provider, registry=registry), srv)
|
||||||
|
connpy_pb2_grpc.add_AuthServiceServicer_to_server(server.AuthServicer(registry), srv)
|
||||||
|
|
||||||
|
port = srv.add_insecure_port('127.0.0.1:0')
|
||||||
|
srv.start()
|
||||||
|
yield f"127.0.0.1:{port}"
|
||||||
|
srv.stop(0)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def channel(auth_grpc_server):
|
||||||
|
with grpc.insecure_channel(auth_grpc_server) as channel:
|
||||||
|
yield channel
|
||||||
|
|
||||||
|
|
||||||
|
class TestGRPCAuthentication:
|
||||||
|
def test_backward_compatibility_no_users(self, channel, registry):
|
||||||
|
"""Verifies that if no users are registered, gRPC calls proceed without authentication."""
|
||||||
|
assert registry.has_users() is False
|
||||||
|
|
||||||
|
# Calling NodeService list_nodes should succeed without any authorization metadata
|
||||||
|
stub = connpy_pb2_grpc.NodeServiceStub(channel)
|
||||||
|
req = connpy_pb2.FilterRequest()
|
||||||
|
res = stub.list_nodes(req)
|
||||||
|
assert res is not None
|
||||||
|
|
||||||
|
def test_login_and_authenticated_calls(self, channel, registry):
|
||||||
|
"""Tests user creation, login to retrieve JWT, and using JWT to access protected endpoints."""
|
||||||
|
username = "alice"
|
||||||
|
password = "alicepassword"
|
||||||
|
|
||||||
|
# 1. Register a user in the registry
|
||||||
|
registry.user_service.create_user(username, password)
|
||||||
|
assert registry.has_users() is True
|
||||||
|
|
||||||
|
# 2. Try unauthenticated call - must fail with UNAUTHENTICATED
|
||||||
|
node_stub = connpy_pb2_grpc.NodeServiceStub(channel)
|
||||||
|
req = connpy_pb2.FilterRequest()
|
||||||
|
with pytest.raises(grpc.RpcError) as exc:
|
||||||
|
node_stub.list_nodes(req)
|
||||||
|
assert exc.value.code() == grpc.StatusCode.UNAUTHENTICATED
|
||||||
|
assert "Authorization token is missing" in exc.value.details()
|
||||||
|
|
||||||
|
# 3. Call login endpoint (open method) - must succeed
|
||||||
|
auth_stub = connpy_pb2_grpc.AuthServiceStub(channel)
|
||||||
|
login_req = connpy_pb2.LoginRequest(username=username, password=password)
|
||||||
|
login_res = auth_stub.login(login_req)
|
||||||
|
|
||||||
|
assert login_res.username == username
|
||||||
|
assert isinstance(login_res.token, str)
|
||||||
|
assert login_res.expires_at > 0
|
||||||
|
|
||||||
|
# 4. Make authenticated call using Bearer token - must succeed
|
||||||
|
metadata = [("authorization", f"Bearer {login_res.token}")]
|
||||||
|
res = node_stub.list_nodes(req, metadata=metadata)
|
||||||
|
assert res is not None
|
||||||
|
|
||||||
|
def test_login_invalid_credentials(self, channel, registry):
|
||||||
|
"""Verifies login fails and returns UNAUTHENTICATED for incorrect credentials."""
|
||||||
|
registry.user_service.create_user("bob", "bobpass")
|
||||||
|
|
||||||
|
auth_stub = connpy_pb2_grpc.AuthServiceStub(channel)
|
||||||
|
login_req = connpy_pb2.LoginRequest(username="bob", password="wrongpassword")
|
||||||
|
|
||||||
|
with pytest.raises(grpc.RpcError) as exc:
|
||||||
|
auth_stub.login(login_req)
|
||||||
|
assert exc.value.code() == grpc.StatusCode.UNAUTHENTICATED
|
||||||
|
assert "Invalid username or password" in exc.value.details()
|
||||||
|
|
||||||
|
def test_change_password(self, channel, registry):
|
||||||
|
"""Tests changing password via gRPC and verifying old password no longer works."""
|
||||||
|
username = "charlie"
|
||||||
|
registry.user_service.create_user(username, "oldpass")
|
||||||
|
|
||||||
|
auth_stub = connpy_pb2_grpc.AuthServiceStub(channel)
|
||||||
|
|
||||||
|
# 1. Login with old password to get token
|
||||||
|
login_res = auth_stub.login(connpy_pb2.LoginRequest(username=username, password="oldpass"))
|
||||||
|
token = login_res.token
|
||||||
|
|
||||||
|
# 2. Change password via gRPC using the token
|
||||||
|
metadata = [("authorization", f"Bearer {token}")]
|
||||||
|
change_req = connpy_pb2.ChangePasswordRequest(old_password="oldpass", new_password="newpass")
|
||||||
|
auth_stub.change_password(change_req, metadata=metadata)
|
||||||
|
|
||||||
|
# 3. Logging in with old password must fail
|
||||||
|
with pytest.raises(grpc.RpcError) as exc:
|
||||||
|
auth_stub.login(connpy_pb2.LoginRequest(username=username, password="oldpass"))
|
||||||
|
assert exc.value.code() == grpc.StatusCode.UNAUTHENTICATED
|
||||||
|
|
||||||
|
# 4. Logging in with new password must succeed
|
||||||
|
login_res_new = auth_stub.login(connpy_pb2.LoginRequest(username=username, password="newpass"))
|
||||||
|
assert login_res_new.token is not None
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from connpy.grpc_layer.server import NodeServicer, _current_user
|
||||||
|
from connpy.grpc_layer.user_registry import UserRegistry
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_config_dir(tmp_path):
|
||||||
|
"""Creates a temporary config directory for testing user registry."""
|
||||||
|
config_dir = tmp_path / "conn_config"
|
||||||
|
config_dir.mkdir()
|
||||||
|
return config_dir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry(test_config_dir):
|
||||||
|
"""Initializes UserRegistry pointing to a temporary directory."""
|
||||||
|
return UserRegistry(str(test_config_dir))
|
||||||
|
|
||||||
|
def test_dynamic_routing_isolation(test_config_dir, registry):
|
||||||
|
"""Verifies that NodeServicer routes list_nodes to the correct user configuration based on _current_user ContextVar."""
|
||||||
|
# Setup fallback provider
|
||||||
|
from connpy.configfile import configfile
|
||||||
|
conf_file = os.path.join(registry.user_service.config_dir, "config.yaml")
|
||||||
|
config = configfile(conf=conf_file)
|
||||||
|
fallback_provider = ServiceProvider(config, mode="local")
|
||||||
|
|
||||||
|
# Create servicer with fallback and registry
|
||||||
|
servicer = NodeServicer(fallback_provider, registry=registry)
|
||||||
|
|
||||||
|
# Register two users
|
||||||
|
u1 = "user1"
|
||||||
|
u2 = "user2"
|
||||||
|
registry.user_service.create_user(u1, "pass1")
|
||||||
|
registry.user_service.create_user(u2, "pass2")
|
||||||
|
|
||||||
|
p1 = registry.get_provider(u1)
|
||||||
|
p2 = registry.get_provider(u2)
|
||||||
|
|
||||||
|
# Add nodes to each user's provider
|
||||||
|
p1.nodes.add_node("node-for-user-1", {"host": "1.1.1.1"})
|
||||||
|
p2.nodes.add_node("node-for-user-2", {"host": "2.2.2.2"})
|
||||||
|
|
||||||
|
# Verify fallback is empty
|
||||||
|
fallback_res = servicer.list_nodes(type('Request', (), {'filter_str': None, 'format_str': None})(), None)
|
||||||
|
from connpy.grpc_layer.utils import from_value
|
||||||
|
assert "node-for-user-1" not in from_value(fallback_res.data)
|
||||||
|
assert "node-for-user-2" not in from_value(fallback_res.data)
|
||||||
|
|
||||||
|
# Set context to User 1
|
||||||
|
t1 = _current_user.set(u1)
|
||||||
|
try:
|
||||||
|
res1 = servicer.list_nodes(type('Request', (), {'filter_str': None, 'format_str': None})(), None)
|
||||||
|
nodes1 = from_value(res1.data)
|
||||||
|
assert "node-for-user-1" in nodes1
|
||||||
|
assert "node-for-user-2" not in nodes1
|
||||||
|
finally:
|
||||||
|
_current_user.reset(t1)
|
||||||
|
|
||||||
|
# Set context to User 2
|
||||||
|
t2 = _current_user.set(u2)
|
||||||
|
try:
|
||||||
|
res2 = servicer.list_nodes(type('Request', (), {'filter_str': None, 'format_str': None})(), None)
|
||||||
|
nodes2 = from_value(res2.data)
|
||||||
|
assert "node-for-user-2" in nodes2
|
||||||
|
assert "node-for-user-1" not in nodes2
|
||||||
|
finally:
|
||||||
|
_current_user.reset(t2)
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from connpy.configfile import configfile
|
||||||
|
from connpy.grpc_layer.user_registry import UserRegistry
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_config_dir(tmp_path):
|
||||||
|
"""Creates a temporary config directory for testing."""
|
||||||
|
config_dir = tmp_path / "conn_shared_test"
|
||||||
|
config_dir.mkdir()
|
||||||
|
return config_dir
|
||||||
|
|
||||||
|
def test_shared_ai_deep_merge(temp_config_dir):
|
||||||
|
"""Test get_effective_setting deep merge logic for 'ai' settings."""
|
||||||
|
shared_dir = os.path.join(temp_config_dir, "shared")
|
||||||
|
user_dir = os.path.join(temp_config_dir, "user")
|
||||||
|
os.makedirs(shared_dir, exist_ok=True)
|
||||||
|
os.makedirs(user_dir, exist_ok=True)
|
||||||
|
|
||||||
|
shared_path = os.path.join(shared_dir, "config.yaml")
|
||||||
|
user_path = os.path.join(user_dir, "config.yaml")
|
||||||
|
|
||||||
|
# Write shared configuration
|
||||||
|
shared_data = {
|
||||||
|
"config": {
|
||||||
|
"theme": "dark",
|
||||||
|
"case": False,
|
||||||
|
"ai": {
|
||||||
|
"engineer_model": "shared-eng-model",
|
||||||
|
"architect_model": "shared-arch-model",
|
||||||
|
"engineer_api_key": "shared-key",
|
||||||
|
"mcp_servers": {
|
||||||
|
"global-server": {
|
||||||
|
"url": "http://global-server/sse",
|
||||||
|
"enabled": True
|
||||||
|
},
|
||||||
|
"override-server": {
|
||||||
|
"url": "http://override-shared/sse",
|
||||||
|
"enabled": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"connections": {},
|
||||||
|
"profiles": {}
|
||||||
|
}
|
||||||
|
with open(shared_path, "w") as f:
|
||||||
|
yaml.safe_dump(shared_data, f)
|
||||||
|
|
||||||
|
# Write user configuration with overrides
|
||||||
|
user_data = {
|
||||||
|
"config": {
|
||||||
|
"case": True,
|
||||||
|
"ai": {
|
||||||
|
"engineer_model": "user-custom-eng-model",
|
||||||
|
"mcp_servers": {
|
||||||
|
"override-server": {
|
||||||
|
"enabled": False
|
||||||
|
},
|
||||||
|
"user-server": {
|
||||||
|
"url": "http://user-server/sse",
|
||||||
|
"enabled": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"connections": {},
|
||||||
|
"profiles": {}
|
||||||
|
}
|
||||||
|
with open(user_path, "w") as f:
|
||||||
|
yaml.safe_dump(user_data, f)
|
||||||
|
|
||||||
|
# Initialize configfile instances
|
||||||
|
shared_config = configfile(conf=shared_path)
|
||||||
|
user_config = configfile(conf=user_path, shared_config=shared_config)
|
||||||
|
|
||||||
|
# Verify non-inheritable settings (theme, case)
|
||||||
|
assert user_config.get_effective_setting("case") is True
|
||||||
|
assert user_config.get_effective_setting("theme") is None # Should NOT inherit "theme"
|
||||||
|
|
||||||
|
# Verify AI setting deep merge
|
||||||
|
effective_ai = user_config.get_effective_setting("ai")
|
||||||
|
|
||||||
|
# Model override
|
||||||
|
assert effective_ai.get("engineer_model") == "user-custom-eng-model"
|
||||||
|
# Model inheritance
|
||||||
|
assert effective_ai.get("architect_model") == "shared-arch-model"
|
||||||
|
# API key inheritance
|
||||||
|
assert effective_ai.get("engineer_api_key") == "shared-key"
|
||||||
|
|
||||||
|
# MCP Servers merge
|
||||||
|
mcp = effective_ai.get("mcp_servers", {})
|
||||||
|
# Inherited server
|
||||||
|
assert "global-server" in mcp
|
||||||
|
assert mcp["global-server"]["url"] == "http://global-server/sse"
|
||||||
|
assert mcp["global-server"]["enabled"] is True
|
||||||
|
|
||||||
|
# Merged & overridden server
|
||||||
|
assert "override-server" in mcp
|
||||||
|
assert mcp["override-server"]["url"] == "http://override-shared/sse" # inherited
|
||||||
|
assert mcp["override-server"]["enabled"] is False # overridden
|
||||||
|
|
||||||
|
# User-only server
|
||||||
|
assert "user-server" in mcp
|
||||||
|
assert mcp["user-server"]["url"] == "http://user-server/sse"
|
||||||
|
|
||||||
|
def test_registry_injection_and_hot_reload(temp_config_dir):
|
||||||
|
"""Test that UserRegistry correctly injects shared config and hot-reloads it when it changes on disk."""
|
||||||
|
registry = UserRegistry(str(temp_config_dir))
|
||||||
|
|
||||||
|
# Define paths
|
||||||
|
shared_path = os.path.join(temp_config_dir, "config.yaml")
|
||||||
|
|
||||||
|
# 1. Create a global config file
|
||||||
|
global_data = {
|
||||||
|
"config": {
|
||||||
|
"ai": {
|
||||||
|
"engineer_api_key": "global-initial-key",
|
||||||
|
"engineer_model": "global-model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"connections": {},
|
||||||
|
"profiles": {}
|
||||||
|
}
|
||||||
|
with open(shared_path, "w") as f:
|
||||||
|
yaml.safe_dump(global_data, f)
|
||||||
|
|
||||||
|
# Re-init registry to pick up the newly created shared config file
|
||||||
|
registry = UserRegistry(str(temp_config_dir))
|
||||||
|
|
||||||
|
# Register user
|
||||||
|
username = "testuser"
|
||||||
|
registry.user_service.create_user(username, "testpassword")
|
||||||
|
|
||||||
|
# Check initial injection
|
||||||
|
provider = registry.get_provider(username)
|
||||||
|
ai_settings = provider.config.get_effective_setting("ai")
|
||||||
|
assert ai_settings.get("engineer_api_key") == "global-initial-key"
|
||||||
|
assert ai_settings.get("engineer_model") == "global-model"
|
||||||
|
|
||||||
|
# 2. Modify global config on disk
|
||||||
|
global_data["config"]["ai"]["engineer_api_key"] = "global-updated-key"
|
||||||
|
|
||||||
|
# Sleep briefly to ensure mtime change is detectable
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
with open(shared_path, "w") as f:
|
||||||
|
yaml.safe_dump(global_data, f)
|
||||||
|
|
||||||
|
# Set the mtime forward explicitly to avoid filesystem resolution limits
|
||||||
|
new_mtime = os.path.getmtime(shared_path) + 10.0
|
||||||
|
os.utime(shared_path, (new_mtime, new_mtime))
|
||||||
|
|
||||||
|
# Retrieve provider again - should trigger hot-reload of shared config
|
||||||
|
provider2 = registry.get_provider(username)
|
||||||
|
|
||||||
|
ai_settings_updated = provider2.config.get_effective_setting("ai")
|
||||||
|
assert ai_settings_updated.get("engineer_api_key") == "global-updated-key"
|
||||||
|
assert ai_settings_updated.get("engineer_model") == "global-model"
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from connpy.grpc_layer.user_registry import UserRegistry
|
||||||
|
from connpy.services.provider import ServiceProvider
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_config_dir(tmp_path):
|
||||||
|
"""Creates a temporary config directory for testing user registry."""
|
||||||
|
config_dir = tmp_path / "conn_config"
|
||||||
|
config_dir.mkdir()
|
||||||
|
return config_dir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry(test_config_dir):
|
||||||
|
"""Initializes UserRegistry pointing to a temporary directory."""
|
||||||
|
return UserRegistry(str(test_config_dir))
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserRegistry:
|
||||||
|
def test_has_users_empty(self, registry):
|
||||||
|
"""Verifies has_users is False when no users exist."""
|
||||||
|
assert registry.has_users() is False
|
||||||
|
|
||||||
|
def test_get_provider_returns_service_provider(self, registry):
|
||||||
|
"""Tests that get_provider lazy-loads a valid ServiceProvider instance."""
|
||||||
|
username = "alice"
|
||||||
|
registry.user_service.create_user(username, "password")
|
||||||
|
|
||||||
|
assert registry.has_users() is True
|
||||||
|
|
||||||
|
provider = registry.get_provider(username)
|
||||||
|
assert isinstance(provider, ServiceProvider)
|
||||||
|
assert provider.mode == "local"
|
||||||
|
|
||||||
|
def test_get_provider_cached(self, registry):
|
||||||
|
"""Verifies that subsequent calls return the cached singleton instance."""
|
||||||
|
username = "bob"
|
||||||
|
registry.user_service.create_user(username, "password")
|
||||||
|
|
||||||
|
p1 = registry.get_provider(username)
|
||||||
|
p2 = registry.get_provider(username)
|
||||||
|
|
||||||
|
assert p1 is p2 # must be exact same object reference
|
||||||
|
|
||||||
|
def test_two_users_isolated(self, registry):
|
||||||
|
"""Ensures different users get completely separate ServiceProviders and configs."""
|
||||||
|
u1 = "user1"
|
||||||
|
u2 = "user2"
|
||||||
|
|
||||||
|
registry.user_service.create_user(u1, "pass1")
|
||||||
|
registry.user_service.create_user(u2, "pass2")
|
||||||
|
|
||||||
|
p1 = registry.get_provider(u1)
|
||||||
|
p2 = registry.get_provider(u2)
|
||||||
|
|
||||||
|
assert p1 is not p2
|
||||||
|
assert p1.config is not p2.config
|
||||||
|
|
||||||
|
# Add a node for user1 and verify user2 is unaffected
|
||||||
|
p1.nodes.add_node("node1", {"host": "1.1.1.1"})
|
||||||
|
assert "node1" in p1.nodes.list_nodes()
|
||||||
|
assert "node1" not in p2.nodes.list_nodes()
|
||||||
|
|
||||||
|
def test_evict_clears_cache(self, registry):
|
||||||
|
"""Verifies that eviction deletes the cached provider from memory."""
|
||||||
|
username = "evictuser"
|
||||||
|
registry.user_service.create_user(username, "pass")
|
||||||
|
|
||||||
|
p1 = registry.get_provider(username)
|
||||||
|
assert username in registry._providers
|
||||||
|
|
||||||
|
registry.evict(username)
|
||||||
|
assert username not in registry._providers
|
||||||
|
|
||||||
|
# Calling get_provider again spawns a new instance
|
||||||
|
p2 = registry.get_provider(username)
|
||||||
|
assert p1 is not p2
|
||||||
|
|
||||||
|
def test_provider_hot_reload_on_external_change(self, registry):
|
||||||
|
"""Verifies that UserRegistry hot-reloads the provider if config.yaml is updated externally."""
|
||||||
|
username = "charlie"
|
||||||
|
registry.user_service.create_user(username, "password")
|
||||||
|
|
||||||
|
# Initial load (no nodes)
|
||||||
|
p1 = registry.get_provider(username)
|
||||||
|
assert len(p1.nodes.list_nodes()) == 0
|
||||||
|
|
||||||
|
# Resolve config.yaml file path
|
||||||
|
conf_file = os.path.join(registry.server_config_dir, "users", username, "config.yaml")
|
||||||
|
|
||||||
|
# Modify the config file physically on disk by appending a node
|
||||||
|
from connpy.configfile import configfile
|
||||||
|
cfg = configfile(conf=conf_file)
|
||||||
|
cfg._connections_add(id="testnode", host="8.8.8.8")
|
||||||
|
cfg._saveconfig(cfg.file)
|
||||||
|
|
||||||
|
# Artificially increase mtime to force reload
|
||||||
|
mtime = os.path.getmtime(conf_file)
|
||||||
|
os.utime(conf_file, (mtime + 5.0, mtime + 5.0))
|
||||||
|
|
||||||
|
# Fetch provider again
|
||||||
|
p2 = registry.get_provider(username)
|
||||||
|
|
||||||
|
# Verify it hot-reloaded and the new node is immediately visible
|
||||||
|
assert p1 is not p2
|
||||||
|
assert "testnode" in p2.nodes.list_nodes()
|
||||||
|
|
||||||
|
def test_provider_hot_reload_fails_on_corrupt_file_keeps_old_provider(self, registry):
|
||||||
|
"""Verifies that UserRegistry keeps serving the old provider if disk config is corrupt."""
|
||||||
|
username = "danny"
|
||||||
|
registry.user_service.create_user(username, "password")
|
||||||
|
|
||||||
|
# Initial load
|
||||||
|
p1 = registry.get_provider(username)
|
||||||
|
p1.nodes.add_node("nodeA", {"host": "2.2.2.2"})
|
||||||
|
assert "nodeA" in p1.nodes.list_nodes()
|
||||||
|
|
||||||
|
# Resolve config.yaml path
|
||||||
|
conf_file = os.path.join(registry.server_config_dir, "users", username, "config.yaml")
|
||||||
|
|
||||||
|
# Write corrupted content directly to config.yaml
|
||||||
|
with open(conf_file, "w") as f:
|
||||||
|
f.write("corrupt yaml content ::: invalid syntax :::")
|
||||||
|
|
||||||
|
# Artificially increase mtime to force reload attempt
|
||||||
|
mtime = os.path.getmtime(conf_file)
|
||||||
|
os.utime(conf_file, (mtime + 5.0, mtime + 5.0))
|
||||||
|
|
||||||
|
# Fetching provider again should fallback to old_provider instead of failing completely
|
||||||
|
p2 = registry.get_provider(username)
|
||||||
|
|
||||||
|
# Verify fallback
|
||||||
|
assert p1 is p2
|
||||||
|
assert "nodeA" in p2.nodes.list_nodes()
|
||||||
@@ -0,0 +1,217 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
import datetime
|
||||||
|
import jwt
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from connpy.services.user_service import UserService
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_config_dir(tmp_path):
|
||||||
|
"""Creates a temporary config directory for testing user registry."""
|
||||||
|
config_dir = tmp_path / "conn_config"
|
||||||
|
config_dir.mkdir()
|
||||||
|
return config_dir
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_service(test_config_dir):
|
||||||
|
"""Initializes UserService pointing to a temporary directory."""
|
||||||
|
return UserService(str(test_config_dir))
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserService:
|
||||||
|
def test_no_users(self, user_service):
|
||||||
|
"""Verifies that a new registry is empty by default."""
|
||||||
|
users = user_service.list_users()
|
||||||
|
assert users == []
|
||||||
|
|
||||||
|
def test_create_user_default(self, user_service):
|
||||||
|
"""Tests Mode A: fresh user config and key creation."""
|
||||||
|
username = "testuser"
|
||||||
|
res = user_service.create_user(username, "mypassword")
|
||||||
|
|
||||||
|
assert res["username"] == username
|
||||||
|
assert res["config_path"] is None
|
||||||
|
assert "created" in res
|
||||||
|
|
||||||
|
# Verify folder, config.yaml and .osk key are created
|
||||||
|
user_dir = os.path.join(user_service.users_dir, username)
|
||||||
|
assert os.path.isdir(user_dir)
|
||||||
|
assert os.path.isdir(os.path.join(user_dir, "plugins"))
|
||||||
|
assert os.path.isdir(os.path.join(user_dir, "ai_sessions"))
|
||||||
|
assert os.path.isfile(os.path.join(user_dir, "config.yaml"))
|
||||||
|
assert os.path.isfile(os.path.join(user_dir, ".osk"))
|
||||||
|
|
||||||
|
def test_create_user_custom_path(self, user_service, tmp_path):
|
||||||
|
"""Tests Mode B: using an existing valid config path."""
|
||||||
|
# Setup existing custom config directory
|
||||||
|
custom_dir = tmp_path / "custom_user_conn"
|
||||||
|
custom_dir.mkdir()
|
||||||
|
|
||||||
|
config_file = custom_dir / "config.yaml"
|
||||||
|
# Write basic config.yaml
|
||||||
|
config_data = {
|
||||||
|
"config": {"case": False, "idletime": 30, "fzf": False},
|
||||||
|
"connections": {},
|
||||||
|
"profiles": {}
|
||||||
|
}
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_data, f)
|
||||||
|
|
||||||
|
res = user_service.create_user("fluzzi", "fluzzipass", config_path=str(custom_dir))
|
||||||
|
|
||||||
|
assert res["username"] == "fluzzi"
|
||||||
|
assert res["config_path"] == str(custom_dir)
|
||||||
|
|
||||||
|
# Verify no directory is created under the server's user folder
|
||||||
|
user_dir = os.path.join(user_service.users_dir, "fluzzi")
|
||||||
|
assert not os.path.exists(user_dir)
|
||||||
|
|
||||||
|
def test_create_user_custom_path_auto_init(self, user_service, tmp_path):
|
||||||
|
"""Ensures create_user automatically initializes a missing directory and default config.yaml."""
|
||||||
|
custom_dir = tmp_path / "new_custom_config"
|
||||||
|
|
||||||
|
# Test creation where the directory does not exist yet
|
||||||
|
res = user_service.create_user("john", "pass", config_path=str(custom_dir))
|
||||||
|
assert res["username"] == "john"
|
||||||
|
assert res["config_path"] == str(custom_dir)
|
||||||
|
|
||||||
|
# Verify custom path and subdirs/configs were created
|
||||||
|
assert os.path.isdir(custom_dir)
|
||||||
|
assert os.path.exists(os.path.join(custom_dir, "config.yaml"))
|
||||||
|
assert os.path.isdir(os.path.join(custom_dir, "plugins"))
|
||||||
|
assert os.path.isdir(os.path.join(custom_dir, "ai_sessions"))
|
||||||
|
|
||||||
|
def test_create_duplicate_user(self, user_service):
|
||||||
|
"""Ensures duplicate usernames are rejected."""
|
||||||
|
user_service.create_user("dupuser", "password")
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
user_service.create_user("dupuser", "anotherpass")
|
||||||
|
|
||||||
|
def test_delete_user_default(self, user_service):
|
||||||
|
"""Tests Mode A: deleting a server-managed user cleans up directories."""
|
||||||
|
username = "deluser"
|
||||||
|
user_service.create_user(username, "password")
|
||||||
|
user_dir = os.path.join(user_service.users_dir, username)
|
||||||
|
assert os.path.isdir(user_dir)
|
||||||
|
|
||||||
|
user_service.delete_user(username)
|
||||||
|
# Directory should be cleaned up
|
||||||
|
assert not os.path.exists(user_dir)
|
||||||
|
# Registry should be updated
|
||||||
|
assert len(user_service.list_users()) == 0
|
||||||
|
|
||||||
|
def test_delete_user_custom_path(self, user_service, tmp_path):
|
||||||
|
"""Tests Mode B: deleting a custom-path user leaves files untouched."""
|
||||||
|
custom_dir = tmp_path / "fluzzi_custom"
|
||||||
|
custom_dir.mkdir()
|
||||||
|
config_file = custom_dir / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump({"config": {}, "connections": {}, "profiles": {}}, f)
|
||||||
|
|
||||||
|
username = "fluzzi"
|
||||||
|
user_service.create_user(username, "pass", config_path=str(custom_dir))
|
||||||
|
|
||||||
|
user_service.delete_user(username)
|
||||||
|
# Registry cleared
|
||||||
|
assert len(user_service.list_users()) == 0
|
||||||
|
# Files remain untouched
|
||||||
|
assert os.path.isdir(str(custom_dir))
|
||||||
|
assert os.path.isfile(str(config_file))
|
||||||
|
|
||||||
|
def test_list_users(self, user_service):
|
||||||
|
"""Tests listing all registered users with their metadata."""
|
||||||
|
user_service.create_user("user1", "pass1")
|
||||||
|
user_service.create_user("user2", "pass2")
|
||||||
|
|
||||||
|
users = user_service.list_users()
|
||||||
|
assert len(users) == 2
|
||||||
|
usernames = [u["username"] for u in users]
|
||||||
|
assert "user1" in usernames
|
||||||
|
assert "user2" in usernames
|
||||||
|
|
||||||
|
def test_get_user(self, user_service):
|
||||||
|
"""Tests retrieving a single user's configuration metadata."""
|
||||||
|
user_service.create_user("user1", "pass1")
|
||||||
|
user = user_service.get_user("user1")
|
||||||
|
|
||||||
|
assert user["username"] == "user1"
|
||||||
|
assert user["config_path"] is None
|
||||||
|
assert "created" in user
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
user_service.get_user("nonexistent")
|
||||||
|
|
||||||
|
def test_authenticate_valid(self, user_service):
|
||||||
|
"""Verifies successful authentication."""
|
||||||
|
user_service.create_user("john", "my-secure-password")
|
||||||
|
assert user_service.authenticate("john", "my-secure-password") is True
|
||||||
|
|
||||||
|
def test_authenticate_invalid(self, user_service):
|
||||||
|
"""Verifies unsuccessful authentication on incorrect or missing credentials."""
|
||||||
|
user_service.create_user("john", "my-secure-password")
|
||||||
|
|
||||||
|
assert user_service.authenticate("john", "wrong-password") is False
|
||||||
|
assert user_service.authenticate("nonexistent", "my-secure-password") is False
|
||||||
|
|
||||||
|
def test_jwt_roundtrip(self, user_service):
|
||||||
|
"""Tests generating a JWT token and verifying it back to the username."""
|
||||||
|
username = "jwttester"
|
||||||
|
user_service.create_user(username, "pass")
|
||||||
|
|
||||||
|
token = user_service.generate_jwt(username)
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
verified_user = user_service.verify_jwt(token)
|
||||||
|
assert verified_user == username
|
||||||
|
|
||||||
|
def test_jwt_expired(self, user_service):
|
||||||
|
"""Tests that expired JWT tokens are rejected and return None."""
|
||||||
|
username = "jwttester"
|
||||||
|
user_service.create_user(username, "pass")
|
||||||
|
|
||||||
|
# Manually generate an expired token by setting exp to the past
|
||||||
|
registry = user_service._load_registry()
|
||||||
|
expired_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(seconds=10)
|
||||||
|
payload = {
|
||||||
|
"sub": username,
|
||||||
|
"exp": expired_time
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, registry["jwt_secret"], algorithm="HS256")
|
||||||
|
if isinstance(token, bytes):
|
||||||
|
token = token.decode("utf-8")
|
||||||
|
|
||||||
|
verified_user = user_service.verify_jwt(token)
|
||||||
|
assert verified_user is None
|
||||||
|
|
||||||
|
def test_change_password(self, user_service):
|
||||||
|
"""Tests changing password for a user."""
|
||||||
|
username = "passchanger"
|
||||||
|
user_service.create_user(username, "oldpass")
|
||||||
|
|
||||||
|
# Old credentials authenticate
|
||||||
|
assert user_service.authenticate(username, "oldpass") is True
|
||||||
|
|
||||||
|
# Change password
|
||||||
|
user_service.change_password(username, "oldpass", "newpass")
|
||||||
|
|
||||||
|
# Old password fails, new password works
|
||||||
|
assert user_service.authenticate(username, "oldpass") is False
|
||||||
|
assert user_service.authenticate(username, "newpass") is True
|
||||||
|
|
||||||
|
# Change with invalid old password should fail
|
||||||
|
with pytest.raises(ValueError, match="Invalid credentials"):
|
||||||
|
user_service.change_password(username, "wrongold", "evennewer")
|
||||||
|
|
||||||
|
def test_admin_change_password(self, user_service):
|
||||||
|
"""Tests administrative password change (no old password required)."""
|
||||||
|
username = "adminpasschanger"
|
||||||
|
user_service.create_user(username, "oldpass")
|
||||||
|
|
||||||
|
# Admin changes password directly
|
||||||
|
user_service.admin_change_password(username, "newpass")
|
||||||
|
|
||||||
|
# Verify credentials
|
||||||
|
assert user_service.authenticate(username, "oldpass") is False
|
||||||
|
assert user_service.authenticate(username, "newpass") is True
|
||||||
@@ -20,3 +20,5 @@ httpx>=0.27.0
|
|||||||
requests>=2.31.0
|
requests>=2.31.0
|
||||||
pytest>=8.0.0
|
pytest>=8.0.0
|
||||||
pytest-mock>=3.12.0
|
pytest-mock>=3.12.0
|
||||||
|
bcrypt>=4.1.0
|
||||||
|
PyJWT>=2.8.0
|
||||||
|
|||||||
Reference in New Issue
Block a user