[llvm] [SPIR-V] Add support for SPV_INTEL_masked_gather_scatter extension (PR #185418)

Arseniy Obolenskiy via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 10 04:48:43 PDT 2026


https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/185418

>From 5f21ea83fd0be4e207a46b7a5ee1c320daa7f47e Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Mon, 9 Mar 2026 14:29:49 +0100
Subject: [PATCH 1/5] [SPIRV] Add support for SPV_INTEL_masked_gather_scatter
 extension

---
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |   8 ++
 llvm/lib/Target/SPIRV/CMakeLists.txt          |   1 +
 llvm/lib/Target/SPIRV/SPIRV.h                 |   3 +
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp    |   2 +
 .../SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp | 136 ++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  29 ++--
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |   6 +
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  68 +++++++++
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  21 ++-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  20 +++
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |   2 +
 llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp  |   2 +
 .../masked-gather-scatter-no-extension.ll     |  12 ++
 .../masked-gather-scatter.ll                  | 103 +++++++++++++
 .../vector-of-pointers-no-extension.ll        |  13 ++
 .../vector-of-pointers-ptrtoint.ll            |  33 +++++
 llvm/test/CodeGen/SPIRV/llc-pipeline.ll       |   2 +
 17 files changed, 449 insertions(+), 12 deletions(-)
 create mode 100644 llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-ptrtoint.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3fc18a254f672..df3dceede04b1 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -44,6 +44,14 @@ let TargetPrefix = "spv" in {
   def int_spv_undef : Intrinsic<[llvm_i32_ty], []>;
   def int_spv_inline_asm : Intrinsic<[], [llvm_metadata_ty, llvm_metadata_ty, llvm_vararg_ty]>;
 
+  // Masked Gather/Scatter (SPV_INTEL_masked_gather_scatter)
+  def int_spv_masked_gather : Intrinsic<[llvm_any_ty],
+                                [llvm_any_ty, llvm_i32_ty, llvm_any_ty, llvm_any_ty],
+                                [IntrReadMem, IntrWillReturn, ImmArg<ArgIndex<1>>]>;
+  def int_spv_masked_scatter : Intrinsic<[],
+                                [llvm_any_ty, llvm_any_ty, llvm_i32_ty, llvm_any_ty],
+                                [IntrWriteMem, IntrWillReturn, ImmArg<ArgIndex<2>>]>;
+
   // Expect, Assume Intrinsics
   def int_spv_assume : Intrinsic<[], [llvm_i1_ty]>;
   def int_spv_expect : Intrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 58989237ad3ea..f492f4bb6507e 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -22,6 +22,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVCallLowering.cpp
   SPIRVInlineAsmLowering.cpp
   SPIRVCommandLine.cpp
+  SPIRVConvertMaskedMemIntrinsics.cpp
   SPIRVEmitIntrinsics.cpp
   SPIRVGlobalRegistry.cpp
   SPIRVInstrInfo.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index da4da1a3fe83b..b36d4f9cab31c 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -32,6 +32,8 @@ FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerCombiner();
 FunctionPass *createSPIRVPreLegalizerPass();
 FunctionPass *createSPIRVPostLegalizerPass();
+FunctionPass *
+createSPIRVConvertMaskedMemIntrinsicsPass(const SPIRVTargetMachine *TM);
 ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
 ModulePass *createSPIRVPrepareGlobalsPass();
 MachineFunctionPass *createSPIRVEmitNonSemanticDIPass(SPIRVTargetMachine *TM);
@@ -59,6 +61,7 @@ void initializeSPIRVPrepareGlobalsPass(PassRegistry &);
 void initializeSPIRVStripConvergentIntrinsicsPass(PassRegistry &);
 void initializeSPIRVLegalizeImplicitBindingPass(PassRegistry &);
 void initializeSPIRVLegalizeZeroSizeArraysLegacyPass(PassRegistry &);
+void initializeSPIRVConvertMaskedMemIntrinsicsPass(PassRegistry &);
 } // namespace llvm
 
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRV_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 33e1b52b724e6..734a03ff60141 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -86,6 +86,8 @@ static const StringMap<SPIRV::Extension::Extension> SPIRVExtensionMap = {
      SPIRV::Extension::Extension::SPV_INTEL_memory_access_aliasing},
     {"SPV_INTEL_joint_matrix",
      SPIRV::Extension::Extension::SPV_INTEL_joint_matrix},
+    {"SPV_INTEL_masked_gather_scatter",
+     SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter},
     {"SPV_KHR_16bit_storage",
      SPIRV::Extension::Extension::SPV_KHR_16bit_storage},
     {"SPV_KHR_device_group", SPIRV::Extension::Extension::SPV_KHR_device_group},
diff --git a/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
new file mode 100644
index 0000000000000..aa1a83657bff8
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
@@ -0,0 +1,136 @@
+//===- SPIRVConvertMaskedMemIntrinsics.cpp - Convert masked mem intrinsics ==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass converts llvm.masked.gather/scatter to spv.masked.gather/scatter
+// to prevent them from being scalarized by the generic scalarization pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "spirv-convert-masked-mem-intrinsics"
+
+namespace {
+
+class SPIRVConvertMaskedMemIntrinsics
+    : public FunctionPass,
+      public InstVisitor<SPIRVConvertMaskedMemIntrinsics> {
+  const SPIRVTargetMachine *TM = nullptr;
+
+public:
+  static char ID;
+
+  SPIRVConvertMaskedMemIntrinsics() : FunctionPass(ID) {
+    initializeSPIRVConvertMaskedMemIntrinsicsPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  SPIRVConvertMaskedMemIntrinsics(const SPIRVTargetMachine *TM)
+      : FunctionPass(ID), TM(TM) {
+    initializeSPIRVConvertMaskedMemIntrinsicsPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override;
+  void visitIntrinsicInst(IntrinsicInst &I);
+
+  StringRef getPassName() const override {
+    return "SPIRV convert masked memory intrinsics";
+  }
+
+private:
+  SmallVector<Instruction *, 4> ToErase;
+};
+
+} // namespace
+
+char SPIRVConvertMaskedMemIntrinsics::ID = 0;
+
+INITIALIZE_PASS(SPIRVConvertMaskedMemIntrinsics,
+                "spirv-convert-masked-mem-intrinsics",
+                "Convert masked memory intrinsics for SPIR-V", false, false)
+
+bool SPIRVConvertMaskedMemIntrinsics::runOnFunction(Function &F) {
+  if (!TM)
+    return false;
+
+  ToErase.clear();
+  visit(F);
+
+  for (Instruction *I : ToErase)
+    I->eraseFromParent();
+
+  return !ToErase.empty();
+}
+
+void SPIRVConvertMaskedMemIntrinsics::visitIntrinsicInst(IntrinsicInst &I) {
+  if (I.getIntrinsicID() == Intrinsic::masked_gather) {
+    const SPIRVSubtarget &ST =
+        TM->getSubtarget<SPIRVSubtarget>(*I.getFunction());
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
+      report_fatal_error(
+          "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter");
+
+    IRBuilder<> B(I.getParent());
+    B.SetInsertPoint(&I);
+
+    Value *Ptrs = I.getArgOperand(0);
+    Value *Mask = I.getArgOperand(1);
+    Value *Passthru = I.getArgOperand(2);
+
+    // Alignment is stored as a parameter attribute, not as a regular parameter
+    uint32_t Alignment = I.getParamAlign(0).valueOrOne().value();
+
+    SmallVector<Value *, 4> Args = {Ptrs, B.getInt32(Alignment), Mask,
+                                    Passthru};
+    SmallVector<Type *, 4> Types = {I.getType(), Ptrs->getType(),
+                                    Mask->getType(), Passthru->getType()};
+
+    auto *NewI = B.CreateIntrinsic(Intrinsic::spv_masked_gather, Types, Args);
+    I.replaceAllUsesWith(NewI);
+    ToErase.push_back(&I);
+  } else if (I.getIntrinsicID() == Intrinsic::masked_scatter) {
+    const SPIRVSubtarget &ST =
+        TM->getSubtarget<SPIRVSubtarget>(*I.getFunction());
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
+      report_fatal_error(
+          "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter");
+
+    IRBuilder<> B(I.getParent());
+    B.SetInsertPoint(&I);
+
+    Value *Values = I.getArgOperand(0);
+    Value *Ptrs = I.getArgOperand(1);
+    Value *Mask = I.getArgOperand(2);
+
+    // Alignment is stored as a parameter attribute on the ptrs parameter (arg
+    // 1)
+    uint32_t Alignment = I.getParamAlign(1).valueOrOne().value();
+
+    SmallVector<Value *, 4> Args = {Values, Ptrs, B.getInt32(Alignment), Mask};
+    SmallVector<Type *, 3> Types = {Values->getType(), Ptrs->getType(),
+                                    Mask->getType()};
+
+    B.CreateIntrinsic(Intrinsic::spv_masked_scatter, Types, Args);
+    ToErase.push_back(&I);
+  }
+}
+
+FunctionPass *
+llvm::createSPIRVConvertMaskedMemIntrinsicsPass(const SPIRVTargetMachine *TM) {
+  return new SPIRVConvertMaskedMemIntrinsics(TM);
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9a85634c82626..c60663b43bdc4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -318,17 +318,26 @@ SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, SPIRVTypeInst ElemType,
   auto EleOpc = ElemType->getOpcode();
   (void)EleOpc;
   assert(NumElems >= 2 && "SPIR-V OpTypeVector requires at least 2 components");
-  assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
-          EleOpc == SPIRV::OpTypeBool) &&
-         "Invalid vector element type");
 
-  return createConstOrTypeAtFunctionEntry(MIRBuilder, [&](MachineIRBuilder
-                                                              &MIRBuilder) {
-    return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
-        .addDef(createTypeVReg(MIRBuilder))
-        .addUse(getSPIRVTypeID(ElemType))
-        .addImm(NumElems);
-  });
+  if (EleOpc == SPIRV::OpTypePointer) {
+    assert(cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
+               .canUseExtension(
+                   SPIRV::Extension::SPV_INTEL_masked_gather_scatter) &&
+           "Vector of pointers requires SPV_INTEL_masked_gather_scatter "
+           "extension");
+  } else {
+    assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
+            EleOpc == SPIRV::OpTypeBool) &&
+           "Invalid vector element type");
+  }
+
+  return createConstOrTypeAtFunctionEntry(
+      MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
+            .addDef(createTypeVReg(MIRBuilder))
+            .addUse(getSPIRVTypeID(ElemType))
+            .addImm(NumElems);
+      });
 }
 
 Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index d2f81bc30e949..819cdd6107d0d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -1131,3 +1131,9 @@ def OpFixedLogALTERA: Op<5932, (outs ID:$res), (ins TYPE:$result_type, ID:$input
       "$res = OpFixedLogALTERA $result_type $input $sign $l $rl $q $o">;
 def OpFixedExpALTERA: Op<5933, (outs ID:$res), (ins TYPE:$result_type, ID:$input, i32imm:$sign, i32imm:$l, i32imm:$rl, i32imm:$q, i32imm:$o),
       "$res = OpFixedExpALTERA $result_type $input $sign $l $rl $q $o">;
+
+//SPV_INTEL_masked_gather_scatter
+def OpMaskedGatherINTEL: Op<6428, (outs ID:$res), (ins TYPE:$resType, ID:$ptrs, ID:$alignment, ID:$mask, ID:$fillEmpty),
+                  "$res = OpMaskedGatherINTEL $resType $ptrs $alignment $mask $fillEmpty">;
+def OpMaskedScatterINTEL: Op<6429, (outs), (ins ID:$ptrs, ID:$alignment, ID:$mask, ID:$values),
+                  "OpMaskedScatterINTEL $ptrs $alignment $mask $values">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7b4c047593a3a..7896b1877362c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -300,6 +300,10 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectGEP(Register ResVReg, SPIRVTypeInst ResType,
                  MachineInstr &I) const;
 
+  bool selectMaskedGather(Register ResVReg, SPIRVTypeInst ResType,
+                          MachineInstr &I) const;
+  bool selectMaskedScatter(MachineInstr &I) const;
+
   bool selectFrameIndex(Register ResVReg, SPIRVTypeInst ResType,
                         MachineInstr &I) const;
   bool selectAllocaArray(Register ResVReg, SPIRVTypeInst ResType,
@@ -1722,6 +1726,60 @@ bool SPIRVInstructionSelector::selectStore(MachineInstr &I) const {
   return true;
 }
 
+bool SPIRVInstructionSelector::selectMaskedGather(Register ResVReg,
+                                                  SPIRVTypeInst ResType,
+                                                  MachineInstr &I) const {
+  assert(I.getNumExplicitDefs() == 1 && "Expected single def for gather");
+  // Operand indices (after explicit defs):
+  // 0: intrinsic ID
+  // 1: vector of pointers
+  // 2: alignment (i32 immediate)
+  // 3: mask (vector of i1)
+  // 4: passthru/fill value
+  Register PtrsReg = I.getOperand(I.getNumExplicitDefs() + 1).getReg();
+  uint32_t Alignment = I.getOperand(I.getNumExplicitDefs() + 2).getImm();
+  Register MaskReg = I.getOperand(I.getNumExplicitDefs() + 3).getReg();
+  Register PassthruReg = I.getOperand(I.getNumExplicitDefs() + 4).getReg();
+  Register AlignmentReg = buildI32Constant(Alignment, I);
+
+  MachineBasicBlock &BB = *I.getParent();
+  auto MIB =
+      BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpMaskedGatherINTEL))
+          .addDef(ResVReg)
+          .addUse(GR.getSPIRVTypeID(ResType))
+          .addUse(PtrsReg)
+          .addUse(AlignmentReg)
+          .addUse(MaskReg)
+          .addUse(PassthruReg);
+  MIB.constrainAllUses(TII, TRI, RBI);
+  return true;
+}
+
+bool SPIRVInstructionSelector::selectMaskedScatter(MachineInstr &I) const {
+  assert(I.getNumExplicitDefs() == 0 && "Expected no defs for scatter");
+  // Operand indices (no explicit defs):
+  // 0: intrinsic ID
+  // 1: value vector
+  // 2: vector of pointers
+  // 3: alignment (i32 immediate)
+  // 4: mask (vector of i1)
+  Register ValuesReg = I.getOperand(1).getReg();
+  Register PtrsReg = I.getOperand(2).getReg();
+  uint32_t Alignment = I.getOperand(3).getImm();
+  Register MaskReg = I.getOperand(4).getReg();
+  Register AlignmentReg = buildI32Constant(Alignment, I);
+  MachineBasicBlock &BB = *I.getParent();
+
+  auto MIB =
+      BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpMaskedScatterINTEL))
+          .addUse(PtrsReg)
+          .addUse(AlignmentReg)
+          .addUse(MaskReg)
+          .addUse(ValuesReg);
+  MIB.constrainAllUses(TII, TRI, RBI);
+  return true;
+}
+
 bool SPIRVInstructionSelector::selectStackSave(Register ResVReg,
                                                SPIRVTypeInst ResType,
                                                MachineInstr &I) const {
@@ -4319,6 +4377,16 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectDerivativeInst(ResVReg, ResType, I, SPIRV::OpDPdyFine);
   case Intrinsic::spv_fwidth:
     return selectDerivativeInst(ResVReg, ResType, I, SPIRV::OpFwidth);
+  case Intrinsic::spv_masked_gather:
+    if (STI.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
+      return selectMaskedGather(ResVReg, ResType, I);
+    report_fatal_error(
+        "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter");
+  case Intrinsic::spv_masked_scatter:
+    if (STI.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
+      return selectMaskedScatter(I);
+    report_fatal_error(
+        "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter");
   default: {
     std::string DiagMsg;
     raw_string_ostream OS(DiagMsg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 93e82750c4f32..92b900be17642 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -356,6 +356,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .legalFor({s1, s128})
       .legalFor(allFloatAndIntScalarsAndPtrs)
       .legalFor(allowedVectorTypes)
+      .legalIf([](const LegalityQuery &Query) {
+        return Query.Types[0].isPointerVector();
+      })
       .moreElementsToNextPow2(0)
       .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
                        LegalizeMutations::changeElementCountTo(
@@ -366,11 +369,25 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   getActionDefinitionsBuilder(G_INTTOPTR)
       .legalForCartesianProduct(allPtrs, allIntScalars)
       .legalIf(
-          all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
+          all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)))
+      .legalIf([](const LegalityQuery &Query) {
+        const LLT DstTy = Query.Types[0];
+        const LLT SrcTy = Query.Types[1];
+        return DstTy.isPointerVector() && SrcTy.isVector() &&
+               !SrcTy.isPointer() &&
+               DstTy.getNumElements() == SrcTy.getNumElements();
+      });
   getActionDefinitionsBuilder(G_PTRTOINT)
       .legalForCartesianProduct(allIntScalars, allPtrs)
       .legalIf(
-          all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)));
+          all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)))
+      .legalIf([](const LegalityQuery &Query) {
+        const LLT DstTy = Query.Types[0];
+        const LLT SrcTy = Query.Types[1];
+        return SrcTy.isPointerVector() && DstTy.isVector() &&
+               !DstTy.isPointer() &&
+               DstTy.getNumElements() == SrcTy.getNumElements();
+      });
   getActionDefinitionsBuilder(G_PTR_ADD)
       .legalForCartesianProduct(allPtrs, allIntScalars)
       .legalIf(
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 923f59917cb10..3b86bde347287 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1496,6 +1496,15 @@ void addInstrRequirements(const MachineInstr &MI,
     unsigned NumComponents = MI.getOperand(2).getImm();
     if (NumComponents == 8 || NumComponents == 16)
       Reqs.addCapability(SPIRV::Capability::Vector16);
+
+    assert(MI.getOperand(1).isReg());
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVTypeInst ElemTypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+    if (ElemTypeDef->getOpcode() == SPIRV::OpTypePointer &&
+        ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
+      Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
+    }
     break;
   }
   case SPIRV::OpTypePointer: {
@@ -1864,6 +1873,17 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpAtomicFMaxEXT:
     AddAtomicFloatRequirements(MI, Reqs, ST);
     break;
+  case SPIRV::OpConvertPtrToU:
+  case SPIRV::OpConvertUToPtr: {
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVTypeInst ResultType = MRI.getVRegDef(MI.getOperand(1).getReg());
+    if (ResultType->getOpcode() == SPIRV::OpTypeVector &&
+        ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
+      Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
+    }
+    break;
+  }
   case SPIRV::OpConvertBF16ToFINTEL:
   case SPIRV::OpConvertFToBF16INTEL:
     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index e1a786ea16043..f1d115b424c97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -397,6 +397,7 @@ defm SPV_NV_shader_atomic_fp16_vector
 defm SPV_EXT_image_raw10_raw12 :ExtensionOperand<133, [EnvOpenCL, EnvVulkan]>;
 defm SPV_ALTERA_arbitrary_precision_floating_point: ExtensionOperand<134, [EnvOpenCL]>;
 defm SPV_KHR_fma : ExtensionOperand<135, [EnvVulkan, EnvOpenCL]>;
+defm SPV_INTEL_masked_gather_scatter : ExtensionOperand<136, [EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -620,6 +621,7 @@ defm PredicatedIOINTEL : CapabilityOperand<6257, 0, 0, [SPV_INTEL_predicated_io]
 defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
 defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm MaskedGatherScatterINTEL : CapabilityOperand<6427, 0, 0, [SPV_INTEL_masked_gather_scatter], []>;
 defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
 defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
 defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index 1759b34af3e90..57da3e92ec582 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -61,6 +61,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() {
   initializeSPIRVPreLegalizerPass(PR);
   initializeSPIRVPostLegalizerPass(PR);
   initializeSPIRVMergeRegionExitTargetsPass(PR);
+  initializeSPIRVConvertMaskedMemIntrinsicsPass(PR);
   initializeSPIRVEmitIntrinsicsPass(PR);
   initializeSPIRVEmitNonSemanticDIPass(PR);
   initializeSPIRVPrepareFunctionsPass(PR);
@@ -175,6 +176,7 @@ TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) {
 
 void SPIRVPassConfig::addIRPasses() {
   addPass(createAtomicExpandLegacyPass());
+  addPass(createSPIRVConvertMaskedMemIntrinsicsPass(&TM));
 
   TargetPassConfig::addIRPasses();
 
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
new file mode 100644
index 0000000000000..ec3a670ed85d5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
@@ -0,0 +1,12 @@
+; RUN: not --crash llc -O0 -mtriple=spirv64-unknown-unknown %s -o - 2>&1 | FileCheck %s
+
+; CHECK: LLVM ERROR: llvm.masked.gather requires SPV_INTEL_masked_gather_scatter
+
+define spir_kernel void @test_gather_no_ext(<4 x i64> %addrs, <4 x i1> %mask, <4 x i32> %passthru) {
+entry:
+  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+  %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask, <4 x i32> %passthru)
+  ret void
+}
+
+declare <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x i32>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter.ll
new file mode 100644
index 0000000000000..add08059d0255
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter.ll
@@ -0,0 +1,103 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - | FileCheck %s
+; TODO: spirv-val does not support vector operands in OpConvertPtrToU and OpConvertUToPtr with SPV_INTEL_masked_gather_scatter
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpCapability MaskedGatherScatterINTEL
+; CHECK-DAG: OpExtension "SPV_INTEL_masked_gather_scatter"
+
+define spir_kernel void @test_gather_undef() {
+; CHECK-LABEL: Begin function test_gather_undef
+; CHECK: OpMaskedGatherINTEL
+entry:
+  %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> poison, i32 4, <4 x i1> poison, <4 x i32> poison)
+  ret void
+}
+
+define spir_kernel void @test_scatter_undef() {
+; CHECK-LABEL: Begin function test_scatter_undef
+; CHECK: OpMaskedScatterINTEL
+entry:
+  call void @llvm.masked.scatter.v4i32.v4p1(<4 x i32> poison, <4 x ptr addrspace(1)> poison, i32 4, <4 x i1> poison)
+  ret void
+}
+
+define spir_kernel void @test_gather_v4i32(<4 x i64> %addrs, <4 x i1> %mask, <4 x i32> %passthru) {
+; CHECK-LABEL: Begin function test_gather_v4i32
+; CHECK: OpMaskedGatherINTEL
+entry:
+  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+  %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask, <4 x i32> %passthru)
+  ret void
+}
+
+define spir_kernel void @test_scatter_v4i32(<4 x i32> %data, <4 x i64> %addrs, <4 x i1> %mask) {
+; CHECK-LABEL: Begin function test_scatter_v4i32
+; CHECK: OpMaskedScatterINTEL
+entry:
+  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+  call void @llvm.masked.scatter.v4i32.v4p1(<4 x i32> %data, <4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask)
+  ret void
+}
+
+define spir_kernel void @test_gather_v2i64(<2 x i64> %addrs, <2 x i1> %mask, <2 x i64> %passthru) {
+; CHECK-LABEL: Begin function test_gather_v2i64
+; CHECK: OpMaskedGatherINTEL
+entry:
+  %ptrs = inttoptr <2 x i64> %addrs to <2 x ptr addrspace(1)>
+  %data = call <2 x i64> @llvm.masked.gather.v2i64.v2p1(<2 x ptr addrspace(1)> %ptrs, i32 8, <2 x i1> %mask, <2 x i64> %passthru)
+  ret void
+}
+
+define spir_kernel void @test_scatter_v2i64(<2 x i64> %data, <2 x i64> %addrs, <2 x i1> %mask) {
+; CHECK-LABEL: Begin function test_scatter_v2i64
+; CHECK: OpMaskedScatterINTEL
+entry:
+  %ptrs = inttoptr <2 x i64> %addrs to <2 x ptr addrspace(1)>
+  call void @llvm.masked.scatter.v2i64.v2p1(<2 x i64> %data, <2 x ptr addrspace(1)> %ptrs, i32 8, <2 x i1> %mask)
+  ret void
+}
+
+define spir_kernel void @test_gather_v8i32(<8 x i64> %addrs, <8 x i1> %mask, <8 x i32> %passthru) {
+; CHECK-LABEL: Begin function test_gather_v8i32
+; CHECK: OpMaskedGatherINTEL
+entry:
+  %ptrs = inttoptr <8 x i64> %addrs to <8 x ptr addrspace(1)>
+  %data = call <8 x i32> @llvm.masked.gather.v8i32.v8p1(<8 x ptr addrspace(1)> %ptrs, i32 4, <8 x i1> %mask, <8 x i32> %passthru)
+  ret void
+}
+
+define spir_kernel void @test_scatter_v8i32(<8 x i32> %data, <8 x i64> %addrs, <8 x i1> %mask) {
+; CHECK-LABEL: Begin function test_scatter_v8i32
+; CHECK: OpMaskedScatterINTEL
+entry:
+  %ptrs = inttoptr <8 x i64> %addrs to <8 x ptr addrspace(1)>
+  call void @llvm.masked.scatter.v8i32.v8p1(<8 x i32> %data, <8 x ptr addrspace(1)> %ptrs, i32 4, <8 x i1> %mask)
+  ret void
+}
+
+define spir_kernel void @test_gather_v4f32(<4 x i64> %addrs, <4 x i1> %mask, <4 x float> %passthru) {
+; CHECK-LABEL: Begin function test_gather_v4f32
+; CHECK: OpMaskedGatherINTEL
+entry:
+  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+  %data = call <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask, <4 x float> %passthru)
+  ret void
+}
+
+define spir_kernel void @test_scatter_v4f32(<4 x float> %data, <4 x i64> %addrs, <4 x i1> %mask) {
+; CHECK-LABEL: Begin function test_scatter_v4f32
+; CHECK: OpMaskedScatterINTEL
+entry:
+  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+  call void @llvm.masked.scatter.v4f32.v4p1(<4 x float> %data, <4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask)
+  ret void
+}
+
+declare <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x i32>)
+declare void @llvm.masked.scatter.v4i32.v4p1(<4 x i32>, <4 x ptr addrspace(1)>, i32, <4 x i1>)
+declare <2 x i64> @llvm.masked.gather.v2i64.v2p1(<2 x ptr addrspace(1)>, i32, <2 x i1>, <2 x i64>)
+declare void @llvm.masked.scatter.v2i64.v2p1(<2 x i64>, <2 x ptr addrspace(1)>, i32, <2 x i1>)
+declare <8 x i32> @llvm.masked.gather.v8i32.v8p1(<8 x ptr addrspace(1)>, i32, <8 x i1>, <8 x i32>)
+declare void @llvm.masked.scatter.v8i32.v8p1(<8 x i32>, <8 x ptr addrspace(1)>, i32, <8 x i1>)
+declare <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x float>)
+declare void @llvm.masked.scatter.v4f32.v4p1(<4 x float>, <4 x ptr addrspace(1)>, i32, <4 x i1>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll
new file mode 100644
index 0000000000000..f84d3b655583b
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll
@@ -0,0 +1,13 @@
+; REQUIRES: asserts
+; RUN: not --crash llc -O0 -mtriple=spirv64-unknown-unknown %s 2>&1 | FileCheck %s
+
+; CHECK: Vector of pointers requires SPV_INTEL_masked_gather_scatter
+
+declare spir_func void @foo(<2 x i64>)
+
+define spir_kernel void @test_ptrtoint(<2 x ptr addrspace(1)> %p) {
+entry:
+  %addr = ptrtoint <2 x ptr addrspace(1)> %p to <2 x i64>
+  call void @foo(<2 x i64> %addr)
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-ptrtoint.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-ptrtoint.ll
new file mode 100644
index 0000000000000..74988e07b537b
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-ptrtoint.ll
@@ -0,0 +1,33 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - | FileCheck %s
+; TODO: spirv-val does not support vector operands in OpConvertPtrToU and OpConvertUToPtr with SPV_INTEL_masked_gather_scatter
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpCapability MaskedGatherScatterINTEL
+; CHECK: OpExtension "SPV_INTEL_masked_gather_scatter"
+
+; CHECK-DAG: %[[#Int64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#PtrTy:]] = OpTypePointer CrossWorkgroup
+; CHECK-DAG: %[[#Vec2Ptr:]] = OpTypeVector %[[#PtrTy]] 2
+; CHECK-DAG: %[[#Vec2Int64:]] = OpTypeVector %[[#Int64]] 2
+
+declare spir_func void @foo(<2 x i64>)
+
+define spir_kernel void @test_ptrtoint_vec2(<2 x ptr addrspace(1)> %p) {
+; CHECK-LABEL: Begin function test_ptrtoint_vec2
+; CHECK: OpConvertPtrToU %[[#Vec2Int64]]
+entry:
+  %addr = ptrtoint <2 x ptr addrspace(1)> %p to <2 x i64>
+  call void @foo(<2 x i64> %addr)
+  ret void
+}
+
+declare spir_func void @bar(<2 x ptr addrspace(1)>)
+
+define spir_kernel void @test_inttoptr_vec2(<2 x i64> %addr) {
+; CHECK-LABEL: Begin function test_inttoptr_vec2
+; CHECK: OpConvertUToPtr %[[#Vec2Ptr]]
+entry:
+  %p = inttoptr <2 x i64> %addr to <2 x ptr addrspace(1)>
+  call void @bar(<2 x ptr addrspace(1)> %p)
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
index eb1128ac5417a..f358eba72d84d 100644
--- a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
+++ b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
@@ -25,6 +25,7 @@
 ; SPIRV-O0-NEXT:    FunctionPass Manager
 ; SPIRV-O0-NEXT:      Expand IR instructions
 ; SPIRV-O0-NEXT:      Expand Atomic instructions
+; SPIRV-O0-NEXT:      SPIRV convert masked memory intrinsics
 ; SPIRV-O0-NEXT:      Lower Garbage Collection Instructions
 ; SPIRV-O0-NEXT:      Shadow Stack GC Lowering
 ; SPIRV-O0-NEXT:      Remove unreachable blocks from the CFG
@@ -105,6 +106,7 @@
 ; SPIRV-Opt-NEXT:    FunctionPass Manager
 ; SPIRV-Opt-NEXT:      Expand IR instructions
 ; SPIRV-Opt-NEXT:      Expand Atomic instructions
+; SPIRV-Opt-NEXT:      SPIRV convert masked memory intrinsics
 ; SPIRV-Opt-NEXT:      Dominator Tree Construction
 ; SPIRV-Opt-NEXT:      Basic Alias Analysis (stateless AA impl)
 ; SPIRV-Opt-NEXT:      Natural Loop Information

>From 6027c89bfc5ca7f9a053471bb4705a5a95cc09a7 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 10 Mar 2026 11:04:09 +0100
Subject: [PATCH 2/5] Address review comments

---
 llvm/lib/Target/SPIRV/SPIRV.h                 |  2 +-
 .../SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp | 67 ++++++++++---------
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 19 ++++--
 .../masked-gather-scatter-no-extension.ll     |  4 +-
 4 files changed, 53 insertions(+), 39 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index b36d4f9cab31c..00bdb26cc0a73 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -32,7 +32,7 @@ FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerCombiner();
 FunctionPass *createSPIRVPreLegalizerPass();
 FunctionPass *createSPIRVPostLegalizerPass();
-FunctionPass *
+ModulePass *
 createSPIRVConvertMaskedMemIntrinsicsPass(const SPIRVTargetMachine *TM);
 ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
 ModulePass *createSPIRVPrepareGlobalsPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
index aa1a83657bff8..a30b2ec9aa3fa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
@@ -12,12 +12,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRV.h"
-#include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
 #include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/InstVisitor.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
 
 using namespace llvm;
@@ -26,34 +25,31 @@ using namespace llvm;
 
 namespace {
 
-class SPIRVConvertMaskedMemIntrinsics
-    : public FunctionPass,
-      public InstVisitor<SPIRVConvertMaskedMemIntrinsics> {
+class SPIRVConvertMaskedMemIntrinsics : public ModulePass {
   const SPIRVTargetMachine *TM = nullptr;
 
 public:
   static char ID;
 
-  SPIRVConvertMaskedMemIntrinsics() : FunctionPass(ID) {
+  SPIRVConvertMaskedMemIntrinsics() : ModulePass(ID) {
     initializeSPIRVConvertMaskedMemIntrinsicsPass(
         *PassRegistry::getPassRegistry());
   }
 
   SPIRVConvertMaskedMemIntrinsics(const SPIRVTargetMachine *TM)
-      : FunctionPass(ID), TM(TM) {
+      : ModulePass(ID), TM(TM) {
     initializeSPIRVConvertMaskedMemIntrinsicsPass(
         *PassRegistry::getPassRegistry());
   }
 
-  bool runOnFunction(Function &F) override;
-  void visitIntrinsicInst(IntrinsicInst &I);
+  bool runOnModule(Module &M) override;
 
   StringRef getPassName() const override {
     return "SPIRV convert masked memory intrinsics";
   }
 
 private:
-  SmallVector<Instruction *, 4> ToErase;
+  bool processIntrinsic(IntrinsicInst &I);
 };
 
 } // namespace
@@ -64,27 +60,34 @@ INITIALIZE_PASS(SPIRVConvertMaskedMemIntrinsics,
                 "spirv-convert-masked-mem-intrinsics",
                 "Convert masked memory intrinsics for SPIR-V", false, false)
 
-bool SPIRVConvertMaskedMemIntrinsics::runOnFunction(Function &F) {
+bool SPIRVConvertMaskedMemIntrinsics::runOnModule(Module &M) {
   if (!TM)
     return false;
 
-  ToErase.clear();
-  visit(F);
+  bool Changed = false;
+  SmallVector<IntrinsicInst *, 8> ToProcess;
 
-  for (Instruction *I : ToErase)
-    I->eraseFromParent();
+  for (Function &F : M) {
+    if (!F.isIntrinsic())
+      continue;
+    Intrinsic::ID IID = F.getIntrinsicID();
+    if (IID != Intrinsic::masked_gather && IID != Intrinsic::masked_scatter)
+      continue;
 
-  return !ToErase.empty();
+    for (User *U : F.users()) {
+      if (auto *II = dyn_cast<IntrinsicInst>(U))
+        ToProcess.push_back(II);
+    }
+  }
+
+  for (IntrinsicInst *II : ToProcess)
+    Changed |= processIntrinsic(*II);
+
+  return Changed;
 }
 
-void SPIRVConvertMaskedMemIntrinsics::visitIntrinsicInst(IntrinsicInst &I) {
+bool SPIRVConvertMaskedMemIntrinsics::processIntrinsic(IntrinsicInst &I) {
   if (I.getIntrinsicID() == Intrinsic::masked_gather) {
-    const SPIRVSubtarget &ST =
-        TM->getSubtarget<SPIRVSubtarget>(*I.getFunction());
-    if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
-      report_fatal_error(
-          "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter");
-
     IRBuilder<> B(I.getParent());
     B.SetInsertPoint(&I);
 
@@ -102,14 +105,11 @@ void SPIRVConvertMaskedMemIntrinsics::visitIntrinsicInst(IntrinsicInst &I) {
 
     auto *NewI = B.CreateIntrinsic(Intrinsic::spv_masked_gather, Types, Args);
     I.replaceAllUsesWith(NewI);
-    ToErase.push_back(&I);
-  } else if (I.getIntrinsicID() == Intrinsic::masked_scatter) {
-    const SPIRVSubtarget &ST =
-        TM->getSubtarget<SPIRVSubtarget>(*I.getFunction());
-    if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
-      report_fatal_error(
-          "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter");
+    I.eraseFromParent();
+    return true;
+  }
 
+  if (I.getIntrinsicID() == Intrinsic::masked_scatter) {
     IRBuilder<> B(I.getParent());
     B.SetInsertPoint(&I);
 
@@ -126,11 +126,14 @@ void SPIRVConvertMaskedMemIntrinsics::visitIntrinsicInst(IntrinsicInst &I) {
                                     Mask->getType()};
 
     B.CreateIntrinsic(Intrinsic::spv_masked_scatter, Types, Args);
-    ToErase.push_back(&I);
+    I.eraseFromParent();
+    return true;
   }
+
+  return false;
 }
 
-FunctionPass *
+ModulePass *
 llvm::createSPIRVConvertMaskedMemIntrinsicsPass(const SPIRVTargetMachine *TM) {
   return new SPIRVConvertMaskedMemIntrinsics(TM);
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7896b1877362c..90c4db70c0e6d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -29,6 +29,7 @@
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/Register.h"
 #include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -304,6 +305,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
                           MachineInstr &I) const;
   bool selectMaskedScatter(MachineInstr &I) const;
 
+  bool diagnoseUnsupported(const MachineInstr &I, const Twine &Msg) const;
+
   bool selectFrameIndex(Register ResVReg, SPIRVTypeInst ResType,
                         MachineInstr &I) const;
   bool selectAllocaArray(Register ResVReg, SPIRVTypeInst ResType,
@@ -1780,6 +1783,14 @@ bool SPIRVInstructionSelector::selectMaskedScatter(MachineInstr &I) const {
   return true;
 }
 
+bool SPIRVInstructionSelector::diagnoseUnsupported(const MachineInstr &I,
+                                                   const Twine &Msg) const {
+  const Function &F = I.getMF()->getFunction();
+  F.getContext().diagnose(
+      DiagnosticInfoUnsupported(F, Msg, I.getDebugLoc(), DS_Error));
+  return false;
+}
+
 bool SPIRVInstructionSelector::selectStackSave(Register ResVReg,
                                                SPIRVTypeInst ResType,
                                                MachineInstr &I) const {
@@ -4380,13 +4391,13 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_masked_gather:
     if (STI.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
       return selectMaskedGather(ResVReg, ResType, I);
-    report_fatal_error(
-        "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter");
+    return diagnoseUnsupported(
+        I, "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter");
   case Intrinsic::spv_masked_scatter:
     if (STI.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter))
       return selectMaskedScatter(I);
-    report_fatal_error(
-        "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter");
+    return diagnoseUnsupported(
+        I, "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter");
   default: {
     std::string DiagMsg;
     raw_string_ostream OS(DiagMsg);
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
index ec3a670ed85d5..6732b9f81b560 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
@@ -1,6 +1,6 @@
-; RUN: not --crash llc -O0 -mtriple=spirv64-unknown-unknown %s -o - 2>&1 | FileCheck %s
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o - 2>&1 | FileCheck %s
 
-; CHECK: LLVM ERROR: llvm.masked.gather requires SPV_INTEL_masked_gather_scatter
+; CHECK: error:{{.*}}: llvm.masked.gather requires SPV_INTEL_masked_gather_scatter
 
 define spir_kernel void @test_gather_no_ext(<4 x i64> %addrs, <4 x i1> %mask, <4 x i32> %passthru) {
 entry:

>From 6986f9557f3c5d5065c9358ebb96945464670bdc Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 10 Mar 2026 11:53:02 +0100
Subject: [PATCH 3/5] Address comments by jmmartinez

---
 .../SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp | 45 ++++++++++++++-----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 15 ++++---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 21 ++++-----
 .../masked-gather-scatter-no-extension.ll     | 17 +++++--
 .../vector-of-pointers-no-extension.ll        |  5 +--
 llvm/test/CodeGen/SPIRV/llc-pipeline.ll       |  6 ++-
 6 files changed, 74 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
index a30b2ec9aa3fa..956016f00ee86 100644
--- a/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
@@ -12,7 +12,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRV.h"
+#include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
@@ -65,31 +68,42 @@ bool SPIRVConvertMaskedMemIntrinsics::runOnModule(Module &M) {
     return false;
 
   bool Changed = false;
-  SmallVector<IntrinsicInst *, 8> ToProcess;
 
-  for (Function &F : M) {
+  for (Function &F : make_early_inc_range(M)) {
     if (!F.isIntrinsic())
       continue;
     Intrinsic::ID IID = F.getIntrinsicID();
     if (IID != Intrinsic::masked_gather && IID != Intrinsic::masked_scatter)
       continue;
 
-    for (User *U : F.users()) {
+    for (User *U : make_early_inc_range(F.users())) {
       if (auto *II = dyn_cast<IntrinsicInst>(U))
-        ToProcess.push_back(II);
+        Changed |= processIntrinsic(*II);
     }
-  }
 
-  for (IntrinsicInst *II : ToProcess)
-    Changed |= processIntrinsic(*II);
+    if (F.use_empty())
+      F.eraseFromParent();
+  }
 
   return Changed;
 }
 
 bool SPIRVConvertMaskedMemIntrinsics::processIntrinsic(IntrinsicInst &I) {
+  const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(*I.getFunction());
+
   if (I.getIntrinsicID() == Intrinsic::masked_gather) {
-    IRBuilder<> B(I.getParent());
-    B.SetInsertPoint(&I);
+    if (!ST.canUseExtension(
+            SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+      I.getContext().emitError(
+          &I, "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter "
+              "extension");
+      // Replace with poison to allow compilation to continue and report error.
+      I.replaceAllUsesWith(PoisonValue::get(I.getType()));
+      I.eraseFromParent();
+      return true;
+    }
+
+    IRBuilder<> B(&I);
 
     Value *Ptrs = I.getArgOperand(0);
     Value *Mask = I.getArgOperand(1);
@@ -110,8 +124,17 @@ bool SPIRVConvertMaskedMemIntrinsics::processIntrinsic(IntrinsicInst &I) {
   }
 
   if (I.getIntrinsicID() == Intrinsic::masked_scatter) {
-    IRBuilder<> B(I.getParent());
-    B.SetInsertPoint(&I);
+    if (!ST.canUseExtension(
+            SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+      I.getContext().emitError(
+          &I, "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter "
+              "extension");
+      // Erase the intrinsic to allow compilation to continue and report error.
+      I.eraseFromParent();
+      return true;
+    }
+
+    IRBuilder<> B(&I);
 
     Value *Values = I.getArgOperand(0);
     Value *Ptrs = I.getArgOperand(1);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index c60663b43bdc4..cf4ab00a4f3b3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -320,11 +320,16 @@ SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, SPIRVTypeInst ElemType,
   assert(NumElems >= 2 && "SPIR-V OpTypeVector requires at least 2 components");
 
   if (EleOpc == SPIRV::OpTypePointer) {
-    assert(cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
-               .canUseExtension(
-                   SPIRV::Extension::SPV_INTEL_masked_gather_scatter) &&
-           "Vector of pointers requires SPV_INTEL_masked_gather_scatter "
-           "extension");
+    if (!cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
+             .canUseExtension(
+                 SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+      const Function &F = MIRBuilder.getMF().getFunction();
+      F.getContext().diagnose(DiagnosticInfoUnsupported(
+          F,
+          "Vector of pointers requires SPV_INTEL_masked_gather_scatter "
+          "extension",
+          DebugLoc(), DS_Error));
+    }
   } else {
     assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
             EleOpc == SPIRV::OpTypeBool) &&
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 90c4db70c0e6d..f5d4bc4c3f6ad 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1733,16 +1733,17 @@ bool SPIRVInstructionSelector::selectMaskedGather(Register ResVReg,
                                                   SPIRVTypeInst ResType,
                                                   MachineInstr &I) const {
   assert(I.getNumExplicitDefs() == 1 && "Expected single def for gather");
-  // Operand indices (after explicit defs):
-  // 0: intrinsic ID
-  // 1: vector of pointers
-  // 2: alignment (i32 immediate)
-  // 3: mask (vector of i1)
-  // 4: passthru/fill value
-  Register PtrsReg = I.getOperand(I.getNumExplicitDefs() + 1).getReg();
-  uint32_t Alignment = I.getOperand(I.getNumExplicitDefs() + 2).getImm();
-  Register MaskReg = I.getOperand(I.getNumExplicitDefs() + 3).getReg();
-  Register PassthruReg = I.getOperand(I.getNumExplicitDefs() + 4).getReg();
+  // Operand indices:
+  // 0: result (def)
+  // 1: intrinsic ID
+  // 2: vector of pointers
+  // 3: alignment (i32 immediate)
+  // 4: mask (vector of i1)
+  // 5: passthru/fill value
+  Register PtrsReg = I.getOperand(2).getReg();
+  uint32_t Alignment = I.getOperand(3).getImm();
+  Register MaskReg = I.getOperand(4).getReg();
+  Register PassthruReg = I.getOperand(5).getReg();
   Register AlignmentReg = buildI32Constant(Alignment, I);
 
   MachineBasicBlock &BB = *I.getParent();
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
index 6732b9f81b560..2afda3549feb2 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
@@ -1,12 +1,21 @@
-; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o - 2>&1 | FileCheck %s
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o /dev/null 2>&1 | FileCheck %s
 
-; CHECK: error:{{.*}}: llvm.masked.gather requires SPV_INTEL_masked_gather_scatter
+; CHECK: error: llvm.masked.gather requires SPV_INTEL_masked_gather_scatter extension
 
-define spir_kernel void @test_gather_no_ext(<4 x i64> %addrs, <4 x i1> %mask, <4 x i32> %passthru) {
+define spir_kernel void @test_gather_no_ext(<4 x ptr addrspace(1)> %ptrs, <4 x i1> %mask, <4 x i32> %passthru) {
 entry:
-  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
   %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask, <4 x i32> %passthru)
   ret void
 }
 
 declare <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x i32>)
+
+; CHECK: error: llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter extension
+
+define spir_kernel void @test_scatter_no_ext(<4 x i32> %data, <4 x ptr addrspace(1)> %ptrs, <4 x i1> %mask) {
+entry:
+  call void @llvm.masked.scatter.v4i32.v4p1(<4 x i32> %data, <4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask)
+  ret void
+}
+
+declare void @llvm.masked.scatter.v4i32.v4p1(<4 x i32>, <4 x ptr addrspace(1)>, i32, <4 x i1>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll
index f84d3b655583b..d892b3487b725 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/vector-of-pointers-no-extension.ll
@@ -1,7 +1,6 @@
-; REQUIRES: asserts
-; RUN: not --crash llc -O0 -mtriple=spirv64-unknown-unknown %s 2>&1 | FileCheck %s
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o /dev/null 2>&1 | FileCheck %s
 
-; CHECK: Vector of pointers requires SPV_INTEL_masked_gather_scatter
+; CHECK: error:{{.*}}Vector of pointers requires SPV_INTEL_masked_gather_scatter extension
 
 declare spir_func void @foo(<2 x i64>)
 
diff --git a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
index f358eba72d84d..ed9ce718c6c55 100644
--- a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
+++ b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
@@ -25,7 +25,8 @@
 ; SPIRV-O0-NEXT:    FunctionPass Manager
 ; SPIRV-O0-NEXT:      Expand IR instructions
 ; SPIRV-O0-NEXT:      Expand Atomic instructions
-; SPIRV-O0-NEXT:      SPIRV convert masked memory intrinsics
+; SPIRV-O0-NEXT:    SPIRV convert masked memory intrinsics
+; SPIRV-O0-NEXT:    FunctionPass Manager
 ; SPIRV-O0-NEXT:      Lower Garbage Collection Instructions
 ; SPIRV-O0-NEXT:      Shadow Stack GC Lowering
 ; SPIRV-O0-NEXT:      Remove unreachable blocks from the CFG
@@ -106,7 +107,8 @@
 ; SPIRV-Opt-NEXT:    FunctionPass Manager
 ; SPIRV-Opt-NEXT:      Expand IR instructions
 ; SPIRV-Opt-NEXT:      Expand Atomic instructions
-; SPIRV-Opt-NEXT:      SPIRV convert masked memory intrinsics
+; SPIRV-Opt-NEXT:    SPIRV convert masked memory intrinsics
+; SPIRV-Opt-NEXT:    FunctionPass Manager
 ; SPIRV-Opt-NEXT:      Dominator Tree Construction
 ; SPIRV-Opt-NEXT:      Basic Alias Analysis (stateless AA impl)
 ; SPIRV-Opt-NEXT:      Natural Loop Information

>From 8c361a0054eae1a4e34b4b172510d8c86c032d5a Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 10 Mar 2026 12:40:23 +0100
Subject: [PATCH 4/5] Address comments by MrSidims

---
 llvm/lib/Target/SPIRV/CMakeLists.txt          |   1 -
 llvm/lib/Target/SPIRV/SPIRV.h                 |   3 -
 .../SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp | 162 ------------------
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  94 ++++++++++
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  11 --
 llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp  |   2 -
 .../Target/SPIRV/SPIRVTargetTransformInfo.cpp |   9 +
 .../Target/SPIRV/SPIRVTargetTransformInfo.h   |   3 +
 .../masked-gather-scatter-no-extension.ll     |  21 ++-
 llvm/test/CodeGen/SPIRV/llc-pipeline.ll       |   4 -
 10 files changed, 116 insertions(+), 194 deletions(-)
 delete mode 100644 llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp

diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index f492f4bb6507e..58989237ad3ea 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -22,7 +22,6 @@ add_llvm_target(SPIRVCodeGen
   SPIRVCallLowering.cpp
   SPIRVInlineAsmLowering.cpp
   SPIRVCommandLine.cpp
-  SPIRVConvertMaskedMemIntrinsics.cpp
   SPIRVEmitIntrinsics.cpp
   SPIRVGlobalRegistry.cpp
   SPIRVInstrInfo.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 00bdb26cc0a73..da4da1a3fe83b 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -32,8 +32,6 @@ FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerCombiner();
 FunctionPass *createSPIRVPreLegalizerPass();
 FunctionPass *createSPIRVPostLegalizerPass();
-ModulePass *
-createSPIRVConvertMaskedMemIntrinsicsPass(const SPIRVTargetMachine *TM);
 ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
 ModulePass *createSPIRVPrepareGlobalsPass();
 MachineFunctionPass *createSPIRVEmitNonSemanticDIPass(SPIRVTargetMachine *TM);
@@ -61,7 +59,6 @@ void initializeSPIRVPrepareGlobalsPass(PassRegistry &);
 void initializeSPIRVStripConvergentIntrinsicsPass(PassRegistry &);
 void initializeSPIRVLegalizeImplicitBindingPass(PassRegistry &);
 void initializeSPIRVLegalizeZeroSizeArraysLegacyPass(PassRegistry &);
-void initializeSPIRVConvertMaskedMemIntrinsicsPass(PassRegistry &);
 } // namespace llvm
 
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRV_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
deleted file mode 100644
index 956016f00ee86..0000000000000
--- a/llvm/lib/Target/SPIRV/SPIRVConvertMaskedMemIntrinsics.cpp
+++ /dev/null
@@ -1,162 +0,0 @@
-//===- SPIRVConvertMaskedMemIntrinsics.cpp - Convert masked mem intrinsics ==//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This pass converts llvm.masked.gather/scatter to spv.masked.gather/scatter
-// to prevent them from being scalarized by the generic scalarization pass.
-//
-//===----------------------------------------------------------------------===//
-
-#include "SPIRV.h"
-#include "SPIRVSubtarget.h"
-#include "SPIRVTargetMachine.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/IntrinsicInst.h"
-#include "llvm/IR/IntrinsicsSPIRV.h"
-#include "llvm/IR/Module.h"
-#include "llvm/InitializePasses.h"
-
-using namespace llvm;
-
-#define DEBUG_TYPE "spirv-convert-masked-mem-intrinsics"
-
-namespace {
-
-class SPIRVConvertMaskedMemIntrinsics : public ModulePass {
-  const SPIRVTargetMachine *TM = nullptr;
-
-public:
-  static char ID;
-
-  SPIRVConvertMaskedMemIntrinsics() : ModulePass(ID) {
-    initializeSPIRVConvertMaskedMemIntrinsicsPass(
-        *PassRegistry::getPassRegistry());
-  }
-
-  SPIRVConvertMaskedMemIntrinsics(const SPIRVTargetMachine *TM)
-      : ModulePass(ID), TM(TM) {
-    initializeSPIRVConvertMaskedMemIntrinsicsPass(
-        *PassRegistry::getPassRegistry());
-  }
-
-  bool runOnModule(Module &M) override;
-
-  StringRef getPassName() const override {
-    return "SPIRV convert masked memory intrinsics";
-  }
-
-private:
-  bool processIntrinsic(IntrinsicInst &I);
-};
-
-} // namespace
-
-char SPIRVConvertMaskedMemIntrinsics::ID = 0;
-
-INITIALIZE_PASS(SPIRVConvertMaskedMemIntrinsics,
-                "spirv-convert-masked-mem-intrinsics",
-                "Convert masked memory intrinsics for SPIR-V", false, false)
-
-bool SPIRVConvertMaskedMemIntrinsics::runOnModule(Module &M) {
-  if (!TM)
-    return false;
-
-  bool Changed = false;
-
-  for (Function &F : make_early_inc_range(M)) {
-    if (!F.isIntrinsic())
-      continue;
-    Intrinsic::ID IID = F.getIntrinsicID();
-    if (IID != Intrinsic::masked_gather && IID != Intrinsic::masked_scatter)
-      continue;
-
-    for (User *U : make_early_inc_range(F.users())) {
-      if (auto *II = dyn_cast<IntrinsicInst>(U))
-        Changed |= processIntrinsic(*II);
-    }
-
-    if (F.use_empty())
-      F.eraseFromParent();
-  }
-
-  return Changed;
-}
-
-bool SPIRVConvertMaskedMemIntrinsics::processIntrinsic(IntrinsicInst &I) {
-  const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(*I.getFunction());
-
-  if (I.getIntrinsicID() == Intrinsic::masked_gather) {
-    if (!ST.canUseExtension(
-            SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
-      I.getContext().emitError(
-          &I, "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter "
-              "extension");
-      // Replace with poison to allow compilation to continue and report error.
-      I.replaceAllUsesWith(PoisonValue::get(I.getType()));
-      I.eraseFromParent();
-      return true;
-    }
-
-    IRBuilder<> B(&I);
-
-    Value *Ptrs = I.getArgOperand(0);
-    Value *Mask = I.getArgOperand(1);
-    Value *Passthru = I.getArgOperand(2);
-
-    // Alignment is stored as a parameter attribute, not as a regular parameter
-    uint32_t Alignment = I.getParamAlign(0).valueOrOne().value();
-
-    SmallVector<Value *, 4> Args = {Ptrs, B.getInt32(Alignment), Mask,
-                                    Passthru};
-    SmallVector<Type *, 4> Types = {I.getType(), Ptrs->getType(),
-                                    Mask->getType(), Passthru->getType()};
-
-    auto *NewI = B.CreateIntrinsic(Intrinsic::spv_masked_gather, Types, Args);
-    I.replaceAllUsesWith(NewI);
-    I.eraseFromParent();
-    return true;
-  }
-
-  if (I.getIntrinsicID() == Intrinsic::masked_scatter) {
-    if (!ST.canUseExtension(
-            SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
-      I.getContext().emitError(
-          &I, "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter "
-              "extension");
-      // Erase the intrinsic to allow compilation to continue and report error.
-      I.eraseFromParent();
-      return true;
-    }
-
-    IRBuilder<> B(&I);
-
-    Value *Values = I.getArgOperand(0);
-    Value *Ptrs = I.getArgOperand(1);
-    Value *Mask = I.getArgOperand(2);
-
-    // Alignment is stored as a parameter attribute on the ptrs parameter (arg
-    // 1)
-    uint32_t Alignment = I.getParamAlign(1).valueOrOne().value();
-
-    SmallVector<Value *, 4> Args = {Values, Ptrs, B.getInt32(Alignment), Mask};
-    SmallVector<Type *, 3> Types = {Values->getType(), Ptrs->getType(),
-                                    Mask->getType()};
-
-    B.CreateIntrinsic(Intrinsic::spv_masked_scatter, Types, Args);
-    I.eraseFromParent();
-    return true;
-  }
-
-  return false;
-}
-
-ModulePass *
-llvm::createSPIRVConvertMaskedMemIntrinsicsPass(const SPIRVTargetMachine *TM) {
-  return new SPIRVConvertMaskedMemIntrinsics(TM);
-}
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index c84f41dada005..73d05f4dc574a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -299,6 +299,8 @@ class SPIRVEmitIntrinsics
   bool processFunctionPointers(Module &M);
   void parseFunDeclarations(Module &M);
   void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);
+  bool processMaskedMemIntrinsic(IntrinsicInst &I);
+  bool convertMaskedMemIntrinsics(Module &M);
 
   void emitUnstructuredLoopControls(Function &F, IRBuilder<> &B);
 
@@ -3281,9 +3283,101 @@ void SPIRVEmitIntrinsics::parseFunDeclarations(Module &M) {
   }
 }
 
+bool SPIRVEmitIntrinsics::processMaskedMemIntrinsic(IntrinsicInst &I) {
+  const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(*I.getFunction());
+
+  if (I.getIntrinsicID() == Intrinsic::masked_gather) {
+    if (!ST.canUseExtension(
+            SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+      I.getContext().emitError(
+          &I, "llvm.masked.gather requires SPV_INTEL_masked_gather_scatter "
+              "extension");
+      // Replace with poison to allow compilation to continue and report error.
+      I.replaceAllUsesWith(PoisonValue::get(I.getType()));
+      I.eraseFromParent();
+      return true;
+    }
+
+    IRBuilder<> B(&I);
+
+    Value *Ptrs = I.getArgOperand(0);
+    Value *Mask = I.getArgOperand(1);
+    Value *Passthru = I.getArgOperand(2);
+
+    // Alignment is stored as a parameter attribute, not as a regular parameter
+    uint32_t Alignment = I.getParamAlign(0).valueOrOne().value();
+
+    SmallVector<Value *, 4> Args = {Ptrs, B.getInt32(Alignment), Mask,
+                                    Passthru};
+    SmallVector<Type *, 4> Types = {I.getType(), Ptrs->getType(),
+                                    Mask->getType(), Passthru->getType()};
+
+    auto *NewI = B.CreateIntrinsic(Intrinsic::spv_masked_gather, Types, Args);
+    I.replaceAllUsesWith(NewI);
+    I.eraseFromParent();
+    return true;
+  }
+
+  if (I.getIntrinsicID() == Intrinsic::masked_scatter) {
+    if (!ST.canUseExtension(
+            SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+      I.getContext().emitError(
+          &I, "llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter "
+              "extension");
+      // Erase the intrinsic to allow compilation to continue and report error.
+      I.eraseFromParent();
+      return true;
+    }
+
+    IRBuilder<> B(&I);
+
+    Value *Values = I.getArgOperand(0);
+    Value *Ptrs = I.getArgOperand(1);
+    Value *Mask = I.getArgOperand(2);
+
+    // Alignment is stored as a parameter attribute on the ptrs parameter (arg
+    // 1)
+    uint32_t Alignment = I.getParamAlign(1).valueOrOne().value();
+
+    SmallVector<Value *, 4> Args = {Values, Ptrs, B.getInt32(Alignment), Mask};
+    SmallVector<Type *, 3> Types = {Values->getType(), Ptrs->getType(),
+                                    Mask->getType()};
+
+    B.CreateIntrinsic(Intrinsic::spv_masked_scatter, Types, Args);
+    I.eraseFromParent();
+    return true;
+  }
+
+  return false;
+}
+
+bool SPIRVEmitIntrinsics::convertMaskedMemIntrinsics(Module &M) {
+  bool Changed = false;
+
+  for (Function &F : make_early_inc_range(M)) {
+    if (!F.isIntrinsic())
+      continue;
+    Intrinsic::ID IID = F.getIntrinsicID();
+    if (IID != Intrinsic::masked_gather && IID != Intrinsic::masked_scatter)
+      continue;
+
+    for (User *U : make_early_inc_range(F.users())) {
+      if (auto *II = dyn_cast<IntrinsicInst>(U))
+        Changed |= processMaskedMemIntrinsic(*II);
+    }
+
+    if (F.use_empty())
+      F.eraseFromParent();
+  }
+
+  return Changed;
+}
+
 bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   bool Changed = false;
 
+  Changed |= convertMaskedMemIntrinsics(M);
+
   parseFunDeclarations(M);
   insertConstantsForFPFastMathDefault(M);
   GVUsers.init(M);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 3b86bde347287..9c0642d0c9114 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1873,17 +1873,6 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpAtomicFMaxEXT:
     AddAtomicFloatRequirements(MI, Reqs, ST);
     break;
-  case SPIRV::OpConvertPtrToU:
-  case SPIRV::OpConvertUToPtr: {
-    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
-    SPIRVTypeInst ResultType = MRI.getVRegDef(MI.getOperand(1).getReg());
-    if (ResultType->getOpcode() == SPIRV::OpTypeVector &&
-        ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
-      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
-      Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
-    }
-    break;
-  }
   case SPIRV::OpConvertBF16ToFINTEL:
   case SPIRV::OpConvertFToBF16INTEL:
     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index 57da3e92ec582..1759b34af3e90 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -61,7 +61,6 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() {
   initializeSPIRVPreLegalizerPass(PR);
   initializeSPIRVPostLegalizerPass(PR);
   initializeSPIRVMergeRegionExitTargetsPass(PR);
-  initializeSPIRVConvertMaskedMemIntrinsicsPass(PR);
   initializeSPIRVEmitIntrinsicsPass(PR);
   initializeSPIRVEmitNonSemanticDIPass(PR);
   initializeSPIRVPrepareFunctionsPass(PR);
@@ -176,7 +175,6 @@ TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) {
 
 void SPIRVPassConfig::addIRPasses() {
   addPass(createAtomicExpandLegacyPass());
-  addPass(createSPIRVConvertMaskedMemIntrinsicsPass(&TM));
 
   TargetPassConfig::addIRPasses();
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.cpp
index 95093d2b3c263..d69591377d315 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVTargetTransformInfo.h"
+#include "SPIRVSubtarget.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 
 using namespace llvm;
@@ -38,3 +39,11 @@ Value *llvm::SPIRVTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
     return nullptr;
   }
 }
+
+bool SPIRVTTIImpl::isLegalMaskedGather(Type *DataType, Align Alignment) const {
+  return ST->canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
+}
+
+bool SPIRVTTIImpl::isLegalMaskedScatter(Type *DataType, Align Alignment) const {
+  return ST->canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h
index 60c4e2de2fb23..35a1aa1922eed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h
@@ -61,6 +61,9 @@ class SPIRVTTIImpl final : public BasicTTIImplBase<SPIRVTTIImpl> {
                                           Value *NewV) const override;
 
   bool allowVectorElementIndexingUsingGEP() const override { return false; }
+
+  bool isLegalMaskedGather(Type *DataType, Align Alignment) const override;
+  bool isLegalMaskedScatter(Type *DataType, Align Alignment) const override;
 };
 
 } // namespace llvm
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
index 2afda3549feb2..e042131343eb7 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
@@ -1,21 +1,20 @@
 ; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o /dev/null 2>&1 | FileCheck %s
 
-; CHECK: error: llvm.masked.gather requires SPV_INTEL_masked_gather_scatter extension
+declare <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x i32>)
+declare void @llvm.masked.scatter.v4i32.v4p1(<4 x i32>, <4 x ptr addrspace(1)>, i32, <4 x i1>)
 
-define spir_kernel void @test_gather_no_ext(<4 x ptr addrspace(1)> %ptrs, <4 x i1> %mask, <4 x i32> %passthru) {
+; CHECK: error: {{.*}}Vector of pointers requires SPV_INTEL_masked_gather_scatter extension
+
+define spir_kernel void @test_gather_no_ext(<4 x i64> %addrs) {
 entry:
-  %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask, <4 x i32> %passthru)
+  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+  %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i32> zeroinitializer)
   ret void
 }
 
-declare <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x i32>)
-
-; CHECK: error: llvm.masked.scatter requires SPV_INTEL_masked_gather_scatter extension
-
-define spir_kernel void @test_scatter_no_ext(<4 x i32> %data, <4 x ptr addrspace(1)> %ptrs, <4 x i1> %mask) {
+define spir_kernel void @test_scatter_no_ext(<4 x i32> %data, <4 x i64> %addrs) {
 entry:
-  call void @llvm.masked.scatter.v4i32.v4p1(<4 x i32> %data, <4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> %mask)
+  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+  call void @llvm.masked.scatter.v4i32.v4p1(<4 x i32> %data, <4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
   ret void
 }
-
-declare void @llvm.masked.scatter.v4i32.v4p1(<4 x i32>, <4 x ptr addrspace(1)>, i32, <4 x i1>)
diff --git a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
index ed9ce718c6c55..eb1128ac5417a 100644
--- a/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
+++ b/llvm/test/CodeGen/SPIRV/llc-pipeline.ll
@@ -25,8 +25,6 @@
 ; SPIRV-O0-NEXT:    FunctionPass Manager
 ; SPIRV-O0-NEXT:      Expand IR instructions
 ; SPIRV-O0-NEXT:      Expand Atomic instructions
-; SPIRV-O0-NEXT:    SPIRV convert masked memory intrinsics
-; SPIRV-O0-NEXT:    FunctionPass Manager
 ; SPIRV-O0-NEXT:      Lower Garbage Collection Instructions
 ; SPIRV-O0-NEXT:      Shadow Stack GC Lowering
 ; SPIRV-O0-NEXT:      Remove unreachable blocks from the CFG
@@ -107,8 +105,6 @@
 ; SPIRV-Opt-NEXT:    FunctionPass Manager
 ; SPIRV-Opt-NEXT:      Expand IR instructions
 ; SPIRV-Opt-NEXT:      Expand Atomic instructions
-; SPIRV-Opt-NEXT:    SPIRV convert masked memory intrinsics
-; SPIRV-Opt-NEXT:    FunctionPass Manager
 ; SPIRV-Opt-NEXT:      Dominator Tree Construction
 ; SPIRV-Opt-NEXT:      Basic Alias Analysis (stateless AA impl)
 ; SPIRV-Opt-NEXT:      Natural Loop Information

>From a54ef257c2fa5cc3de3bcc35add30d7bd91bd58e Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 10 Mar 2026 12:48:31 +0100
Subject: [PATCH 5/5] Split no-extension tests

---
 ...no-extension.ll => masked-gather-no-extension.ll} |  8 --------
 .../masked-scatter-no-extension.ll                   | 12 ++++++++++++
 2 files changed, 12 insertions(+), 8 deletions(-)
 rename llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/{masked-gather-scatter-no-extension.ll => masked-gather-no-extension.ll} (59%)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-scatter-no-extension.ll

diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-no-extension.ll
similarity index 59%
rename from llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
rename to llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-no-extension.ll
index e042131343eb7..f3e940f1a5ff2 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-scatter-no-extension.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-gather-no-extension.ll
@@ -1,7 +1,6 @@
 ; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o /dev/null 2>&1 | FileCheck %s
 
 declare <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)>, i32, <4 x i1>, <4 x i32>)
-declare void @llvm.masked.scatter.v4i32.v4p1(<4 x i32>, <4 x ptr addrspace(1)>, i32, <4 x i1>)
 
 ; CHECK: error: {{.*}}Vector of pointers requires SPV_INTEL_masked_gather_scatter extension
 
@@ -11,10 +10,3 @@ entry:
   %data = call <4 x i32> @llvm.masked.gather.v4i32.v4p1(<4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i32> zeroinitializer)
   ret void
 }
-
-define spir_kernel void @test_scatter_no_ext(<4 x i32> %data, <4 x i64> %addrs) {
-entry:
-  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
-  call void @llvm.masked.scatter.v4i32.v4p1(<4 x i32> %data, <4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
-  ret void
-}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-scatter-no-extension.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-scatter-no-extension.ll
new file mode 100644
index 0000000000000..4cbef7f905047
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/masked-scatter-no-extension.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o /dev/null 2>&1 | FileCheck %s
+
+declare void @llvm.masked.scatter.v4i32.v4p1(<4 x i32>, <4 x ptr addrspace(1)>, i32, <4 x i1>)
+
+; CHECK: error: {{.*}}Vector of pointers requires SPV_INTEL_masked_gather_scatter extension
+
+define spir_kernel void @test_scatter_no_ext(<4 x i32> %data, <4 x i64> %addrs) {
+entry:
+  %ptrs = inttoptr <4 x i64> %addrs to <4 x ptr addrspace(1)>
+  call void @llvm.masked.scatter.v4i32.v4p1(<4 x i32> %data, <4 x ptr addrspace(1)> %ptrs, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 false>)
+  ret void
+}



More information about the llvm-commits mailing list