346 lines
12 KiB
Python
346 lines
12 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import List, Optional
|
|
from enum import Enum
|
|
|
|
|
|
class ServiceType(Enum):
|
|
"""Enum for different types of services that can run on hosts."""
|
|
SSH = "SSH"
|
|
WEB_GUI = "Web GUI"
|
|
RDP = "RDP"
|
|
VNC = "VNC"
|
|
SMB = "SMB"
|
|
DATABASE = "Database"
|
|
FTP = "FTP"
|
|
|
|
|
|
@dataclass
|
|
class Service:
|
|
"""Represents a service on a host."""
|
|
name: str
|
|
service_type: ServiceType
|
|
port: int
|
|
|
|
|
|
class HostType(Enum):
|
|
"""Enum for different types of hosts."""
|
|
LINUX = "Linux"
|
|
WINDOWS = "Windows"
|
|
WINDOWS_SERVER = "Windows Server"
|
|
PROXMOX = "Proxmox"
|
|
ESXI = "ESXi"
|
|
ROUTER = "Router"
|
|
SWITCH = "Switch"
|
|
|
|
|
|
class VPNType(Enum):
|
|
"""Enum for different VPN types."""
|
|
OPENVPN = "OpenVPN"
|
|
WIREGUARD = "WireGuard"
|
|
IPSEC = "IPSec"
|
|
|
|
|
|
@dataclass
|
|
class NetworkSegment:
|
|
"""Represents a network segment with metadata."""
|
|
name: str # "LAN", "DMZ", "Management"
|
|
cidr: str # "192.168.1.0/24"
|
|
vlan_id: Optional[int] = None # VLAN 100
|
|
zone: str = "general" # "production", "dmz", "management", "guest"
|
|
gateway: Optional[str] = None # "192.168.1.1"
|
|
description: str = "" # "Main office network"
|
|
|
|
|
|
@dataclass
|
|
class PortForwarding:
|
|
"""Represents a port forwarding rule for external access."""
|
|
external_port: int # Port on external address (e.g., 8080)
|
|
# Target internal IP (e.g., "192.168.1.10")
|
|
internal_ip: str
|
|
internal_port: int # Target internal port (e.g., 80)
|
|
protocol: str = "tcp" # "tcp", "udp", or "both"
|
|
description: str = "" # "Web server access"
|
|
enabled: bool = True # Whether the forwarding is active
|
|
|
|
|
|
@dataclass
|
|
class HostIP:
|
|
"""IP address with network segment context."""
|
|
ip_address: str
|
|
network_segment: str # References NetworkSegment.name
|
|
is_primary: bool = False # Primary interface for this host
|
|
|
|
|
|
@dataclass
|
|
class Host:
|
|
"""Represents a physical or virtual host at a location."""
|
|
name: str
|
|
ip_addresses: List[HostIP] = field(default_factory=list)
|
|
host_type: HostType = HostType.LINUX
|
|
# Icon name without extension (e.g., 'ubuntu', 'windows')
|
|
icon: Optional[str] = None
|
|
description: str = ""
|
|
services: List[Service] = field(default_factory=list)
|
|
sub_hosts: List['Host'] = field(
|
|
default_factory=list) # For VMs under hypervisors
|
|
|
|
def get_service_by_name(self, service_name: str) -> Optional[Service]:
|
|
"""Get a service by its name."""
|
|
for service in self.services:
|
|
if service.name == service_name:
|
|
return service
|
|
return None
|
|
|
|
def is_hypervisor(self) -> bool:
|
|
"""Check if this host has sub-hosts (VMs)."""
|
|
return len(self.sub_hosts) > 0
|
|
|
|
def get_primary_ip(self) -> str:
|
|
"""Get the primary IP address, or first IP if no primary set."""
|
|
if not self.ip_addresses:
|
|
return ""
|
|
|
|
# Look for explicitly marked primary
|
|
for host_ip in self.ip_addresses:
|
|
if host_ip.is_primary:
|
|
return host_ip.ip_address
|
|
|
|
# Fall back to first IP
|
|
return self.ip_addresses[0].ip_address
|
|
|
|
def get_ip_display(self) -> str:
|
|
"""Get a display string for IP addresses."""
|
|
if not self.ip_addresses:
|
|
return "No IP"
|
|
elif len(self.ip_addresses) == 1:
|
|
return self.ip_addresses[0].ip_address
|
|
else:
|
|
primary_ip = self.get_primary_ip()
|
|
return f"{primary_ip} (+{len(self.ip_addresses)-1} more)"
|
|
|
|
def get_all_ips(self) -> List[str]:
|
|
"""Get all IP addresses as a simple list."""
|
|
return [host_ip.ip_address for host_ip in self.ip_addresses]
|
|
|
|
def get_ips_in_segment(self, segment_name: str) -> List[str]:
|
|
"""Get all IP addresses in a specific network segment."""
|
|
return [host_ip.ip_address for host_ip in self.ip_addresses
|
|
if host_ip.network_segment == segment_name]
|
|
|
|
|
|
@dataclass
|
|
class Location:
|
|
"""Represents a customer location."""
|
|
name: str
|
|
vpn_type: VPNType
|
|
connected: bool = False
|
|
active: bool = False
|
|
vpn_config: str = "" # Path to VPN config or connection details
|
|
hosts: List[Host] = field(default_factory=list)
|
|
|
|
# Enhanced network configuration
|
|
network_segments: List[NetworkSegment] = field(
|
|
default_factory=list) # Network segments with rich metadata
|
|
external_addresses: List[str] = field(
|
|
default_factory=list) # External VPN endpoints
|
|
port_forwardings: List[PortForwarding] = field(
|
|
default_factory=list) # Port forwarding rules
|
|
|
|
# Legacy field for backward compatibility (will be deprecated)
|
|
# Simple network list (legacy)
|
|
networks: List[str] = field(default_factory=list)
|
|
|
|
# VPN connection management fields
|
|
# NetworkManager connection name
|
|
nmcli_connection_name: Optional[str] = None
|
|
auto_import: bool = True # Auto-import .ovpn file if not in NetworkManager
|
|
|
|
# Credential storage - can be:
|
|
# - Passbolt UUID string (for future use)
|
|
# - Dict with 'username' and 'password' keys
|
|
# - None if no credentials needed
|
|
vpn_credentials: Optional[dict | str] = None
|
|
|
|
def get_host_by_name(self, host_name: str) -> Optional[Host]:
|
|
"""Get a host by its name (searches recursively in sub-hosts)."""
|
|
def search_hosts(hosts_list: List[Host]) -> Optional[Host]:
|
|
for host in hosts_list:
|
|
if host.name == host_name:
|
|
return host
|
|
# Search in sub-hosts
|
|
sub_result = search_hosts(host.sub_hosts)
|
|
if sub_result:
|
|
return sub_result
|
|
return None
|
|
|
|
return search_hosts(self.hosts)
|
|
|
|
def get_all_hosts_flat(self) -> List[Host]:
|
|
"""Get all hosts including sub-hosts in a flat list."""
|
|
def collect_hosts(hosts_list: List[Host]) -> List[Host]:
|
|
result = []
|
|
for host in hosts_list:
|
|
result.append(host)
|
|
result.extend(collect_hosts(host.sub_hosts))
|
|
return result
|
|
|
|
return collect_hosts(self.hosts)
|
|
|
|
def get_hypervisors(self) -> List[Host]:
|
|
"""Get all hosts that have sub-hosts (hypervisors)."""
|
|
return [host for host in self.get_all_hosts_flat() if host.is_hypervisor()]
|
|
|
|
def get_segment_by_name(self, segment_name: str) -> Optional[NetworkSegment]:
|
|
"""Get a network segment by its name."""
|
|
return next((seg for seg in self.network_segments if seg.name == segment_name), None)
|
|
|
|
def get_hosts_in_segment(self, segment_name: str) -> List[Host]:
|
|
"""Get all hosts that have IPs in the specified network segment."""
|
|
hosts = []
|
|
for host in self.get_all_hosts_flat():
|
|
if any(host_ip.network_segment == segment_name for host_ip in host.ip_addresses):
|
|
hosts.append(host)
|
|
return hosts
|
|
|
|
def get_segments_by_zone(self, zone: str) -> List[NetworkSegment]:
|
|
"""Get all network segments in a specific zone."""
|
|
return [seg for seg in self.network_segments if seg.zone == zone]
|
|
|
|
def get_port_forwardings_for_host(self, host_ip: str) -> List[PortForwarding]:
|
|
"""Get all port forwardings targeting a specific host IP."""
|
|
return [pf for pf in self.port_forwardings if pf.internal_ip == host_ip and pf.enabled]
|
|
|
|
def get_externally_accessible_services(self) -> List[tuple]:
|
|
"""Get all services accessible from external addresses via port forwarding.
|
|
|
|
Returns list of tuples: (external_address, external_port, host, service, port_forwarding)
|
|
"""
|
|
accessible_services = []
|
|
|
|
for external_addr in self.external_addresses:
|
|
for port_forward in self.port_forwardings:
|
|
if not port_forward.enabled:
|
|
continue
|
|
|
|
# Find the host that owns the target IP
|
|
target_host = None
|
|
target_service = None
|
|
|
|
for host in self.get_all_hosts_flat():
|
|
host_ips = [hip.ip_address for hip in host.ip_addresses]
|
|
if port_forward.internal_ip in host_ips:
|
|
target_host = host
|
|
|
|
# Find matching service on this host
|
|
for service in host.services:
|
|
if service.port == port_forward.internal_port:
|
|
target_service = service
|
|
break
|
|
break
|
|
|
|
if target_host:
|
|
accessible_services.append((
|
|
external_addr,
|
|
port_forward.external_port,
|
|
target_host,
|
|
target_service, # May be None if no matching service defined
|
|
port_forward
|
|
))
|
|
|
|
return accessible_services
|
|
|
|
def is_service_externally_accessible(self, host_ip: str, service_port: int) -> bool:
|
|
"""Check if a specific service is accessible from external addresses."""
|
|
for pf in self.port_forwardings:
|
|
if (pf.enabled and
|
|
pf.internal_ip == host_ip and
|
|
pf.internal_port == service_port):
|
|
return True
|
|
return False
|
|
|
|
def is_service_reachable(self, host: 'Host', service: Service) -> bool:
|
|
"""Check if a service is reachable (either via VPN connection or port forwarding).
|
|
|
|
Returns True if:
|
|
- VPN is connected (all internal services become reachable)
|
|
- Service has a port forwarding rule enabled
|
|
"""
|
|
# If VPN is connected, all services are reachable
|
|
if self.connected:
|
|
return True
|
|
|
|
# Check if service is externally accessible via port forwarding
|
|
for host_ip in host.ip_addresses:
|
|
if self.is_service_externally_accessible(host_ip.ip_address, service.port):
|
|
return True
|
|
|
|
return False
|
|
|
|
def get_external_url_for_service(self, host: 'Host', service: Service) -> Optional[str]:
|
|
"""Get the external URL for a service if it has port forwarding.
|
|
|
|
Returns the external URL (e.g., "https://vpn.example.com:8006") or None.
|
|
"""
|
|
for host_ip in host.ip_addresses:
|
|
for pf in self.port_forwardings:
|
|
if (pf.enabled and
|
|
pf.internal_ip == host_ip.ip_address and
|
|
pf.internal_port == service.port):
|
|
# Use first external address if available
|
|
if self.external_addresses:
|
|
protocol = "https" if service.port in [
|
|
443, 8006, 8080] else "http"
|
|
return f"{protocol}://{self.external_addresses[0]}:{pf.external_port}"
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class CustomerService:
|
|
"""Represents a customer's cloud/web service."""
|
|
name: str
|
|
url: str
|
|
service_type: str # e.g., "Email", "Phone System", "CRM", "ERP"
|
|
description: str = ""
|
|
|
|
|
|
@dataclass
|
|
class Customer:
|
|
"""Represents a customer with their services and locations."""
|
|
name: str
|
|
|
|
# Customer's cloud/web services (available regardless of location)
|
|
services: List[CustomerService] = field(default_factory=list)
|
|
|
|
# Customer's locations with their infrastructure
|
|
locations: List[Location] = field(default_factory=list)
|
|
|
|
def get_location_by_name(self, location_name: str) -> Optional[Location]:
|
|
"""Get a location by its name."""
|
|
for location in self.locations:
|
|
if location.name == location_name:
|
|
return location
|
|
return None
|
|
|
|
def get_active_locations(self) -> List[Location]:
|
|
"""Get all active locations for this customer."""
|
|
return [loc for loc in self.locations if loc.active]
|
|
|
|
def get_inactive_locations(self) -> List[Location]:
|
|
"""Get all inactive locations for this customer."""
|
|
return [loc for loc in self.locations if not loc.active]
|
|
|
|
def has_active_locations(self) -> bool:
|
|
"""Check if customer has any active locations."""
|
|
return any(loc.active for loc in self.locations)
|
|
|
|
def has_connected_locations(self) -> bool:
|
|
"""Check if customer has any connected locations."""
|
|
return any(loc.connected for loc in self.locations)
|
|
|
|
def get_all_hosts_flat(self) -> List[Host]:
|
|
"""Get all hosts from all locations in a flat list."""
|
|
all_hosts = []
|
|
for location in self.locations:
|
|
all_hosts.extend(location.get_all_hosts_flat())
|
|
return all_hosts
|