Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
import signal
import traceback
from pathlib import Path
import threading
import time

# Add src directory to Python path for absolute imports
current_dir = Path(__file__).parent
src_dir = current_dir / 'src'
sys.path.insert(0, str(src_dir))
_SHUTDOWN_INITIATED = False
_SHUTDOWN_LOCK = threading.Lock()

try:
from firewall.core import SimpleFirewall
Expand Down Expand Up @@ -148,15 +152,34 @@ def main():

# Setup signal handlers for graceful shutdown
def signal_handler(signum, frame):
print(f"\n{Fore.YELLOW}Signal {signum} received, shutting down...{Style.RESET_ALL}")
firewall.stop()
sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

# Start the firewall
firewall.start()
"""Enhanced signal handler with timeout and state tracking"""
global _SHUTDOWN_INITIATED

with _SHUTDOWN_LOCK:
if _SHUTDOWN_INITIATED:
print(f"\n{Fore.RED}Forced shutdown initiated...{Style.RESET_ALL}")
os._exit(1)

_SHUTDOWN_INITIATED = True

print(f"\n{Fore.YELLOW}Signal {signum} received, initiating graceful shutdown...{Style.RESET_ALL}")

def force_shutdown():
time.sleep(5)
print(f"\n{Fore.RED}Graceful shutdown timeout, forcing exit...{Style.RESET_ALL}")
os._exit(1)

force_thread = threading.Thread(target=force_shutdown, daemon=True)
force_thread.start()

try:
if 'firewall' in locals():
firewall.stop()
print(f"{Fore.GREEN}Firewall shutdown completed successfully.{Style.RESET_ALL}")
sys.exit(0)
except Exception as e:
print(f"{Fore.RED}Error during shutdown: {e}{Style.RESET_ALL}")
os._exit(1)

except KeyboardInterrupt:
print(f"\n{Fore.YELLOW}Interrupted by user{Style.RESET_ALL}")
Expand Down
31 changes: 20 additions & 11 deletions src/firewall/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import signal
import sys
from scapy.all import sniff
from scapy.all import sniff,AsyncSniffer
from colorama import Fore, Style
from typing import Optional

Expand Down Expand Up @@ -41,6 +41,9 @@ def __init__(self, interface: str = None, config_file: str = None):
# Control flags
self.running = False
self._threads = []

self._sniffer = None
self._stop_event = threading.Event()

self.logger.info("Simple Firewall initialized successfully")

Expand Down Expand Up @@ -167,26 +170,32 @@ def start(self):
self.stop()

def stop(self):
"""Stop the firewall and cleanup"""
"""Stop the firewall and cleanup - ENHANCED VERSION"""
if not self.running:
return

print(f"\n{Fore.YELLOW}Stopping firewall...{Style.RESET_ALL}")
self.running = False
self._stop_event.set()

if self._sniffer and self._sniffer.running:
self._sniffer.stop()

timeout = 3.0
start_time = time.time()

# Wait for threads to finish (with timeout)
for thread in self._threads:
if thread.is_alive():
thread.join(timeout=2.0)
remaining_time = timeout - (time.time() - start_time)
if remaining_time > 0:
thread.join(timeout=remaining_time)

# Display final stats
self._display_stats()

# Cleanup firewall rules
print(f"{Fore.YELLOW}Cleaning up firewall rules...{Style.RESET_ALL}")
cleaned_ips = self.blocker.cleanup_all_blocks()
if cleaned_ips:
self.logger.info(f"Cleaned up blocks for {len(cleaned_ips)} IPs")

print(f"{Fore.GREEN}Firewall stopped.{Style.RESET_ALL}")
self.logger.info("Simple Firewall stopped")
self._display_stats()
print(f"{Fore.GREEN}Firewall stopped successfully.{Style.RESET_ALL}")

def get_status(self) -> dict:
"""Get current firewall status"""
Expand Down
Loading