Skip to content

Commit bf6ba55

Browse files
MrSidimsjsji
authored andcommitted
Remove legacy joint matrix instructions (#3438)
Brand-new spec: #12497 Signed-off-by: Dmitry Sidorov <dmitry.sidorov@intel.com> Co-authored-by: Viktoria Maximova <viktoria.maksimova@intel.com> Original commit: KhronosGroup/SPIRV-LLVM-Translator@60d78aa6d1d98cb
1 parent 1bb3ca2 commit bf6ba55

22 files changed

+50
-1239
lines changed

llvm-spirv/lib/SPIRV/OCLUtil.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,6 @@ SPIRAddressSpace getOCLOpaqueTypeAddrSpace(Op OpCode) {
911911
case OpConstantSampler:
912912
case OpTypeSampler:
913913
return SPIRV_SAMPLER_T_ADDR_SPACE;
914-
case internal::OpTypeJointMatrixINTEL:
915-
case internal::OpTypeJointMatrixINTELv2:
916914
case OpTypeCooperativeMatrixKHR:
917915
case internal::OpTypeTaskSequenceINTEL:
918916
return SPIRAS_Global;

llvm-spirv/lib/SPIRV/SPIRVInternal.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,6 @@ const static char ConstantSampler[] = "ConstantSampler";
317317
const static char PipeStorage[] = "PipeStorage";
318318
const static char ConstantPipeStorage[] = "ConstantPipeStorage";
319319
const static char VmeImageINTEL[] = "VmeImageINTEL";
320-
const static char JointMatrixINTEL[] = "JointMatrixINTEL";
321320
const static char CooperativeMatrixKHR[] = "CooperativeMatrixKHR";
322321
const static char BufferSurfaceINTEL[] = "BufferSurfaceINTEL";
323322
} // namespace kSPIRVTypeName
@@ -972,7 +971,6 @@ template <> inline void SPIRVMap<std::string, Op, SPIRVOpaqueType>::init() {
972971
_SPIRV_OP(BufferSurfaceINTEL)
973972
_SPIRV_OP(CooperativeMatrixKHR)
974973
#undef _SPIRV_OP
975-
add("JointMatrixINTEL", internal::OpTypeJointMatrixINTEL);
976974
add("TaskSequenceINTEL", internal::OpTypeTaskSequenceINTEL);
977975
}
978976

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -477,31 +477,6 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
477477
}
478478
return mapType(T, Ty);
479479
}
480-
case internal::OpTypeJointMatrixINTEL: {
481-
auto *MT = static_cast<SPIRVTypeJointMatrixINTEL *>(T);
482-
auto R = static_cast<SPIRVConstant *>(MT->getRows())->getZExtIntValue();
483-
auto C = static_cast<SPIRVConstant *>(MT->getColumns())->getZExtIntValue();
484-
std::vector<unsigned> Params = {(unsigned)R, (unsigned)C};
485-
if (auto *Layout = MT->getLayout())
486-
Params.push_back(static_cast<SPIRVConstant *>(Layout)->getZExtIntValue());
487-
Params.push_back(
488-
static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue());
489-
if (auto *Use = MT->getUse())
490-
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
491-
auto *CTI = MT->getComponentTypeInterpretation();
492-
if (!CTI)
493-
return mapType(
494-
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
495-
transType(MT->getCompType()), Params));
496-
const unsigned CTIValue =
497-
static_cast<SPIRVConstant *>(CTI)->getZExtIntValue();
498-
assert(CTIValue <= internal::InternalJointMatrixCTI::PackedInt4 &&
499-
"Unknown matrix component type interpretation");
500-
Params.push_back(CTIValue);
501-
return mapType(
502-
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
503-
transType(MT->getCompType()), Params));
504-
}
505480
case OpTypeCooperativeMatrixKHR: {
506481
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(T);
507482
unsigned Scope =
@@ -2564,7 +2539,6 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
25642539
auto *Load = new LoadInst(ST, Alloca, "load", false, BB);
25652540
return mapValue(BV, Load);
25662541
}
2567-
case internal::OpTypeJointMatrixINTEL:
25682542
case OpTypeCooperativeMatrixKHR:
25692543
case internal::OpTypeTaskSequenceINTEL:
25702544
return mapValue(BV, transSPIRVBuiltinFromInst(CC, BB));
@@ -2595,9 +2569,6 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
25952569
case OpVectorExtractDynamic: {
25962570
auto *VED = static_cast<SPIRVVectorExtractDynamic *>(BV);
25972571
SPIRVValue *Vec = VED->getVector();
2598-
if (Vec->getType()->getOpCode() == internal::OpTypeJointMatrixINTEL) {
2599-
return mapValue(BV, transSPIRVBuiltinFromInst(VED, BB));
2600-
}
26012572
return mapValue(
26022573
BV, ExtractElementInst::Create(transValue(Vec, F, BB),
26032574
transValue(VED->getIndex(), F, BB),
@@ -2628,9 +2599,6 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
26282599
case OpVectorInsertDynamic: {
26292600
auto *VID = static_cast<SPIRVVectorInsertDynamic *>(BV);
26302601
SPIRVValue *Vec = VID->getVector();
2631-
if (Vec->getType()->getOpCode() == internal::OpTypeJointMatrixINTEL) {
2632-
return mapValue(BV, transSPIRVBuiltinFromInst(VID, BB));
2633-
}
26342602
return mapValue(
26352603
BV, InsertElementInst::Create(
26362604
transValue(Vec, F, BB), transValue(VID->getComponent(), F, BB),
@@ -3916,7 +3884,6 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
39163884
case OpUDotAccSatKHR:
39173885
case OpSUDotAccSatKHR:
39183886
case OpReadClockKHR:
3919-
case internal::OpJointMatrixLoadINTEL:
39203887
case OpCooperativeMatrixLoadKHR:
39213888
case internal::OpCooperativeMatrixLoadCheckedINTEL:
39223889
case internal::OpCooperativeMatrixLoadOffsetINTEL:

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -596,18 +596,6 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
596596
ArrayRef<unsigned> Ops = TargetTy->int_params();
597597
return mapType(T, BM->addBufferSurfaceINTELType(CastAccess(Ops[0])));
598598
}
599-
case internal::OpTypeJointMatrixINTEL: {
600-
// The expected representation is:
601-
// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%,
602-
// %layout%, %scope%, %use%,
603-
// (optional) %element_type_interpretation%)
604-
auto *ElemTy = transType(TargetTy->getTypeParameter(0));
605-
ArrayRef<unsigned> Ops = TargetTy->int_params();
606-
std::vector<SPIRVValue *> Args;
607-
for (const auto &Op : Ops)
608-
Args.emplace_back(transConstant(getUInt32(M, Op)));
609-
return mapType(T, BM->addJointMatrixINTELType(ElemTy, Args));
610-
}
611599
case OpTypeCooperativeMatrixKHR: {
612600
// The expected representation is:
613601
// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%,

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,6 @@ SPIRVEntry *SPIRVEntry::create(Op OpCode) {
8686
static const OpToFactoryMapTy OpToFactoryMap(std::begin(Table),
8787
std::end(Table));
8888

89-
// TODO: To remove this when we make a switch to new version
90-
if (OpCode == internal::OpTypeJointMatrixINTELv2)
91-
OpCode = internal::OpTypeJointMatrixINTEL;
92-
9389
// OpAtomicCompareExchangeWeak is removed starting from SPIR-V 1.4
9490
if (OpCode == OpAtomicCompareExchangeWeak)
9591
OpCode = OpAtomicCompareExchange;

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,11 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
206206
ADD_VEC_INIT(CapabilitySubgroupAvcMotionEstimationChromaINTEL,
207207
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
208208
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
209-
{internal::CapabilityJointMatrixINTEL});
210-
ADD_VEC_INIT(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
211-
{internal::CapabilityJointMatrixINTEL});
212-
ADD_VEC_INIT(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
213-
{internal::CapabilityJointMatrixINTEL});
214-
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
215-
{internal::CapabilityJointMatrixINTEL});
216-
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
217-
{internal::CapabilityJointMatrixINTEL});
209+
{CapabilityCooperativeMatrixKHR});
210+
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixTF32ComponentTypeINTEL,
211+
{CapabilityCooperativeMatrixKHR});
212+
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixBFloat16ComponentTypeINTEL,
213+
{CapabilityCooperativeMatrixKHR});
218214
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixPrefetchINTEL,
219215
{CapabilityCooperativeMatrixKHR});
220216
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,8 +2126,6 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
21262126
break;
21272127
case OpTypeArray:
21282128
case OpTypeStruct:
2129-
case internal::OpTypeJointMatrixINTEL:
2130-
case internal::OpTypeJointMatrixINTELv2:
21312129
case OpTypeCooperativeMatrixKHR:
21322130
break;
21332131
default:
@@ -2387,8 +2385,7 @@ class SPIRVVectorExtractDynamic : public SPIRVInstruction {
23872385
SPIRVInstruction::validate();
23882386
if (getValue(VectorId)->isForward())
23892387
return;
2390-
assert(getValueType(VectorId)->isTypeVector() ||
2391-
getValueType(VectorId)->isTypeJointMatrixINTEL());
2388+
assert(getValueType(VectorId)->isTypeVector());
23922389
}
23932390
SPIRVId VectorId;
23942391
SPIRVId IndexId;
@@ -2425,8 +2422,7 @@ class SPIRVVectorInsertDynamic : public SPIRVInstruction {
24252422
SPIRVInstruction::validate();
24262423
if (getValue(VectorId)->isForward())
24272424
return;
2428-
assert(getValueType(VectorId)->isTypeVector() ||
2429-
getValueType(VectorId)->isTypeJointMatrixINTEL());
2425+
assert(getValueType(VectorId)->isTypeVector());
24302426
}
24312427
SPIRVId VectorId;
24322428
SPIRVId IndexId;
@@ -3604,8 +3600,9 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
36043600
SPIRVCapVec getRequiredCapability() const override {
36053601
SPIRVType *ResCompTy = this->getType();
36063602
if (ResCompTy->isTypeCooperativeMatrixKHR())
3607-
return getVec(CapabilityBFloat16ConversionINTEL,
3608-
internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
3603+
return getVec(
3604+
CapabilityBFloat16ConversionINTEL,
3605+
internal::CapabilityCooperativeMatrixBFloat16ComponentTypeINTEL);
36093606
return getVec(CapabilityBFloat16ConversionINTEL);
36103607
}
36113608

@@ -3700,26 +3697,6 @@ class SPIRVJointMatrixINTELInstBase : public SPIRVInstTemplateBase {
37003697
}
37013698
};
37023699

3703-
class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
3704-
SPIRVCapVec getRequiredCapability() const override {
3705-
return getVec(internal::CapabilityJointMatrixINTEL);
3706-
}
3707-
};
3708-
3709-
#define _SPIRV_OP(x, ...) \
3710-
typedef SPIRVInstTemplate<SPIRVJointMatrixINTELInst, internal::Op##x##INTEL, \
3711-
__VA_ARGS__> \
3712-
SPIRV##x##INTEL;
3713-
_SPIRV_OP(JointMatrixLoad, true, 6, true)
3714-
_SPIRV_OP(JointMatrixStore, false, 5, true)
3715-
_SPIRV_OP(JointMatrixMad, true, 6, true)
3716-
_SPIRV_OP(JointMatrixSUMad, true, 6, true)
3717-
_SPIRV_OP(JointMatrixUSMad, true, 6, true)
3718-
_SPIRV_OP(JointMatrixUUMad, true, 6, true)
3719-
// TODO: move to SPIRVJointMatrixINTELWorkItemInst
3720-
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
3721-
#undef _SPIRV_OP
3722-
37233700
class SPIRVJointMatrixINTELWorkItemInst : public SPIRVJointMatrixINTELInstBase {
37243701
protected:
37253702
SPIRVCapVec getRequiredCapability() const override {
@@ -4031,8 +4008,9 @@ class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
40314008
SPIRVCapVec getRequiredCapability() const override {
40324009
SPIRVType *ResCompTy = this->getType();
40334010
if (ResCompTy->isTypeCooperativeMatrixKHR())
4034-
return getVec(CapabilityTensorFloat32RoundingINTEL,
4035-
internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
4011+
return getVec(
4012+
CapabilityTensorFloat32RoundingINTEL,
4013+
internal::CapabilityCooperativeMatrixTF32ComponentTypeINTEL);
40364014
return getVec(CapabilityTensorFloat32RoundingINTEL);
40374015
}
40384016

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,6 @@ class SPIRVModuleImpl : public SPIRVModule {
328328
SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) override;
329329
void closeStructType(SPIRVTypeStruct *T, bool) override;
330330
SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) override;
331-
SPIRVTypeJointMatrixINTEL *
332-
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) override;
333331
SPIRVTypeCooperativeMatrixKHR *
334332
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) override;
335333
SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() override;
@@ -1170,12 +1168,6 @@ SPIRVTypeVector *SPIRVModuleImpl::addVectorType(SPIRVType *CompType,
11701168
return addType(Ty);
11711169
}
11721170

