-
Notifications
You must be signed in to change notification settings - Fork 1.1k
graph: enable RMSNorm OP #4392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
graph: enable RMSNorm OP #4392
Conversation
| return args; | ||
| } | ||
|
|
||
| arg_indices_t get_arg_indices_for_lnorm_and_gnorm(const op_t *op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe change the function name to get_arg_indices_for_norm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed, thanks for your advice
| new_op->set_attr<bool>(op_attr::use_affine, false); | ||
| } | ||
| // layernorm primitive doesn't support rrms | ||
| new_op->set_attr<bool>(op_attr::keep_stats, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean if it's rms norm, keep_stats means rrms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
primitive doesn't support rrms. RMSNorm is lowered down to layernorm. In layernorm primitive, for rmsnorm, this means to output mean and variance but not rrms. So disable it.
| p_rmsnorm->append_decision_function( | ||
| check_input_ndim_from_offset<0, 2, 5>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess during pattern matching stage, the ndims may not be available
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If usr add the ndims in pattern match stage, it may help.
I remembered this is added to filter out 6D layernorm in pytorch models during ipex integration.
8cef641 to
479e108
Compare
479e108 to
ae4bb7e
Compare
ae4bb7e to
8e22b15
Compare
|
make test |
| @@ -0,0 +1,56 @@ | |||
| RMSNorm {#dev_guide_op_rmsnorm} | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mzhukova , could you please help to review this? As requested, we also need to add the RMS normalization support into Graph API. Thanks.
| } | ||
|
|
||
| // check function for data_type of RMSNorm. | ||
| bool check_rmsn_data_type(const op_t *n) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference of data type constraints between rms norm and layer norm?
| @@ -0,0 +1,148 @@ | |||
| /******************************************************************************* | |||
| * Copyright 2024 Intel Corporation | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| * Copyright 2024 Intel Corporation | |
| * Copyright 2025 Intel Corporation |
| // RMS = sqrt(mean(x^2, axis=-1) + epsilon), output = (x / RMS) * scale | ||
| std::vector<float> ref_dst {0.63245427f, 2.52981707f, 0.8485278f, | ||
| 2.26274079f, 0.90535731f, 2.17285755f, 0.93126606f, 2.12860815f, | ||
| 0.94605894f, 2.10235321f, 0.9556189f, 2.08498669f}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correctness validation should be covered by benchdnn already? Any specific case we need to cover by gtests?
enable RMSNorm in oneDNN Graph
RFC: rfcs: graph api: support rms norm #4291
JIRA: MFDNN-14313
[oneDNN Graph][TF-CPU] Support RMSNorm via OneDNN Graph API