[llvm] [SPIRV] Support for extension SPV_INTEL_masked_gather_scatter (PR #131566)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Mar 16 22:56:04 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: VISHAKH PRAKASH (VishMCW)
<details>
<summary>Changes</summary>
Add intrinsic SPV_INTEL_masked_gather_scatter
- Introduce a new pass CodeGenPrepare that runs before all passes
---
Full diff: https://github.com/llvm/llvm-project/pull/131566.diff
14 Files Affected:
- (modified) llvm/docs/SPIRVUsage.rst (+2)
- (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+12)
- (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1)
- (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1)
- (added) llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp (+125)
- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+3-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+12-4)
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+15-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+6)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+28)
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+8)
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+2)
- (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+5-1)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll (+50)
``````````diff
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 3e19ff881dffc..781a16dff0d0f 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -209,6 +209,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Adds the ability to declare extended instruction sets that have no semantic impact and can be safely removed from a module.
* - ``SPV_INTEL_fp_max_error``
- Adds the ability to specify the maximum error for floating-point operations.
+ * - ``SPV_INTEL_masked_gather_scatter``
+ - Allows OpTypeVector to have a phyiscal pointer type component type and introduces gather scatter instructions
To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index df3e137c80980..5a70350af1804 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -143,4 +143,16 @@ let TargetPrefix = "spv" in {
// FPMaxErrorDecorationINTEL
def int_spv_assign_fpmaxerror_decoration: Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>;
+
+ // Masked Gather Scatter Intrinsics
+ def int_spv_masked_gather
+ :DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+ [LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty,
+ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMMatchType<0>],
+ [IntrReadMem, IntrWillReturn, ImmArg<ArgIndex<1>>]>;
+ def int_spv_masked_scatter
+ :DefaultAttrsIntrinsic<[],
+ [llvm_anyvector_ty, LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty,
+ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
+ [IntrWriteMem, IntrWillReturn, ImmArg<ArgIndex<2>>]>;
}
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 4a2b534b948d6..48ef19b334695 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -46,6 +46,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVTargetMachine.cpp
SPIRVUtils.cpp
SPIRVEmitNonSemanticDI.cpp
+ SPIRVCodeGenPrepare.cpp
LINK_COMPONENTS
Analysis
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index d765dfe370be2..a843a7144f3c2 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -19,6 +19,7 @@ class SPIRVSubtarget;
class InstructionSelector;
class RegisterBankInfo;
+ModulePass *createSPIRVCodeGenPreparePass( const SPIRVTargetMachine &TM);
ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
FunctionPass *createSPIRVStructurizerPass();
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp b/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp
new file mode 100644
index 0000000000000..ea497880360ff
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp
@@ -0,0 +1,125 @@
+//===-- SPIRVCodeGenPreparePass.cpp - preserve masked scatter gather --*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass preserves the intrinsic @llvm.masked.* intrinsics by replacing
+// it with a spv intrinsic
+//===----------------------------------------------------------------------===//
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Pass.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "SPIRV.h"
+#include "SPIRVGlobalRegistry.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+
+using namespace llvm;
+
+namespace llvm {
+void initializeSPIRVCodeGenPreparePass(PassRegistry &);
+} // namespace llvm
+
+namespace {
+class SPIRVCodeGenPrepare : public ModulePass {
+
+ const SPIRVTargetMachine &TM;
+
+public:
+ static char ID;
+ SPIRVCodeGenPrepare(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
+ initializeSPIRVCodeGenPreparePass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnModule(Module &M) override;
+
+ StringRef getPassName() const override {
+ return "SPIRV CodeGen prepare pass";
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ ModulePass::getAnalysisUsage(AU);
+ }
+};
+
+} // namespace
+
+char SPIRVCodeGenPrepare::ID = 0;
+INITIALIZE_PASS(SPIRVCodeGenPrepare, "codegen-prepare", "SPIRV codegen prepare",
+ false, false)
+
+static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
+ ArrayRef<unsigned> OpNos) {
+ Function *F = nullptr;
+ if (OpNos.empty()) {
+ F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID);
+ } else {
+ SmallVector<Type *> Tys;
+ for (unsigned OpNo : OpNos) {
+ Tys.push_back(II->getOperand(OpNo)->getType());
+ }
+
+ F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID, Tys);
+ }
+ II->setCalledFunction(F);
+ return true;
+}
+
+static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic,
+ const SPIRVSubtarget &ST,
+ SPIRVGlobalRegistry &GR) {
+ auto IntrinsicID = Intrinsic->getIntrinsicID();
+ if (ST.canUseExtension(
+ SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) {
+ switch (IntrinsicID) {
+ case Intrinsic::masked_scatter: {
+ return toSpvOverloadedIntrinsic(
+ Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_scatter,
+ {0, 1});
+ } break;
+
+ case Intrinsic::masked_gather: {
+ VectorType* Vty = dyn_cast<VectorType>(Intrinsic -> getOperand(0) -> getType());
+ PointerType* PTy = dyn_cast<PointerType>(Vty -> getElementType());
+
+ VectorType* ResVecType = dyn_cast<VectorType>(Intrinsic -> getType());
+ Type *CompType = ResVecType -> getElementType();
+ GR.addPointerToBaseTypeMap(PTy, CompType);
+ return toSpvOverloadedIntrinsic(
+ Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_gather, {3, 0});
+ } break;
+ default:
+ break;
+ }
+ }
+ return false;
+}
+
+bool SPIRVCodeGenPrepare::runOnModule(Module &M) {
+ bool Changed = false;
+ for (Function &F : M) {
+ const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(F);
+ SPIRVGlobalRegistry &GR = *(STI.getSPIRVGlobalRegistry());
+ for (BasicBlock &BB : F) {
+ for (Instruction &I : BB) {
+ if (auto *II = dyn_cast<IntrinsicInst>(&I))
+ Changed |= lowerIntrinsicToFunction(II, STI, GR);
+ }
+ }
+ }
+ return Changed;
+}
+
+ModulePass *llvm::createSPIRVCodeGenPreparePass(const SPIRVTargetMachine &TM) {
+ return new SPIRVCodeGenPrepare(TM);
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 37119bf01545c..357bccbbc2c1a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -92,7 +92,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_INTEL_long_composites",
SPIRV::Extension::Extension::SPV_INTEL_long_composites},
{"SPV_INTEL_fp_max_error",
- SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}};
+ SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
+ {"SPV_INTEL_masked_gather_scatter",
+ SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter}};
bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
llvm::StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cbec1c95eadc3..829de62ed8213 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -231,11 +231,17 @@ SPIRVType *SPIRVGlobalRegistry::createOpType(
SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder) {
+
+ const SPIRVSubtarget &ST =
+ cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
auto EleOpc = ElemType->getOpcode();
(void)EleOpc;
- assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
- EleOpc == SPIRV::OpTypeBool) &&
- "Invalid vector element type");
+ if (!ST.canUseExtension(
+ SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) {
+ assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
+ EleOpc == SPIRV::OpTypeBool) &&
+ "Invalid vector element type");
+ }
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
@@ -1060,6 +1066,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
+
if (Ty->isFloatingPointTy())
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
if (Ty->isVoidTy())
@@ -1088,11 +1095,12 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR));
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
}
-
unsigned AddrSpace = typeToAddressSpace(Ty);
SPIRVType *SpvElementType = nullptr;
if (Type *ElemTy = ::getPointeeType(Ty))
SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
+ else if (Type *ElemTy = this->findPointerToBaseTypeMap(Ty))
+ SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
else
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 89599f17ef737..24cc62b3b69a0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -78,7 +78,8 @@ class SPIRVGlobalRegistry {
// Holds the maximum ID we have in the module.
unsigned Bound;
-
+ /// maps the pointer type to the base type
+ DenseMap<Type *, Type *> PointerToBaseTypeMap;
// Maps values associated with untyped pointers into deduced element types of
// untyped pointers.
DenseMap<Value *, Type *> DeducedElTys;
@@ -635,6 +636,19 @@ class SPIRVGlobalRegistry {
void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg);
void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
void updateAssignType(CallInst *AssignCI, Value *Arg, Value *OfType);
+
+ void addPointerToBaseTypeMap(Type *PTy, Type *BaseTy) {
+ if(PTy == nullptr)
+ return;
+ assert(PTy->isPointerTy() && "PTy must be a pointer type");
+ PointerToBaseTypeMap[PTy] = BaseTy;
+ }
+
+ Type *findPointerToBaseTypeMap(const Type *PTy) {
+ auto BaseTyIter = PointerToBaseTypeMap.find(PTy);
+ return BaseTyIter == PointerToBaseTypeMap.end() ? nullptr : BaseTyIter -> second;
+ }
+
};
} // end namespace llvm
#endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a8f862271dbab..d16621eff4f3b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -956,3 +956,9 @@ def OpAliasScopeDeclINTEL: Op<5912, (outs ID:$res), (ins ID:$AliasDomain, variab
"$res = OpAliasScopeDeclINTEL $AliasDomain">;
def OpAliasScopeListDeclINTEL: Op<5913, (outs ID:$res), (ins variable_ops),
"$res = OpAliasScopeListDeclINTEL">;
+
+//SPV_INTEL_masked_gather_scatter
+def OpMaskedGatherINTEL: Op<6428, (outs ID:$res) , (ins TYPE:$type, ID:$PtrVector, i32imm:$alignment, ID:$mask, ID:$fillempty),
+ "$res = OpMaskedGatherINTEL $type $PtrVector $alignment $mask $fillempty">;
+def OpMaskedScatterINTEL: Op<6429, (outs) , (ins ID:$inVector, ID:$PtrVector, i32imm:$alignment, ID:$mask),
+ "OpMaskedScatterINTEL $inVector $PtrVector $alignment $mask">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b188f36ca9a9e..1b25470791286 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3192,6 +3192,34 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_discard: {
return selectDiscard(ResVReg, ResType, I);
}
+ case Intrinsic::spv_masked_gather: {
+ Register MemLoc = I.getOperand(2).getReg();
+ int32_t Alignment = I.getOperand(3).getImm();
+ Register Mask = I.getOperand(4).getReg();
+ Register PassThrough = I.getOperand(5).getReg();
+ return BuildMI(*(I.getParent()), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpMaskedGatherINTEL))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(MemLoc)
+ .addImm(Alignment)
+ .addUse(Mask)
+ .addUse(PassThrough)
+ .constrainAllUses(TII, TRI, RBI);
+ }
+ case Intrinsic::spv_masked_scatter: {
+ Register Value = I.getOperand(1).getReg();
+ Register MemLocs = I.getOperand(2).getReg();
+ int32_t Alignment = I.getOperand(3).getImm();
+ Register Mask = I.getOperand(4).getReg();
+ auto MIB = BuildMI(*(I.getParent()), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpMaskedScatterINTEL))
+ .addUse(Value)
+ .addUse(MemLocs)
+ .addImm(Alignment)
+ .addUse(Mask);
+ return MIB.constrainAllUses(TII, TRI, RBI);
+ }
default: {
std::string DiagMsg;
raw_string_ostream OS(DiagMsg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 63894acacbc73..e487cca7decd5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1770,6 +1770,14 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
break;
}
+ case SPIRV::OpMaskedGatherINTEL:
+ case SPIRV::OpMaskedScatterINTEL:
+ if (ST.canUseExtension(
+ SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
+ Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
+ }
+ break;
default:
break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index caee778eddbc4..d3ee45b6591a7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -313,6 +313,7 @@ defm SPV_INTEL_bindless_images : ExtensionOperand<116>;
defm SPV_INTEL_long_composites : ExtensionOperand<117>;
defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
+defm SPV_INTEL_masked_gather_scatter : ExtensionOperand<120>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -513,6 +514,7 @@ defm LongCompositesINTEL : CapabilityOperand<6089, 0, 0, [SPV_INTEL_long_composi
defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_images], []>;
defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>;
defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;
+defm MaskedGatherScatterINTEL : CapabilityOperand<6427, 0, 0, [SPV_INTEL_masked_gather_scatter], []>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index 0aa214dd354ee..ad26ed3b597db 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -175,8 +175,12 @@ TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) {
}
void SPIRVPassConfig::addIRPasses() {
- TargetPassConfig::addIRPasses();
+ if (TM.getSubtargetImpl()->canUseExtension(
+ SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+ addPass(createSPIRVCodeGenPreparePass(TM));
+ }
+ TargetPassConfig::addIRPasses();
if (TM.getSubtargetImpl()->isVulkanEnv()) {
// 1. Simplify loop for subsequent transformations. After this steps, loops
// have the following properties:
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll
new file mode 100644
index 0000000000000..b10c63b700a5a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll
@@ -0,0 +1,50 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-NOT: Name [[#]] "llvm.masked.gather.v4i32.v4p4"
+; CHECK-NOT: Name [[#]] "llvm.masked.scatter.v4i32.v4p4"
+
+; CHECK-DAG: OpCapability MaskedGatherScatterINTEL
+; CHECK-DAG: OpExtension "SPV_INTEL_masked_gather_scatter"
+
+; CHECK-DAG: %[[#TYPEINT:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#TYPEPTRINT:]] = OpTypePointer Generic %[[#TYPEINT]]
+; CHECK-DAG: %[[#TYPEVECPTR:]] = OpTypeVector %[[#TYPEPTRINT]] 4
+; CHECK-DAG: %[[#TYPEVECINT:]] = OpTypeVector %[[#TYPEINT]] 4
+
+; CHECK-DAG: %[[#CONST4:]] = OpConstant %[[#TYPEINT]] 4
+; CHECK-DAG: %[[#CONST0:]] = OpConstant %[[#TYPEINT]] 0
+; CHECK-DAG: %[[#CONST1:]] = OpConstant %[[#TYPEINT]] 1
+; CHECK-DAG: %[[#TRUE:]] = OpConstantTrue %[[#]]
+; CHECK-DAG: %[[#FALSE:]] = OpConstantFalse %[[#]]
+; CHECK-DAG: %[[#MASK1:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#FALSE]] %[[#TRUE]] %[[#TRUE]]
+; CHECK-DAG: %[[#FILL:]] = OpConstantComposite %[[#]] %[[#CONST4]] %[[#CONST0]] %[[#CONST1]] %[[#CONST0]]
+; CHECK-DAG: %[[#MASK2:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]]
+
+; CHECK: %[[#VECGATHER:]] = OpLoad %[[#TYPEVECPTR]]
+; CHECK: %[[#VECSCATTER:]] = OpLoad %[[#TYPEVECPTR]]
+; CHECK: %[[#GATHER:]] = OpMaskedGatherINTEL %[[#TYPEVECINT]] %[[#VECGATHER]] 4 %[[#MASK1]] %[[#FILL]]
+; CHECK: OpMaskedScatterINTEL %[[#GATHER]] %[[#VECSCATTER]] 4 %[[#MASK2]]
+
+; Function Attrs: nounwind readnone
+define spir_kernel void @foo() {
+entry:
+ %arg0 = alloca <4 x ptr addrspace(4)>
+ %arg1 = alloca <4 x ptr addrspace(4)>
+ %0 = load <4 x ptr addrspace(4)>, ptr %arg0
+ %1 = load <4 x ptr addrspace(4)>, ptr %arg1
+ %res = call <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)> %0, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 true>, <4 x i32> <i32 4, i32 0, i32 1, i32 0>)
+ call void @llvm.masked.scatter.v4i32.v4p4(<4 x i32> %res, <4 x ptr addrspace(4)> %1, i32 4, <4 x i1> splat (i1 true))
+ ret void
+}
+
+declare <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)>, i32, <4 x i1>, <4 x i32>)
+
+declare void @llvm.masked.scatter.v4i32.v4p4(<4 x i32>, <4 x ptr addrspace(4)>, i32, <4 x i1>)
+
+!llvm.module.flags = !{!0}
+!opencl.spir.version = !{!1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 1, i32 2}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/131566
More information about the llvm-commits
mailing list