Files
VPNTray/models.py
2025-09-07 23:33:55 +02:00

346 lines
12 KiB
Python

from dataclasses import dataclass, field
from typing import List, Optional
from enum import Enum
class ServiceType(Enum):
"""Enum for different types of services that can run on hosts."""
SSH = "SSH"
WEB_GUI = "Web GUI"
RDP = "RDP"
VNC = "VNC"
SMB = "SMB"
DATABASE = "Database"
FTP = "FTP"
@dataclass
class Service:
"""Represents a service on a host."""
name: str
service_type: ServiceType
port: int
class HostType(Enum):
"""Enum for different types of hosts."""
LINUX = "Linux"
WINDOWS = "Windows"
WINDOWS_SERVER = "Windows Server"
PROXMOX = "Proxmox"
ESXI = "ESXi"
ROUTER = "Router"
SWITCH = "Switch"
class VPNType(Enum):
"""Enum for different VPN types."""
OPENVPN = "OpenVPN"
WIREGUARD = "WireGuard"
IPSEC = "IPSec"
@dataclass
class NetworkSegment:
"""Represents a network segment with metadata."""
name: str # "LAN", "DMZ", "Management"
cidr: str # "192.168.1.0/24"
vlan_id: Optional[int] = None # VLAN 100
zone: str = "general" # "production", "dmz", "management", "guest"
gateway: Optional[str] = None # "192.168.1.1"
description: str = "" # "Main office network"
@dataclass
class PortForwarding:
"""Represents a port forwarding rule for external access."""
external_port: int # Port on external address (e.g., 8080)
# Target internal IP (e.g., "192.168.1.10")
internal_ip: str
internal_port: int # Target internal port (e.g., 80)
protocol: str = "tcp" # "tcp", "udp", or "both"
description: str = "" # "Web server access"
enabled: bool = True # Whether the forwarding is active
@dataclass
class HostIP:
"""IP address with network segment context."""
ip_address: str
network_segment: str # References NetworkSegment.name
is_primary: bool = False # Primary interface for this host
@dataclass
class Host:
"""Represents a physical or virtual host at a location."""
name: str
ip_addresses: List[HostIP] = field(default_factory=list)
host_type: HostType = HostType.LINUX
# Icon name without extension (e.g., 'ubuntu', 'windows')
icon: Optional[str] = None
description: str = ""
services: List[Service] = field(default_factory=list)
sub_hosts: List['Host'] = field(
default_factory=list) # For VMs under hypervisors
def get_service_by_name(self, service_name: str) -> Optional[Service]:
"""Get a service by its name."""
for service in self.services:
if service.name == service_name:
return service
return None
def is_hypervisor(self) -> bool:
"""Check if this host has sub-hosts (VMs)."""
return len(self.sub_hosts) > 0
def get_primary_ip(self) -> str:
"""Get the primary IP address, or first IP if no primary set."""
if not self.ip_addresses:
return ""
# Look for explicitly marked primary
for host_ip in self.ip_addresses:
if host_ip.is_primary:
return host_ip.ip_address
# Fall back to first IP
return self.ip_addresses[0].ip_address
def get_ip_display(self) -> str:
"""Get a display string for IP addresses."""
if not self.ip_addresses:
return "No IP"
elif len(self.ip_addresses) == 1:
return self.ip_addresses[0].ip_address
else:
primary_ip = self.get_primary_ip()
return f"{primary_ip} (+{len(self.ip_addresses)-1} more)"
def get_all_ips(self) -> List[str]:
"""Get all IP addresses as a simple list."""
return [host_ip.ip_address for host_ip in self.ip_addresses]
def get_ips_in_segment(self, segment_name: str) -> List[str]:
"""Get all IP addresses in a specific network segment."""
return [host_ip.ip_address for host_ip in self.ip_addresses
if host_ip.network_segment == segment_name]
@dataclass
class Location:
"""Represents a customer location."""
name: str
vpn_type: VPNType
connected: bool = False
active: bool = False
vpn_config: str = "" # Path to VPN config or connection details
hosts: List[Host] = field(default_factory=list)
# Enhanced network configuration
network_segments: List[NetworkSegment] = field(
default_factory=list) # Network segments with rich metadata
external_addresses: List[str] = field(
default_factory=list) # External VPN endpoints
port_forwardings: List[PortForwarding] = field(
default_factory=list) # Port forwarding rules
# Legacy field for backward compatibility (will be deprecated)
# Simple network list (legacy)
networks: List[str] = field(default_factory=list)
# VPN connection management fields
# NetworkManager connection name
nmcli_connection_name: Optional[str] = None
auto_import: bool = True # Auto-import .ovpn file if not in NetworkManager
# Credential storage - can be:
# - Passbolt UUID string (for future use)
# - Dict with 'username' and 'password' keys
# - None if no credentials needed
vpn_credentials: Optional[dict | str] = None
def get_host_by_name(self, host_name: str) -> Optional[Host]:
"""Get a host by its name (searches recursively in sub-hosts)."""
def search_hosts(hosts_list: List[Host]) -> Optional[Host]:
for host in hosts_list:
if host.name == host_name:
return host
# Search in sub-hosts
sub_result = search_hosts(host.sub_hosts)
if sub_result:
return sub_result
return None
return search_hosts(self.hosts)
def get_all_hosts_flat(self) -> List[Host]:
"""Get all hosts including sub-hosts in a flat list."""
def collect_hosts(hosts_list: List[Host]) -> List[Host]:
result = []
for host in hosts_list:
result.append(host)
result.extend(collect_hosts(host.sub_hosts))
return result
return collect_hosts(self.hosts)
def get_hypervisors(self) -> List[Host]:
"""Get all hosts that have sub-hosts (hypervisors)."""
return [host for host in self.get_all_hosts_flat() if host.is_hypervisor()]
def get_segment_by_name(self, segment_name: str) -> Optional[NetworkSegment]:
"""Get a network segment by its name."""
return next((seg for seg in self.network_segments if seg.name == segment_name), None)
def get_hosts_in_segment(self, segment_name: str) -> List[Host]:
"""Get all hosts that have IPs in the specified network segment."""
hosts = []
for host in self.get_all_hosts_flat():
if any(host_ip.network_segment == segment_name for host_ip in host.ip_addresses):
hosts.append(host)
return hosts
def get_segments_by_zone(self, zone: str) -> List[NetworkSegment]:
"""Get all network segments in a specific zone."""
return [seg for seg in self.network_segments if seg.zone == zone]
def get_port_forwardings_for_host(self, host_ip: str) -> List[PortForwarding]:
"""Get all port forwardings targeting a specific host IP."""
return [pf for pf in self.port_forwardings if pf.internal_ip == host_ip and pf.enabled]
def get_externally_accessible_services(self) -> List[tuple]:
"""Get all services accessible from external addresses via port forwarding.
Returns list of tuples: (external_address, external_port, host, service, port_forwarding)
"""
accessible_services = []
for external_addr in self.external_addresses:
for port_forward in self.port_forwardings:
if not port_forward.enabled:
continue
# Find the host that owns the target IP
target_host = None
target_service = None
for host in self.get_all_hosts_flat():
host_ips = [hip.ip_address for hip in host.ip_addresses]
if port_forward.internal_ip in host_ips:
target_host = host
# Find matching service on this host
for service in host.services:
if service.port == port_forward.internal_port:
target_service = service
break
break
if target_host:
accessible_services.append((
external_addr,
port_forward.external_port,
target_host,
target_service, # May be None if no matching service defined
port_forward
))
return accessible_services
def is_service_externally_accessible(self, host_ip: str, service_port: int) -> bool:
"""Check if a specific service is accessible from external addresses."""
for pf in self.port_forwardings:
if (pf.enabled and
pf.internal_ip == host_ip and
pf.internal_port == service_port):
return True
return False
def is_service_reachable(self, host: 'Host', service: Service) -> bool:
"""Check if a service is reachable (either via VPN connection or port forwarding).
Returns True if:
- VPN is connected (all internal services become reachable)
- Service has a port forwarding rule enabled
"""
# If VPN is connected, all services are reachable
if self.connected:
return True
# Check if service is externally accessible via port forwarding
for host_ip in host.ip_addresses:
if self.is_service_externally_accessible(host_ip.ip_address, service.port):
return True
return False
def get_external_url_for_service(self, host: 'Host', service: Service) -> Optional[str]:
"""Get the external URL for a service if it has port forwarding.
Returns the external URL (e.g., "https://vpn.example.com:8006") or None.
"""
for host_ip in host.ip_addresses:
for pf in self.port_forwardings:
if (pf.enabled and
pf.internal_ip == host_ip.ip_address and
pf.internal_port == service.port):
# Use first external address if available
if self.external_addresses:
protocol = "https" if service.port in [
443, 8006, 8080] else "http"
return f"{protocol}://{self.external_addresses[0]}:{pf.external_port}"
return None
@dataclass
class CustomerService:
"""Represents a customer's cloud/web service."""
name: str
url: str
service_type: str # e.g., "Email", "Phone System", "CRM", "ERP"
description: str = ""
@dataclass
class Customer:
"""Represents a customer with their services and locations."""
name: str
# Customer's cloud/web services (available regardless of location)
services: List[CustomerService] = field(default_factory=list)
# Customer's locations with their infrastructure
locations: List[Location] = field(default_factory=list)
def get_location_by_name(self, location_name: str) -> Optional[Location]:
"""Get a location by its name."""
for location in self.locations:
if location.name == location_name:
return location
return None
def get_active_locations(self) -> List[Location]:
"""Get all active locations for this customer."""
return [loc for loc in self.locations if loc.active]
def get_inactive_locations(self) -> List[Location]:
"""Get all inactive locations for this customer."""
return [loc for loc in self.locations if not loc.active]
def has_active_locations(self) -> bool:
"""Check if customer has any active locations."""
return any(loc.active for loc in self.locations)
def has_connected_locations(self) -> bool:
"""Check if customer has any connected locations."""
return any(loc.connected for loc in self.locations)
def get_all_hosts_flat(self) -> List[Host]:
"""Get all hosts from all locations in a flat list."""
all_hosts = []
for location in self.locations:
all_hosts.extend(location.get_all_hosts_flat())
return all_hosts