-
Notifications
You must be signed in to change notification settings - Fork 8
[mlir-gen] Add mlir builders for llama3.1 and tests #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Should this be in |
|
The e2e should be, yup, but this is mostly tests and getters. |
|
I moved the whole thing to examples and added attention the list of tests. |
rengolin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
rolfmorel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Have left some comments inline.
| [xq_scores_map, keys_scores_map, scores_map], | ||
| [parallel, parallel, parallel, parallel, reduction], | ||
| ) | ||
| def compute_scores(q_val, k_val, score_val): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be written as a linalg.contract, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can move generics to contract and elementwise later. TPP-MLIR has linalg generalization because some passes don't work with the new ops.
python/examples/llama/test_llama3.py
Outdated
| module = generate_module(ctx, ir_type) | ||
| bufferize_module(ctx, module) | ||
| schedule = create_schedule(ctx) | ||
| apply_schedule(module, schedule) | ||
| pm = create_pass_pipeline(ctx) | ||
| pm.run(module.operation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| module = generate_module(ctx, ir_type) | |
| bufferize_module(ctx, module) | |
| schedule = create_schedule(ctx) | |
| apply_schedule(module, schedule) | |
| pm = create_pass_pipeline(ctx) | |
| pm.run(module.operation) | |
| module = generate_module(ctx, ir_type) | |
| schedule = create_schedule(ctx) | |
| apply_schedule(module, schedule) |
Just move the passes from inside bufferize_module(ctx, module) and create_pass_pipeline(ctx) into the start and end of the schedule, i.e. with transform.apply_registered_pass.
I know this antipattern originates in an example script we merged, but we should not let this proliferate. It clearly is already confusing people.
| return schedule | ||
|
|
||
|
|
||
| def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: | ||
| interpreter.apply_named_sequence( | ||
| payload_root=kernel, | ||
| transform_root=schedule.body.operations[0], | ||
| transform_module=schedule, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return schedule | |
| def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: | |
| interpreter.apply_named_sequence( | |
| payload_root=kernel, | |
| transform_root=schedule.body.operations[0], | |
| transform_module=schedule, | |
| ) | |
| return named_seq |
If we do this, you can simply do:
schedule = create_schedule()
schedule.apply(module)
If you need access to the Module around the named_sequence, just ask for its .parent.
| @@ -1,2 +1,2 @@ | |||
| import ctypes | |||
| import torch | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this PR didn't introduce it, though looking at it now, I feel we should think about compartmentalizing code that depends on heavy dependencies a bit more. That is, not have it in the same module with code that doesn't have the dependency, e.g. get_packed_arg.
python/examples/llama/test_llama3.py
Outdated
| def create_schedule(ctx: ir.Context) -> ir.Module: | ||
| """ | ||
| Create an MLIR module containing transformation schedule. | ||
| The schedule provides partial lowering to scalar operations. | ||
| Args: | ||
| ctx: MLIR context. | ||
| """ | ||
| with ctx, ir.Location.unknown(context=ctx): | ||
| # Create transform module. | ||
| schedule = ir.Module.create() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def create_schedule(ctx: ir.Context) -> ir.Module: | |
| """ | |
| Create an MLIR module containing transformation schedule. | |
| The schedule provides partial lowering to scalar operations. | |
| Args: | |
| ctx: MLIR context. | |
| """ | |
| with ctx, ir.Location.unknown(context=ctx): | |
| # Create transform module. | |
| schedule = ir.Module.create() | |
| def create_schedule() -> ir.Module: | |
| schedule = ir.Module.create() |
And just de-indent the rest of the function.
python/examples/llama/test_llama3.py
Outdated
| def bufferize_module(ctx: ir.Context, kernel: ir.Module) -> None: | ||
| with ctx: | ||
| pm = PassManager("builtin.module") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def bufferize_module(ctx: ir.Context, kernel: ir.Module) -> None: | |
| with ctx: | |
| pm = PassManager("builtin.module") | |
| def bufferize_module(kernel: ir.Module) -> None: | |
| pm = PassManager("builtin.module") |
python/examples/llama/test_llama3.py
Outdated
| def to_ir_type(type_str, ctx): | ||
| if type_str == "f32": | ||
| return ir.F32Type.get(context=ctx) | ||
| elif type_str == "f64": | ||
| return ir.F64Type.get(context=ctx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def to_ir_type(type_str, ctx): | |
| if type_str == "f32": | |
| return ir.F32Type.get(context=ctx) | |
| elif type_str == "f64": | |
| return ir.F64Type.get(context=ctx) | |
| def to_ir_type(type_str): | |
| if type_str == "f32": | |
| return ir.F32Type.get() | |
| elif type_str == "f64": | |
| return ir.F64Type.get() |
In effect, these .get() methods are doing a ir.Context.current under the hood when you don't pass a context explicitly (just like the Op builders).
rengolin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's a lot of smaller comments that we can leave for post-merge review. This is an example, and a complicated one at that, and we don't want to over-engineer something that will soon move to a better program.
| [xq_scores_map, keys_scores_map, scores_map], | ||
| [parallel, parallel, parallel, parallel, reduction], | ||
| ) | ||
| def compute_scores(q_val, k_val, score_val): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can move generics to contract and elementwise later. TPP-MLIR has linalg generalization because some passes don't work with the new ops.
Putting up this dirty draft for early feedback/questions. I'm putting together some tests to run a e2e llama3.1 going through linalg on tensors. The goal is to generate some nice linalg that would be optimization friendly. At the moment, there are just functional blocks and pieces that are just smoke-tested. These include naive implementations for rotary embeddings, feed forward, rms, and a bunch of other small snippets that are useful to implement the model. These are already enough to put an attention block together. It'd be nice to test it against the original implementation, but that'd require fairscale as a dependency. For now I only added pytest and kept the pipeline as simple as possible. I also reused the example with the schedule, so now it is a part of every test.