"""Enhanced VPN management with VPNTray naming and route control.""" import subprocess import re import logging from dataclasses import dataclass from typing import Optional, List from enum import Enum from pathlib import Path from models import Location class VPNStatus(Enum): """VPN connection status.""" CONNECTED = "connected" DISCONNECTED = "disconnected" CONNECTING = "connecting" DISCONNECTING = "disconnecting" FAILED = "failed" UNKNOWN = "unknown" class VPNConnectionError(Exception): """Exception raised for VPN connection errors.""" pass @dataclass class VPNConnectionInfo: """Information about a VPN connection.""" name: str uuid: str vpntray_name: str # Our custom name with vpntray_ prefix status: VPNStatus device: Optional[str] = None routes: List[str] = None # List of routes added class VPNManager: """Enhanced VPN manager with VPNTray naming and route management.""" VPNTRAY_PREFIX = "vpntray_" VPN_CONFIG_DIR = Path.home() / ".vpntray" / "vpn" def __init__(self): """Initialize VPN manager.""" self.logger = logging.getLogger(__name__) self._check_nmcli_available() self._ensure_vpn_config_dir() def _check_nmcli_available(self) -> None: """Check if nmcli is available.""" try: subprocess.run(['nmcli', '--version'], capture_output=True, check=True) except (subprocess.CalledProcessError, FileNotFoundError): raise VPNConnectionError( "nmcli is not available. Install NetworkManager.") def _ensure_vpn_config_dir(self) -> None: """Ensure VPN config directory exists.""" self.VPN_CONFIG_DIR.mkdir(parents=True, exist_ok=True) def _run_nmcli(self, args: List[str], check: bool = True, timeout: int = 30) -> subprocess.CompletedProcess: """Run nmcli command with logging and timeout.""" command = ['nmcli'] + args command_str = ' '.join(command) try: result = subprocess.run( command, capture_output=True, text=True, check=check, timeout=timeout # Add timeout to prevent hanging ) self.logger.debug(f"Command: {command_str}") if result.stdout.strip(): self.logger.debug(f"Output: {result.stdout.strip()}") if result.stderr.strip(): self.logger.warning(f"Stderr: {result.stderr.strip()}") if result.returncode == 0: self.logger.debug("Command completed successfully") else: self.logger.error( f"Command exited with code: {result.returncode}") return result except subprocess.TimeoutExpired: self.logger.error( f"Command timed out after {timeout}s: {command_str}") raise VPNConnectionError( f"nmcli command timed out after {timeout} seconds") except subprocess.CalledProcessError as e: self.logger.debug(f"Failed command: {command_str}") if e.stdout and e.stdout.strip(): self.logger.debug(f"Output: {e.stdout.strip()}") if e.stderr and e.stderr.strip(): self.logger.error(f"Error: {e.stderr.strip()}") error_details = e.stderr or str(e) raise VPNConnectionError( f"nmcli command failed (exit code {e.returncode}): {error_details}") def _get_vpntray_connection_name(self, config_filename: str) -> str: """Generate VPNTray-specific connection name.""" # Remove extension and sanitize base_name = Path(config_filename).stem sanitized = re.sub(r'[^a-zA-Z0-9_-]', '_', base_name) return f"{self.VPNTRAY_PREFIX}{sanitized}" def get_vpn_config_path(self, filename: str) -> Path: """Get full path to VPN config file.""" return self.VPN_CONFIG_DIR / filename def list_vpntray_connections(self) -> List[VPNConnectionInfo]: """List all VPNTray-managed connections.""" connections = [] try: result = self._run_nmcli(['connection', 'show']) for line in result.stdout.strip().split('\\n'): if self.VPNTRAY_PREFIX in line: parts = line.split() if len(parts) >= 4: name = parts[0] uuid = parts[1] device = parts[3] if parts[3] != '--' else None # Get detailed status status = self._get_connection_status(name) connections.append(VPNConnectionInfo( name=name, uuid=uuid, vpntray_name=name, status=status, device=device )) except VPNConnectionError: pass # No connections or nmcli error return connections def _get_connection_status(self, connection_name: str) -> VPNStatus: """Get the status of a specific connection.""" try: result = self._run_nmcli(['connection', 'show', connection_name]) # Parse connection state from output for line in result.stdout.split('\\n'): if 'GENERAL.STATE:' in line: state = line.split(':')[1].strip() if 'activated' in state.lower(): return VPNStatus.CONNECTED elif 'activating' in state.lower(): return VPNStatus.CONNECTING elif 'deactivating' in state.lower(): return VPNStatus.DISCONNECTING else: return VPNStatus.DISCONNECTED except VPNConnectionError: pass return VPNStatus.UNKNOWN def import_vpn_config(self, location: Location) -> str: """Import VPN configuration for a location with VPNTray naming.""" config_path = self.get_vpn_config_path(location.vpn_config) if not config_path.exists(): raise VPNConnectionError(f"VPN config not found: {config_path}") self.logger.info( f"Config file exists: {config_path} ({config_path.stat().st_size} bytes)") vpntray_name = self._get_vpntray_connection_name(location.vpn_config) # Check if already imported if self._get_connection_by_name(vpntray_name): self.logger.info(f"Connection already imported: {vpntray_name}") return vpntray_name # Import based on VPN type self.logger.info( f"Importing {location.vpn_type.value} config: {config_path.name}") if location.vpn_type.value == "OpenVPN": return self._import_openvpn(config_path, vpntray_name, location) elif location.vpn_type.value == "WireGuard": return self._import_wireguard(config_path, vpntray_name, location) else: raise VPNConnectionError( f"Unsupported VPN type: {location.vpn_type.value}") def _import_openvpn(self, config_path: Path, vpntray_name: str, location: Location) -> str: """Import OpenVPN configuration with route control.""" # Import the config file first (nmcli will auto-generate a name) import_args = [ 'connection', 'import', 'type', 'openvpn', 'file', str(config_path) ] self.logger.info(f"Running nmcli import: {' '.join(import_args)}") try: result = self._run_nmcli(import_args) # Extract the auto-generated connection name from the output # nmcli outputs: "Connection 'name' (uuid) successfully added." import re match = re.search(r"Connection '([^']+)'", result.stdout) if not match: raise VPNConnectionError( "Failed to parse imported connection name from nmcli output") auto_generated_name = match.group(1) self.logger.info( f"Config imported with auto name: {auto_generated_name}") # Rename to our VPNTray naming convention rename_args = [ 'connection', 'modify', auto_generated_name, 'connection.id', vpntray_name ] self.logger.info(f"Renaming to: {vpntray_name}") self._run_nmcli(rename_args) self.logger.info( f"OpenVPN config imported as {vpntray_name}") except VPNConnectionError as e: self.logger.error(f"OpenVPN import failed: {e}") raise # Configure credentials immediately after import if provided if location.vpn_credentials: self._configure_credentials(vpntray_name, location) # Configure the connection to not route everything by default self._configure_connection_routes(vpntray_name, location) return vpntray_name def _import_wireguard(self, config_path: Path, vpntray_name: str, location: Location) -> str: """Import WireGuard configuration with route control.""" # Import the config file first (nmcli will auto-generate a name) import_args = [ 'connection', 'import', 'type', 'wireguard', 'file', str(config_path) ] self.logger.info( f"Running nmcli import: {' '.join(import_args)}") try: result = self._run_nmcli(import_args) # Extract the auto-generated connection name from the output # nmcli outputs: "Connection 'name' (uuid) successfully added." import re match = re.search(r"Connection '([^']+)'", result.stdout) if not match: raise VPNConnectionError( "Failed to parse imported connection name from nmcli output") auto_generated_name = match.group(1) self.logger.info( f"Config imported with auto name: {auto_generated_name}") # Rename to our VPNTray naming convention rename_args = [ 'connection', 'modify', auto_generated_name, 'connection.id', vpntray_name ] self.logger.info(f"Renaming to: {vpntray_name}") self._run_nmcli(rename_args) self.logger.info( f"WireGuard config imported as {vpntray_name}") except VPNConnectionError as e: self.logger.error(f"WireGuard import failed: {e}") raise # Configure credentials immediately after import if provided if location.vpn_credentials: self._configure_credentials(vpntray_name, location) # Configure routes self._configure_connection_routes(vpntray_name, location) return vpntray_name def _configure_connection_routes(self, connection_name: str, location: Location) -> None: """Configure connection to only route specified network segments.""" try: # Disable automatic default route self._run_nmcli([ 'connection', 'modify', connection_name, 'ipv4.never-default', 'true' ]) # Add routes for each network segment routes = [] for segment in location.network_segments: # Add route for the network segment routes.append(segment.cidr) if routes: routes_str = ','.join(routes) self._run_nmcli([ 'connection', 'modify', connection_name, 'ipv4.routes', routes_str ]) self.logger.info( f"Configured routes for {connection_name}: {routes_str}") except VPNConnectionError as e: self.logger.error(f"Failed to configure routes: {e}") # Don't fail the import, just log the error def _get_connection_by_name(self, name: str) -> Optional[VPNConnectionInfo]: """Get connection info by name.""" try: # Check if connection exists (simple and fast) result = self._run_nmcli(['connection', 'show', name], check=False) if result.returncode == 0: # Connection exists, create minimal info object return VPNConnectionInfo( name=name, uuid="unknown", vpntray_name=name, status=VPNStatus.UNKNOWN # Status will be checked when needed ) return None except VPNConnectionError: return None def connect_vpn(self, location: Location) -> bool: """Connect to VPN for a location.""" try: vpntray_name = self._get_vpntray_connection_name( location.vpn_config) config_path = self.get_vpn_config_path(location.vpn_config) self.logger.info(f"VPN config: {config_path}") self.logger.info(f"Connection name: {vpntray_name}") # Check if config file exists if not config_path.exists(): error_msg = f"VPN config file not found: {config_path}" self.logger.error(error_msg) return False # Import if not already imported existing_conn = self._get_connection_by_name(vpntray_name) if not existing_conn: self.logger.info( "Importing VPN config for first time...") try: self.import_vpn_config(location) self.logger.info( "VPN config imported successfully") except Exception as import_error: error_msg = f"Failed to import VPN config: {import_error}" self.logger.error(error_msg) return False else: self.logger.info( f"Using existing connection: {existing_conn.status.value}") # Connect with simple command - credentials already set during import self.logger.info("Attempting connection...") # Simple connection command without credential complications connect_args = ['connection', 'up', vpntray_name] self._run_nmcli(connect_args, timeout=60) self.logger.info(f"Connected to {vpntray_name}") return True except VPNConnectionError as e: self.logger.error(f"VPN connection failed: {e}") return False except Exception as e: self.logger.error( f"Unexpected error during connection: {e}") return False def disconnect_vpn(self, location: Location) -> bool: """Disconnect VPN for a location.""" try: vpntray_name = self._get_vpntray_connection_name( location.vpn_config) self.logger.info(f"Disconnecting from {vpntray_name}...") # Check if connection exists existing_conn = self._get_connection_by_name(vpntray_name) if not existing_conn: self.logger.error( f"Connection {vpntray_name} not found") return False # Disconnect self._run_nmcli(['connection', 'down', vpntray_name]) self.logger.info(f"Disconnected from {vpntray_name}") return True except VPNConnectionError as e: self.logger.error(f"Failed to disconnect: {e}") return False except Exception as e: self.logger.error( f"Unexpected error during disconnection: {e}") return False def get_connection_status(self, location: Location) -> VPNStatus: """Get connection status for a location.""" vpntray_name = self._get_vpntray_connection_name(location.vpn_config) return self._get_connection_status(vpntray_name) def remove_vpn_config(self, location: Location) -> bool: """Remove VPN connection configuration.""" try: vpntray_name = self._get_vpntray_connection_name( location.vpn_config) # First disconnect if connected try: self._run_nmcli( ['connection', 'down', vpntray_name], check=False) except VPNConnectionError: pass # Ignore if already disconnected # Remove the connection self._run_nmcli(['connection', 'delete', vpntray_name]) self.logger.info( f"Removed VPN configuration {vpntray_name}") return True except VPNConnectionError as e: self.logger.error(f"Failed to remove config: {e}") return False def cleanup_vpntray_connections(self) -> int: """Remove all VPNTray-managed connections. Returns count removed.""" connections = self.list_vpntray_connections() removed_count = 0 for conn in connections: try: # Disconnect first self._run_nmcli(['connection', 'down', conn.name], check=False) # Remove self._run_nmcli(['connection', 'delete', conn.name]) removed_count += 1 except VPNConnectionError: pass # Continue with other connections if self.logger and removed_count > 0: self.logger.info( f"Cleaned up {removed_count} VPNTray connections") return removed_count def _configure_credentials(self, connection_name: str, location: Location) -> None: """Configure VPN credentials directly in the connection.""" if not location.vpn_credentials: self.logger.info( f"No credentials provided for {connection_name}") return try: # Handle dictionary credentials (username/password) if isinstance(location.vpn_credentials, dict): username = location.vpn_credentials.get('username') password = location.vpn_credentials.get('password') self.logger.info( f"Setting credentials for {connection_name}...") # Set username and password with correct nmcli syntax if username: self._run_nmcli([ 'connection', 'modify', connection_name, '+vpn.data', f'username={username}' ]) self.logger.info( f"Username configured for {connection_name}") if password: self._run_nmcli([ 'connection', 'modify', connection_name, '+vpn.secrets', f'password={password}' ]) self.logger.info( f"Password configured for {connection_name}") if username and password: self.logger.info( f"Full credentials configured for {connection_name}") elif username or password: self.logger.info( f"Partial credentials configured for {connection_name}") except VPNConnectionError as e: self.logger.error(f"Failed to configure credentials: {e}") # Don't fail the whole operation for credential issues