[llvm] [SPIRV] Add type inference of function parameters by call instances (PR #85077)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 13 14:38:07 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/85077

>From 5b064779c3b8ac60faf7791ecdec1d54baa1a3b1 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 13 Mar 2024 06:44:20 -0700
Subject: [PATCH 1/4] add type inference of function parameters by call
 instances

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   |  2 +-
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 62 +++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |  5 ++
 .../SPIRV/pointers/type-deduce-args-rev.ll    | 28 +++++++++
 .../SPIRV/pointers/type-deduce-args.ll        | 28 +++++++++
 5 files changed, 124 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index f1fbe2ba1bc416..77319f58ff4d97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -209,7 +209,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   // spv_assign_ptr_type intrinsic or otherwise use default pointer element
   // type.
   Argument *Arg = F.getArg(ArgIdx);
-  if (Arg->hasByValAttr() || Arg->hasByRefAttr()) {
+  if (HasPointeeTypeAttr(Arg)) {
     Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
                                              : Arg->getParamByRefType();
     SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index c5b901235402c1..e9099fac1c1a3c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -91,6 +91,7 @@ class SPIRVEmitIntrinsics
                                         IRBuilder<> &B);
   void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
   void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
+  void processParamTypes(Function *F, IRBuilder<> &B);
 
 public:
   static char ID;
@@ -794,6 +795,62 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
   }
 }
 
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+  DenseMap<unsigned, Argument *> Args;
+  unsigned i = 0;
+  for (Argument &Arg : F->args()) {
+    if (isUntypedPointerTy(Arg.getType()) &&
+        DeducedElTys.find(&Arg) == DeducedElTys.end() &&
+        !HasPointeeTypeAttr(&Arg))
+      Args[i++] = &Arg;
+  }
+  if (i == 0)
+    return;
+
+  // Args contains opaque pointers without element type definition
+  B.SetInsertPointPastAllocas(F);
+  std::unordered_set<Value *> Visited;
+  for (User *U : F->users()) {
+    CallInst *CI = dyn_cast<CallInst>(U);
+    if (!CI)
+      continue;
+    for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
+         OpIdx++) {
+      auto It = Args.find(OpIdx);
+      Argument *Arg = It == Args.end() ? nullptr : It->second;
+      if (!Arg)
+        continue;
+      Value *OpArg = CI->getArgOperand(OpIdx);
+      if (!isPointerTy(OpArg->getType()))
+        continue;
+      // maybe we already know the operand's element type
+      auto DeducedIt = DeducedElTys.find(OpArg);
+      Type *ElemTy = DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
+      if (!ElemTy) {
+        for (User *OpU : OpArg->users()) {
+          if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
+            Visited.clear();
+            ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
+            if (ElemTy)
+              break;
+          }
+        }
+      }
+      if (ElemTy) {
+        unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
+        CallInst *AssignPtrTyCI = buildIntrWithMD(
+            Intrinsic::spv_assign_ptr_type, {Arg->getType()},
+            Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
+        DeducedElTys[AssignPtrTyCI] = ElemTy;
+        DeducedElTys[Arg] = ElemTy;
+        Args.erase(It);
+      }
+    }
+    if (Args.size() == 0)
+      break;
+  }
+}
+
 bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   if (Func.isDeclaration())
     return false;
@@ -839,6 +896,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
       continue;
     processInstrAfterVisit(I, B);
   }
+
+  // check if function parameter types are set
+  if (!F->isIntrinsic())
+    processParamTypes(F, B);
+
   return true;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index d5ed501def9986..eb87349f0941c5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -126,5 +126,10 @@ inline unsigned getPointerAddressSpace(const Type *T) {
              : cast<TypedPointerType>(SubT)->getAddressSpace();
 }
 
+// Return true if the Argument is decorated with a pointee type
+inline bool HasPointeeTypeAttr(Argument *Arg) {
+  return Arg->hasByValAttr() || Arg->hasByRefAttr();
+}
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
new file mode 100644
index 00000000000000..3f8edfef78b03c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
+; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
+; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
+
+define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
+entry:
+  %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
+  call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+  ret void
+}
+
+define void @foo(ptr addrspace(1) %known_type_ptr) {
+entry:
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
new file mode 100644
index 00000000000000..be8582f9226d5c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
+; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
+
+define void @foo(ptr addrspace(1) %known_type_ptr) {
+entry:
+  ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
+entry:
+  %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
+  call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+  ret void
+}

>From 80935f0952cda868b6dc76ae018442736616064e Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 13 Mar 2024 07:26:17 -0700
Subject: [PATCH 2/4] harden the test case

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  8 +-
 .../SPIRV/pointers/type-deduce-args-rev.ll    |  2 +-
 .../SPIRV/pointers/type-deduce-args.ll        | 85 +++++++++++++++++--
 3 files changed, 83 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e9099fac1c1a3c..458af9229ed7b1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -802,9 +802,10 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
     if (isUntypedPointerTy(Arg.getType()) &&
         DeducedElTys.find(&Arg) == DeducedElTys.end() &&
         !HasPointeeTypeAttr(&Arg))
-      Args[i++] = &Arg;
+      Args[i] = &Arg;
+    i++;
   }
