[llvm] 57f7937 - [SPIR-V]: Add SPIR-V extension: SPV_KHR_cooperative_matrix (#96091)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 24 00:57:31 PDT 2024
Author: Vyacheslav Levytskyy
Date: 2024-06-24T09:57:27+02:00
New Revision: 57f79371a5c08e1328e85b68b757cd5547f2bf62
URL: https://github.com/llvm/llvm-project/commit/57f79371a5c08e1328e85b68b757cd5547f2bf62
DIFF: https://github.com/llvm/llvm-project/commit/57f79371a5c08e1328e85b68b757cd5547f2bf62.diff
LOG: [SPIR-V]: Add SPIR-V extension: SPV_KHR_cooperative_matrix (#96091)
This PR adds SPIR-V extension SPV_KHR_cooperative_matrix that "adds a
new set of types known as "cooperative matrix" types, where the storage
for and computations performed on the matrix are spread across a set of
invocations such as a subgroup" (see
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc).
This PR also fixes https://github.com/llvm/llvm-project/issues/96170, a
new test cases is attached
(llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll).
Added:
llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll
Modified:
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
llvm/lib/Target/SPIRV/SPIRVBuiltins.td
llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c14e5098be711..f5f36075d4a31 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -558,16 +558,21 @@ static Register buildMemSemanticsReg(Register SemanticsRegister,
static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
const SPIRV::IncomingCall *Call,
- Register TypeReg = Register(0)) {
+ Register TypeReg,
+ ArrayRef<uint32_t> ImmArgs = {}) {
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
auto MIB = MIRBuilder.buildInstr(Opcode);
if (TypeReg.isValid())
MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
- for (Register ArgReg : Call->Arguments) {
+ unsigned Sz = Call->Arguments.size() - ImmArgs.size();
+ for (unsigned i = 0; i < Sz; ++i) {
+ Register ArgReg = Call->Arguments[i];
if (!MRI->getRegClassOrNull(ArgReg))
MRI->setRegClass(ArgReg, &SPIRV::IDRegClass);
MIB.addUse(ArgReg);
}
+ for (uint32_t ImmArg : ImmArgs)
+ MIB.addImm(ImmArg);
return true;
}
@@ -575,7 +580,7 @@ static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder) {
if (Call->isSpirvOp())
- return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call);
+ return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0));
assert(Call->Arguments.size() == 2 &&
"Need 2 arguments for atomic init translation");
@@ -633,7 +638,7 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
- return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call);
+ return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
Register ScopeRegister =
buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -870,7 +875,7 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
- return buildOpFromWrapper(MIRBuilder, Opcode, Call);
+ return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0));
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
@@ -1824,6 +1829,45 @@ static bool generateSelectInst(const SPIRV::IncomingCall *Call,
return true;
}
+static bool generateConstructInst(const SPIRV::IncomingCall *Call,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ return buildOpFromWrapper(MIRBuilder, SPIRV::OpCompositeConstruct, Call,
+ GR->getSPIRVTypeID(Call->ReturnType));
+}
+
+static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+ unsigned Opcode =
+ SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
+ bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
+ unsigned ArgSz = Call->Arguments.size();
+ unsigned LiteralIdx = 0;
+ if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
+ LiteralIdx = 3;
+ else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
+ LiteralIdx = 4;
+ SmallVector<uint32_t, 1> ImmArgs;
+ MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+ if (LiteralIdx > 0)
+ ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
+ Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+ if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) {
+ SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
+ if (!CoopMatrType)
+ report_fatal_error("Can't find a register's type definition");
+ MIRBuilder.buildInstr(Opcode)
+ .addDef(Call->ReturnRegister)
+ .addUse(TypeReg)
+ .addUse(CoopMatrType->getOperand(0).getReg());
+ return true;
+ }
+ return buildOpFromWrapper(MIRBuilder, Opcode, Call,
+ IsSet ? TypeReg : Register(0), ImmArgs);
+}
+
static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
@@ -2382,6 +2426,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
case SPIRV::Select:
return generateSelectInst(Call.get(), MIRBuilder);
+ case SPIRV::Construct:
+ return generateConstructInst(Call.get(), MIRBuilder, GR);
case SPIRV::SpecConstant:
return generateSpecConstantInst(Call.get(), MIRBuilder, GR);
case SPIRV::Enqueue:
@@ -2400,6 +2446,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
case SPIRV::KernelClock:
return generateKernelClockInst(Call.get(), MIRBuilder, GR);
+ case SPIRV::CoopMatr:
+ return generateCoopMatrInst(Call.get(), MIRBuilder, GR);
}
return false;
}
@@ -2524,6 +2572,22 @@ static SPIRVType *getPipeType(const TargetExtType *ExtensionType,
ExtensionType->getIntParameter(0)));
}
+static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ assert(ExtensionType->getNumIntParameters() == 4 &&
+ "Invalid number of parameters for SPIR-V coop matrices builtin!");
+ assert(ExtensionType->getNumTypeParameters() == 1 &&
+ "SPIR-V coop matrices builtin type must have a type parameter!");
+ const SPIRVType *ElemType =
+ GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
+ // Create or get an existing type from GlobalRegistry.
+ return GR->getOrCreateOpTypeCoopMatr(
+ MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
+ ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
+ ExtensionType->getIntParameter(3));
+}
+
static SPIRVType *
getImageType(const TargetExtType *ExtensionType,
const SPIRV::AccessQualifier::AccessQualifier Qualifier,
@@ -2654,6 +2718,9 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
case SPIRV::OpTypeSampledImage:
TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
break;
+ case SPIRV::OpTypeCooperativeMatrixKHR:
+ TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
+ break;
default:
TargetType =
getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 2b8e6d856686a..4bd1104103664 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -60,6 +60,8 @@ def AtomicFloating : BuiltinGroup;
def GroupUniform : BuiltinGroup;
def KernelClock : BuiltinGroup;
def CastToPtr : BuiltinGroup;
+def Construct : BuiltinGroup;
+def CoopMatr : BuiltinGroup;
//===----------------------------------------------------------------------===//
// Class defining a demangled builtin record. The information in the record
@@ -114,6 +116,9 @@ def : DemangledBuiltin<"__spirv_ImageSampleExplicitLod", OpenCL_std, SampleImage
// Select builtin record:
def : DemangledBuiltin<"__spirv_Select", OpenCL_std, Select, 3, 3>;
+// Composite Construct builtin record:
+def : DemangledBuiltin<"__spirv_CompositeConstruct", OpenCL_std, Construct, 1, 0>;
+
//===----------------------------------------------------------------------===//
// Class defining an extended builtin record used for lowering into an
// OpExtInst instruction.
@@ -608,6 +613,12 @@ defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToGlobal", Ope
defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToLocal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToPrivate", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
+// Cooperative Matrix builtin records:
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadKHR", OpenCL_std, CoopMatr, 2, 0, OpCooperativeMatrixLoadKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixStoreKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixMulAddKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>;
+
//===----------------------------------------------------------------------===//
// Class defining a work/sub group builtin that should be translated into a
// SPIR-V instruction using the defined properties.
@@ -1436,7 +1447,7 @@ def : BuiltinType<"spirv.DeviceEvent", OpTypeDeviceEvent>;
def : BuiltinType<"spirv.Image", OpTypeImage>;
def : BuiltinType<"spirv.SampledImage", OpTypeSampledImage>;
def : BuiltinType<"spirv.Pipe", OpTypePipe>;
-
+def : BuiltinType<"spirv.CooperativeMatrixKHR", OpTypeCooperativeMatrixKHR>;
//===----------------------------------------------------------------------===//
// Class matching an OpenCL builtin type name to an equivalent SPIR-V
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 75aa1823b11f2..c7c244cfa8977 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -66,6 +66,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
SPIRV::Extension::Extension::SPV_INTEL_function_pointers},
{"SPV_KHR_shader_clock",
SPIRV::Extension::Extension::SPV_KHR_shader_clock},
+ {"SPV_KHR_cooperative_matrix",
+ SPIRV::Extension::Extension::SPV_KHR_cooperative_matrix},
};
bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index b22d2a04f75b1..b8710d24bff94 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1080,12 +1080,14 @@ bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
return IntType && IntType->getOperand(2).getImm() != 0;
}
+SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
+ return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
+ ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
+ : nullptr;
+}
+
unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
- SPIRVType *PtrType = getSPIRVTypeForVReg(PtrReg);
- SPIRVType *ElemType =
- PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
- ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
- : nullptr;
+ SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
return ElemType ? ElemType->getOpcode() : 0;
}
@@ -1189,6 +1191,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
.addUse(getSPIRVTypeID(ImageType));
}
+SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
+ MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
+ const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
+ uint32_t Use) {
+ Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
+ if (ResVReg.isValid())
+ return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
+ ResVReg = createTypeVReg(MIRBuilder);
+ SPIRVType *SpirvTy =
+ MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
+ .addDef(ResVReg)
+ .addUse(getSPIRVTypeID(ElemType))
+ .addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
+ .addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
+ .addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
+ .addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
+ DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
+ return SpirvTy;
+}
+
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index db01f68f48de9..cc4e20b8247cc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -292,6 +292,8 @@ class SPIRVGlobalRegistry {
return Res->second;
}
+ // Return a pointee's type, or nullptr otherwise.
+ SPIRVType *getPointeeType(SPIRVType *PtrType);
// Return a pointee's type op code, or 0 otherwise.
unsigned getPointeeTypeOp(Register PtrReg);
@@ -514,7 +516,11 @@ class SPIRVGlobalRegistry {
SPIRVType *getOrCreateOpTypeSampledImage(SPIRVType *ImageType,
MachineIRBuilder &MIRBuilder);
-
+ SPIRVType *getOrCreateOpTypeCoopMatr(MachineIRBuilder &MIRBuilder,
+ const TargetExtType *ExtensionType,
+ const SPIRVType *ElemType,
+ uint32_t Scope, uint32_t Rows,
+ uint32_t Columns, uint32_t Use);
SPIRVType *
getOrCreateOpTypePipe(MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index dedfd5e6e32db..63549b06e9670 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -211,6 +211,9 @@ def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
"$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
+def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
+ (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
+ "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
// 3.42.7 Constant-Creation Instructions
@@ -864,3 +867,16 @@ def OpAsmINTEL: Op<5610, (outs ID:$res), (ins TYPE:$type, TYPE:$asm_type, ID:$ta
"$res = OpAsmINTEL $type $asm_type $target $asm">;
def OpAsmCallINTEL: Op<5611, (outs ID:$res), (ins TYPE:$type, ID:$asm, variable_ops),
"$res = OpAsmCallINTEL $type $asm">;
+
+// SPV_KHR_cooperative_matrix
+def OpCooperativeMatrixLoadKHR: Op<4457, (outs ID:$res),
+ (ins TYPE:$resType, ID:$pointer, ID:$memory_layout, variable_ops),
+ "$res = OpCooperativeMatrixLoadKHR $resType $pointer $memory_layout">;
+def OpCooperativeMatrixStoreKHR: Op<4458, (outs),
+ (ins ID:$pointer, ID:$objectToStore, ID:$memory_layout, variable_ops),
+ "OpCooperativeMatrixStoreKHR $pointer $objectToStore $memory_layout">;
+def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
+ (ins TYPE:$type, ID:$A, ID:$B, ID:$C, variable_ops),
+ "$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">;
+def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
+ "$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d7b96b28445d6..a7c6c713a110c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1117,7 +1117,7 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
- SrcPtrTy, I, TII, SPIRV::StorageClass::Generic);
+ GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
MachineBasicBlock &BB = *I.getParent();
const DebugLoc &DL = I.getDebugLoc();
bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 30a6c474f467a..ac0aa682ea4be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1168,6 +1168,15 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
}
break;
+ case SPIRV::OpTypeCooperativeMatrixKHR:
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
+ report_fatal_error(
+ "OpTypeCooperativeMatrixKHR type requires the "
+ "following SPIR-V extension: SPV_KHR_cooperative_matrix",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
+ Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+ break;
default:
break;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 318c5cebb7a43..96601dd8796c6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -302,6 +302,7 @@ defm SPV_INTEL_inline_assembly : ExtensionOperand<107>;
defm SPV_INTEL_cache_controls : ExtensionOperand<108>;
defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>;
defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>;
+defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -478,6 +479,7 @@ defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_gl
defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_controls], []>;
+defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
new file mode 100644
index 0000000000000..1c41c7331cda8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
@@ -0,0 +1,54 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: OpTypeCooperativeMatrixKHR type requires the following SPIR-V extension: SPV_KHR_cooperative_matrix
+
+; CHECK: OpCapability CooperativeMatrixKHR
+; CHECK: OpExtension "SPV_KHR_cooperative_matrix"
+
+; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#Const12:]] = OpConstant %[[#Int32Ty]] 12
+; CHECK-DAG: %[[#Const48:]] = OpConstant %[[#Int32Ty]] 48
+; CHECK-DAG: %[[#Const3:]] = OpConstant %[[#Int32Ty]] 3
+; CHECK-DAG: %[[#Const2:]] = OpConstant %[[#Int32Ty]] 2
+; CHECK-DAG: %[[#Const1:]] = OpConstant %[[#Int32Ty]] 1
+; CHECK-DAG: %[[#Const0:]] = OpConstant %[[#Int32Ty]] 0
+; CHECK-DAG: %[[#MatTy1:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const12]] %[[#Const12]] %[[#Const2]]
+; CHECK-DAG: %[[#MatTy2:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const12]] %[[#Const48]] %[[#Const0]]
+; CHECK-DAG: %[[#MatTy3:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const48]] %[[#Const12]] %[[#Const1]]
+; CHECK: OpCompositeConstruct %[[#MatTy1]]
+; CHECK: %[[#Load1:]] = OpCooperativeMatrixLoadKHR %[[#MatTy2]]
+; CHECK: OpCooperativeMatrixLengthKHR %[[#Int32Ty]] %[[#MatTy2:]]
+; CHECK: OpCooperativeMatrixLoadKHR %[[#MatTy3]]
+; CHECK: OpCooperativeMatrixMulAddKHR %[[#MatTy1]]
+; CHECK: OpCooperativeMatrixStoreKHR
+
+define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
+entry:
+ %addr1 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
+ %res = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
+ %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 0)
+ store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m1, ptr %addr1, align 8
+ %accA3 = addrspacecast ptr addrspace(1) %_arg_accA to ptr addrspace(3)
+ %m2 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3) %accA3, i32 0, i64 %_arg_K, i32 1)
+ %len = tail call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) %m2)
+ %accB3 = addrspacecast ptr addrspace(1) %_arg_accB to ptr addrspace(3)
+ %m3 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3) %accB3, i32 0, i64 0)
+ %m4 = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %addr1, align 8
+ %m5 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) %m2, target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) %m3, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m4, i32 12)
+ store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m5, ptr %res, align 8
+ %r = load i64, ptr %res, align 8
+ store i64 %r, ptr %addr1, align 8
+ %accC3 = addrspacecast ptr addrspace(1) %_arg_accC to ptr addrspace(3)
+ %m6 = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %addr1, align 8
+ tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3) %accC3, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m6, i32 0, i64 %_arg_N, i32 1)
+ ret void
+}
+
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32)
+declare dso_local spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0))
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3), i32, i64, i32)
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3), i32, i64)
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0), target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32)
+declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32, i64, i32)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll
new file mode 100644
index 0000000000000..818243ab19e41
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll
@@ -0,0 +1,30 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[#GlobalCharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
+; CHECK-SPIRV-DAG: %[[#LocalCharPtr:]] = OpTypePointer Workgroup %[[#Char]]
+; CHECK-SPIRV-DAG: %[[#GenericCharPtr:]] = OpTypePointer Generic %[[#Char]]
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#Arg1:]] = OpFunctionParameter %[[#GlobalCharPtr]]
+; CHECK-SPIRV: %[[#Ptr1:]] = OpPtrCastToGeneric %[[#GenericCharPtr]] %[[#Arg1]]
+; CHECK-SPIRV: OpGenericCastToPtr %[[#LocalCharPtr]] %[[#Ptr1]]
+; CHECK-SPIRV: OpFunctionEnd
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#Arg2:]] = OpFunctionParameter %[[#GlobalCharPtr]]
+; CHECK-SPIRV: %[[#Ptr2:]] = OpPtrCastToGeneric %[[#GenericCharPtr]] %[[#Arg2]]
+; CHECK-SPIRV: OpGenericCastToPtr %[[#LocalCharPtr]] %[[#Ptr2]]
+; CHECK-SPIRV: OpFunctionEnd
+
+define spir_kernel void @foo(ptr addrspace(1) %arg) {
+entry:
+ %p = addrspacecast ptr addrspace(1) %arg to ptr addrspace(3)
+ ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %arg) {
+entry:
+ %p1 = addrspacecast ptr addrspace(1) %arg to ptr addrspace(4)
+ %p2 = addrspacecast ptr addrspace(4) %p1 to ptr addrspace(3)
+ ret void
+}
More information about the llvm-commits
mailing list