From e44dbbc6ca5093ba93cf55495462f0e5addf6c40 Mon Sep 17 00:00:00 2001 From: Satvik-Singh192 Date: Sun, 2 Nov 2025 21:41:46 +0530 Subject: [PATCH] fix: firewall shutdown on ctrl+c signal --- main.py | 41 ++++++++++++++++++++++++++++++++--------- src/firewall/core.py | 31 ++++++++++++++++++++----------- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/main.py b/main.py index 127d8b7..346c299 100644 --- a/main.py +++ b/main.py @@ -10,11 +10,15 @@ import signal import traceback from pathlib import Path +import threading +import time # Add src directory to Python path for absolute imports current_dir = Path(__file__).parent src_dir = current_dir / 'src' sys.path.insert(0, str(src_dir)) +_SHUTDOWN_INITIATED = False +_SHUTDOWN_LOCK = threading.Lock() try: from firewall.core import SimpleFirewall @@ -148,15 +152,34 @@ def main(): # Setup signal handlers for graceful shutdown def signal_handler(signum, frame): - print(f"\n{Fore.YELLOW}Signal {signum} received, shutting down...{Style.RESET_ALL}") - firewall.stop() - sys.exit(0) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - # Start the firewall - firewall.start() + """Enhanced signal handler with timeout and state tracking""" + global _SHUTDOWN_INITIATED + + with _SHUTDOWN_LOCK: + if _SHUTDOWN_INITIATED: + print(f"\n{Fore.RED}Forced shutdown initiated...{Style.RESET_ALL}") + os._exit(1) + + _SHUTDOWN_INITIATED = True + + print(f"\n{Fore.YELLOW}Signal {signum} received, initiating graceful shutdown...{Style.RESET_ALL}") + + def force_shutdown(): + time.sleep(5) + print(f"\n{Fore.RED}Graceful shutdown timeout, forcing exit...{Style.RESET_ALL}") + os._exit(1) + + force_thread = threading.Thread(target=force_shutdown, daemon=True) + force_thread.start() + + try: + if 'firewall' in locals(): + firewall.stop() + print(f"{Fore.GREEN}Firewall shutdown completed successfully.{Style.RESET_ALL}") + sys.exit(0) + except Exception as e: + print(f"{Fore.RED}Error during shutdown: {e}{Style.RESET_ALL}") + os._exit(1) except KeyboardInterrupt: print(f"\n{Fore.YELLOW}Interrupted by user{Style.RESET_ALL}") diff --git a/src/firewall/core.py b/src/firewall/core.py index cf7d8c9..79f2e3a 100644 --- a/src/firewall/core.py +++ b/src/firewall/core.py @@ -4,7 +4,7 @@ import time import signal import sys -from scapy.all import sniff +from scapy.all import sniff,AsyncSniffer from colorama import Fore, Style from typing import Optional @@ -41,6 +41,9 @@ def __init__(self, interface: str = None, config_file: str = None): # Control flags self.running = False self._threads = [] + + self._sniffer = None + self._stop_event = threading.Event() self.logger.info("Simple Firewall initialized successfully") @@ -167,26 +170,32 @@ def start(self): self.stop() def stop(self): - """Stop the firewall and cleanup""" + """Stop the firewall and cleanup - ENHANCED VERSION""" + if not self.running: + return + print(f"\n{Fore.YELLOW}Stopping firewall...{Style.RESET_ALL}") self.running = False + self._stop_event.set() + + if self._sniffer and self._sniffer.running: + self._sniffer.stop() + + timeout = 3.0 + start_time = time.time() - # Wait for threads to finish (with timeout) for thread in self._threads: if thread.is_alive(): - thread.join(timeout=2.0) + remaining_time = timeout - (time.time() - start_time) + if remaining_time > 0: + thread.join(timeout=remaining_time) - # Display final stats - self._display_stats() - - # Cleanup firewall rules - print(f"{Fore.YELLOW}Cleaning up firewall rules...{Style.RESET_ALL}") cleaned_ips = self.blocker.cleanup_all_blocks() if cleaned_ips: self.logger.info(f"Cleaned up blocks for {len(cleaned_ips)} IPs") - print(f"{Fore.GREEN}Firewall stopped.{Style.RESET_ALL}") - self.logger.info("Simple Firewall stopped") + self._display_stats() + print(f"{Fore.GREEN}Firewall stopped successfully.{Style.RESET_ALL}") def get_status(self) -> dict: """Get current firewall status"""