[llvm] [SPIRV] Support for extension SPV_INTEL_masked_gather_scatter (PR #131566)

via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 16 22:56:04 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: VISHAKH PRAKASH (VishMCW)

<details>
<summary>Changes</summary>

Add intrinsic SPV_INTEL_masked_gather_scatter
- Introduce a new pass CodeGenPrepare that runs before all passes

---
Full diff: https://github.com/llvm/llvm-project/pull/131566.diff


14 Files Affected:

- (modified) llvm/docs/SPIRVUsage.rst (+2) 
- (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+12) 
- (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1) 
- (added) llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp (+125) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+3-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+12-4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+15-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+28) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+8) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+5-1) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll (+50) 


``````````diff
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 3e19ff881dffc..781a16dff0d0f 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -209,6 +209,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
      - Adds the ability to declare extended instruction sets that have no semantic impact and can be safely removed from a module.
    * - ``SPV_INTEL_fp_max_error``
      - Adds the ability to specify the maximum error for floating-point operations.
+   * - ``SPV_INTEL_masked_gather_scatter``
+     - Allows OpTypeVector to have a phyiscal pointer type component type and introduces gather scatter instructions
 
 To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
 
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index df3e137c80980..5a70350af1804 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -143,4 +143,16 @@ let TargetPrefix = "spv" in {
 
   // FPMaxErrorDecorationINTEL
   def int_spv_assign_fpmaxerror_decoration: Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>;
+
+  // Masked Gather Scatter Intrinsics
+  def int_spv_masked_gather
+    :DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+            [LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty,
+             LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMMatchType<0>],
+            [IntrReadMem, IntrWillReturn, ImmArg<ArgIndex<1>>]>;
+  def int_spv_masked_scatter
+    :DefaultAttrsIntrinsic<[],
+                [llvm_anyvector_ty, LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty,
+                 LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
+                [IntrWriteMem, IntrWillReturn, ImmArg<ArgIndex<2>>]>;
 }
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 4a2b534b948d6..48ef19b334695 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -46,6 +46,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVTargetMachine.cpp
   SPIRVUtils.cpp
   SPIRVEmitNonSemanticDI.cpp
+  SPIRVCodeGenPrepare.cpp
 
   LINK_COMPONENTS
   Analysis
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index d765dfe370be2..a843a7144f3c2 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -19,6 +19,7 @@ class SPIRVSubtarget;
 class InstructionSelector;
 class RegisterBankInfo;
 
+ModulePass *createSPIRVCodeGenPreparePass( const SPIRVTargetMachine &TM);
 ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
 FunctionPass *createSPIRVStructurizerPass();
 FunctionPass *createSPIRVMergeRegionExitTargetsPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp b/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp
new file mode 100644
index 0000000000000..ea497880360ff
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVCodeGenPrepare.cpp
@@ -0,0 +1,125 @@
+//===-- SPIRVCodeGenPreparePass.cpp - preserve masked scatter gather --*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass preserves the intrinsic @llvm.masked.* intrinsics by replacing 
+// it with a spv intrinsic
+//===----------------------------------------------------------------------===//
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Pass.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "SPIRV.h"
+#include "SPIRVGlobalRegistry.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+
+using namespace llvm;
+
+namespace llvm {
+void initializeSPIRVCodeGenPreparePass(PassRegistry &);
+} // namespace llvm
+
+namespace {
+class SPIRVCodeGenPrepare : public ModulePass {
+
+  const SPIRVTargetMachine &TM;
+
+public:
+  static char ID;
+  SPIRVCodeGenPrepare(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
+    initializeSPIRVCodeGenPreparePass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnModule(Module &M) override;
+
+  StringRef getPassName() const override {
+    return "SPIRV CodeGen prepare pass";
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    ModulePass::getAnalysisUsage(AU);
+  }
+};
+
+} // namespace
+
+char SPIRVCodeGenPrepare::ID = 0;
+INITIALIZE_PASS(SPIRVCodeGenPrepare, "codegen-prepare", "SPIRV codegen prepare",
+                false, false)
+
+static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
+                                     ArrayRef<unsigned> OpNos) {
+  Function *F = nullptr;
+  if (OpNos.empty()) {
+    F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID);
+  } else {
+    SmallVector<Type *> Tys;
+    for (unsigned OpNo : OpNos) {
+      Tys.push_back(II->getOperand(OpNo)->getType());
+    }
+
+    F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID, Tys);
+  }
+  II->setCalledFunction(F);
+  return true;
+}
+
+static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic,
+                                     const SPIRVSubtarget &ST,
+                                     SPIRVGlobalRegistry &GR) {
+  auto IntrinsicID = Intrinsic->getIntrinsicID();
+  if (ST.canUseExtension(
+          SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) {
+    switch (IntrinsicID) {
+    case Intrinsic::masked_scatter: {
+      return toSpvOverloadedIntrinsic(
+          Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_scatter,
+          {0, 1});
+    } break;
+
+    case Intrinsic::masked_gather: {
+      VectorType* Vty = dyn_cast<VectorType>(Intrinsic -> getOperand(0) -> getType());
+      PointerType* PTy = dyn_cast<PointerType>(Vty -> getElementType());
+      
+      VectorType* ResVecType = dyn_cast<VectorType>(Intrinsic -> getType());
+      Type *CompType = ResVecType -> getElementType();
+      GR.addPointerToBaseTypeMap(PTy, CompType);
+      return toSpvOverloadedIntrinsic(
+          Intrinsic, Intrinsic::SPVIntrinsics::spv_masked_gather, {3, 0});
+    } break;
+    default:
+      break;
+    }
+  }
+  return false;
+}
+
+bool SPIRVCodeGenPrepare::runOnModule(Module &M) {
+  bool Changed = false;
+  for (Function &F : M) {
+    const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(F);
+    SPIRVGlobalRegistry &GR = *(STI.getSPIRVGlobalRegistry());
+    for (BasicBlock &BB : F) {
+      for (Instruction &I : BB) {
+        if (auto *II = dyn_cast<IntrinsicInst>(&I))
+          Changed |= lowerIntrinsicToFunction(II, STI, GR);
+      }
+    }
+  }
+  return Changed;
+}
+
+ModulePass *llvm::createSPIRVCodeGenPreparePass(const SPIRVTargetMachine &TM) {
+  return new SPIRVCodeGenPrepare(TM);
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 37119bf01545c..357bccbbc2c1a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -92,7 +92,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_INTEL_long_composites",
          SPIRV::Extension::Extension::SPV_INTEL_long_composites},
         {"SPV_INTEL_fp_max_error",
-         SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}};
+         SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
+        {"SPV_INTEL_masked_gather_scatter",
+         SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
                                   llvm::StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cbec1c95eadc3..829de62ed8213 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -231,11 +231,17 @@ SPIRVType *SPIRVGlobalRegistry::createOpType(
 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
                                                 SPIRVType *ElemType,
                                                 MachineIRBuilder &MIRBuilder) {
+
+  const SPIRVSubtarget &ST =
+      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
   auto EleOpc = ElemType->getOpcode();
   (void)EleOpc;
-  assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
-          EleOpc == SPIRV::OpTypeBool) &&
-         "Invalid vector element type");
+  if (!ST.canUseExtension(
+          SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) {
+    assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
+            EleOpc == SPIRV::OpTypeBool) &&
+           "Invalid vector element type");
+  }
 
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
     return MIRBuilder.buildInstr(SPIRV::OpTypeVector)
