[llvm] [SPIR-V] Add pass to remove spv_ptrcast intrinsics (PR #128896)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 26 07:55:19 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Nathan Gauër (Keenuts)

<details>
<summary>Changes</summary>

OpenCL is allowed to cast pointers, meaning they can resolve some type mismatches this way. In logical SPIR-V, those are restricted. This new pass legalizes such pointer cast when targeting logical SPIR-V.

For now, this pass supports 3 cases we witnessed:
 - loading a vec3 from a vec4*.
 - loading a scalar from a vec*.
 - loading the 1st element of an array.

---

Patch is 35.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128896.diff


11 Files Affected:

- (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+23-71) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+35) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+3) 
- (added) llvm/lib/Target/SPIRV/SPIRVLegalizePointerLoad.cpp (+265) 
- (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+11) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+10) 
- (added) llvm/test/CodeGen/SPIRV/pointers/array-skips-gep.ll (+38) 
- (added) llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll (+128) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index efdd8c8d24fbd..d8322f77e0800 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -27,6 +27,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVInstrInfo.cpp
   SPIRVInstructionSelector.cpp
   SPIRVStripConvergentIntrinsics.cpp
+  SPIRVLegalizePointerLoad.cpp
   SPIRVMergeRegionExitTargets.cpp
   SPIRVISelLowering.cpp
   SPIRVLegalizerInfo.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 6d00a046ff7ca..8ccb3bfc25a1a 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -23,6 +23,7 @@ ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
 FunctionPass *createSPIRVStructurizerPass();
 FunctionPass *createSPIRVMergeRegionExitTargetsPass();
 FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
+FunctionPass *createSPIRVLegalizePointerLoadPass(SPIRVTargetMachine *TM);
 FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerCombiner();
 FunctionPass *createSPIRVPreLegalizerPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5dfba8427258f..b73e85abfad26 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -57,12 +57,6 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
 
 namespace {
 
-inline MetadataAsValue *buildMD(Value *Arg) {
-  LLVMContext &Ctx = Arg->getContext();
-  return MetadataAsValue::get(
-      Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(Arg)));
-}
-
 class SPIRVEmitIntrinsics
     : public ModulePass,
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
@@ -142,23 +136,10 @@ class SPIRVEmitIntrinsics
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
 
-  CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types,
-                            Value *Arg, Value *Arg2, ArrayRef<Constant *> Imms,
-                            IRBuilder<> &B) {
-    SmallVector<Value *, 4> Args;
-    Args.push_back(Arg2);
-    Args.push_back(buildMD(Arg));
-    for (auto *Imm : Imms)
-      Args.push_back(Imm);
-    return B.CreateIntrinsic(IntrID, {Types}, Args);
-  }
-
   Type *reconstructType(Value *Op, bool UnknownElemTypeI8,
                         bool IsPostprocessing);
 
   void buildAssignType(IRBuilder<> &B, Type *ElemTy, Value *Arg);
-  void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
-  void updateAssignType(CallInst *AssignCI, Value *Arg, Value *OfType);
 
   void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
   void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
@@ -445,37 +426,6 @@ void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
   GR->addAssignPtrTypeInstr(Arg, AssignCI);
 }
 
-void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
-                                         Value *Arg) {
-  ElemTy = normalizeType(ElemTy);
-  Value *OfType = PoisonValue::get(ElemTy);
-  CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg);
-  if (AssignPtrTyCI == nullptr ||
-      AssignPtrTyCI->getParent()->getParent() != CurrF) {
-    AssignPtrTyCI = buildIntrWithMD(
-        Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
-        {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
-    GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
-    GR->addDeducedElementType(Arg, ElemTy);
-    GR->addAssignPtrTypeInstr(Arg, AssignPtrTyCI);
-  } else {
-    updateAssignType(AssignPtrTyCI, Arg, OfType);
-  }
-}
-
-void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg,
-                                           Value *OfType) {
-  AssignCI->setArgOperand(1, buildMD(OfType));
-  if (cast<IntrinsicInst>(AssignCI)->getIntrinsicID() !=
-      Intrinsic::spv_assign_ptr_type)
-    return;
-
-  // update association with the pointee type
-  Type *ElemTy = normalizeType(OfType->getType());
-  GR->addDeducedElementType(AssignCI, ElemTy);
-  GR->addDeducedElementType(Arg, ElemTy);
-}
-
 CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Function *F, Value *Op,
                                                Type *ElemTy) {
   IRBuilder<> B(Op->getContext());
@@ -495,7 +445,7 @@ CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Function *F, Value *Op,
                                   B.getInt32(getPointerAddressSpace(OpTy))};
   CallInst *PtrCasted =
       B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
