[llvm] [SPIR-V] Add support for SPV_INTEL_masked_gather_scatter extension (PR #185418)
Arseniy Obolenskiy via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 9 06:40:24 PDT 2026
https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/185418
>From f8dc54953c9ad2317aebfd8358ac951223cfa2d8 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Mon, 9 Mar 2026 14:29:49 +0100
Subject: [PATCH] [SPIRV] Add support for SPV_INTEL_masked_gather_scatter
extension
---
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 8 ++
llvm/lib/Target/SPIRV/CMakeLists.txt | 1 +
llvm/lib/Target/SPIRV/SPIRV.h | 3 +
llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 2 +
.../SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp | 136 ++++++++++++++++++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 29 ++--
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 6 +
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 68 +++++++++
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 21 ++-
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 20 +++
.../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 2 +
llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp | 2 +
.../masked-gather-scatter-no-extension.ll | 12 ++
.../masked-gather-scatter.ll | 103 +++++++++++++
.../vector-of-pointers-no-extension.ll | 13 ++
.../vector-of-pointers-ptrtoint.ll | 33 +++++
llvm/test/CodeGen/SPIRV/llc-pipeline.ll | 2 +
17 files changed, 449 insertions(+), 12 deletions(-)
create mode 100644 llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-ptrtoint.ll
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3fc18a254f672..df3dceede04b1 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -44,6 +44,14 @@ let TargetPrefix = "spv" in {
def int_spv_undef : Intrinsic<[llvm_i32_ty], []>;
def int_spv_inline_asm : Intrinsic<[], [llvm_metadata_ty, llvm_metadata_ty, llvm_vararg_ty]>;
+ // Masked Gather/Scatter (SPV_INTEL_masked_gather_scatter)
+ def int_spv_masked_gather : Intrinsic<[llvm_any_ty],
+ [llvm_any_ty, llvm_i32_ty, llvm_any_ty, llvm_any_ty],
+ [IntrReadMem, IntrWillReturn, ImmArg<ArgIndex<1>>]>;
+ def int_spv_masked_scatter : Intrinsic<[],
+ [llvm_any_ty, llvm_any_ty, llvm_i32_ty, llvm_any_ty],
+ [IntrWriteMem, IntrWillReturn, ImmArg<ArgIndex<2>>]>;
+
// Expect, Assume Intrinsics
def int_spv_assume : Intrinsic<[], [llvm_i1_ty]>;
def int_spv_expect : Intrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 58989237ad3ea..f492f4bb6507e 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -22,6 +22,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVCallLowering.cpp
SPIRVInlineAsmLowering.cpp
SPIRVCommandLine.cpp
+ SPIRVConvertMaskedMemIntrinsics.cpp
SPIRVEmitIntrinsics.cpp
SPIRVGlobalRegistry.cpp
SPIRVInstrInfo.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index da4da1a3fe83b..b36d4f9cab31c 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -32,6 +32,8 @@ FunctionPass *createSPIRVRegularizerPass();
FunctionPass *createSPIRVPreLegalizerCombiner();
FunctionPass *createSPIRVPreLegalizerPass();
FunctionPass *createSPIRVPostLegalizerPass();
+FunctionPass *
+createSPIRVConvertMaskedMemIntrinsicsPass(const SPIRVTargetMachine *TM);
ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
ModulePass *createSPIRVPrepareGlobalsPass();
MachineFunctionPass *createSPIRVEmitNonSemanticDIPass(SPIRVTargetMachine *TM);
@@ -59,6 +61,7 @@ void initializeSPIRVPrepareGlobalsPass(PassRegistry &);
void initializeSPIRVStripConvergentIntrinsicsPass(PassRegistry &);
void initializeSPIRVLegalizeImplicitBindingPass(PassRegistry &);
void initializeSPIRVLegalizeZeroSizeArraysLegacyPass(PassRegistry &);
+void initializeSPIRVConvertMaskedMemIntrinsicsPass(PassRegistry &);
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRV_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 33e1b52b724e6..734a03ff60141 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -86,6 +86,8 @@ static const StringMap<SPIRV::Extension::Extension> SPIRVExtensionMap = {
SPIRV::Extension::Extension::SPV_INTEL_memory_access_aliasing},
{"SPV_INTEL_joint_matrix",
SPIRV::Extension::Extension::SPV_INTEL_joint_matrix},
+ {"SPV_INTEL_masked_gather_scatter",
+ SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter},
{"SPV_KHR_16bit_storage",
SPIRV::Extension::Extension::SPV_KHR_16bit_storage},
{"SPV_KHR_device_group", SPIRV::Extension::Extension::SPV_KHR_device_group},
diff --git a/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
new file mode 100644
index 0000000000000..aa1a83657bff8
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
@@ -0,0 +1,136 @@
+//===- SPIRVConvertMaskedMemIntrinsics.cpp - Convert masked mem intrinsics ==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass converts llvm.masked.gather/scatter to spv.masked.gather/scatter
+// to prevent them from being scalarized by the generic scalarization pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "spirv-convert-masked-mem-intrinsics"
+
+namespace {
+
+class SPIRVConvertMaskedMemIntrinsics
+ : public FunctionPass,
+ public InstVisitor<SPIRVConvertMaskedMemIntrinsics> {
+ const SPIRVTargetMachine *TM = nullptr;
+
+public:
+ static char ID;
+
+ SPIRVConvertMaskedMemIntrinsics() : FunctionPass(ID) {
+ initializeSPIRVConvertMaskedMemIntrinsicsPass(
+ *PassRegistry::getPassRegistry());
+ }
+
+ SPIRVConvertMaskedMemIntrinsics(const SPIRVTargetMachine *TM)
+ : FunctionPass(ID), TM(TM) {
+ initializeSPIRVConvertMaskedMemIntrinsicsPass(
+ *PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override;
+ void visitIntrinsicInst(IntrinsicInst &I);
+
+ StringRef getPassName() const override {
+ return "SPIRV convert masked memory intrinsics";
+ }
+
+private:
+ SmallVector<Instruction *, 4> ToErase;
+};
+
+} // namespace
+
+char SPIRVConvertMaskedMemIntrinsics::ID = 0;
+
+INITIALIZE_PASS(SPIRVConvertMaskedMemIntrinsics,
+ "spirv-convert-masked-mem-intrinsics",
+ "Convert masked memory intrinsics for SPIR-V", false, false)
+
+bool SPIRVConvertMaskedMemIntrinsics::runOnFunction(Function &F) {
+ if (!TM)
+ return false;
+
+ ToErase.clear();
+ visit(F);
+
+ for (Instruction *I : ToErase)
+ I->eraseFromParent();
+
+ return !ToErase.empty();
+}
+
+void SPIRVConvertMaskedMemIntrinsics::visitIntrinsicInst(IntrinsicInst &I) {
+ if (I.getIntrinsicID() == Intrinsic::masked_gather) {
+ const SPIRVSubtarget &ST =
+ TM->getSubtarget<SPIRVSubtarget>(*I.getFunction());
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
+ report_fatal_error(
+ "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter");
+
+ IRBuilder<> B(I.getParent());
+ B.SetInsertPoint(&I);
+
+ Value *Ptrs = I.getArgOperand(0);
+ Value *Mask = I.getArgOperand(1);
+ Value *Passthru = I.getArgOperand(2);
+
+ // Alignment is stored as a parameter attribute, not as a regular parameter
+ uint32_t Alignment = I.getParamAlign(0).valueOrOne().value();
+
+ SmallVector<Value *, 4> Args = {Ptrs, B.getInt32(Alignment), Mask,
+ Passthru};
+ SmallVector<Type *, 4> Types = {I.getType(), Ptrs->getType(),
+ Mask->getType(), Passthru->getType()};
+
+ auto *NewI = B.CreateIntrinsic(Intrinsic::spv_masked_gather, Types, Args);
+ I.replaceAllUsesWith(NewI);
+ ToErase.push_back(&I);
+ } else if (I.getIntrinsicID() == Intrinsic::masked_scatter) {
+ const SPIRVSubtarget &ST =
+ TM->getSubtarget<SPIRVSubtarget>(*I.getFunction());
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
+ report_fatal_error(
+ "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter");
+
+ IRBuilder<> B(I.getParent());
+ B.SetInsertPoint(&I);
+
+ Value *Values = I.getArgOperand(0);
+ Value *Ptrs = I.getArgOperand(1);
+ Value *Mask = I.getArgOperand(2);
+
+ // Alignment is stored as a parameter attribute on the ptrs parameter (arg
+ // 1)
+ uint32_t Alignment = I.getParamAlign(1).valueOrOne().value();
+
+ SmallVector<Value *, 4> Args = {Values, Ptrs, B.getInt32(Alignment), Mask};
+ SmallVector<Type *, 3> Types = {Values->getType(), Ptrs->getType(),
+ Mask->getType()};
+
+ B.CreateIntrinsic(Intrinsic::spv_masked_scatter, Types, Args);
+ ToErase.push_back(&I);
+ }
+}
+
+FunctionPass *
+llvm::createSPIRVConvertMaskedMemIntrinsicsPass(const SPIRVTargetMachine *TM) {
+ return new SPIRVConvertMaskedMemIntrinsics(TM);
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9a85634c82626..c60663b43bdc4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -318,17 +318,26 @@ SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, SPIRVTypeInst ElemType,
auto EleOpc = ElemType->getOpcode();
(void)EleOpc;
assert(NumElems >= 2 && "SPIR-V OpTypeVector requires at least 2 components");
- assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
- EleOpc == SPIRV::OpTypeBool) &&
- "Invalid vector element type");
- return createConstOrTypeAtFunctionEntry(MIRBuilder, [&](MachineIRBuilder
- &MIRBuilder) {
- return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
- .addDef(createTypeVReg(MIRBuilder))
- .addUse(getSPIRVTypeID(ElemType))
- .addImm(NumElems);
- });
+ if (EleOpc == SPIRV::OpTypePointer) {
+ assert(cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
+ .canUseExtension(
+ SPIRV::Extension::SPV_INTEL_masked_gather_scatter) &&
+ "Vector of pointers requires SPV_INTEL_masked_gather_scatter "
+ "extension");
+ } else {
+ assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
+ EleOpc == SPIRV::OpTypeBool) &&
+ "Invalid vector element type");
+ }
+
+ return createConstOrTypeAtFunctionEntry(
+ MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
+ .addDef(createTypeVReg(MIRBuilder))
+ .addUse(getSPIRVTypeID(ElemType))
+ .addImm(NumElems);
+ });
}
Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index d2f81bc30e949..819cdd6107d0d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -1131,3 +1131,9 @@ def OpFixedLogALTERA: Op<5932, (outs ID:$res), (ins TYPE:$result_type, ID:$input
"$res = OpFixedLogALTERA $result_type $input $sign $l $rl $q $o">;
def OpFixedExpALTERA: Op<5933, (outs ID:$res), (ins TYPE:$result_type, ID:$input, i32imm:$sign, i32imm:$l, i32imm:$rl, i32imm:$q, i32imm:$o),
"$res = OpFixedExpALTERA $result_type $input $sign $l $rl $q $o">;
+
+//SPV_INTEL_masked_gather_scatter
+def OpMaskedGatherINTEL: Op<6428, (outs ID:$res), (ins TYPE:$resType, ID:$ptrs, ID:$alignment, ID:$mask, ID:$fillEmpty),
+ "$res = OpMaskedGatherINTEL $resType $ptrs $alignment $mask $fillEmpty">;
+def OpMaskedScatterINTEL: Op<6429, (outs), (ins ID:$ptrs, ID:$alignment, ID:$mask, ID:$values),
+ "OpMaskedScatterINTEL $ptrs $alignment $mask $values">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7b4c047593a3a..7896b1877362c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -300,6 +300,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectGEP(Register ResVReg, SPIRVTypeInst ResType,
MachineInstr &I) const;
+ bool selectMaskedGather(Register ResVReg, SPIRVTypeInst ResType,
+ MachineInstr &I) const;
+ bool selectMaskedScatter(MachineInstr &I) const;
+
bool selectFrameIndex(Register ResVReg, SPIRVTypeInst ResType,
MachineInstr &I) const;
bool selectAllocaArray(Register ResVReg, SPIRVTypeInst ResType,
@@ -1722,6 +1726,60 @@ bool SPIRVInstructionSelector::selectStore(MachineInstr &I) const {
return true;
}
+bool SPIRVInstructionSelector::selectMaskedGather(Register ResVReg,
+ SPIRVTypeInst ResType,
+ MachineInstr &I) const {
+ assert(I.getNumExplicitDefs() == 1 && "Expected single def for gather");
+ // Operand indices (after explicit defs):
+ // 0: intrinsic ID
+ // 1: vector of pointers
+ // 2: alignment (i32 immediate)
+ // 3: mask (vector of i1)
+ // 4: passthru/fill value
+ Register PtrsReg = I.getOperand(I.getNumExplicitDefs() + 1).getReg();
+ uint32_t Alignment = I.getOperand(I.getNumExplicitDefs() + 2).getImm();
+ Register MaskReg = I.getOperand(I.getNumExplicitDefs() + 3).getReg();
+ Register PassthruReg = I.getOperand(I.getNumExplicitDefs() + 4).getReg();
+ Register AlignmentReg = buildI32Constant(Alignment, I);
+
+ MachineBasicBlock &BB = *I.getParent();
+ auto MIB =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpMaskedGatherINTEL))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(PtrsReg)
+ .addUse(AlignmentReg)
+ .addUse(MaskReg)
+ .addUse(PassthruReg);
+ MIB.constrainAllUses(TII, TRI, RBI);
+ return true;
+}
+
+bool SPIRVInstructionSelector::selectMaskedScatter(MachineInstr &I) const {
+ assert(I.getNumExplicitDefs() == 0 && "Expected no defs for scatter");
+ // Operand indices (no explicit defs):
+ // 0: intrinsic ID
+ // 1: value vector
+ // 2: vector of pointers
+ // 3: alignment (i32 immediate)
+ // 4: mask (vector of i1)
+ Register ValuesReg = I.getOperand(1).getReg();
+ Register PtrsReg = I.getOperand(2).getReg();
+ uint32_t Alignment = I.getOperand(3).getImm();
+ Register MaskReg = I.getOperand(4).getReg();
+ Register AlignmentReg = buildI32Constant(Alignment, I);
+ MachineBasicBlock &BB = *I.getParent();
+
+ auto MIB =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpMaskedScatterINTEL))
+ .addUse(PtrsReg)
+ .addUse(AlignmentReg)
+ .addUse(MaskReg)
+ .addUse(ValuesReg);
+ MIB.constrainAllUses(TII, TRI, RBI);
+ return true;
+}
+
bool SPIRVInstructionSelector::selectStackSave(Register ResVReg,
SPIRVTypeInst ResType,
MachineInstr &I) const {
@@ -4319,6 +4377,16 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectDerivativeInst(ResVReg, ResType, I, SPIRV::OpDPdyFine);
case Intrinsic::spv_fwidth:
return selectDerivativeInst(ResVReg, ResType, I, SPIRV::OpFwidth);
+ case Intrinsic::spv_masked_gather:
+ if (STI.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
+ return selectMaskedGather(ResVReg, ResType, I);
+ report_fatal_error(
+ "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter");
+ case Intrinsic::spv_masked_scatter:
+ if (STI.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
+ return selectMaskedScatter(I);
+ report_fatal_error(
+ "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter");
default: {
std::string DiagMsg;
raw_string_ostream OS(DiagMsg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 93e82750c4f32..92b900be17642 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -356,6 +356,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
.legalFor({s1, s128})
.legalFor(allFloatAndIntScalarsAndPtrs)
.legalFor(allowedVectorTypes)
+ .legalIf([](const LegalityQuery &Query) {
+ return Query.Types[0].isPointerVector();
+ })
.moreElementsToNextPow2(0)
.fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
LegalizeMutations::changeElementCountTo(
@@ -366,11 +369,25 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
getActionDefinitionsBuilder(G_INTTOPTR)
.legalForCartesianProduct(allPtrs, allIntScalars)
.legalIf(
- all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
+ all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)))
+ .legalIf([](const LegalityQuery &Query) {
+ const LLT DstTy = Query.Types[0];
+ const LLT SrcTy = Query.Types[1];
+ return DstTy.isPointerVector() && SrcTy.isVector() &&
+ !SrcTy.isPointer() &&
+ DstTy.getNumElements() == SrcTy.getNumElements();
+ });
getActionDefinitionsBuilder(G_PTRTOINT)
.legalForCartesianProduct(allIntScalars, allPtrs)
.legalIf(
- all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)));
+ all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)))
+ .legalIf([](const LegalityQuery &Query) {
+ const LLT DstTy = Query.Types[0];
+ const LLT SrcTy = Query.Types[1];
+ return SrcTy.isPointerVector() && DstTy.isVector() &&
+ !DstTy.isPointer() &&
+ DstTy.getNumElements() == SrcTy.getNumElements();
+ });
getActionDefinitionsBuilder(G_PTR_ADD)
.legalForCartesianProduct(allPtrs, allIntScalars)
.legalIf(
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 923f59917cb10..3b86bde347287 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1496,6 +1496,15 @@ void addInstrRequirements(const MachineInstr &MI,
unsigned NumComponents = MI.getOperand(2).getImm();
if (NumComponents == 8 || NumComponents == 16)
Reqs.addCapability(SPIRV::Capability::Vector16);
+
+ assert(MI.getOperand(1).isReg());
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ SPIRVTypeInst ElemTypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+ if (ElemTypeDef->getOpcode() == SPIRV::OpTypePointer &&
+ ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
+ Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
+ }
break;
}
case SPIRV::OpTypePointer: {
@@ -1864,6 +1873,17 @@ void addInstrRequirements(const MachineInstr &MI,
case SPIRV::OpAtomicFMaxEXT:
AddAtomicFloatRequirements(MI, Reqs, ST);
break;
+ case SPIRV::OpConvertPtrToU:
+ case SPIRV::OpConvertUToPtr: {
+ const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+ SPIRVTypeInst ResultType = MRI.getVRegDef(MI.getOperand(1).getReg());
+ if (ResultType->getOpcode() == SPIRV::OpTypeVector &&
+ ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
+ Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
+ }
+ break;
+ }
case SPIRV::OpConvertBF16ToFINTEL:
case SPIRV::OpConvertFToBF16INTEL:
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index e1a786ea16043..f1d115b424c97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -397,6 +397,7 @@ defm SPV_NV_shader_atomic_fp16_vector
defm SPV_EXT_image_raw10_raw12 :ExtensionOperand<133, [EnvOpenCL, EnvVulkan]>;
defm SPV_ALTERA_arbitrary_precision_floating_point: ExtensionOperand<134, [EnvOpenCL]>;
defm SPV_KHR_fma : ExtensionOperand<135, [EnvVulkan, EnvOpenCL]>;
+defm SPV_INTEL_masked_gather_scatter : ExtensionOperand<136, [EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -620,6 +621,7 @@ defm PredicatedIOINTEL : CapabilityOperand<6257, 0, 0, [SPV_INTEL_predicated_io]
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm MaskedGatherScatterINTEL : CapabilityOperand<6427, 0, 0, [SPV_INTEL_masked_gather_scatter], []>;
defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index 1759b34af3e90..57da3e92ec582 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -61,6 +61,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() {
initializeSPIRVPreLegalizerPass(PR);
initializeSPIRVPostLegalizerPass(PR);
initializeSPIRVMergeRegionExitTargetsPass(PR);
+ initializeSPIRVConvertMaskedMemIntrinsicsPass(PR);
initializeSPIRVEmitIntrinsicsPass(PR);
initializeSPIRVEmitNonSemanticDIPass(PR);
initializeSPIRVPrepareFunctionsPass(PR);
@@ -175,6 +176,7 @@ TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) {
void SPIRVPassConfig::addIRPasses() {
addPass(createAtomicExpandLegacyPass());
+ addPass(createSPIRVConvertMaskedMemIntrinsicsPass(&TM));
TargetPassConfig::addIRPasses();
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
new file mode 100644
index 0000000000000..ec3a670ed85d5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
@@ -0,0 +1,12 @@
+; RUN: not --crash llc -O0 -mtriple=spirv64-unknown-unknown %s -o - 2>&1 | FileCheck %s
+
+; CHECK: LLVM ERROR: llvm.masked.gather requires SPV_INTEL_masked_gather_scatter
+
+define spir_kernel void @test_gather_no_ext(<4 x i64> %addrs, <4 x i1> %mask, <4 x i32> %passthru) {
+entry:
+ %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+ %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask, <4 x i32> %passthru)
+ ret void
+}
+
+declare <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x i32>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter.ll
new file mode 100644
index 0000000000000..ad836f74c6d73
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter.ll
@@ -0,0 +1,103 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - | FileCheck %s
+; TODO: spirv-val does not support vector operands in OpConvertPtrToU and OpConvertUToPtr with SPV_INTEL_masked_gather_scatter
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability MaskedGatherScatterINTEL
+; CHECK-DAG: OpExtension "SPV_INTEL_masked_gather_scatter"
+
+define spir_kernel void @test_gather_undef() {
+; CHECK-LABEL: Begin function test_gather_undef
+; CHECK: OpMaskedGatherINTEL
+entry:
+ %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> poison, i32 4, <4 x i1> poison, <4 x i32> poison)
+ ret void
+}
+
+define spir_kernel void @test_scatter_undef() {
+; CHECK-LABEL: Begin function test_scatter_undef
+; CHECK: OpMaskedScatterINTEL
+entry:
+ call void @llvm.masked.scatter.v4i32.v4p1(<4 x i32> poison, <4 x ptr addrspace(1)> poison, i32 4, <4 x i1> poison)
+ ret void
+}
+
+define spir_kernel void @test_gather_v4i32(<4 x i64> %addrs, <4 x i1> %mask, <4 x i32> %passthru) {
+; CHECK-LABEL: Begin function test_gather_v4i32
+; CHECK: OpMaskedGatherINTEL
+entry:
+ %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+ %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask, <4 x i32> %passthru)
+ ret void
+}
+
+define spir_kernel void @test_scatter_v4i32(<4 x i32> %data, <4 x i64> %addrs, <4 x i1> %mask) {
+; CHECK-LABEL: Begin function test_scatter_v4i32
+; CHECK: OpMaskedScatterINTEL
+entry:
+ %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+ call void @llvm.masked.scatter.v4i32.v4p1(<4 x i32> %data, <4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask)
+ ret void
+}
+
+define spir_kernel void @test_gather_v2i64(<2 x i64> %addrs, <2 x i1> %mask, <2 x i64> %passthru) {
+; CHECK-LABEL: Begin function test_gather_v2i64
+; CHECK: OpMaskedGatherINTEL
+entry:
+ %ptrs = inttoptr <2 x i64> %addrs to <2 x ptr addrspace(1)>
+ %data = call <2 x i64> @llvm.masked.gather.v2i64.v2p1(<2 x ptr addrspace(1)> %ptrs, i32 8, <2 x i1> %mask, <2 x i64> %passthru)
+ ret void
+}
+
+define spir_kernel void @test_scatter_v2i64(<2 x i64> %data, <2 x i64> %addrs, <2 x i1> %mask) {
+; CHECK-LABEL: Begin function test_scatter_v2i64
+; CHECK: OpMaskedScatterINTEL
+entry:
+ %ptrs = inttoptr <2 x i64> %addrs to <2 x ptr addrspace(1)>
+ call void @llvm.masked.scatter.v2i64.v2p1(<2 x i64> %data, <2 x ptr addrspace(1)> %ptrs, i32 8, <2 x i1> %mask)
+ ret void
+}
+
+define spir_kernel void @test_gather_v8i32(<8 x i64> %addrs, <8 x i1> %mask, <8 x i32> %passthru) {
+; CHECK-LABEL: Begin function test_gather_v8i32
+; CHECK: OpMaskedGatherINTEL
+entry:
+ %ptrs = inttoptr <8 x i64> %addrs to <8 x ptr addrspace(1)>
+ %data = call <8 x i32> @llvm.masked.gather.v8i32.v8p1(<8 x ptr addrspace(1)> %ptrs, i32 4, <8 x i1> %mask, <8 x i32> %passthru)
+ ret void
+}
+
+define spir_kernel void @test_scatter_v8i32(<8 x i32> %data, <8 x i64> %addrs, <8 x i1> %mask) {
+; CHECK-LABEL: Begin function test_scatter_v8i32
+; CHECK: OpMaskedScatterINTEL
+entry:
+ %ptrs = inttoptr <8 x i64> %addrs to <8 x ptr addrspace(1)>
+ call void @llvm.masked.scatter.v8i32.v8p1(<8 x i32> %data, <8 x ptr addrspace(1)> %ptrs, i32 4, <8 x i1> %mask)
+ ret void
+}
+
+define spir_kernel void @test_gather_v4f32(<4 x i64> %addrs, <4 x i1> %mask, <4 x float> %passthru) {
+; CHECK-LABEL: Begin function test_gather_v4f32
+; CHECK: OpMaskedGatherINTEL
+entry:
+ %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+ %data = call <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask, <4 x float> %passthru)
+ ret void
+}
+
+define spir_kernel void @test_scatter_v4f32(<4 x float> %data, <4 x i64> %addrs, <4 x i1> %mask) {
+; CHECK-LABEL: Begin function test_scatter_v4f32
+; CHECK: OpMaskedScatterINTEL
+entry:
+ %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+ call void @llvm.masked.scatter.v4f32.v4p1(<4 x float> %data, <4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask)
+ ret void
+}
+
+declare <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x i32>)
+declare void @llvm.masked.scatter.v4i32.v4p1(<4 x i32>, <4 x ptr addrspace(1)>, i32, <4 x i1>)
+declare <2 x i64> @llvm.masked.gather.v2i64.v2p1(<2 x ptr addrspace(1)>, i32, <2 x i1>, <2 x i64>)
+declare void @llvm.masked.scatter.v2i64.v2p1(<2 x i64>, <2 x ptr addrspace(1)>, i32, <2 x i1>)
+declare <8 x i32> @llvm.masked.gather.v8i32.v8p1(<8 x ptr addrspace(1)>, i32, <8 x i1>, <8 x i32>)
+declare void @llvm.masked.scatter.v8i32.v8p1(<8 x i32>, <8 x ptr addrspace(1)>, i32, <8 x i1>)
+declare <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x float>)
+declare void @llvm.masked.scatter.v4f32.v4p1(<4 x float>, <4 x ptr addrspace(1)>, i32, <4 x i1>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll
new file mode 100644
index 0000000000000..f84d3b655583b
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll
@@ -0,0 +1,13 @@
+; REQUIRES: asserts
+; RUN: not --crash llc -O0 -mtriple=spirv64-unknown-unknown %s 2>&1 | FileCheck %s
+
+; CHECK: Vector of pointers requires SPV_INTEL_masked_gather_scatter
+
+declare spir_func void @foo(<2 x i64>)
+
+define spir_kernel void @test_ptrtoint(<2 x ptr addrspace(1)> %p) {
+entry:
+ %addr = ptrtoint <2 x ptr addrspace(1)> %p to <2 x i64>
+ call void @foo(<2 x i64> %addr)
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-ptrtoint.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-ptrtoint.ll
new file mode 100644
index 0000000000000..74988e07b537b
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-ptrtoint.ll
@@ -0,0 +1,33 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - | FileCheck %s
+; TODO: spirv-val does not support vector operands in OpConvertPtrToU and OpConvertUToPtr with SPV_INTEL_masked_gather_scatter
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpCapability MaskedGatherScatterINTEL
+; CHECK: OpExtension "SPV_INTEL_masked_gather_scatter"
+
+; CHECK-DAG: %[[#Int64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#PtrTy:]] = OpTypePointer CrossWorkgroup
+; CHECK-DAG: %[[#Vec2Ptr:]] = OpTypeVector %[[#PtrTy]] 2
+; CHECK-DAG: %[[#Vec2Int64:]] = OpTypeVector %[[#Int64]] 2
+
+declare spir_func void @foo(<2 x i64>)
+
+define spir_kernel void @test_ptrtoint_vec2(<2 x ptr addrspace(1)> %p) {
+; CHECK-LABEL: Begin function test_ptrtoint_vec2
+; CHECK: OpConvertPtrToU %[[#Vec2Int64]]
+entry:
+ %addr = ptrtoint <2 x ptr addrspace(1)> %p to <2 x i64>
+ call void @foo(<2 x i64> %addr)
+ ret void
+}
+
+declare spir_func void @bar(<2 x ptr addrspace(1)>)
+
+define spir_kernel void @test_inttoptr_vec2(<2 x i64> %addr) {
+; CHECK-LABEL: Begin function test_inttoptr_vec2
+; CHECK: OpConvertUToPtr %[[#Vec2Ptr]]
+entry:
+ %p = inttoptr <2 x i64> %addr to <2 x ptr addrspace(1)>
+ call void @bar(<2 x ptr addrspace(1)> %p)
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
index eb1128ac5417a..f358eba72d84d 100644
--- a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
+++ b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
@@ -25,6 +25,7 @@
; SPIRV-O0-NEXT: FunctionPass Manager
; SPIRV-O0-NEXT: Expand IR instructions
; SPIRV-O0-NEXT: Expand Atomic instructions
+; SPIRV-O0-NEXT: SPIRV convert masked memory intrinsics
; SPIRV-O0-NEXT: Lower Garbage Collection Instructions
; SPIRV-O0-NEXT: Shadow Stack GC Lowering
; SPIRV-O0-NEXT: Remove unreachable blocks from the CFG
@@ -105,6 +106,7 @@
; SPIRV-Opt-NEXT: FunctionPass Manager
; SPIRV-Opt-NEXT: Expand IR instructions
; SPIRV-Opt-NEXT: Expand Atomic instructions
+; SPIRV-Opt-NEXT: SPIRV convert masked memory intrinsics
; SPIRV-Opt-NEXT: Dominator Tree Construction
; SPIRV-Opt-NEXT: Basic Alias Analysis (stateless AA impl)
; SPIRV-Opt-NEXT: Natural Loop Information
More information about the llvm-commits
mailing list