diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c2854bf --- /dev/null +++ b/.gitignore @@ -0,0 +1,165 @@ +# Custom +examples/data/shakespeare/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md index 7897e2c..ab3af29 100644 --- a/README.md +++ b/README.md @@ -26,10 +26,38 @@ python examples/data/shakespeare.py And finally, let's train a GPT: ```bash -python examples/train-gpt.py +python examples/gpt.py train ``` -This runs on CPU and should get train loss: 1.65 and test loss: 1.80 after 2000 iterations. +This runs on CPU and should get train loss: 1.65 and test loss: 1.80 after 2000 iterations (took a few minutes). + +The trained weights will be saved to `examples/data/shakespeare/weights.pt`. You can now run inference: +```bash +python examples/gpt.py inference "ROMEO:" +``` + +You could also generate your own dataset and train your own GPT! See `examples/data/shakespeare.py` and change the +source text files, then train your new model: + +```bash +python examples/gpt.py train \ + --train=mydataset/train.bin \ + --validation=mydataset/val.bin \ + --weights=mydataset/weights.pt +``` + +Now you can run inference with our fresh weights: + +```bash +python examples/gpt.py inference \ + --weights=mydataset/weights.pt \ + "JULIET:" +``` + +> Note: you may need to change the `chars` in `examples/gpt.py` to match the chars of your dataset. +> If you want a more generic approach, consider using something like: +> `chars = list(string.ascii_letters + string.digits + string.punctuation + string.whitespace)` + ## Project roadmap @@ -86,7 +114,7 @@ for step in range(steps:=20): mlp.normalize(grad := weights.grad()) # normalize the gradient in the modular norm weights -= 0.1 * grad weights.zero_grad() - + mlp.regularize(weights, strength = 0.01) # regularize the weight vector print(step, loss.item()) diff --git a/examples/train-gpt.py b/examples/gpt.py similarity index 55% rename from examples/train-gpt.py rename to examples/gpt.py index a05587e..cc0e29d 100644 --- a/examples/train-gpt.py +++ b/examples/gpt.py @@ -1,8 +1,11 @@ +import time + import torch import numpy as np # Karpathy's smallest GPT config +chars = list("\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") vocab_size = 65 context = 64 num_heads = 4 @@ -10,6 +13,8 @@ d_query = 32 d_value = 32 num_blocks = 4 +assert len(chars) == vocab_size, "`chars` must be aligned to `vocab_size`" + # training hparams @@ -20,6 +25,19 @@ eval_steps = 100 log_interval = 200 +# encoding/decoding + +stoi = {ch: i for i, ch in enumerate(chars)} +itos = {i: ch for i, ch in enumerate(chars)} + +def encode(s): + global stoi + return [stoi[c] for c in s] + +def decode(l): + global itos + return ''.join([itos[i] for i in l]) + # let's start by defining our GPT architecture # (we could instead just import GPT from modula.compound) @@ -80,23 +98,22 @@ def __len__(self): # now let's start doing stuff -if __name__ == "__main__": - +def train(train_filename, validation_filename): # load the data - trainset = SimpleLLMDataset(np.memmap("examples/data/shakespeare/train.bin", dtype=np.uint16, mode='r'), context) - testset = SimpleLLMDataset(np.memmap("examples/data/shakespeare/val.bin", dtype=np.uint16, mode='r'), context) + trainset = SimpleLLMDataset(np.memmap(train_filename, dtype=np.uint16, mode="r"), context) + testset = SimpleLLMDataset(np.memmap(validation_filename, dtype=np.uint16, mode="r"), context) train_sampler = RandomSampler(trainset, batch_size) test_sampler = RandomSampler(testset, batch_size) - train_loader = torch.utils.data.DataLoader( trainset, num_workers=1, pin_memory=True, batch_sampler=train_sampler) - test_loader = torch.utils.data.DataLoader( testset, num_workers=1, pin_memory=True, batch_sampler=test_sampler) + train_loader = torch.utils.data.DataLoader(trainset, num_workers=1, pin_memory=True, batch_sampler=train_sampler) + test_loader = torch.utils.data.DataLoader(testset, num_workers=1, pin_memory=True, batch_sampler=test_sampler) train_iterator = iter(train_loader) test_iterator = iter(test_loader) - getBatch = lambda train: next(train_iterator if train else test_iterator) + get_batch = lambda train: next(train_iterator if train else test_iterator) # load the model @@ -114,12 +131,13 @@ def __len__(self): # train the model + start = time.time() for step in range(steps): if step % log_interval == 0: test_loss = test_acc = 0 for eval_step in range(eval_steps): - data, target = getBatch(train = False) + data, target = get_batch(train=False) output = gpt.forward(data, weights) output = output.view(-1, output.size(-1)) target = target.view(-1) @@ -131,7 +149,7 @@ def __len__(self): test_loss /= eval_steps test_acc /= eval_steps - data, target = getBatch(train = True) + data, target = get_batch(train=True) output = gpt.forward(data, weights) output = output.view(-1, output.size(-1)) target = target.view(-1) @@ -160,6 +178,70 @@ def __len__(self): weights.zero_grad() if step % log_interval == 0: - print( "step:", step, - "\t train loss:", "%.2f" % train_loss.item(), - "\t test loss:", "%.2f" % test_loss.item() ) + print( "step:", step, + "\t train loss:", "%.2f" % train_loss.item(), + "\t test loss:", "%.2f" % test_loss.item() , + f"\t took: {time.time() - start:.2f}s") + start = time.time() + + return weights + + +def inference(weights, input_text, chars_to_generate): + gpt = GPT(vocab_size, context, num_heads, d_embed, d_query, d_value, num_blocks) + print(input_text, end="", flush=True) + context_tokens = torch.tensor(encode(input_text)).unsqueeze(0) + for _ in range(chars_to_generate): + with torch.no_grad(): + output = gpt.forward(context_tokens, weights) + logits = output[0, -1, :] + probs = torch.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).item() + print(decode([next_token]), end="", flush=True) + context_tokens = torch.cat([context_tokens, torch.tensor([[next_token]])], dim=1) + if context_tokens.shape[1] > context: + context_tokens = context_tokens[:, -context:] + + +if __name__ == "__main__": + import argparse + from pathlib import Path + + + data_path = Path(__file__).parent / "data" / "shakespeare" + default_weights_filename = data_path / "weights.pt" + default_train_filename = data_path / "train.bin" + default_validation_filename = data_path / "val.bin" + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="mode", required=True) + + parser_train = subparsers.add_parser("train") + parser_train.add_argument("--weights", "-w", type=Path, default=default_weights_filename, help="Weights filename to save") + parser_train.add_argument("--train", "-t", type=Path, default=default_train_filename, help="Train dataset filename") + parser_train.add_argument("--validation", "-v", type=Path, default=default_validation_filename, help="Validation dataset filename") + + parser_inference = subparsers.add_parser("inference") + parser_inference.add_argument("--weights", "-w", type=Path, default=default_weights_filename, help="Weights filename to load") + parser_inference.add_argument("--chars", "-c", type=int, default=1024, help="Number of chars to generate") + parser_inference.add_argument("input", type=str, help="Text to be feed into the model") + + args = parser.parse_args() + + if args.mode == "train": + weights_filename = args.weights + train_filename = args.train + validation_filename = args.validation + + weights = train(train_filename, validation_filename) + torch.save(weights, weights_filename) + print(f"Weights saved to {weights_filename}") + + elif args.mode == "inference": + weights_filename = args.weights + input_text = args.input + chars_to_generate = args.chars + + print(f"Loading weights from {weights_filename}") + weights = torch.load(weights_filename) + print() + inference(weights, input_text, chars_to_generate)