Skip to content

Commit ed1becb

Browse files
MrSidimsjsji
authored andcommitted
Implement SPV_INTEL_float4 and SPV_INTEL_fp_conversions extensions (#3419)
As well as their appropriate conversions via __builtin_spirv mechanism. Specification: #20467 Signed-off-by: Dmitry Sidorov <dmitry.sidorov@intel.com> Original commit: KhronosGroup/SPIRV-LLVM-Translator@af594c0b45250fe
1 parent 548f339 commit ed1becb

File tree

15 files changed

+1123
-106
lines changed

15 files changed

+1123
-106
lines changed

llvm-spirv/include/LLVMSPIRVExtensions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,5 @@ EXT(SPV_INTEL_shader_atomic_bfloat16)
8484
EXT(SPV_EXT_float8)
8585
EXT(SPV_INTEL_predicated_io)
8686
EXT(SPV_INTEL_sigmoid)
87+
EXT(SPV_INTEL_float4)
88+
EXT(SPV_INTEL_fp_conversions)

llvm-spirv/lib/SPIRV/SPIRVInternal.h

Lines changed: 112 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,7 @@ enum FPEncodingWrap {
10491049
BF16 = FPEncoding::FPEncodingBFloat16KHR,
10501050
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
10511051
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
1052+
E2M1 = internal::FPEncodingFloat4E2M1INTEL,
10521053
};
10531054

10541055
// Structure describing non-trivial conversions (FP8 and int4)
@@ -1077,36 +1078,117 @@ typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
10771078

10781079
// clang-format off
10791080
template <> inline void FPConvertToEncodingMap::init() {
1080-
// 8-bit conversions
1081-
add("ConvertE4M3ToFP16EXT",
1082-
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1083-
add("ConvertE5M2ToFP16EXT",
1084-
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1085-
add("ConvertE4M3ToBF16EXT",
1086-
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1087-
add("ConvertE5M2ToBF16EXT",
1088-
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1089-
add("ConvertFP16ToE4M3EXT",
1090-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1091-
add("ConvertFP16ToE5M2EXT",
1092-
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1093-
add("ConvertBF16ToE4M3EXT",
1094-
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1095-
add("ConvertBF16ToE5M2EXT",
1096-
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1097-
1098-
add("ConvertInt4ToE4M3INTEL",
1099-
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1100-
add("ConvertInt4ToE5M2INTEL",
1101-
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1102-
add("ConvertInt4ToFP16INTEL",
1103-
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1104-
add("ConvertInt4ToBF16INTEL",
1105-
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1106-
add("ConvertFP16ToInt4INTEL",
1107-
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1108-
add("ConvertBF16ToInt4INTEL",
1109-
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1081+
// 4-bit conversions
1082+
add("ConvertE2M1ToE4M3INTEL",
1083+
{FPEncodingWrap::E2M1, FPEncodingWrap::E4M3, OpFConvert});
1084+
add("ConvertE2M1ToE5M2INTEL",
1085+
{FPEncodingWrap::E2M1, FPEncodingWrap::E5M2, OpFConvert});
1086+
add("ConvertE2M1ToFP16INTEL",
1087+
{FPEncodingWrap::E2M1, FPEncodingWrap::IEEE754, OpFConvert});
1088+
add("ConvertE2M1ToBF16INTEL",
1089+
{FPEncodingWrap::E2M1, FPEncodingWrap::BF16, OpFConvert});
1090+
1091+
add("ConvertInt4ToE4M3INTEL",
1092+
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1093+
add("ConvertInt4ToE5M2INTEL",
1094+
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1095+
add("ConvertInt4ToFP16INTEL",
1096+
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1097+
add("ConvertInt4ToBF16INTEL",
1098+
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1099+
add("ConvertInt4ToInt8INTEL",
1100+
{FPEncodingWrap::Integer, FPEncodingWrap::Integer, OpSConvert});
1101+
1102+
add("ConvertFP16ToE2M1INTEL",
1103+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1, OpFConvert});
1104+
add("ConvertBF16ToE2M1INTEL",
1105+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1, OpFConvert});
1106+
add("ConvertFP16ToInt4INTEL",
1107+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1108+
add("ConvertBF16ToInt4INTEL",
1109+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1110+
1111+
// 8-bit conversions
1112+
add("ConvertE4M3ToFP16EXT",
1113+
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1114+
add("ConvertE5M2ToFP16EXT",
1115+
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1116+
add("ConvertE4M3ToBF16EXT",
1117+
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1118+
add("ConvertE5M2ToBF16EXT",
1119+
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1120+
add("ConvertFP16ToE4M3EXT",
1121+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1122+
add("ConvertFP16ToE5M2EXT",
1123+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1124+
add("ConvertBF16ToE4M3EXT",
1125+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1126+
add("ConvertBF16ToE5M2EXT",
1127+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1128+
1129+
// SPV_INTEL_fp_conversions
1130+
add("ClampConvertFP16ToE2M1INTEL",
1131+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1132+
internal::OpClampConvertFToFINTEL});
1133+
add("ClampConvertBF16ToE2M1INTEL",
1134+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1135+
internal::OpClampConvertFToFINTEL});
1136+
add("ClampConvertFP16ToE4M3INTEL",
1137+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1138+
internal::OpClampConvertFToFINTEL});
1139+
add("ClampConvertBF16ToE4M3INTEL",
1140+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1141+
internal::OpClampConvertFToFINTEL});
1142+
add("ClampConvertFP16ToE5M2INTEL",
1143+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1144+
internal::OpClampConvertFToFINTEL});
1145+
add("ClampConvertBF16ToE5M2INTEL",
1146+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1147+
internal::OpClampConvertFToFINTEL});
1148+
add("ClampConvertFP16ToInt4INTEL",
1149+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1150+
internal::OpClampConvertFToSINTEL});
1151+
add("ClampConvertBF16ToInt4INTEL",
1152+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1153+
internal::OpClampConvertFToSINTEL});
1154+
1155+
add("StochasticRoundFP16ToE5M2INTEL",
1156+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1157+
internal::OpStochasticRoundFToFINTEL});
1158+
add("StochasticRoundFP16ToE4M3INTEL",
1159+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1160+
internal::OpStochasticRoundFToFINTEL});
1161+
add("StochasticRoundBF16ToE5M2INTEL",
1162+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1163+
internal::OpStochasticRoundFToFINTEL});
1164+
add("StochasticRoundBF16ToE4M3INTEL",
1165+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1166+
internal::OpStochasticRoundFToFINTEL});
1167+
add("StochasticRoundFP16ToE2M1INTEL",
1168+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E2M1,
1169+
internal::OpStochasticRoundFToFINTEL});
1170+
add("StochasticRoundBF16ToE2M1INTEL",
1171+
{FPEncodingWrap::BF16, FPEncodingWrap::E2M1,
1172+
internal::OpStochasticRoundFToFINTEL});
1173+
add("ClampStochasticRoundFP16ToInt4INTEL",
1174+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer,
1175+
internal::OpClampStochasticRoundFToSINTEL});
1176+
add("ClampStochasticRoundBF16ToInt4INTEL",
1177+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer,
1178+
internal::OpClampStochasticRoundFToSINTEL});
1179+
1180+
add("ClampStochasticRoundFP16ToE5M2INTEL",
1181+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2,
1182+
internal::OpClampStochasticRoundFToFINTEL});
1183+
add("ClampStochasticRoundFP16ToE4M3INTEL",
1184+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3,
1185+
internal::OpClampStochasticRoundFToFINTEL});
1186+
add("ClampStochasticRoundBF16ToE5M2INTEL",
1187+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2,
1188+
internal::OpClampStochasticRoundFToFINTEL});
1189+
add("ClampStochasticRoundBF16ToE4M3INTEL",
1190+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3,
1191+
internal::OpClampStochasticRoundFToFINTEL});
11101192
}
11111193

11121194
// clang-format on

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,11 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
299299

300300
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
301301
switch (T->getFloatBitWidth()) {
302+
case 4:
303+
// No LLVM IR counter part for FP4 - map it on i4.
304+
return Type::getIntNTy(*Context, 4);
302305
case 8:
303-
// No LLVM IR counter part for FP8 - map it on i8
306+
// No LLVM IR counter part for FP8 - map it on i8.
304307
return Type::getIntNTy(*Context, 8);
305308
case 16:
306309
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
@@ -1066,11 +1069,12 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10661069
return FPEncodingWrap::IEEE754;
10671070
};
10681071

1069-
auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1070-
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1072+
auto IsFP4OrFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1073+
return Encoding == FPEncodingWrap::E4M3 ||
1074+
Encoding == FPEncodingWrap::E5M2 || Encoding == FPEncodingWrap::E2M1;
10711075
};
10721076

1073-
switch (BC->getOpCode()) {
1077+
switch (static_cast<unsigned>(BC->getOpCode())) {
10741078
case OpPtrCastToGeneric:
10751079
case OpGenericCastToPtr:
10761080
case OpPtrCastToCrossWorkgroupINTEL:
@@ -1091,6 +1095,11 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10911095
case OpUConvert:
10921096
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
10931097
break;
1098+
case internal::OpClampConvertFToFINTEL:
1099+
case internal::OpClampConvertFToSINTEL:
1100+
case internal::OpStochasticRoundFToFINTEL:
1101+
case internal::OpClampStochasticRoundFToFINTEL:
1102+
case internal::OpClampStochasticRoundFToSINTEL:
10941103
case OpConvertSToF:
10951104
case OpConvertFToS:
10961105
case OpConvertUToF:
@@ -1115,7 +1124,7 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11151124

11161125
FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
11171126
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
1118-
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
1127+
if (IsFP4OrFP8Encoding(SrcEnc) || IsFP4OrFP8Encoding(DstEnc) ||
11191128
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
11201129
FPConversionDesc FPDesc = {SrcEnc, DstEnc, BC->getOpCode()};
11211130
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
@@ -1125,13 +1134,47 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
11251134
std::string BuiltinName =
11261135
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
11271136
BuiltinFuncMangleInfo Info;
1128-
std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1137+
std::string MangledName;
1138+
// Translate additional Ops for stochastic conversions.
1139+
if (OC == internal::OpStochasticRoundFToFINTEL ||
1140+
OC == internal::OpClampStochasticRoundFToFINTEL ||
1141+
OC == internal::OpClampStochasticRoundFToSINTEL) {
1142+
// Seed.
1143+
Ops.emplace_back(transValue(SPVOps[1], F, BB, true));
1144+
OpsTys.emplace_back(Ops[1]->getType());
1145+
constexpr unsigned MaxOpsSize = 3;
1146+
if (SPVOps.size() == MaxOpsSize) {
1147+
// New Seed.
1148+
Ops.emplace_back(transValue(SPVOps[2], F, BB, true));
1149+
1150+
// The following mess is needed to create a function with correct
1151+
// mangling.
1152+
SPIRVType *PtrTy = SPVOps[2]->getType();
1153+
const unsigned AS =
1154+
SPIRSPIRVAddrSpaceMap::rmap(PtrTy->getPointerStorageClass());
1155+
Type *ElementTy = transType(PtrTy->getPointerElementType());
1156+
OpsTys.emplace_back(TypedPointerType::get(ElementTy, AS));
1157+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1158+
// But to create function itself we need untyped pointer type.
1159+
OpsTys[2] = opaquifyType(OpsTys[2]);
1160+
}
1161+
}
1162+
1163+
if (MangledName.empty())
1164+
MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
11291165

11301166
FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
11311167
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
11321168
return CallInst::Create(Func, Ops, "", BB);
11331169
}
11341170
}
1171+
// These conversions can be done without __builtin_spirv prefixed functions
1172+
// as their operand and result types have native representation in LLVM IR.
1173+
if (OC == internal::OpClampConvertFToFINTEL ||
1174+
OC == internal::OpStochasticRoundFToFINTEL ||
1175+
OC == internal::OpClampStochasticRoundFToFINTEL)
1176+
return mapValue(BV, transSPIRVBuiltinFromInst(
1177+
static_cast<SPIRVInstruction *>(BV), BB));
11351178

11361179
if (OC == OpFConvert) {
11371180
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
@@ -3056,7 +3099,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
30563099
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
30573100
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
30583101
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
3059-
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
3102+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
3103+
OutMatrixElementTy->isTypeFloat(
3104+
4, internal::FPEncodingFloat4E2M1INTEL) ||
3105+
InMatrixElementTy->isTypeFloat(4,
3106+
internal::FPEncodingFloat4E2M1INTEL))
30603107
Inst = transConvertInst(BV, F, BB);
30613108
else
30623109
Inst = transSPIRVBuiltinFromInst(BI, BB);
@@ -3065,6 +3112,8 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
30653112
}
30663113
return mapValue(BV, Inst);
30673114
}
3115+
if (isIntelCvtOpCode(OC))
3116+
return mapValue(BV, transConvertInst(BV, F, BB));
30683117
return mapValue(
30693118
BV, transSPIRVBuiltinFromInst(static_cast<SPIRVInstruction *>(BV), BB));
30703119
}
@@ -3881,6 +3930,11 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
38813930
case internal::OpTaskSequenceCreateINTEL:
38823931
case internal::OpConvertHandleToImageINTEL:
38833932
case internal::OpConvertHandleToSampledImageINTEL:
3933+
case internal::OpClampConvertFToFINTEL:
3934+
case internal::OpClampConvertFToSINTEL:
3935+
case internal::OpStochasticRoundFToFINTEL:
3936+
case internal::OpClampStochasticRoundFToFINTEL:
3937+
case internal::OpClampStochasticRoundFToSINTEL:
38843938
AddRetTypePostfix = true;
38853939
break;
38863940
default: {

llvm-spirv/lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ void SPIRVToOCLBase::visitCastInst(CastInst &Cast) {
247247
DstVecTy->getScalarSizeInBits() == 1)
248248
return;
249249

250+
// We don't have OpenCL builtins for 4-bit conversions.
251+
if (DstVecTy->getScalarSizeInBits() == 4 || SrcTy->getScalarSizeInBits() == 4)
252+
return;
253+
250254
// Assemble built-in name -> convert_gentypeN
251255
std::string CastBuiltInName(kOCLBuiltinName::ConvertPrefix);
252256
// Check if this is 'floating point -> unsigned integer' cast

0 commit comments

Comments
 (0)