-  buildAssignPtr(B, ElemTy, PtrCasted);
+  GR->buildAssignPtr(B, ElemTy, PtrCasted);
   return PtrCasted;
 }
 
@@ -1026,7 +976,8 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
         continue;
       if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
         if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
-          updateAssignType(AssignCI, CI, getNormalizedPoisonValue(OpElemTy));
+          GR->updateAssignType(AssignCI, CI,
+                               getNormalizedPoisonValue(OpElemTy));
           propagateElemType(CI, PrevElemTy, VisitedSubst);
         }
       }
@@ -1212,7 +1163,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
                             {B.getInt32(getPointerAddressSpace(OpTy))}, B);
         GR->addAssignPtrTypeInstr(Op, CI);
       } else {
-        updateAssignType(AssignCI, Op, OpTyVal);
+        GR->updateAssignType(AssignCI, Op, OpTyVal);
         DenseSet<std::pair<Value *, Value *>> VisitedSubst{
             std::make_pair(I, Op)};
         propagateElemTypeRec(Op, KnownElemTy, PrevElemTy, VisitedSubst);
@@ -1522,7 +1473,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeTargetExt(
 
   // Our previous guess about the type seems to be wrong, let's update
   // inferred type according to a new, more precise type information.
-  updateAssignType(AssignCI, V, getNormalizedPoisonValue(AssignedType));
+  GR->updateAssignType(AssignCI, V, getNormalizedPoisonValue(AssignedType));
 }
 
 void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
@@ -1579,7 +1530,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
     if (FirstPtrCastOrAssignPtrType) {
       // If this would be the first spv_ptrcast, do not emit spv_ptrcast and
       // emit spv_assign_ptr_type instead.
-      buildAssignPtr(B, ExpectedElementType, Pointer);
+      GR->buildAssignPtr(B, ExpectedElementType, Pointer);
       return;
     } else if (isTodoType(Pointer)) {
       eraseTodoType(Pointer);
@@ -1591,10 +1542,10 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
           assert(PrevElemTy);
           DenseSet<std::pair<Value *, Value *>> VisitedSubst{
               std::make_pair(I, Pointer)};
-          updateAssignType(AssignCI, Pointer, ExpectedElementVal);
+          GR->updateAssignType(AssignCI, Pointer, ExpectedElementVal);
           propagateElemType(Pointer, PrevElemTy, VisitedSubst);
         } else {
-          buildAssignPtr(B, ExpectedElementType, Pointer);
+          GR->buildAssignPtr(B, ExpectedElementType, Pointer);
         }
         return;
       }
@@ -1607,7 +1558,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
   auto *PtrCastI = B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
   I->setOperand(OperandToReplace, PtrCastI);
   // We need to set up a pointee type for the newly created spv_ptrcast.
-  buildAssignPtr(B, ExpectedElementType, PtrCastI);
+  GR->buildAssignPtr(B, ExpectedElementType, PtrCastI);
 }
 
 void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
@@ -1923,7 +1874,7 @@ bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
 
   setInsertPointAfterDef(B, I);
   if (Type *ElemTy = deduceElementType(I, UnknownElemTypeI8)) {
-    buildAssignPtr(B, ElemTy, I);
+    GR->buildAssignPtr(B, ElemTy, I);
     return false;
   }
   return true;
