@@ -123,17 +123,6 @@ at::Tensor dil_qembeddingbag(
123123} // namespace cpu
124124} // namespace torch_ipex
125125
126- namespace {
127- TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
128- m.def (
129- torch::schema (
130- " torch_ipex::embedding_bag(Tensor weight, Tensor indices, Tensor "
131- " offsets, bool sparse, bool include_last_offset) -> Tensor" ,
132- c10::AliasAnalysisKind::PURE_FUNCTION),
133- torch_ipex::embedding_bag);
134- }
135- } // namespace
136-
137126namespace torch_ipex {
138127namespace autocast {
139128
@@ -156,10 +145,6 @@ at::Tensor embedding_bag(
156145 return op.call (casted_weight, indices, offsets, sparse, include_last_offset);
157146}
158147
159- TORCH_LIBRARY_IMPL (torch_ipex, AutocastCPU, m) {
160- m.impl (" embedding_bag" , torch_ipex::autocast::embedding_bag);
161- }
162-
163148} // namespace autocast
164149} // namespace torch_ipex
165150
@@ -179,4 +164,36 @@ at::Tensor embedding_bag(
179164 weight, indices, offsets, sparse, include_last_offset);
180165}
181166
167+ at::Tensor embedding_bag_meta (
168+ const at::Tensor& weight,
169+ const at::Tensor& indices,
170+ const at::Tensor& offsets,
171+ bool sparse,
172+ bool include_last_offset) {
173+ auto num_bags = offsets.sym_size (0 );
174+ if (indices.dim () == 2 ) {
175+ num_bags = indices.sym_size (0 );
176+ }
177+ c10::SymDimVector output_size (2 );
178+ output_size[0 ] = num_bags;
179+ output_size[1 ] = weight.sym_size (1 );
180+ auto output = at::empty_symint (output_size, weight.options ());
181+ return output;
182+ }
183+
182184} // namespace torch_ipex
185+
186+ namespace {
187+ TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
188+ m.def (
189+ " embedding_bag(Tensor weight, Tensor indices, Tensor "
190+ " offsets, bool sparse, bool include_last_offset) -> Tensor" );
191+ m.impl (" embedding_bag" , c10::DispatchKey::CPU, torch_ipex::embedding_bag);
192+ m.impl (
193+ " embedding_bag" , c10::DispatchKey::Meta, torch_ipex::embedding_bag_meta);
194+ m.impl (
195+ " embedding_bag" ,
196+ c10::DispatchKey::AutocastCPU,
197+ torch_ipex::autocast::embedding_bag);
198+ }
199+ } // namespace
0 commit comments