diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..60e9630 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,129 @@ +name: CI + +on: + push: + branches: + - "main" + pull_request: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + lint-python: + name: Lint Python + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + cache: "pip" + + - name: Run flake8 + uses: py-actions/flake8@v2 + + validate-compute-block: + name: Validate Compute Block Config + runs-on: ubuntu-latest + needs: lint-python + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + + - name: Intall dependencies + run: | + pip install -r requirements.txt + + - name: Check cbcs + run: | + python3 - <<'EOF' + import main + + from scystream.sdk.config import load_config, get_compute_block + from scystream.sdk.config.config_loader import _compare_configs + from pathlib import Path + + CBC_PATH = Path("cbc.yaml") + + if not CBC_PATH.exists(): + raise FileNotFoundError("cbc.yaml not found in repo root.") + + block_from_code = get_compute_block() + block_from_yaml = load_config(str(CBC_PATH)) + + _compare_configs(block_from_code, block_from_yaml) + + print("cbc.yaml matches python code definition") + EOF + + run-test: + name: Run Tests + runs-on: ubuntu-latest + needs: validate-compute-block + services: + minio: + image: lazybit/minio + ports: + - 9000:9000 + env: + MINIO_ROOT_USER: minioadmin + MINIO_ROOT_PASSWORD: minioadmin + options: >- + --health-cmd "curl -f http://localhost:9000/minio/health/live || exit 1" + --health-interval 5s + --health-retries 5 + --health-timeout 5s + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + cache: "pip" + + - name: Install dependencies + run: | + pip install -r requirements.txt + + - name: Run Tests + run: pytest -vv + + build: + name: Build Docker Image + runs-on: ubuntu-latest + needs: run-test + permissions: + contents: read + packages: write + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata for docker + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}/language-preprocessing + tags: | + type=ref, event=pr + type=raw, value=latest, enable=${{ (github.ref == format('refs/heads/{0}', 'main')) }} + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml deleted file mode 100644 index df0d4cf..0000000 --- a/.github/workflows/docker.yaml +++ /dev/null @@ -1,44 +0,0 @@ -name: Docker -on: - push: - branches: - - "main" - pull_request: - -env: - REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository }} - -jobs: - build: - name: Build docker image - runs-on: ubuntu-latest - permissions: - contents: read - packages: write - steps: - - name: Checkout Repository - uses: actions/checkout@v4 - - - name: Log in to Docker Hub - uses: docker/login-action@v3 - with: - registry: ${{ env.REGISTRY }} - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Extract metadata for docker - id: meta - uses: docker/metadata-action@v5 - with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}/language-preprocessing - tags: | - type=ref, event=pr - type=raw, value=latest, enable=${{ (github.ref == format('refs/heads/{0}', 'main')) }} - - - name: Build and push Docker image - uses: docker/build-push-action@v5 - with: - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} diff --git a/cbc.yaml b/cbc.yaml index 0932961..d20bcdf 100644 --- a/cbc.yaml +++ b/cbc.yaml @@ -1,15 +1,16 @@ author: Paul Kalhorn description: Language preprocessing for .txt or .bib files -docker_image: ghcr.io/rwth-time/language-preprocessing/language-preprocessing +docker_image: ghcr.io/rwth-time/language-preprocessing/language-preprocessing entrypoints: preprocess_bib_file: - description: Entrypoint for preprocessing a .bib file + description: Entrypoint for preprocessing a .bib file envs: + BIB_DOWNLOAD_PATH: /tmp/input.bib FILTER_STOPWORDS: true LANGUAGE: en NGRAM_MAX: 3 NGRAM_MIN: 2 - UNIGRAM_NORMALIZER: porter + UNIGRAM_NORMALIZER: lemma USE_NGRAMS: true inputs: bib_input: @@ -23,7 +24,7 @@ entrypoints: bib_file_S3_PORT: null bib_file_S3_SECRET_KEY: null bib_file_SELECTED_ATTRIBUTE: Abstract - description: The bib file, aswell as one attribute selected for preprocessing + description: The bib file, aswell as one attribute selected for preprocessing type: file outputs: dtm_output: @@ -36,7 +37,7 @@ entrypoints: dtm_output_S3_HOST: null dtm_output_S3_PORT: null dtm_output_S3_SECRET_KEY: null - description: Numpy representation of document-term matrix as .pkl file + description: Numpy representation of document-term matrix as .pkl file type: file vocab_output: config: @@ -57,7 +58,8 @@ entrypoints: LANGUAGE: en NGRAM_MAX: 3 NGRAM_MIN: 2 - UNIGRAM_NORMALIZER: porter + TXT_DOWNLOAD_PATH: /tmp/input.txt + UNIGRAM_NORMALIZER: lemma USE_NGRAMS: true inputs: txt_input: @@ -70,7 +72,7 @@ entrypoints: txt_file_S3_HOST: null txt_file_S3_PORT: null txt_file_S3_SECRET_KEY: null - description: A .txt file + description: A .txt file type: file outputs: dtm_output: diff --git a/main.py b/main.py index 230d15f..e599750 100644 --- a/main.py +++ b/main.py @@ -1,18 +1,25 @@ import pickle import tempfile +import logging from scystream.sdk.core import entrypoint from scystream.sdk.env.settings import ( - EnvSettings, - InputSettings, - OutputSettings, - FileSettings + EnvSettings, + InputSettings, + OutputSettings, + FileSettings ) from scystream.sdk.file_handling.s3_manager import S3Operations from preprocessing.core import Preprocessor from preprocessing.loader import TxtLoader, BibLoader +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + class DTMFileOutput(FileSettings, OutputSettings): __identifier__ = "dtm_output" @@ -46,6 +53,8 @@ class PreprocessTXT(EnvSettings): NGRAM_MIN: int = 2 NGRAM_MAX: int = 3 + TXT_DOWNLOAD_PATH: str = "/tmp/input.txt" + txt_input: TXTFileInput dtm_output: DTMFileOutput vocab_output: VocabFileOutput @@ -59,6 +68,8 @@ class PreprocessBIB(EnvSettings): NGRAM_MIN: int = 2 NGRAM_MAX: int = 3 + BIB_DOWNLOAD_PATH: str = "/tmp/input.bib" + bib_input: BIBFileInput dtm_output: DTMFileOutput vocab_output: VocabFileOutput @@ -66,6 +77,8 @@ class PreprocessBIB(EnvSettings): def _preprocess_and_store(texts, settings): """Shared preprocessing logic for TXT and BIB.""" + logger.info(f"Starting preprocessing with {len(texts)} documents") + pre = Preprocessor( language=settings.LANGUAGE, filter_stopwords=settings.FILTER_STOPWORDS, @@ -74,14 +87,16 @@ def _preprocess_and_store(texts, settings): ngram_min=settings.NGRAM_MIN, ngram_max=settings.NGRAM_MAX, ) - pre.texts = texts + pre.texts = texts pre.analyze_texts() + pre.generate_bag_of_words() + dtm, vocab = pre.generate_document_term_matrix() with tempfile.NamedTemporaryFile(suffix="_dtm.pkl") as tmp_dtm, \ - tempfile.NamedTemporaryFile(suffix="_vocab.pkl") as tmp_vocab: + tempfile.NamedTemporaryFile(suffix="_vocab.pkl") as tmp_vocab: pickle.dump(dtm, tmp_dtm) tmp_dtm.flush() @@ -89,59 +104,32 @@ def _preprocess_and_store(texts, settings): pickle.dump(vocab, tmp_vocab) tmp_vocab.flush() + logger.info("Uploading DTM to S3...") S3Operations.upload(settings.dtm_output, tmp_dtm.name) + + logger.info("Uploading vocabulary to S3...") S3Operations.upload(settings.vocab_output, tmp_vocab.name) + logger.info("Preprocessing completed successfully.") + @entrypoint(PreprocessTXT) def preprocess_txt_file(settings): - S3Operations.download(settings.txt_input, "input.txt") - texts = TxtLoader.load("./input.txt") + logger.info("Downloading TXT input from S3...") + S3Operations.download(settings.txt_input, settings.TXT_DOWNLOAD_PATH) + + texts = TxtLoader.load(settings.TXT_DOWNLOAD_PATH) + _preprocess_and_store(texts, settings) @entrypoint(PreprocessBIB) def preprocess_bib_file(settings): - S3Operations.download(settings.bib_input, "input.bib") + logger.info("Downloading BIB input from S3...") + S3Operations.download(settings.bib_input, settings.BIB_DOWNLOAD_PATH) + texts = BibLoader.load( - "./input.bib", + settings.BIB_DOWNLOAD_PATH, attribute=settings.bib_input.SELECTED_ATTRIBUTE, ) _preprocess_and_store(texts, settings) - - -""" -if __name__ == "__main__": - test = PreprocessBIB( - bib_input=BIBFileInput( - S3_HOST="http://localhost", - S3_PORT="9000", - S3_ACCESS_KEY="minioadmin", - S3_SECRET_KEY="minioadmin", - BUCKET_NAME="input-bucket", - FILE_PATH="input_file_path", - FILE_NAME="wos_export", - SELECTED_ATTRIBUTE="abstract" - ), - dtm_output=DTMFileOutput( - S3_HOST="http://localhost", - S3_PORT="9000", - S3_ACCESS_KEY="minioadmin", - S3_SECRET_KEY="minioadmin", - BUCKET_NAME="output-bucket", - FILE_PATH="output_file_path", - FILE_NAME="dtm_file_bib" - ), - vocab_output=VocabFileOutput( - S3_HOST="http://localhost", - S3_PORT="9000", - S3_ACCESS_KEY="minioadmin", - S3_SECRET_KEY="minioadmin", - BUCKET_NAME="output-bucket", - FILE_PATH="output_file_path", - FILE_NAME="vocab_file_bib" - ) - ) - - preprocess_bib_file(test) -""" diff --git a/preprocessing/__init__.py b/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/preprocessing/core.py b/preprocessing/core.py index 4db4585..dba2a8d 100644 --- a/preprocessing/core.py +++ b/preprocessing/core.py @@ -1,3 +1,4 @@ +import logging import spacy import numpy as np @@ -9,6 +10,7 @@ "en": "en_core_web_sm", "de": "de_core_news_sm" } +logger = logging.getLogger(__name__) class Preprocessor: @@ -21,6 +23,12 @@ def __init__( ngram_min: int = 2, ngram_max: int = 3, ): + logger.info( + "Init Preprocessor (lang=%s, filter_stopwords=%s, ngrams=%s)", + language, + filter_stopwords, + use_ngrams, + ) self.language = language self.filter_stopwords = filter_stopwords self.unigram_normalizer = unigram_normalizer @@ -58,6 +66,7 @@ def filter_tokens( ] def analyze_texts(self): + logger.info(f"Analyzing {len(self.texts)} texts...") porter = PorterStemmer() for text in self.texts: doc = self.nlp(text) @@ -67,8 +76,8 @@ def analyze_texts(self): for sentence in doc.sents: filtered_tokens = self.filter_tokens( - list(sentence), - self.filter_stopwords + list(sentence), + self.filter_stopwords ) normalized_tokens = [ self.normalize_token(t, porter) for t in filtered_tokens @@ -93,6 +102,10 @@ def analyze_texts(self): if ngram_list: self.ngram_frequency.update(ngram_list) self.ngram_document_frequency.update(set(ngram_list)) + logger.info( + f"Finished analyzing texts: {self.token_frequency} unigrams, { + self.ngram_frequency} n-grams", + ) def normalize_token( self, @@ -110,6 +123,7 @@ def normalize_token( return word def generate_bag_of_words(self): + logger.info("Generating bag-of-words...") porter = PorterStemmer() self.bag_of_words = [] @@ -177,7 +191,7 @@ def generate_document_term_matrix(self) -> (np.ndarray, dict): dtm (np.ndarray): shape = (num_docs, num_terms) vocab (dict): mapping term -> column index """ - + logger.info("Building document-term-matrix...") all_terms = set() for doc in self.bag_of_words: for t in doc: @@ -194,4 +208,5 @@ def generate_document_term_matrix(self) -> (np.ndarray, dict): term_idx = vocab[token["term"]] dtm[doc_idx, term_idx] += 1 + logger.info(f"Matrix shape: {dtm.shape} | Vocab size: {len(vocab)}") return dtm, vocab diff --git a/preprocessing/loader.py b/preprocessing/loader.py index d55aac3..50d0177 100644 --- a/preprocessing/loader.py +++ b/preprocessing/loader.py @@ -1,21 +1,24 @@ +import logging import re import bibtexparser +logger = logging.getLogger(__name__) + def normalize_text(text: str) -> str: if not text: return "" - # Remove curly braces - text = re.sub(r"[{}]", "", text) - # Remove LaTeX commands - text = re.sub(r"\\[a-zA-Z]+\s*(\{[^}]*\})?", "", text) + text = re.sub(r"\\[a-zA-Z]+\{([^}]*)\}", r"\1", text) + + text = re.sub(r"\\[a-zA-Z]+", "", text) - # Remove LaTeX escaped quotes/accents - text = re.sub(r"\\""[a-zA-Z]", lambda m: m.group(0)[-1], text) + text = re.sub(r"[{}]", "", text) + + text = re.sub(r'\\"([a-zA-Z])', r'\1', text) text = re.sub(r"\\'", "", text) - text = text.replace("'", "") + text = re.sub(r"\s+", " ", text) return text.strip() @@ -24,6 +27,7 @@ def normalize_text(text: str) -> str: class TxtLoader: @staticmethod def load(file_path: str) -> list[str]: + logger.info("Loading TXT file...") with open(file_path, "r", encoding="utf-8") as f: lines = f.readlines() return [normalize_text(line) for line in lines] @@ -32,6 +36,7 @@ def load(file_path: str) -> list[str]: class BibLoader: @staticmethod def load(file_path: str, attribute: str) -> list[str]: + logger.info(f"Loading BIB file (attribute={attribute})...") with open(file_path, "r", encoding="utf-8") as f: bib_database = bibtexparser.load(f) diff --git a/requirements.txt b/requirements.txt index 3493ec1..e7737db 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ spacy==3.8.7 nltk==3.9.1 numpy==2.3.3 bibtexparser==1.4.3 +pytest==9.0.1 diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..5ce5cca --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,20 @@ +import pytest + +from preprocessing.core import Preprocessor + + +@pytest.fixture +def simple_texts(): + return ["This is a test sentence.", "Another test sentence."] + + +@pytest.fixture +def preprocessor(): + return Preprocessor( + language="en", + filter_stopwords=True, + unigram_normalizer="porter", + use_ngrams=True, + ngram_min=2, + ngram_max=3, + ) diff --git a/test/files/expected_dtm_from_bib.pkl b/test/files/expected_dtm_from_bib.pkl new file mode 100644 index 0000000..8127603 Binary files /dev/null and b/test/files/expected_dtm_from_bib.pkl differ diff --git a/test/files/expected_dtm_from_txt.pkl b/test/files/expected_dtm_from_txt.pkl new file mode 100644 index 0000000..6ca1a74 Binary files /dev/null and b/test/files/expected_dtm_from_txt.pkl differ diff --git a/test/files/expected_vocab_from_bib.pkl b/test/files/expected_vocab_from_bib.pkl new file mode 100644 index 0000000..0641a70 Binary files /dev/null and b/test/files/expected_vocab_from_bib.pkl differ diff --git a/test/files/expected_vocab_from_txt.pkl b/test/files/expected_vocab_from_txt.pkl new file mode 100644 index 0000000..0698bc3 Binary files /dev/null and b/test/files/expected_vocab_from_txt.pkl differ diff --git a/test/files/test.txt b/test/files/input.txt similarity index 100% rename from test/files/test.txt rename to test/files/input.txt diff --git a/test/test_full.py b/test/test_full.py new file mode 100644 index 0000000..c97b370 --- /dev/null +++ b/test/test_full.py @@ -0,0 +1,217 @@ +import os +import boto3 +import pytest +import pickle +import numpy as np + +from pathlib import Path +from main import preprocess_bib_file, preprocess_txt_file +from botocore.exceptions import ClientError + +MINIO_USER = "minioadmin" +MINIO_PWD = "minioadmin" +BUCKET_NAME = "testbucket" + + +def ensure_bucket(s3, bucket): + try: + s3.head_bucket(Bucket=bucket) + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code in ("404", "NoSuchBucket"): + s3.create_bucket(Bucket=bucket) + else: + raise + + +def download_to_tmp(s3, bucket, key): + tmp_path = Path("/tmp") / key.replace("/", "_") + s3.download_file(bucket, key, str(tmp_path)) + return tmp_path + + +@pytest.fixture +def s3_minio(): + client = boto3.client( + "s3", + endpoint_url="http://localhost:9000", + aws_access_key_id=MINIO_USER, + aws_secret_access_key=MINIO_PWD + ) + ensure_bucket(client, BUCKET_NAME) + return client + + +def test_full_bib(s3_minio): + input_file_name = "input" + dtm_output_file_name = "dtm_file" + vocab_output_file_name = "vocab_file" + + bib_path = Path(__file__).parent / "files" / f"{input_file_name}.bib" + bib_bytes = bib_path.read_bytes() + + s3_minio.put_object( + Bucket=BUCKET_NAME, + Key=f"{input_file_name}.bib", + Body=bib_bytes + ) + + env = { + "UNIGRAM_NORMALIZER": "porter", + + "bib_file_S3_HOST": "http://127.0.0.1", + "bib_file_S3_PORT": "9000", + "bib_file_S3_ACCESS_KEY": MINIO_USER, + "bib_file_S3_SECRET_KEY": MINIO_PWD, + "bib_file_BUCKET_NAME": BUCKET_NAME, + "bib_file_FILE_PATH": "", + "bib_file_FILE_NAME": input_file_name, + "bib_file_SELECTED_ATTRIBUTE": "abstract", + + "dtm_output_S3_HOST": "http://127.0.0.1", + "dtm_output_S3_PORT": "9000", + "dtm_output_S3_ACCESS_KEY": MINIO_USER, + "dtm_output_S3_SECRET_KEY": MINIO_PWD, + "dtm_output_BUCKET_NAME": BUCKET_NAME, + "dtm_output_FILE_PATH": "", + "dtm_output_FILE_NAME": dtm_output_file_name, + + "vocab_output_S3_HOST": "http://127.0.0.1", + "vocab_output_S3_PORT": "9000", + "vocab_output_S3_ACCESS_KEY": MINIO_USER, + "vocab_output_S3_SECRET_KEY": MINIO_PWD, + "vocab_output_BUCKET_NAME": BUCKET_NAME, + "vocab_output_FILE_PATH": "", + "vocab_output_FILE_NAME": vocab_output_file_name, + } + + for k, v in env.items(): + os.environ[k] = v + + preprocess_bib_file() + + keys = [ + o["Key"] + for o in s3_minio.list_objects_v2( + Bucket="testbucket").get("Contents", []) + ] + + assert f"{dtm_output_file_name}.pkl" in keys + assert f"{vocab_output_file_name}.pkl" in keys + + dtm_path = download_to_tmp(s3_minio, BUCKET_NAME, f"{ + dtm_output_file_name}.pkl") + vocab_path = download_to_tmp(s3_minio, BUCKET_NAME, f"{ + vocab_output_file_name}.pkl") + + # Load produced results + with open(dtm_path, "rb") as f: + dtm = pickle.load(f) + + with open(vocab_path, "rb") as f: + vocab = pickle.load(f) + + # Load expected snapshot files + expected_vocab_path = Path(__file__).parent / \ + "files" / "expected_vocab_from_bib.pkl" + expected_dtm_path = Path(__file__).parent / "files" / \ + "expected_dtm_from_bib.pkl" + + with open(expected_vocab_path, "rb") as f: + expected_vocab = pickle.load(f) + + with open(expected_dtm_path, "rb") as f: + expected_dtm = pickle.load(f) + + assert vocab == expected_vocab + np.testing.assert_array_equal(dtm, expected_dtm) + + +def test_full_txt(s3_minio): + input_file_name = "input" + dtm_output_file_name = "dtm_txt_file" + vocab_output_file_name = "vocab_txt_file" + + txt_path = Path(__file__).parent / "files" / f"{input_file_name}.txt" + txt_bytes = txt_path.read_bytes() + + s3_minio.put_object( + Bucket=BUCKET_NAME, + Key=f"{input_file_name}.txt", + Body=txt_bytes + ) + + env = { + "UNIGRAM_NORMALIZER": "porter", + + "txt_file_S3_HOST": "http://127.0.0.1", + "txt_file_S3_PORT": "9000", + "txt_file_S3_ACCESS_KEY": MINIO_USER, + "txt_file_S3_SECRET_KEY": MINIO_PWD, + "txt_file_BUCKET_NAME": BUCKET_NAME, + "txt_file_FILE_PATH": "", + "txt_file_FILE_NAME": input_file_name, + + "dtm_output_S3_HOST": "http://127.0.0.1", + "dtm_output_S3_PORT": "9000", + "dtm_output_S3_ACCESS_KEY": MINIO_USER, + "dtm_output_S3_SECRET_KEY": MINIO_PWD, + "dtm_output_BUCKET_NAME": BUCKET_NAME, + "dtm_output_FILE_PATH": "", + "dtm_output_FILE_NAME": dtm_output_file_name, + + "vocab_output_S3_HOST": "http://127.0.0.1", + "vocab_output_S3_PORT": "9000", + "vocab_output_S3_ACCESS_KEY": MINIO_USER, + "vocab_output_S3_SECRET_KEY": MINIO_PWD, + "vocab_output_BUCKET_NAME": BUCKET_NAME, + "vocab_output_FILE_PATH": "", + "vocab_output_FILE_NAME": vocab_output_file_name, + } + + for k, v in env.items(): + os.environ[k] = v + + preprocess_txt_file() + + keys = [ + o["Key"] + for o in s3_minio.list_objects_v2( + Bucket=BUCKET_NAME).get("Contents", []) + ] + + assert f"{dtm_output_file_name}.pkl" in keys + assert f"{vocab_output_file_name}.pkl" in keys + + # Download produced files + dtm_path = download_to_tmp(s3_minio, BUCKET_NAME, f"{ + dtm_output_file_name}.pkl") + vocab_path = download_to_tmp(s3_minio, BUCKET_NAME, f"{ + vocab_output_file_name}.pkl") + + # Load produced results + with open(dtm_path, "rb") as f: + dtm = pickle.load(f) + + with open(vocab_path, "rb") as f: + vocab = pickle.load(f) + + # Load expected snapshot files + expected_vocab_path = Path(__file__).parent / \ + "files" / "expected_vocab_from_txt.pkl" + expected_dtm_path = Path(__file__).parent / \ + "files" / "expected_dtm_from_txt.pkl" + + with open(expected_vocab_path, "rb") as f: + expected_vocab = pickle.load(f) + + with open(expected_dtm_path, "rb") as f: + expected_dtm = pickle.load(f) + + # Assertions + assert vocab == expected_vocab + + if hasattr(dtm, "toarray"): + np.testing.assert_array_equal(dtm.toarray(), expected_dtm.toarray()) + else: + np.testing.assert_array_equal(dtm, expected_dtm) diff --git a/test/test_loaders.py b/test/test_loaders.py new file mode 100644 index 0000000..3c96468 --- /dev/null +++ b/test/test_loaders.py @@ -0,0 +1,33 @@ +import os +import tempfile + +from preprocessing.loader import TxtLoader, BibLoader + + +def test_txt_loader_reads_and_normalizes(): + with tempfile.NamedTemporaryFile("w+", delete=False) as f: + f.write("Hello {World}\nSecond line") + fname = f.name + + result = TxtLoader.load(fname) + os.unlink(fname) + + assert result == ["Hello World", "Second line"] + + +def test_bib_loader_extracts_attribute(): + bib_content = r""" + @article{a, + abstract = {This is {Bib} \textbf{text}.}, + title = {Ignore me} + } + """ + + with tempfile.NamedTemporaryFile("w+", delete=False) as f: + f.write(bib_content) + fname = f.name + + result = BibLoader.load(fname, "abstract") + os.unlink(fname) + + assert result == ["This is Bib text."] diff --git a/test/test_normalize.py b/test/test_normalize.py new file mode 100644 index 0000000..33dd6e5 --- /dev/null +++ b/test/test_normalize.py @@ -0,0 +1,17 @@ +from preprocessing.loader import normalize_text + + +def test_normalize_removes_braces(): + assert normalize_text("{abc}") == "abc" + + +def test_normalize_removes_latex_commands(): + assert normalize_text(r"\textbf{Hello}") == "Hello" + + +def test_normalize_removes_accents(): + assert normalize_text(r"\'a") == "a" + + +def test_normalize_collapses_whitespace(): + assert normalize_text("a b c") == "a b c" diff --git a/test/test_preprocessor_unit.py b/test/test_preprocessor_unit.py new file mode 100644 index 0000000..7829105 --- /dev/null +++ b/test/test_preprocessor_unit.py @@ -0,0 +1,26 @@ +def test_preprocessor_tokenization(preprocessor, simple_texts): + preprocessor.texts = simple_texts + preprocessor.analyze_texts() + + assert len(preprocessor.token_frequency) > 0 + + +def test_preprocessor_bag_of_words(preprocessor, simple_texts): + preprocessor.texts = simple_texts + preprocessor.analyze_texts() + preprocessor.generate_bag_of_words() + + assert len(preprocessor.bag_of_words) == 2 + assert all(len(doc) > 0 for doc in preprocessor.bag_of_words) + + +def test_generate_document_term_matrix(preprocessor, simple_texts): + preprocessor.texts = simple_texts + preprocessor.analyze_texts() + preprocessor.generate_bag_of_words() + + dtm, vocab = preprocessor.generate_document_term_matrix() + + assert dtm.shape[0] == 2 + assert dtm.shape[1] == len(vocab) + assert dtm.sum() > 0