Files
VPNTray/services/vpn_manager.py
2025-09-07 23:33:55 +02:00

508 lines
19 KiB
Python

"""Enhanced VPN management with VPNTray naming and route control."""
import subprocess
import re
import logging
from dataclasses import dataclass
from typing import Optional, List
from enum import Enum
from pathlib import Path
from models import Location
class VPNStatus(Enum):
"""VPN connection status."""
CONNECTED = "connected"
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
DISCONNECTING = "disconnecting"
FAILED = "failed"
UNKNOWN = "unknown"
class VPNConnectionError(Exception):
"""Exception raised for VPN connection errors."""
pass
@dataclass
class VPNConnectionInfo:
"""Information about a VPN connection."""
name: str
uuid: str
vpntray_name: str # Our custom name with vpntray_ prefix
status: VPNStatus
device: Optional[str] = None
routes: List[str] = None # List of routes added
class VPNManager:
"""Enhanced VPN manager with VPNTray naming and route management."""
VPNTRAY_PREFIX = "vpntray_"
VPN_CONFIG_DIR = Path.home() / ".vpntray" / "vpn"
def __init__(self):
"""Initialize VPN manager."""
self.logger = logging.getLogger(__name__)
self._check_nmcli_available()
self._ensure_vpn_config_dir()
def _check_nmcli_available(self) -> None:
"""Check if nmcli is available."""
try:
subprocess.run(['nmcli', '--version'],
capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
raise VPNConnectionError(
"nmcli is not available. Install NetworkManager.")
def _ensure_vpn_config_dir(self) -> None:
"""Ensure VPN config directory exists."""
self.VPN_CONFIG_DIR.mkdir(parents=True, exist_ok=True)
def _run_nmcli(self, args: List[str], check: bool = True, timeout: int = 30) -> subprocess.CompletedProcess:
"""Run nmcli command with logging and timeout."""
command = ['nmcli'] + args
command_str = ' '.join(command)
try:
result = subprocess.run(
command,
capture_output=True,
text=True,
check=check,
timeout=timeout # Add timeout to prevent hanging
)
self.logger.debug(f"Command: {command_str}")
if result.stdout.strip():
self.logger.debug(f"Output: {result.stdout.strip()}")
if result.stderr.strip():
self.logger.warning(f"Stderr: {result.stderr.strip()}")
if result.returncode == 0:
self.logger.debug("Command completed successfully")
else:
self.logger.error(
f"Command exited with code: {result.returncode}")
return result
except subprocess.TimeoutExpired:
self.logger.error(
f"Command timed out after {timeout}s: {command_str}")
raise VPNConnectionError(
f"nmcli command timed out after {timeout} seconds")
except subprocess.CalledProcessError as e:
self.logger.debug(f"Failed command: {command_str}")
if e.stdout and e.stdout.strip():
self.logger.debug(f"Output: {e.stdout.strip()}")
if e.stderr and e.stderr.strip():
self.logger.error(f"Error: {e.stderr.strip()}")
error_details = e.stderr or str(e)
raise VPNConnectionError(
f"nmcli command failed (exit code {e.returncode}): {error_details}")
def _get_vpntray_connection_name(self, config_filename: str) -> str:
"""Generate VPNTray-specific connection name."""
# Remove extension and sanitize
base_name = Path(config_filename).stem
sanitized = re.sub(r'[^a-zA-Z0-9_-]', '_', base_name)
return f"{self.VPNTRAY_PREFIX}{sanitized}"
def get_vpn_config_path(self, filename: str) -> Path:
"""Get full path to VPN config file."""
return self.VPN_CONFIG_DIR / filename
def list_vpntray_connections(self) -> List[VPNConnectionInfo]:
"""List all VPNTray-managed connections."""
connections = []
try:
result = self._run_nmcli(['connection', 'show'])
for line in result.stdout.strip().split('\\n'):
if self.VPNTRAY_PREFIX in line:
parts = line.split()
if len(parts) >= 4:
name = parts[0]
uuid = parts[1]
device = parts[3] if parts[3] != '--' else None
# Get detailed status
status = self._get_connection_status(name)
connections.append(VPNConnectionInfo(
name=name,
uuid=uuid,
vpntray_name=name,
status=status,
device=device
))
except VPNConnectionError:
pass # No connections or nmcli error
return connections
def _get_connection_status(self, connection_name: str) -> VPNStatus:
"""Get the status of a specific connection."""
try:
result = self._run_nmcli(['connection', 'show', connection_name])
# Parse connection state from output
for line in result.stdout.split('\\n'):
if 'GENERAL.STATE:' in line:
state = line.split(':')[1].strip()
if 'activated' in state.lower():
return VPNStatus.CONNECTED
elif 'activating' in state.lower():
return VPNStatus.CONNECTING
elif 'deactivating' in state.lower():
return VPNStatus.DISCONNECTING
else:
return VPNStatus.DISCONNECTED
except VPNConnectionError:
pass
return VPNStatus.UNKNOWN
def import_vpn_config(self, location: Location) -> str:
"""Import VPN configuration for a location with VPNTray naming."""
config_path = self.get_vpn_config_path(location.vpn_config)
if not config_path.exists():
raise VPNConnectionError(f"VPN config not found: {config_path}")
self.logger.info(
f"Config file exists: {config_path} ({config_path.stat().st_size} bytes)")
vpntray_name = self._get_vpntray_connection_name(location.vpn_config)
# Check if already imported
if self._get_connection_by_name(vpntray_name):
self.logger.info(f"Connection already imported: {vpntray_name}")
return vpntray_name
# Import based on VPN type
self.logger.info(
f"Importing {location.vpn_type.value} config: {config_path.name}")
if location.vpn_type.value == "OpenVPN":
return self._import_openvpn(config_path, vpntray_name, location)
elif location.vpn_type.value == "WireGuard":
return self._import_wireguard(config_path, vpntray_name, location)
else:
raise VPNConnectionError(
f"Unsupported VPN type: {location.vpn_type.value}")
def _import_openvpn(self, config_path: Path, vpntray_name: str, location: Location) -> str:
"""Import OpenVPN configuration with route control."""
# Import the config file first (nmcli will auto-generate a name)
import_args = [
'connection', 'import', 'type', 'openvpn',
'file', str(config_path)
]
self.logger.info(f"Running nmcli import: {' '.join(import_args)}")
try:
result = self._run_nmcli(import_args)
# Extract the auto-generated connection name from the output
# nmcli outputs: "Connection 'name' (uuid) successfully added."
import re
match = re.search(r"Connection '([^']+)'", result.stdout)
if not match:
raise VPNConnectionError(
"Failed to parse imported connection name from nmcli output")
auto_generated_name = match.group(1)
self.logger.info(
f"Config imported with auto name: {auto_generated_name}")
# Rename to our VPNTray naming convention
rename_args = [
'connection', 'modify', auto_generated_name,
'connection.id', vpntray_name
]
self.logger.info(f"Renaming to: {vpntray_name}")
self._run_nmcli(rename_args)
self.logger.info(
f"OpenVPN config imported as {vpntray_name}")
except VPNConnectionError as e:
self.logger.error(f"OpenVPN import failed: {e}")
raise
# Configure credentials immediately after import if provided
if location.vpn_credentials:
self._configure_credentials(vpntray_name, location)
# Configure the connection to not route everything by default
self._configure_connection_routes(vpntray_name, location)
return vpntray_name
def _import_wireguard(self, config_path: Path, vpntray_name: str, location: Location) -> str:
"""Import WireGuard configuration with route control."""
# Import the config file first (nmcli will auto-generate a name)
import_args = [
'connection', 'import', 'type', 'wireguard',
'file', str(config_path)
]
self.logger.info(
f"Running nmcli import: {' '.join(import_args)}")
try:
result = self._run_nmcli(import_args)
# Extract the auto-generated connection name from the output
# nmcli outputs: "Connection 'name' (uuid) successfully added."
import re
match = re.search(r"Connection '([^']+)'", result.stdout)
if not match:
raise VPNConnectionError(
"Failed to parse imported connection name from nmcli output")
auto_generated_name = match.group(1)
self.logger.info(
f"Config imported with auto name: {auto_generated_name}")
# Rename to our VPNTray naming convention
rename_args = [
'connection', 'modify', auto_generated_name,
'connection.id', vpntray_name
]
self.logger.info(f"Renaming to: {vpntray_name}")
self._run_nmcli(rename_args)
self.logger.info(
f"WireGuard config imported as {vpntray_name}")
except VPNConnectionError as e:
self.logger.error(f"WireGuard import failed: {e}")
raise
# Configure credentials immediately after import if provided
if location.vpn_credentials:
self._configure_credentials(vpntray_name, location)
# Configure routes
self._configure_connection_routes(vpntray_name, location)
return vpntray_name
def _configure_connection_routes(self, connection_name: str, location: Location) -> None:
"""Configure connection to only route specified network segments."""
try:
# Disable automatic default route
self._run_nmcli([
'connection', 'modify', connection_name,
'ipv4.never-default', 'true'
])
# Add routes for each network segment
routes = []
for segment in location.network_segments:
# Add route for the network segment
routes.append(segment.cidr)
if routes:
routes_str = ','.join(routes)
self._run_nmcli([
'connection', 'modify', connection_name,
'ipv4.routes', routes_str
])
self.logger.info(
f"Configured routes for {connection_name}: {routes_str}")
except VPNConnectionError as e:
self.logger.error(f"Failed to configure routes: {e}")
# Don't fail the import, just log the error
def _get_connection_by_name(self, name: str) -> Optional[VPNConnectionInfo]:
"""Get connection info by name."""
try:
# Check if connection exists (simple and fast)
result = self._run_nmcli(['connection', 'show', name], check=False)
if result.returncode == 0:
# Connection exists, create minimal info object
return VPNConnectionInfo(
name=name,
uuid="unknown",
vpntray_name=name,
status=VPNStatus.UNKNOWN # Status will be checked when needed
)
return None
except VPNConnectionError:
return None
def connect_vpn(self, location: Location) -> bool:
"""Connect to VPN for a location."""
try:
vpntray_name = self._get_vpntray_connection_name(
location.vpn_config)
config_path = self.get_vpn_config_path(location.vpn_config)
self.logger.info(f"VPN config: {config_path}")
self.logger.info(f"Connection name: {vpntray_name}")
# Check if config file exists
if not config_path.exists():
error_msg = f"VPN config file not found: {config_path}"
self.logger.error(error_msg)
return False
# Import if not already imported
existing_conn = self._get_connection_by_name(vpntray_name)
if not existing_conn:
self.logger.info(
"Importing VPN config for first time...")
try:
self.import_vpn_config(location)
self.logger.info(
"VPN config imported successfully")
except Exception as import_error:
error_msg = f"Failed to import VPN config: {import_error}"
self.logger.error(error_msg)
return False
else:
self.logger.info(
f"Using existing connection: {existing_conn.status.value}")
# Connect with simple command - credentials already set during import
self.logger.info("Attempting connection...")
# Simple connection command without credential complications
connect_args = ['connection', 'up', vpntray_name]
self._run_nmcli(connect_args, timeout=60)
self.logger.info(f"Connected to {vpntray_name}")
return True
except VPNConnectionError as e:
self.logger.error(f"VPN connection failed: {e}")
return False
except Exception as e:
self.logger.error(
f"Unexpected error during connection: {e}")
return False
def disconnect_vpn(self, location: Location) -> bool:
"""Disconnect VPN for a location."""
try:
vpntray_name = self._get_vpntray_connection_name(
location.vpn_config)
self.logger.info(f"Disconnecting from {vpntray_name}...")
# Check if connection exists
existing_conn = self._get_connection_by_name(vpntray_name)
if not existing_conn:
self.logger.error(
f"Connection {vpntray_name} not found")
return False
# Disconnect
self._run_nmcli(['connection', 'down', vpntray_name])
self.logger.info(f"Disconnected from {vpntray_name}")
return True
except VPNConnectionError as e:
self.logger.error(f"Failed to disconnect: {e}")
return False
except Exception as e:
self.logger.error(
f"Unexpected error during disconnection: {e}")
return False
def get_connection_status(self, location: Location) -> VPNStatus:
"""Get connection status for a location."""
vpntray_name = self._get_vpntray_connection_name(location.vpn_config)
return self._get_connection_status(vpntray_name)
def remove_vpn_config(self, location: Location) -> bool:
"""Remove VPN connection configuration."""
try:
vpntray_name = self._get_vpntray_connection_name(
location.vpn_config)
# First disconnect if connected
try:
self._run_nmcli(
['connection', 'down', vpntray_name], check=False)
except VPNConnectionError:
pass # Ignore if already disconnected
# Remove the connection
self._run_nmcli(['connection', 'delete', vpntray_name])
self.logger.info(
f"Removed VPN configuration {vpntray_name}")
return True
except VPNConnectionError as e:
self.logger.error(f"Failed to remove config: {e}")
return False
def cleanup_vpntray_connections(self) -> int:
"""Remove all VPNTray-managed connections. Returns count removed."""
connections = self.list_vpntray_connections()
removed_count = 0
for conn in connections:
try:
# Disconnect first
self._run_nmcli(['connection', 'down', conn.name], check=False)
# Remove
self._run_nmcli(['connection', 'delete', conn.name])
removed_count += 1
except VPNConnectionError:
pass # Continue with other connections
if self.logger and removed_count > 0:
self.logger.info(
f"Cleaned up {removed_count} VPNTray connections")
return removed_count
def _configure_credentials(self, connection_name: str, location: Location) -> None:
"""Configure VPN credentials directly in the connection."""
if not location.vpn_credentials:
self.logger.info(
f"No credentials provided for {connection_name}")
return
try:
# Handle dictionary credentials (username/password)
if isinstance(location.vpn_credentials, dict):
username = location.vpn_credentials.get('username')
password = location.vpn_credentials.get('password')
self.logger.info(
f"Setting credentials for {connection_name}...")
# Set username and password with correct nmcli syntax
if username:
self._run_nmcli([
'connection', 'modify', connection_name,
'+vpn.data', f'username={username}'
])
self.logger.info(
f"Username configured for {connection_name}")
if password:
self._run_nmcli([
'connection', 'modify', connection_name,
'+vpn.secrets', f'password={password}'
])
self.logger.info(
f"Password configured for {connection_name}")
if username and password:
self.logger.info(
f"Full credentials configured for {connection_name}")
elif username or password:
self.logger.info(
f"Partial credentials configured for {connection_name}")
except VPNConnectionError as e:
self.logger.error(f"Failed to configure credentials: {e}")
# Don't fail the whole operation for credential issues