-  if (i == 0)
+  if (Args.size() == 0)
     return;
 
   // Args contains opaque pointers without element type definition
@@ -825,7 +826,8 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
         continue;
       // maybe we already know the operand's element type
       auto DeducedIt = DeducedElTys.find(OpArg);
-      Type *ElemTy = DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
+      Type *ElemTy =
+          DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
       if (!ElemTy) {
         for (User *OpU : OpArg->users()) {
           if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
index 3f8edfef78b03c..ae7fb99907b131 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
@@ -18,7 +18,7 @@
 define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
 entry:
   %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
-  call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+  call void @foo(ptr addrspace(1) %unknown_type_ptr)
   ret void
 }
 
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
index be8582f9226d5c..ee411f26466027 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
@@ -1,28 +1,97 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
-; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "unknown_type_ptr"
 ; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
-; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[BarArg:.*]] "known_type_ptr"
 ; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-SPIRV-DAG: OpName %[[UntypedArg:.*]] "arg"
+; CHECK-SPIRV-DAG: OpName %[[FunUntypedArg:.*]] "foo_untyped_arg"
+; CHECK-SPIRV-DAG: OpName %[[UnusedArg1:.*]] "unused_arg1"
+; CHECK-SPIRV-DAG: OpName %[[Foo2Arg:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Foo2:.*]] "foo2"
+; CHECK-SPIRV-DAG: OpName %[[Bar2Arg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar2:.*]] "bar2"
+; CHECK-SPIRV-DAG: OpName %[[Foo5Arg1:.*]] "unknown_type_ptr1"
+; CHECK-SPIRV-DAG: OpName %[[Foo5Arg2:.*]] "unknown_type_ptr2"
+; CHECK-SPIRV-DAG: OpName %[[Foo5:.*]] "foo5"
+; CHECK-SPIRV-DAG: OpName %[[Bar5Arg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar5:.*]] "bar5"
+; CHECK-SPIRV-DAG: %[[Char:.*]] = OpTypeInt 8 0
 ; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Half:.*]] = OpTypeFloat 16
 ; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[HalfConst:.*]] = OpConstant %[[Half]] 15360
+; CHECK-SPIRV-DAG: %[[CharPtr:.*]] = OpTypePointer CrossWorkgroup %[[Char]]
 ; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
 ; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
+; CHECK-SPIRV-DAG: %[[Fun2:.*]] = OpTypeFunction %[[Void]] %[[Half]] %[[LongPtr]]
+; CHECK-SPIRV-DAG: %[[Fun5:.*]] = OpTypeFunction %[[Void]] %[[Half]] %[[LongPtr]] %[[Half]] %[[LongPtr]] %[[Half]]
+; CHECK-SPIRV-DAG: %[[FunUntyped:.*]] = OpTypeFunction %[[Void]] %[[CharPtr]]
+
 ; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
 ; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
 ; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
-; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
-; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
+; CHECK-SPIRV: %[[BarArg]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[BarArg]]
+
+; CHECK-SPIRV: %[[FunUntypedArg]] = OpFunction %[[Void]] None %[[FunUntyped]]
+; CHECK-SPIRV: %[[UntypedArg]] = OpFunctionParameter %[[CharPtr]]
+
+; CHECK-SPIRV: %[[Foo2]] = OpFunction %[[Void]] None %[[Fun2]]
+; CHECK-SPIRV: %[[UnusedArg1]] = OpFunctionParameter %[[Half]]
+; CHECK-SPIRV: %[[Foo2Arg]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: %[[Bar2]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[Bar2Arg]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo2]] %[[HalfConst]] %[[Bar2Arg]]
+
+; CHECK-SPIRV: %[[Foo5]] = OpFunction %[[Void]] None %[[Fun5]]
+; CHECK-SPIRV: OpFunctionParameter %[[Half]]
+; CHECK-SPIRV: %[[Foo5Arg1]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionParameter %[[Half]]
+; CHECK-SPIRV: %[[Foo5Arg2]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionParameter %[[Half]]
+; CHECK-SPIRV: %[[Bar5]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[Bar5Arg]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo5]] %[[HalfConst]] %[[Bar5Arg]] %[[HalfConst]] %[[Bar5Arg]] %[[HalfConst]]
+
+define void @foo(ptr addrspace(1) %unknown_type_ptr) {
+entry:
+  ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %known_type_ptr) {
+entry:
+  %elem = getelementptr inbounds i32, ptr addrspace(1) %known_type_ptr, i64 0
+  call void @foo(ptr addrspace(1) %known_type_ptr)
+  ret void
+}
+
+define void @foo_untyped_arg(ptr addrspace(1) %arg) {
+entry:
+  ret void
+}
+
+define void @foo2(half %unused_arg1, ptr addrspace(1) %unknown_type_ptr) {
+entry:
+  ret void
+}
+
+define spir_kernel void @bar2(ptr addrspace(1) %known_type_ptr) {
+entry:
+  %elem = getelementptr inbounds i32, ptr addrspace(1) %known_type_ptr, i64 0
+  call void @foo2(half 1.0, ptr addrspace(1) %known_type_ptr)
+  ret void
+}
 
-define void @foo(ptr addrspace(1) %known_type_ptr) {
+define void @foo5(half %unused_arg1, ptr addrspace(1) %unknown_type_ptr1, half %unused_arg2, ptr addrspace(1) %unknown_type_ptr2, half %unused_arg3) {
 entry:
   ret void
 }
 
-define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
+define spir_kernel void @bar5(ptr addrspace(1) %known_type_ptr) {
 entry:
-  %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
-  call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+  %elem = getelementptr inbounds i32, ptr addrspace(1) %known_type_ptr, i64 0
+  call void @foo5(half 1.0, ptr addrspace(1) %known_type_ptr, half 1.0, ptr addrspace(1) %known_type_ptr, half 1.0)
   ret void
 }

>From 4928c9a8d9caaa543e650c87fb607624089feb39 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 13 Mar 2024 09:48:50 -0700
Subject: [PATCH 3/4] fix wrong insertion position of bitcasts

---
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp           | 6 +-----
 llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll  | 2 +-
 llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll | 2 +-
 3 files changed, 3 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 61748070fc0fb2..e6e9131d8dc28f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -104,11 +104,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
   SPIRV::StorageClass::StorageClass SC =
       static_cast<SPIRV::StorageClass::StorageClass>(
           OpType->getOperand(1).getImm());
-  MachineInstr *PrevI = I.getPrevNode();
-  MachineBasicBlock &MBB = *I.getParent();
-  MachineBasicBlock::iterator InsPt =
-      PrevI ? PrevI->getIterator() : MBB.begin();
-  MachineIRBuilder MIB(MBB, InsPt);
+  MachineIRBuilder MIB(I);
   SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC);
   if (!GR.isBitcastCompatible(NewPtrType, OpType))
     report_fatal_error(
diff --git a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll
index a30d0792e39988..18752fdf843d20 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll
@@ -9,7 +9,7 @@
 ; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]]
 ; CHECK: %[[#PTRTOSTRUCT:]] = OpFunctionParameter %[[#TYSTRUCTPTR]]
 ; CHECK: %[[#PTRTOLONG:]] = OpBitcast %[[#TYLONGPTR]] %[[#PTRTOSTRUCT]]
-; CHECK: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]]
+; CHECK-NEXT: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]]
 
 %struct.S = type { i32 }
 %struct.__wrapper_class = type { [7 x %struct.S] }
diff --git a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
index 4701f02ea33af3..202bcfbf2599a9 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
@@ -13,7 +13,7 @@
 ; CHECK: %[[#OBJ:]] = OpFunctionParameter %[[#TYSTRUCT]]
 ; CHECK: %[[#ARGPTR2:]] = OpFunctionParameter %[[#TYLONGPTR]]
 ; CHECK: %[[#PTRTOSTRUCT:]] = OpBitcast %[[#TYSTRUCTPTR]] %[[#ARGPTR2]]
-; CHECK: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]]
+; CHECK-NEXT: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]]
 
 %struct.S = type { i32 }
 %struct.__wrapper_class = type { [7 x %struct.S] }

>From c40cfffe3e42978be885e851fd9193fff3fdbed9 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 13 Mar 2024 14:37:55 -0700
Subject: [PATCH 4/4] add ByVal to supported OpDecorate options

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 77319f58ff4d97..6f23f055b8c2ab 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -319,6 +319,12 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
         buildOpDecorate(VRegs[i][0], MIRBuilder,
                         SPIRV::Decoration::FuncParamAttr, {Attr});
       }
+      if (Arg.hasAttribute(Attribute::ByVal)) {
+        auto Attr =
+            static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
+        buildOpDecorate(VRegs[i][0], MIRBuilder,
+                        SPIRV::Decoration::FuncParamAttr, {Attr});
+      }
 
       if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
         std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =



More information about the llvm-commits mailing list