|
| 1 | +import logging |
| 2 | +from abc import ABC |
| 3 | +from pathlib import Path |
| 4 | +from inspect import getmembers, isclass |
| 5 | +from typing import List, Type, Generator, Any |
| 6 | +from importlib.util import spec_from_file_location, module_from_spec |
| 7 | + |
| 8 | +from bot import OperationalException, DependencyException |
| 9 | + |
| 10 | + |
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +class RemoteLoader(ABC): |
| 15 | + """ |
| 16 | + RemoteLoader class: abstract class with some util functions to handle searching for modules, |
| 17 | + classes and instantiating classes |
| 18 | + """ |
| 19 | + |
| 20 | + @staticmethod |
| 21 | + def locate_python_modules(dir_path: Path) -> List[Path]: |
| 22 | + """ |
| 23 | + Functions that will search through all the files in a directory to locate the class matching the |
| 24 | + given class name. |
| 25 | + """ |
| 26 | + |
| 27 | + if not dir_path.is_dir(): |
| 28 | + raise OperationalException("Given directory path is not a directory") |
| 29 | + |
| 30 | + modules: List[Path] = [] |
| 31 | + |
| 32 | + logger.info("Searching in directory {} ...".format(dir_path)) |
| 33 | + |
| 34 | + for entry in dir_path.iterdir(): |
| 35 | + |
| 36 | + if not str(entry).endswith('.py'): |
| 37 | + continue |
| 38 | + |
| 39 | + logger.info("Found module: {}, appending it to search paths".format(str(entry))) |
| 40 | + modules.append(entry) |
| 41 | + |
| 42 | + return modules |
| 43 | + |
| 44 | + @staticmethod |
| 45 | + def locate_class(modules: List[Path], class_name: str) -> Path: |
| 46 | + """ |
| 47 | + Function that will search all the given modules and will return the corresponding path where |
| 48 | + the class is located. |
| 49 | + """ |
| 50 | + |
| 51 | + logger.info("Searching for class: {}".format(class_name)) |
| 52 | + |
| 53 | + for module_path in modules: |
| 54 | + spec = spec_from_file_location(module_path.stem, str(module_path)) |
| 55 | + module = module_from_spec(spec) |
| 56 | + |
| 57 | + try: |
| 58 | + spec.loader.exec_module(module) |
| 59 | + except (ModuleNotFoundError, SyntaxError) as err: |
| 60 | + # Catch errors in case a specific module is not installed |
| 61 | + logger.warning(f"Could not import {module_path} due to '{err}'") |
| 62 | + |
| 63 | + if hasattr(module, class_name): |
| 64 | + logger.info("Found class {} in module {}".format(class_name, str(module_path))) |
| 65 | + return module_path |
| 66 | + |
| 67 | + raise DependencyException("Could not find given class in selection of modules") |
| 68 | + |
| 69 | + @staticmethod |
| 70 | + def create_class_generators(module_path: Path, class_name: str, |
| 71 | + class_type: Type[Any]) -> Generator[Any, None, None]: |
| 72 | + """ |
| 73 | + Function that creates a generator for a given module path and class name |
| 74 | + """ |
| 75 | + spec = spec_from_file_location(module_path.stem, str(module_path)) |
| 76 | + module = module_from_spec(spec) |
| 77 | + |
| 78 | + try: |
| 79 | + spec.loader.exec_module(module) |
| 80 | + except (ModuleNotFoundError, SyntaxError) as err: |
| 81 | + # Catch errors in case a specific module is not installed |
| 82 | + logger.warning(f"Could not import {module_path} due to '{err}'") |
| 83 | + |
| 84 | + object_generators = ( |
| 85 | + obj for name, obj in getmembers(module, isclass) if (class_name is None or class_name == name) |
| 86 | + and class_type in obj.__bases__ |
| 87 | + ) |
| 88 | + |
| 89 | + return object_generators |
0 commit comments