1173-
SPIRVTypeJointMatrixINTEL *
1174-
SPIRVModuleImpl::addJointMatrixINTELType(SPIRVType *CompType,
1175-
std::vector<SPIRVValue *> Args) {
1176-
return addType(new SPIRVTypeJointMatrixINTEL(this, getId(), CompType, Args));
1177-
}
1178-
11791171
SPIRVTypeCooperativeMatrixKHR *
11801172
SPIRVModuleImpl::addCooperativeMatrixKHRType(SPIRVType *CompType,
11811173
std::vector<SPIRVValue *> Args) {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,6 @@ class SPIRVModule {
289289
virtual SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) = 0;
290290
virtual void closeStructType(SPIRVTypeStruct *, bool) = 0;
291291
virtual SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord) = 0;
292-
virtual SPIRVTypeJointMatrixINTEL *
293-
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
294292
virtual SPIRVTypeCooperativeMatrixKHR *
295293
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
296294
virtual SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() = 0;

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -672,23 +672,18 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
672672
// From spirv_internal.hpp
673673
add(internal::CapabilityOptNoneINTEL, "OptNoneINTEL");
674674
add(internal::CapabilityTokenTypeINTEL, "TokenTypeINTEL");
675-
add(internal::CapabilityJointMatrixINTEL, "JointMatrixINTEL");
676675
add(internal::CapabilityHWThreadQueryINTEL, "HWThreadQueryINTEL");
677676
add(internal::CapabilityGlobalVariableDecorationsINTEL,
678677
"GlobalVariableDecorationsINTEL");
679678
add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
680679
add(CapabilityTensorFloat32RoundingINTEL,
681680
"TensorFloat32RoundingINTEL");
682681
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
683-
"JointMatrixWIInstructionsINTEL");
684-
add(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
685-
"JointMatrixTF32ComponentTypeINTEL");
686-
add(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
687-
"JointMatrixBF16ComponentTypeINTEL");
688-
add(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
689-
"JointMatrixPackedInt2ComponentTypeINTEL");
690-
add(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
691-
"JointMatrixPackedInt4ComponentTypeINTEL");
682+
"CooperativeMatrixInvocationInstructionsINTEL");
683+
add(internal::CapabilityCooperativeMatrixTF32ComponentTypeINTEL,
684+
"CooperativeMatrixTF32ComponentTypeINTEL");
685+
add(internal::CapabilityCooperativeMatrixBFloat16ComponentTypeINTEL,
686+
"CooperativeMatrixBFloat16ComponentTypeINTEL");
692687
add(internal::CapabilityCooperativeMatrixPrefetchINTEL,
693688
"CooperativeMatrixPrefetchINTEL");
694689
add(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,

0 commit comments

Comments
 (0)