@@ -1060,6 +1066,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     return Width == 1 ? getOpTypeBool(MIRBuilder)
                       : getOpTypeInt(Width, MIRBuilder, false);
   }
+
   if (Ty->isFloatingPointTy())
     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
   if (Ty->isVoidTy())
@@ -1088,11 +1095,12 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
       ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR));
     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
   }
-
   unsigned AddrSpace = typeToAddressSpace(Ty);
   SPIRVType *SpvElementType = nullptr;
   if (Type *ElemTy = ::getPointeeType(Ty))
     SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
+  else if (Type *ElemTy = this->findPointerToBaseTypeMap(Ty))
+    SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR);
   else
     SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 89599f17ef737..24cc62b3b69a0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -78,7 +78,8 @@ class SPIRVGlobalRegistry {
 
   // Holds the maximum ID we have in the module.
   unsigned Bound;
-
+  /// maps the pointer type to the base type
+  DenseMap<Type *, Type *> PointerToBaseTypeMap;
   // Maps values associated with untyped pointers into deduced element types of
   // untyped pointers.
   DenseMap<Value *, Type *> DeducedElTys;
@@ -635,6 +636,19 @@ class SPIRVGlobalRegistry {
   void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg);
   void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
   void updateAssignType(CallInst *AssignCI, Value *Arg, Value *OfType);
+
+  void addPointerToBaseTypeMap(Type *PTy, Type *BaseTy) {
+      if(PTy == nullptr)
+          return;
+      assert(PTy->isPointerTy() && "PTy must be a pointer type");
+    PointerToBaseTypeMap[PTy] = BaseTy;
+  }
+
+  Type *findPointerToBaseTypeMap(const Type *PTy) {
+    auto BaseTyIter = PointerToBaseTypeMap.find(PTy); 
+    return BaseTyIter == PointerToBaseTypeMap.end() ? nullptr : BaseTyIter -> second;
+  }
+
 };
 } // end namespace llvm
 #endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a8f862271dbab..d16621eff4f3b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -956,3 +956,9 @@ def OpAliasScopeDeclINTEL: Op<5912, (outs ID:$res), (ins ID:$AliasDomain, variab
                   "$res = OpAliasScopeDeclINTEL $AliasDomain">;
 def OpAliasScopeListDeclINTEL: Op<5913, (outs ID:$res), (ins variable_ops),
                   "$res = OpAliasScopeListDeclINTEL">;
