508 lines
19 KiB
Python
508 lines
19 KiB
Python
"""Enhanced VPN management with VPNTray naming and route control."""
|
|
|
|
import subprocess
|
|
import re
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Optional, List
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from models import Location
|
|
|
|
|
|
class VPNStatus(Enum):
|
|
"""VPN connection status."""
|
|
CONNECTED = "connected"
|
|
DISCONNECTED = "disconnected"
|
|
CONNECTING = "connecting"
|
|
DISCONNECTING = "disconnecting"
|
|
FAILED = "failed"
|
|
UNKNOWN = "unknown"
|
|
|
|
|
|
class VPNConnectionError(Exception):
|
|
"""Exception raised for VPN connection errors."""
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class VPNConnectionInfo:
|
|
"""Information about a VPN connection."""
|
|
name: str
|
|
uuid: str
|
|
vpntray_name: str # Our custom name with vpntray_ prefix
|
|
status: VPNStatus
|
|
device: Optional[str] = None
|
|
routes: List[str] = None # List of routes added
|
|
|
|
|
|
class VPNManager:
|
|
"""Enhanced VPN manager with VPNTray naming and route management."""
|
|
|
|
VPNTRAY_PREFIX = "vpntray_"
|
|
VPN_CONFIG_DIR = Path.home() / ".vpntray" / "vpn"
|
|
|
|
def __init__(self):
|
|
"""Initialize VPN manager."""
|
|
self.logger = logging.getLogger(__name__)
|
|
self._check_nmcli_available()
|
|
self._ensure_vpn_config_dir()
|
|
|
|
def _check_nmcli_available(self) -> None:
|
|
"""Check if nmcli is available."""
|
|
try:
|
|
subprocess.run(['nmcli', '--version'],
|
|
capture_output=True, check=True)
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
raise VPNConnectionError(
|
|
"nmcli is not available. Install NetworkManager.")
|
|
|
|
def _ensure_vpn_config_dir(self) -> None:
|
|
"""Ensure VPN config directory exists."""
|
|
self.VPN_CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
def _run_nmcli(self, args: List[str], check: bool = True, timeout: int = 30) -> subprocess.CompletedProcess:
|
|
"""Run nmcli command with logging and timeout."""
|
|
command = ['nmcli'] + args
|
|
command_str = ' '.join(command)
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
command,
|
|
capture_output=True,
|
|
text=True,
|
|
check=check,
|
|
timeout=timeout # Add timeout to prevent hanging
|
|
)
|
|
|
|
self.logger.debug(f"Command: {command_str}")
|
|
if result.stdout.strip():
|
|
self.logger.debug(f"Output: {result.stdout.strip()}")
|
|
if result.stderr.strip():
|
|
self.logger.warning(f"Stderr: {result.stderr.strip()}")
|
|
if result.returncode == 0:
|
|
self.logger.debug("Command completed successfully")
|
|
else:
|
|
self.logger.error(
|
|
f"Command exited with code: {result.returncode}")
|
|
|
|
return result
|
|
except subprocess.TimeoutExpired:
|
|
self.logger.error(
|
|
f"Command timed out after {timeout}s: {command_str}")
|
|
raise VPNConnectionError(
|
|
f"nmcli command timed out after {timeout} seconds")
|
|
except subprocess.CalledProcessError as e:
|
|
self.logger.debug(f"Failed command: {command_str}")
|
|
if e.stdout and e.stdout.strip():
|
|
self.logger.debug(f"Output: {e.stdout.strip()}")
|
|
if e.stderr and e.stderr.strip():
|
|
self.logger.error(f"Error: {e.stderr.strip()}")
|
|
error_details = e.stderr or str(e)
|
|
raise VPNConnectionError(
|
|
f"nmcli command failed (exit code {e.returncode}): {error_details}")
|
|
|
|
def _get_vpntray_connection_name(self, config_filename: str) -> str:
|
|
"""Generate VPNTray-specific connection name."""
|
|
# Remove extension and sanitize
|
|
base_name = Path(config_filename).stem
|
|
sanitized = re.sub(r'[^a-zA-Z0-9_-]', '_', base_name)
|
|
return f"{self.VPNTRAY_PREFIX}{sanitized}"
|
|
|
|
def get_vpn_config_path(self, filename: str) -> Path:
|
|
"""Get full path to VPN config file."""
|
|
return self.VPN_CONFIG_DIR / filename
|
|
|
|
def list_vpntray_connections(self) -> List[VPNConnectionInfo]:
|
|
"""List all VPNTray-managed connections."""
|
|
connections = []
|
|
|
|
try:
|
|
result = self._run_nmcli(['connection', 'show'])
|
|
for line in result.stdout.strip().split('\\n'):
|
|
if self.VPNTRAY_PREFIX in line:
|
|
parts = line.split()
|
|
if len(parts) >= 4:
|
|
name = parts[0]
|
|
uuid = parts[1]
|
|
device = parts[3] if parts[3] != '--' else None
|
|
|
|
# Get detailed status
|
|
status = self._get_connection_status(name)
|
|
|
|
connections.append(VPNConnectionInfo(
|
|
name=name,
|
|
uuid=uuid,
|
|
vpntray_name=name,
|
|
status=status,
|
|
device=device
|
|
))
|
|
except VPNConnectionError:
|
|
pass # No connections or nmcli error
|
|
|
|
return connections
|
|
|
|
def _get_connection_status(self, connection_name: str) -> VPNStatus:
|
|
"""Get the status of a specific connection."""
|
|
try:
|
|
result = self._run_nmcli(['connection', 'show', connection_name])
|
|
|
|
# Parse connection state from output
|
|
for line in result.stdout.split('\\n'):
|
|
if 'GENERAL.STATE:' in line:
|
|
state = line.split(':')[1].strip()
|
|
if 'activated' in state.lower():
|
|
return VPNStatus.CONNECTED
|
|
elif 'activating' in state.lower():
|
|
return VPNStatus.CONNECTING
|
|
elif 'deactivating' in state.lower():
|
|
return VPNStatus.DISCONNECTING
|
|
else:
|
|
return VPNStatus.DISCONNECTED
|
|
except VPNConnectionError:
|
|
pass
|
|
|
|
return VPNStatus.UNKNOWN
|
|
|
|
def import_vpn_config(self, location: Location) -> str:
|
|
"""Import VPN configuration for a location with VPNTray naming."""
|
|
config_path = self.get_vpn_config_path(location.vpn_config)
|
|
|
|
if not config_path.exists():
|
|
raise VPNConnectionError(f"VPN config not found: {config_path}")
|
|
|
|
self.logger.info(
|
|
f"Config file exists: {config_path} ({config_path.stat().st_size} bytes)")
|
|
|
|
vpntray_name = self._get_vpntray_connection_name(location.vpn_config)
|
|
|
|
# Check if already imported
|
|
if self._get_connection_by_name(vpntray_name):
|
|
self.logger.info(f"Connection already imported: {vpntray_name}")
|
|
return vpntray_name
|
|
|
|
# Import based on VPN type
|
|
self.logger.info(
|
|
f"Importing {location.vpn_type.value} config: {config_path.name}")
|
|
|
|
if location.vpn_type.value == "OpenVPN":
|
|
return self._import_openvpn(config_path, vpntray_name, location)
|
|
elif location.vpn_type.value == "WireGuard":
|
|
return self._import_wireguard(config_path, vpntray_name, location)
|
|
else:
|
|
raise VPNConnectionError(
|
|
f"Unsupported VPN type: {location.vpn_type.value}")
|
|
|
|
def _import_openvpn(self, config_path: Path, vpntray_name: str, location: Location) -> str:
|
|
"""Import OpenVPN configuration with route control."""
|
|
# Import the config file first (nmcli will auto-generate a name)
|
|
import_args = [
|
|
'connection', 'import', 'type', 'openvpn',
|
|
'file', str(config_path)
|
|
]
|
|
self.logger.info(f"Running nmcli import: {' '.join(import_args)}")
|
|
|
|
try:
|
|
result = self._run_nmcli(import_args)
|
|
|
|
# Extract the auto-generated connection name from the output
|
|
# nmcli outputs: "Connection 'name' (uuid) successfully added."
|
|
import re
|
|
match = re.search(r"Connection '([^']+)'", result.stdout)
|
|
if not match:
|
|
raise VPNConnectionError(
|
|
"Failed to parse imported connection name from nmcli output")
|
|
|
|
auto_generated_name = match.group(1)
|
|
self.logger.info(
|
|
f"Config imported with auto name: {auto_generated_name}")
|
|
|
|
# Rename to our VPNTray naming convention
|
|
rename_args = [
|
|
'connection', 'modify', auto_generated_name,
|
|
'connection.id', vpntray_name
|
|
]
|
|
self.logger.info(f"Renaming to: {vpntray_name}")
|
|
|
|
self._run_nmcli(rename_args)
|
|
self.logger.info(
|
|
f"OpenVPN config imported as {vpntray_name}")
|
|
|
|
except VPNConnectionError as e:
|
|
self.logger.error(f"OpenVPN import failed: {e}")
|
|
raise
|
|
|
|
# Configure credentials immediately after import if provided
|
|
if location.vpn_credentials:
|
|
self._configure_credentials(vpntray_name, location)
|
|
|
|
# Configure the connection to not route everything by default
|
|
self._configure_connection_routes(vpntray_name, location)
|
|
|
|
return vpntray_name
|
|
|
|
def _import_wireguard(self, config_path: Path, vpntray_name: str, location: Location) -> str:
|
|
"""Import WireGuard configuration with route control."""
|
|
# Import the config file first (nmcli will auto-generate a name)
|
|
import_args = [
|
|
'connection', 'import', 'type', 'wireguard',
|
|
'file', str(config_path)
|
|
]
|
|
self.logger.info(
|
|
f"Running nmcli import: {' '.join(import_args)}")
|
|
|
|
try:
|
|
result = self._run_nmcli(import_args)
|
|
|
|
# Extract the auto-generated connection name from the output
|
|
# nmcli outputs: "Connection 'name' (uuid) successfully added."
|
|
import re
|
|
match = re.search(r"Connection '([^']+)'", result.stdout)
|
|
if not match:
|
|
raise VPNConnectionError(
|
|
"Failed to parse imported connection name from nmcli output")
|
|
|
|
auto_generated_name = match.group(1)
|
|
self.logger.info(
|
|
f"Config imported with auto name: {auto_generated_name}")
|
|
|
|
# Rename to our VPNTray naming convention
|
|
rename_args = [
|
|
'connection', 'modify', auto_generated_name,
|
|
'connection.id', vpntray_name
|
|
]
|
|
self.logger.info(f"Renaming to: {vpntray_name}")
|
|
|
|
self._run_nmcli(rename_args)
|
|
self.logger.info(
|
|
f"WireGuard config imported as {vpntray_name}")
|
|
|
|
except VPNConnectionError as e:
|
|
self.logger.error(f"WireGuard import failed: {e}")
|
|
raise
|
|
|
|
# Configure credentials immediately after import if provided
|
|
if location.vpn_credentials:
|
|
self._configure_credentials(vpntray_name, location)
|
|
|
|
# Configure routes
|
|
self._configure_connection_routes(vpntray_name, location)
|
|
|
|
return vpntray_name
|
|
|
|
def _configure_connection_routes(self, connection_name: str, location: Location) -> None:
|
|
"""Configure connection to only route specified network segments."""
|
|
try:
|
|
# Disable automatic default route
|
|
self._run_nmcli([
|
|
'connection', 'modify', connection_name,
|
|
'ipv4.never-default', 'true'
|
|
])
|
|
|
|
# Add routes for each network segment
|
|
routes = []
|
|
for segment in location.network_segments:
|
|
# Add route for the network segment
|
|
routes.append(segment.cidr)
|
|
|
|
if routes:
|
|
routes_str = ','.join(routes)
|
|
self._run_nmcli([
|
|
'connection', 'modify', connection_name,
|
|
'ipv4.routes', routes_str
|
|
])
|
|
self.logger.info(
|
|
f"Configured routes for {connection_name}: {routes_str}")
|
|
|
|
except VPNConnectionError as e:
|
|
self.logger.error(f"Failed to configure routes: {e}")
|
|
# Don't fail the import, just log the error
|
|
|
|
def _get_connection_by_name(self, name: str) -> Optional[VPNConnectionInfo]:
|
|
"""Get connection info by name."""
|
|
try:
|
|
# Check if connection exists (simple and fast)
|
|
result = self._run_nmcli(['connection', 'show', name], check=False)
|
|
if result.returncode == 0:
|
|
# Connection exists, create minimal info object
|
|
return VPNConnectionInfo(
|
|
name=name,
|
|
uuid="unknown",
|
|
vpntray_name=name,
|
|
status=VPNStatus.UNKNOWN # Status will be checked when needed
|
|
)
|
|
return None
|
|
except VPNConnectionError:
|
|
return None
|
|
|
|
def connect_vpn(self, location: Location) -> bool:
|
|
"""Connect to VPN for a location."""
|
|
try:
|
|
vpntray_name = self._get_vpntray_connection_name(
|
|
location.vpn_config)
|
|
config_path = self.get_vpn_config_path(location.vpn_config)
|
|
self.logger.info(f"VPN config: {config_path}")
|
|
self.logger.info(f"Connection name: {vpntray_name}")
|
|
|
|
# Check if config file exists
|
|
if not config_path.exists():
|
|
error_msg = f"VPN config file not found: {config_path}"
|
|
self.logger.error(error_msg)
|
|
return False
|
|
|
|
# Import if not already imported
|
|
existing_conn = self._get_connection_by_name(vpntray_name)
|
|
if not existing_conn:
|
|
self.logger.info(
|
|
"Importing VPN config for first time...")
|
|
try:
|
|
self.import_vpn_config(location)
|
|
self.logger.info(
|
|
"VPN config imported successfully")
|
|
except Exception as import_error:
|
|
error_msg = f"Failed to import VPN config: {import_error}"
|
|
self.logger.error(error_msg)
|
|
return False
|
|
else:
|
|
self.logger.info(
|
|
f"Using existing connection: {existing_conn.status.value}")
|
|
|
|
# Connect with simple command - credentials already set during import
|
|
self.logger.info("Attempting connection...")
|
|
|
|
# Simple connection command without credential complications
|
|
connect_args = ['connection', 'up', vpntray_name]
|
|
self._run_nmcli(connect_args, timeout=60)
|
|
self.logger.info(f"Connected to {vpntray_name}")
|
|
|
|
return True
|
|
|
|
except VPNConnectionError as e:
|
|
self.logger.error(f"VPN connection failed: {e}")
|
|
return False
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Unexpected error during connection: {e}")
|
|
return False
|
|
|
|
def disconnect_vpn(self, location: Location) -> bool:
|
|
"""Disconnect VPN for a location."""
|
|
try:
|
|
vpntray_name = self._get_vpntray_connection_name(
|
|
location.vpn_config)
|
|
self.logger.info(f"Disconnecting from {vpntray_name}...")
|
|
|
|
# Check if connection exists
|
|
existing_conn = self._get_connection_by_name(vpntray_name)
|
|
if not existing_conn:
|
|
self.logger.error(
|
|
f"Connection {vpntray_name} not found")
|
|
return False
|
|
|
|
# Disconnect
|
|
self._run_nmcli(['connection', 'down', vpntray_name])
|
|
self.logger.info(f"Disconnected from {vpntray_name}")
|
|
|
|
return True
|
|
|
|
except VPNConnectionError as e:
|
|
self.logger.error(f"Failed to disconnect: {e}")
|
|
return False
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Unexpected error during disconnection: {e}")
|
|
return False
|
|
|
|
def get_connection_status(self, location: Location) -> VPNStatus:
|
|
"""Get connection status for a location."""
|
|
vpntray_name = self._get_vpntray_connection_name(location.vpn_config)
|
|
return self._get_connection_status(vpntray_name)
|
|
|
|
def remove_vpn_config(self, location: Location) -> bool:
|
|
"""Remove VPN connection configuration."""
|
|
try:
|
|
vpntray_name = self._get_vpntray_connection_name(
|
|
location.vpn_config)
|
|
|
|
# First disconnect if connected
|
|
try:
|
|
self._run_nmcli(
|
|
['connection', 'down', vpntray_name], check=False)
|
|
except VPNConnectionError:
|
|
pass # Ignore if already disconnected
|
|
|
|
# Remove the connection
|
|
self._run_nmcli(['connection', 'delete', vpntray_name])
|
|
self.logger.info(
|
|
f"Removed VPN configuration {vpntray_name}")
|
|
|
|
return True
|
|
|
|
except VPNConnectionError as e:
|
|
self.logger.error(f"Failed to remove config: {e}")
|
|
return False
|
|
|
|
def cleanup_vpntray_connections(self) -> int:
|
|
"""Remove all VPNTray-managed connections. Returns count removed."""
|
|
connections = self.list_vpntray_connections()
|
|
removed_count = 0
|
|
|
|
for conn in connections:
|
|
try:
|
|
# Disconnect first
|
|
self._run_nmcli(['connection', 'down', conn.name], check=False)
|
|
# Remove
|
|
self._run_nmcli(['connection', 'delete', conn.name])
|
|
removed_count += 1
|
|
except VPNConnectionError:
|
|
pass # Continue with other connections
|
|
|
|
if self.logger and removed_count > 0:
|
|
self.logger.info(
|
|
f"Cleaned up {removed_count} VPNTray connections")
|
|
|
|
return removed_count
|
|
|
|
def _configure_credentials(self, connection_name: str, location: Location) -> None:
|
|
"""Configure VPN credentials directly in the connection."""
|
|
if not location.vpn_credentials:
|
|
self.logger.info(
|
|
f"No credentials provided for {connection_name}")
|
|
return
|
|
|
|
try:
|
|
# Handle dictionary credentials (username/password)
|
|
if isinstance(location.vpn_credentials, dict):
|
|
username = location.vpn_credentials.get('username')
|
|
password = location.vpn_credentials.get('password')
|
|
self.logger.info(
|
|
f"Setting credentials for {connection_name}...")
|
|
|
|
# Set username and password with correct nmcli syntax
|
|
if username:
|
|
self._run_nmcli([
|
|
'connection', 'modify', connection_name,
|
|
'+vpn.data', f'username={username}'
|
|
])
|
|
self.logger.info(
|
|
f"Username configured for {connection_name}")
|
|
|
|
if password:
|
|
self._run_nmcli([
|
|
'connection', 'modify', connection_name,
|
|
'+vpn.secrets', f'password={password}'
|
|
])
|
|
self.logger.info(
|
|
f"Password configured for {connection_name}")
|
|
|
|
if username and password:
|
|
self.logger.info(
|
|
f"Full credentials configured for {connection_name}")
|
|
elif username or password:
|
|
self.logger.info(
|
|
f"Partial credentials configured for {connection_name}")
|
|
|
|
except VPNConnectionError as e:
|
|
self.logger.error(f"Failed to configure credentials: {e}")
|
|
# Don't fail the whole operation for credential issues
|