@@ -2019,10 +1970,11 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
         Type *OpTy = Op->getType();
         Type *OpTyElem = getPointeeType(OpTy);
         if (OpTyElem) {
-          buildAssignPtr(B, OpTyElem, Op);
+          GR->buildAssignPtr(B, OpTyElem, Op);
         } else if (isPointerTy(OpTy)) {
           Type *ElemTy = GR->findDeducedElementType(Op);
-          buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op, true), Op);
+          GR->buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op, true),
+                             Op);
         } else {
           CallInst *AssignCI =
               buildIntrWithMD(Intrinsic::spv_assign_type, {OpTy},
@@ -2083,14 +2035,14 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
       if (!IsConstComposite && isPointerTy(OpTy) &&
           (OpElemTy = GR->findDeducedElementType(Op)) != nullptr &&
           OpElemTy != IntegerType::getInt8Ty(I->getContext())) {
-        buildAssignPtr(B, IntegerType::getInt8Ty(I->getContext()), NewOp);
+        GR->buildAssignPtr(B, IntegerType::getInt8Ty(I->getContext()), NewOp);
         SmallVector<Type *, 2> Types = {OpTy, OpTy};
         SmallVector<Value *, 2> Args = {
             NewOp, buildMD(getNormalizedPoisonValue(OpElemTy)),
             B.getInt32(getPointerAddressSpace(OpTy))};
         CallInst *PtrCasted =
             B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
-        buildAssignPtr(B, OpElemTy, PtrCasted);
+        GR->buildAssignPtr(B, OpElemTy, PtrCasted);
         NewOp = PtrCasted;
       }
       I->setOperand(OpNo, NewOp);
@@ -2172,7 +2124,7 @@ void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
       continue;
     if (hasPointeeTypeAttr(Arg) &&
         (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
-      buildAssignPtr(B, ElemTy, Arg);
+      GR->buildAssignPtr(B, ElemTy, Arg);
       continue;
     }
     // search in function's call sites
@@ -2188,7 +2140,7 @@ void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
         break;
     }
     if (ElemTy) {
-      buildAssignPtr(B, ElemTy, Arg);
+      GR->buildAssignPtr(B, ElemTy, Arg);
       continue;
     }
     if (HaveFunPtrs) {
@@ -2200,7 +2152,7 @@ void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
           SmallVector<std::pair<Value *, unsigned>> Ops;
           deduceOperandElementTypeFunctionPointer(CI, Ops, ElemTy, false);
           if (ElemTy) {
-            buildAssignPtr(B, ElemTy, Arg);
+            GR->buildAssignPtr(B, ElemTy, Arg);
             break;
           }
         }
@@ -2219,11 +2171,11 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
     if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
       if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(Arg)) {
         DenseSet<std::pair<Value *, Value *>> VisitedSubst;
-        updateAssignType(AssignCI, Arg, getNormalizedPoisonValue(ElemTy));
+        GR->updateAssignType(AssignCI, Arg, getNormalizedPoisonValue(ElemTy));
         propagateElemType(Arg, IntegerType::getInt8Ty(F->getContext()),
                           VisitedSubst);
       } else {
-        buildAssignPtr(B, ElemTy, Arg);
+        GR->buildAssignPtr(B, ElemTy, Arg);
       }
     }
   }
@@ -2273,7 +2225,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
           continue;
         if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type ||
             II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
-          updateAssignType(II, &F, getNormalizedPoisonValue(FPElemTy));
+          GR->updateAssignType(II, &F, getNormalizedPoisonValue(FPElemTy));
           break;
         }
       }
@@ -2324,7 +2276,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
           if (!hasPointeeTypeAttr(Arg)) {
             B.SetInsertPointPastAllocas(Arg->getParent());
             B.SetCurrentDebugLocation(DebugLoc());
-            buildAssignPtr(B, ElemTy, Arg);
+            GR->buildAssignPtr(B, ElemTy, Arg);
           }
         } else if (isa<Instruction>(Param)) {
           GR->addDeducedElementType(Param, normalizeType(ElemTy));
@@ -2334,7 +2286,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
                                ->getParent()
                                ->getEntryBlock()
                                .getFirstNonPHIOrDbgOrAlloca());
-          buildAssignPtr(B, ElemTy, Param);
+          GR->buildAssignPtr(B, ElemTy, Param);
         }
         CallInst *Ref = dyn_cast<CallInst>(Param);
         if (!Ref)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0ed414ebc8bbe..7cca1e3d4780c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -22,6 +22,9 @@
 #include "SPIRVUtils.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/Casting.h"
 #include <cassert>
@@ -1739,3 +1742,35 @@ LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
   }
   return LLT::scalar(64);
 }