+
+//SPV_INTEL_masked_gather_scatter
+def OpMaskedGatherINTEL: Op<6428, (outs ID:$res) , (ins TYPE:$type, ID:$PtrVector, i32imm:$alignment, ID:$mask, ID:$fillempty),
+				  "$res = OpMaskedGatherINTEL $type $PtrVector $alignment $mask $fillempty">;
+def OpMaskedScatterINTEL: Op<6429, (outs) , (ins ID:$inVector, ID:$PtrVector, i32imm:$alignment, ID:$mask),
+				  "OpMaskedScatterINTEL $inVector $PtrVector $alignment $mask">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b188f36ca9a9e..1b25470791286 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3192,6 +3192,34 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_discard: {
     return selectDiscard(ResVReg, ResType, I);
   }
+  case Intrinsic::spv_masked_gather: {
+    Register MemLoc = I.getOperand(2).getReg();
+    int32_t Alignment = I.getOperand(3).getImm();
+    Register Mask = I.getOperand(4).getReg();
+    Register PassThrough = I.getOperand(5).getReg();
+    return BuildMI(*(I.getParent()), I, I.getDebugLoc(),
+                   TII.get(SPIRV::OpMaskedGatherINTEL))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ResType))
+        .addUse(MemLoc)
+        .addImm(Alignment)
+        .addUse(Mask)
+        .addUse(PassThrough)
+        .constrainAllUses(TII, TRI, RBI);
+  }
+  case Intrinsic::spv_masked_scatter: {
+    Register Value = I.getOperand(1).getReg();
+    Register MemLocs = I.getOperand(2).getReg();
+    int32_t Alignment = I.getOperand(3).getImm();
+    Register Mask = I.getOperand(4).getReg();
+    auto MIB = BuildMI(*(I.getParent()), I, I.getDebugLoc(),
+                       TII.get(SPIRV::OpMaskedScatterINTEL))
+                   .addUse(Value)
+                   .addUse(MemLocs)
+                   .addImm(Alignment)
+                   .addUse(Mask);
+    return MIB.constrainAllUses(TII, TRI, RBI);
+  }
   default: {
     std::string DiagMsg;
     raw_string_ostream OS(DiagMsg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 63894acacbc73..e487cca7decd5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1770,6 +1770,14 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
     break;
   }
+  case SPIRV::OpMaskedGatherINTEL:
+  case SPIRV::OpMaskedScatterINTEL:
+    if (ST.canUseExtension(
+            SPIRV::Extension::Extension::SPV_INTEL_masked_gather_scatter)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
+      Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
+    }
+    break;
 
   default:
     break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index caee778eddbc4..d3ee45b6591a7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -313,6 +313,7 @@ defm SPV_INTEL_bindless_images : ExtensionOperand<116>;
 defm SPV_INTEL_long_composites : ExtensionOperand<117>;
 defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
 defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
+defm SPV_INTEL_masked_gather_scatter : ExtensionOperand<120>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -513,6 +514,7 @@ defm LongCompositesINTEL : CapabilityOperand<6089, 0, 0, [SPV_INTEL_long_composi
 defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_images], []>;
 defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>;
 defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;
+defm MaskedGatherScatterINTEL : CapabilityOperand<6427, 0, 0, [SPV_INTEL_masked_gather_scatter], []>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index 0aa214dd354ee..ad26ed3b597db 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -175,8 +175,12 @@ TargetPassConfig *SPIRVTargetMachine::createPassConfig(PassManagerBase &PM) {
 }
 
 void SPIRVPassConfig::addIRPasses() {
-  TargetPassConfig::addIRPasses();
 
+  if (TM.getSubtargetImpl()->canUseExtension(
+          SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
+    addPass(createSPIRVCodeGenPreparePass(TM));
+  }
+  TargetPassConfig::addIRPasses();
   if (TM.getSubtargetImpl()->isVulkanEnv()) {
     // 1.  Simplify loop for subsequent transformations. After this steps, loops
     // have the following properties:
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll
new file mode 100644
index 0000000000000..b10c63b700a5a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_masked_gather_scatter/intel-gather-scatter.ll
@@ -0,0 +1,50 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_masked_gather_scatter %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-NOT: Name [[#]] "llvm.masked.gather.v4i32.v4p4"
+; CHECK-NOT: Name [[#]] "llvm.masked.scatter.v4i32.v4p4"
+
+; CHECK-DAG: OpCapability MaskedGatherScatterINTEL
+; CHECK-DAG: OpExtension "SPV_INTEL_masked_gather_scatter"
+
+; CHECK-DAG: %[[#TYPEINT:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#TYPEPTRINT:]] = OpTypePointer Generic %[[#TYPEINT]]
+; CHECK-DAG: %[[#TYPEVECPTR:]] = OpTypeVector %[[#TYPEPTRINT]] 4
+; CHECK-DAG: %[[#TYPEVECINT:]] = OpTypeVector %[[#TYPEINT]] 4
+
+; CHECK-DAG: %[[#CONST4:]] = OpConstant %[[#TYPEINT]]  4
+; CHECK-DAG: %[[#CONST0:]] = OpConstant %[[#TYPEINT]]  0
+; CHECK-DAG: %[[#CONST1:]] = OpConstant %[[#TYPEINT]]  1
+; CHECK-DAG: %[[#TRUE:]] = OpConstantTrue %[[#]] 
+; CHECK-DAG: %[[#FALSE:]] = OpConstantFalse %[[#]] 
+; CHECK-DAG: %[[#MASK1:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#FALSE]] %[[#TRUE]] %[[#TRUE]]
+; CHECK-DAG: %[[#FILL:]] = OpConstantComposite %[[#]] %[[#CONST4]] %[[#CONST0]] %[[#CONST1]] %[[#CONST0]]
+; CHECK-DAG: %[[#MASK2:]] = OpConstantComposite %[[#]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]] %[[#TRUE]]
+
+; CHECK: %[[#VECGATHER:]] = OpLoad %[[#TYPEVECPTR]] 
+; CHECK: %[[#VECSCATTER:]] = OpLoad %[[#TYPEVECPTR]] 
+; CHECK: %[[#GATHER:]] = OpMaskedGatherINTEL %[[#TYPEVECINT]] %[[#VECGATHER]] 4 %[[#MASK1]] %[[#FILL]]
+; CHECK: OpMaskedScatterINTEL %[[#GATHER]] %[[#VECSCATTER]] 4 %[[#MASK2]]
+
+; Function Attrs: nounwind readnone
+define spir_kernel void @foo() {
+entry:
+  %arg0 = alloca <4 x ptr addrspace(4)>
+  %arg1 = alloca <4 x ptr addrspace(4)>
+  %0 = load <4 x ptr addrspace(4)>, ptr %arg0
+  %1 = load <4 x ptr addrspace(4)>, ptr %arg1
+  %res = call <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)> %0, i32 4, <4 x i1> <i1 true, i1 false, i1 true, i1 true>, <4 x i32> <i32 4, i32 0, i32 1, i32 0>)
+  call void @llvm.masked.scatter.v4i32.v4p4(<4 x i32> %res, <4 x ptr addrspace(4)> %1, i32 4, <4 x i1> splat (i1 true))
+  ret void
+}
+
+declare <4 x i32> @llvm.masked.gather.v4i32.v4p4(<4 x ptr addrspace(4)>, i32, <4 x i1>, <4 x i32>)
+
+declare void @llvm.masked.scatter.v4i32.v4p4(<4 x i32>, <4 x ptr addrspace(4)>, i32, <4 x i1>)
+
+!llvm.module.flags = !{!0}
+!opencl.spir.version = !{!1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 1, i32 2}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/131566


More information about the llvm-commits mailing list