Skip to content

Commit 881c6fe

Browse files
authored
add meta backend for EmbeddingBag (#1525) (#1568)
* add meta backend for EmbeddingBag * add UT * modify UT * fix UT
1 parent 046f7df commit 881c6fe

File tree

3 files changed

+61
-15
lines changed

3 files changed

+61
-15
lines changed

csrc/cpu/aten/EmbeddingBag.cpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,6 @@ at::Tensor dil_qembeddingbag(
123123
} // namespace cpu
124124
} // namespace torch_ipex
125125

126-
namespace {
127-
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
128-
m.def(
129-
torch::schema(
130-
"torch_ipex::embedding_bag(Tensor weight, Tensor indices, Tensor "
131-
"offsets, bool sparse, bool include_last_offset) -> Tensor",
132-
c10::AliasAnalysisKind::PURE_FUNCTION),
133-
torch_ipex::embedding_bag);
134-
}
135-
} // namespace
136-
137126
namespace torch_ipex {
138127
namespace autocast {
139128

@@ -156,10 +145,6 @@ at::Tensor embedding_bag(
156145
return op.call(casted_weight, indices, offsets, sparse, include_last_offset);
157146
}
158147

159-
TORCH_LIBRARY_IMPL(torch_ipex, AutocastCPU, m) {
160-
m.impl("embedding_bag", torch_ipex::autocast::embedding_bag);
161-
}
162-
163148
} // namespace autocast
164149
} // namespace torch_ipex
165150

@@ -179,4 +164,36 @@ at::Tensor embedding_bag(
179164
weight, indices, offsets, sparse, include_last_offset);
180165
}
181166

167+
at::Tensor embedding_bag_meta(
168+
const at::Tensor& weight,
169+
const at::Tensor& indices,
170+
const at::Tensor& offsets,
171+
bool sparse,
172+
bool include_last_offset) {
173+
auto num_bags = offsets.sym_size(0);
174+
if (indices.dim() == 2) {
175+
num_bags = indices.sym_size(0);
176+
}
177+
c10::SymDimVector output_size(2);
178+
output_size[0] = num_bags;
179+
output_size[1] = weight.sym_size(1);
180+
auto output = at::empty_symint(output_size, weight.options());
181+
return output;
182+
}
183+
182184
} // namespace torch_ipex
185+
186+
namespace {
187+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
188+
m.def(
189+
"embedding_bag(Tensor weight, Tensor indices, Tensor "
190+
"offsets, bool sparse, bool include_last_offset) -> Tensor");
191+
m.impl("embedding_bag", c10::DispatchKey::CPU, torch_ipex::embedding_bag);
192+
m.impl(
193+
"embedding_bag", c10::DispatchKey::Meta, torch_ipex::embedding_bag_meta);
194+
m.impl(
195+
"embedding_bag",
196+
c10::DispatchKey::AutocastCPU,
197+
torch_ipex::autocast::embedding_bag);
198+
}
199+
} // namespace

csrc/cpu/aten/EmbeddingBag.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ at::Tensor embedding_bag(
1111
bool sparse,
1212
bool include_last_offset);
1313

14+
at::Tensor embedding_bag_meta(
15+
const at::Tensor& weight,
16+
const at::Tensor& indices,
17+
const at::Tensor& offsets,
18+
bool sparse,
19+
bool include_last_offset);
20+
1421
} // namespace torch_ipex
1522

1623
namespace torch_ipex {

tests/cpu/test_emb.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
ipex_emb_fn = ipex.nn.functional._embeddingbag._embeddingbag
1010
aten_emb_fn = ipex.nn.functional._embeddingbag.torch_embedding_bag
1111

12+
class Embeddingbag(torch.nn.Module):
13+
def __init__(self):
14+
super(Embeddingbag, self).__init__()
15+
self.embeddingbag = nn.EmbeddingBag(10, 3, mode='sum', sparse=True)
16+
17+
def forward(self, input, offsets):
18+
return self.embeddingbag(input, offsets)
19+
1220
class TestEMB(TestCase):
1321

1422
def _test_emb(
@@ -96,5 +104,19 @@ def test_emb_jit_scriptable(self):
96104
out = script_emb(input, offsets)
97105
self.assertEqual(out, ref_out)
98106

107+
def test_emb_torch_compile(self):
108+
emb = Embeddingbag().eval()
109+
input = torch.LongTensor([1,2,4,5,4,3,2,9])
110+
offsets = torch.LongTensor([0,1,2,3,4,5,6,7])
111+
# TODO: add dynamic tests when 'ipex' backend supports it.
112+
for dtype, backend, dynamic in itertools.product([torch.float32, torch.bfloat16], ['ipex', 'inductor'], [False]):
113+
torch._dynamo.reset()
114+
emb_torchcompile = torch.compile(emb, backend=backend, dynamic=dynamic)
115+
with torch.cpu.amp.autocast(enabled=(dtype==torch.bfloat16)), torch.no_grad():
116+
y0 = emb(input, offsets)
117+
y1 = emb_torchcompile(input, offsets)
118+
self.assertEqual(y0, y1)
119+
self.assertEqual(y1.dtype, dtype)
120+
99121
if __name__ == '__main__':
100122
test = unittest.main()

0 commit comments

Comments
 (0)