+
+void SPIRVGlobalRegistry::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
+                                         Value *Arg) {
+  Value *OfType = PoisonValue::get(ElemTy);
+  CallInst *AssignPtrTyCI = findAssignPtrTypeInstr(Arg);
+  Function *CurrF =
+      B.GetInsertBlock() ? B.GetInsertBlock()->getParent() : nullptr;
+  if (AssignPtrTyCI == nullptr ||
+      AssignPtrTyCI->getParent()->getParent() != CurrF) {
+    AssignPtrTyCI = buildIntrWithMD(
+        Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
+        {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
+    addDeducedElementType(AssignPtrTyCI, ElemTy);
+    addDeducedElementType(Arg, ElemTy);
+    addAssignPtrTypeInstr(Arg, AssignPtrTyCI);
+  } else {
+    updateAssignType(AssignPtrTyCI, Arg, OfType);
+  }
+}
+
+void SPIRVGlobalRegistry::updateAssignType(CallInst *AssignCI, Value *Arg,
+                                           Value *OfType) {
+  AssignCI->setArgOperand(1, buildMD(OfType));
+  if (cast<IntrinsicInst>(AssignCI)->getIntrinsicID() !=
+      Intrinsic::spv_assign_ptr_type)
+    return;
+
+  // update association with the pointee type
+  Type *ElemTy = OfType->getType();
+  addDeducedElementType(AssignCI, ElemTy);
+  addDeducedElementType(Arg, ElemTy);
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2c24ba79ea8e6..fc9dd297b1994 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -620,6 +620,9 @@ class SPIRVGlobalRegistry {
 
   const TargetRegisterClass *getRegClass(SPIRVType *SpvType) const;
   LLT getRegType(SPIRVType *SpvType) const;
+
+  void buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg);
+  void updateAssignType(CallInst *AssignCI, Value *Arg, Value *OfType);
 };
 } // end namespace llvm
 #endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerLoad.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerLoad.cpp
new file mode 100644
index 0000000000000..76a8d690372a9
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerLoad.cpp
@@ -0,0 +1,265 @@
+//===-- SPIRVLegalizePointerLoad.cpp ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The LLVM IR has multiple legal patterns we cannot lower to Logical SPIR-V.
+// This pass modifies such loads to have an IR we can directly lower to valid
+// logical SPIR-V.
+// OpenCL can avoid this because they rely on ptrcast, which is not supported
+// by logical SPIR-V.
+//
+// This pass relies on the assign_ptr_type intrinsic to deduce the type of the
+// pointed values, must replace all occurences of `ptrcast`. This is why
+// unhandled cases are reported as unreachable: we MUST cover all cases.
+//
+// 1. Loading the first element of an array
+//
+//        %array = [10 x i32]
+//        %value = load i32, ptr %array
+//
+//    LLVM can skip the GEP instruction, and only request loading the first 4
+//    bytes. In logical SPIR-V, we need an OpAccessChain to access the first
+//    element. This pass will add a getelementptr instruction before the load.
+//
+//
+// 2. Implicit downcast from load
+//
+//        %1 = getelementptr <4 x i32>, ptr %vec4, i64 0
+//        %2 = load <3 x i32>, ptr %1
+//
+//    The pointer in the GEP instruction is only used for offset computations,
+//    but it doesn't NEED to match the pointed type. OpAccessChain however
+//    requires this. Also, LLVM loads define the bitwidth of the load, not the
+//    pointer. In this example, we can guess %vec4 is a vec4 thanks to the GEP
+//    instruction basetype, but we only want to load the first 3 elements, hence
+//    do a partial load. In logical SPIR-V, this is not legal. What we must do
+//    is load the full vector (basetype), extract 3 elements, and recombine them
+//    to form a 3-element vector.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+
+using namespace llvm;
+
+namespace llvm {
+void initializeSPIRVLegalizePointerLoadPass(PassRegistry &);
+}
+
+class SPIRVLegalizePointerLoad : public FunctionPass {
+
+  // Replace all uses of a |Old| with |New| updates the global registry type
+  // mappings.
+  void replaceAllUsesWith(Value *Old, Value *New) {
+    Old->replaceAllUsesWith(New);
+    GR->updateIfExistDeducedElementType(Old, New, /* deleteOld = */ true);
+    GR->updateIfExistAssignPtrTypeInstr(Old, New, /* deleteOld = */ true);
+  }
+
+  // Builds the `spv_assign_type` assigning |Ty| to |Value| at the current
+  // builder position.
+  void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) {
+    Value *OfType = PoisonValue::get(Ty);
+    CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
+                                         {Arg->getType()}, OfType, Arg, {}, B);
+    GR->addAssignPtrTypeInstr(Arg, AssignCI);
+  }
+
+  // Loads a single scalar of type |To| from the vector pointed by |Source| of
+  // the type |From|. Returns the loaded value.
+  Value *loadScalarFromVector(IRBuilder<> &B, Value *Source,
+                              FixedVectorType *From) {
+
+    LoadInst *NewLoad = B.CreateLoad(From, Source);
+    buildAssignType(B, From, NewLoad);
+
+    SmallVector<Value *, 2> Args = {NewLoad, B.getInt64(0)};
+    SmallVector<Type *, 3> Types = {From->getElementType(), Args[0]->getType(),
+                                    Args[1]->getType()};
+    Value *Extracted =
+        B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
+    buildAssignType(B, Extracted->getType(), Extracted);
+    return Extracted;
+  }
+
+  // Loads parts of the vector of type |From| from the pointer |Source| and
+  // create a new vector of type |To|. |To| must be a vector type, and element
+  // types of |To| and |From| must match. Returns the loaded value.
+  Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *From,
+                              FixedVectorType *To, Value *Source) {
+    // We expect the codegen to avoid doing implicit bitcast from a load.
+    assert(To->getElementType() == From->getElementType());
+    assert(To->getNumElements() < From->getNumElements());
+
+    LoadInst *NewLoad = B.CreateLoad(From, Source);
+    buildAssignType(B, From, NewLoad);
+
+    auto ConstInt = ConstantInt::get(IntegerType::get(B.getContext(), 32), 0);
+    ElementCount VecElemCount = ElementCount::getFixed(To->getNumElements());
+    Value *Output = ConstantVec...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/128896


More information about the llvm-commits mailing list