Skip to content

Conversation

@rongzha1
Copy link
Contributor

enable RMSNorm in oneDNN Graph

RFC: rfcs: graph api: support rms norm #4291

JIRA: MFDNN-14313
[oneDNN Graph][TF-CPU] Support RMSNorm via OneDNN Graph API

@rongzha1 rongzha1 requested review from a team as code owners November 27, 2025 03:46
@github-actions github-actions bot added documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc component:api Codeowner: @oneapi-src/onednn-arch component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch labels Nov 27, 2025
return args;
}

arg_indices_t get_arg_indices_for_lnorm_and_gnorm(const op_t *op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe change the function name to get_arg_indices_for_norm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed, thanks for your advice

new_op->set_attr<bool>(op_attr::use_affine, false);
}
// layernorm primitive doesn't support rrms
new_op->set_attr<bool>(op_attr::keep_stats, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean if it's rms norm, keep_stats means rrms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

primitive doesn't support rrms. RMSNorm is lowered down to layernorm. In layernorm primitive, for rmsnorm, this means to output mean and variance but not rrms. So disable it.

Comment on lines +449 to +447
p_rmsnorm->append_decision_function(
check_input_ndim_from_offset<0, 2, 5>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess during pattern matching stage, the ndims may not be available

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If usr add the ndims in pattern match stage, it may help.
I remembered this is added to filter out 6D layernorm in pytorch models during ipex integration.

@rongzha1 rongzha1 force-pushed the rzhang/graph_op_rmsnorm branch 2 times, most recently from 8cef641 to 479e108 Compare December 2, 2025 16:05
@rongzha1 rongzha1 force-pushed the rzhang/graph_op_rmsnorm branch from 479e108 to ae4bb7e Compare December 3, 2025 04:02
@rongzha1 rongzha1 force-pushed the rzhang/graph_op_rmsnorm branch from ae4bb7e to 8e22b15 Compare December 3, 2025 04:13
@rongzha1
Copy link
Contributor Author

rongzha1 commented Dec 3, 2025

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@@ -0,0 +1,56 @@
RMSNorm {#dev_guide_op_rmsnorm}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mzhukova , could you please help to review this? As requested, we also need to add the RMS normalization support into Graph API. Thanks.

}

// check function for data_type of RMSNorm.
bool check_rmsn_data_type(const op_t *n) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference of data type constraints between rms norm and layer norm?

@@ -0,0 +1,148 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Copyright 2024 Intel Corporation
* Copyright 2025 Intel Corporation

// RMS = sqrt(mean(x^2, axis=-1) + epsilon), output = (x / RMS) * scale
std::vector<float> ref_dst {0.63245427f, 2.52981707f, 0.8485278f,
2.26274079f, 0.90535731f, 2.17285755f, 0.93126606f, 2.12860815f,
0.94605894f, 2.10235321f, 0.9556189f, 2.08498669f};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness validation should be covered by benchdnn already? Any specific case we need to cover by gtests?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component:api Codeowner: @oneapi-src/onednn-arch component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants