[llvm] [SPIR-V] Improve type inference (PR #94626)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 6 07:54:08 PDT 2024


https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/94626

This PR continues https://github.com/llvm/llvm-project/pull/94467 and contains fixes in emission of type intrinsics, constant recording and corresponding test cases:
* type-deduce-global-dup.ll -- fix of integer constant emission on 32-bit platforms and correct type deduction for globals
* type-deduce-simple-for.ll -- fix of GEP translation (there was an issue previously that led to incorrect translation/broken logic of for-range implementation)


>From 06c7b1c762f4ca22769271adab01ff6507feccc0 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 6 Jun 2024 07:50:54 -0700
Subject: [PATCH] improve type inference

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 69 ++++++++-----------
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |  1 +
 ...insics-no-divergent-spv_assign_ptr_type.ll |  2 +-
 .../SPIRV/pointers/type-deduce-global-dup.ll  | 27 ++++++++
 .../SPIRV/pointers/type-deduce-simple-for.ll  | 65 +++++++++++++++++
 5 files changed, 124 insertions(+), 40 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/type-deduce-global-dup.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/type-deduce-simple-for.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index bbd25dc85f52b..ad7c5dc552a9e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -50,6 +50,13 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
 } // namespace llvm
 
 namespace {
+
+inline MetadataAsValue *buildMD(Value *Arg) {
+  LLVMContext &Ctx = Arg->getContext();
+  return MetadataAsValue::get(
+      Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(Arg)));
+}
+
 class SPIRVEmitIntrinsics
     : public ModulePass,
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
@@ -84,12 +91,9 @@ class SPIRVEmitIntrinsics
   CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types,
                             Value *Arg, Value *Arg2, ArrayRef<Constant *> Imms,
                             IRBuilder<> &B) {
-    ConstantAsMetadata *CM = ValueAsMetadata::getConstant(Arg);
-    MDTuple *TyMD = MDNode::get(F->getContext(), CM);
-    MetadataAsValue *VMD = MetadataAsValue::get(F->getContext(), TyMD);
     SmallVector<Value *, 4> Args;
     Args.push_back(Arg2);
-    Args.push_back(VMD);
+    Args.push_back(buildMD(Arg));
     for (auto *Imm : Imms)
       Args.push_back(Imm);
     return B.CreateIntrinsic(IntrID, {Types}, Args);
@@ -228,20 +232,23 @@ void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
 void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
                                          Value *Arg) {
   Value *OfType = PoisonValue::get(ElemTy);
-  CallInst *AssignPtrTyCI = buildIntrWithMD(
-      Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
-      {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
-  GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
-  GR->addDeducedElementType(Arg, ElemTy);
-  GR->addAssignPtrTypeInstr(Arg, AssignPtrTyCI);
+  CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg);
+  if (AssignPtrTyCI == nullptr ||
+      AssignPtrTyCI->getParent()->getParent() != F) {
+    AssignPtrTyCI = buildIntrWithMD(
+        Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
+        {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
+    GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
+    GR->addDeducedElementType(Arg, ElemTy);
+    GR->addAssignPtrTypeInstr(Arg, AssignPtrTyCI);
+  } else {
+    updateAssignType(AssignPtrTyCI, Arg, OfType);
+  }
 }
 
 void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg,
                                            Value *OfType) {
-  LLVMContext &Ctx = Arg->getContext();
-  AssignCI->setArgOperand(
-      1, MetadataAsValue::get(
-             Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OfType))));
+  AssignCI->setArgOperand(1, buildMD(OfType));
   if (cast<IntrinsicInst>(AssignCI)->getIntrinsicID() !=
       Intrinsic::spv_assign_ptr_type)
     return;
@@ -560,9 +567,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
         B.SetInsertPoint(F->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
       }
       SmallVector<Type *, 2> Types = {OpTy, OpTy};
-      MetadataAsValue *VMD = MetadataAsValue::get(
-          Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal)));
-      SmallVector<Value *, 2> Args = {Op, VMD,
+      SmallVector<Value *, 2> Args = {Op, buildMD(OpTyVal),
                                       B.getInt32(getPointerAddressSpace(OpTy))};
       CallInst *PtrCastI =
           B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
@@ -689,8 +694,7 @@ Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
   Constant *TyC = UndefValue::get(IA->getFunctionType());
   MDString *ConstraintString = MDString::get(Ctx, IA->getConstraintString());
   SmallVector<Value *> Args = {
-      MetadataAsValue::get(Ctx,
-                           MDNode::get(Ctx, ValueAsMetadata::getConstant(TyC))),
+      buildMD(TyC),
       MetadataAsValue::get(Ctx, MDNode::get(Ctx, ConstraintString))};
   for (unsigned OpIdx = 0; OpIdx < Call.arg_size(); OpIdx++)
     Args.push_back(Call.getArgOperand(OpIdx));
@@ -821,13 +825,8 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
   if (PointerElemTy == ExpectedElementType)
     return;
 
-  setInsertPointSkippingPhis(B, I);
-  Constant *ExpectedElementTypeConst =
-      Constant::getNullValue(ExpectedElementType);
-  ConstantAsMetadata *CM =
-      ValueAsMetadata::getConstant(ExpectedElementTypeConst);
-  MDTuple *TyMD = MDNode::get(F->getContext(), CM);
-  MetadataAsValue *VMD = MetadataAsValue::get(F->getContext(), TyMD);
+  setInsertPointSkippingPhis(B, I); // PoisonValue::get(ElemTy);
+  MetadataAsValue *VMD = buildMD(PoisonValue::get(ExpectedElementType));
   unsigned AddressSpace = getPointerAddressSpace(Pointer->getType());
   bool FirstPtrCastOrAssignPtrType = true;
 
@@ -873,12 +872,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
   // spv_assign_ptr_type instead.
   if (FirstPtrCastOrAssignPtrType &&
       (isa<Instruction>(Pointer) || isa<Argument>(Pointer))) {
-    CallInst *CI = buildIntrWithMD(
-        Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
-        ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
-    GR->addDeducedElementType(CI, ExpectedElementType);
-    GR->addDeducedElementType(Pointer, ExpectedElementType);
-    GR->addAssignPtrTypeInstr(Pointer, CI);
+    buildAssignPtr(B, ExpectedElementType, Pointer);
     return;
   }
 
@@ -1167,12 +1161,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
 
   setInsertPointAfterDef(B, I);
   Type *ElemTy = deduceElementType(I);
-  Constant *EltTyConst = UndefValue::get(ElemTy);
-  unsigned AddressSpace = getPointerAddressSpace(I->getType());
-  CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
-                                 EltTyConst, I, {B.getInt32(AddressSpace)}, B);
-  GR->addDeducedElementType(CI, ElemTy);
-  GR->addAssignPtrTypeInstr(I, CI);
+  buildAssignPtr(B, ElemTy, I);
 }
 
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -1407,12 +1396,14 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
       continue;
 
     insertAssignPtrTypeIntrs(I, B);
-    deduceOperandElementType(I);
     insertAssignTypeIntrs(I, B);
     insertPtrCastOrAssignTypeInstr(I, B);
     insertSpirvDecorations(I, B);
   }
 
+  for (auto &I : instructions(Func))
+    deduceOperandElementType(&I);
+
   for (auto *I : Worklist) {
     TrackConstants = true;
     if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index a0a253c23b1e8..2ed3db312ee98 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -453,6 +453,7 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           Ty = TargetExtIt == TargetExtConstTypes.end()
                    ? MI.getOperand(1).getCImm()->getType()
                    : TargetExtIt->second;
+          GR->add(MI.getOperand(1).getCImm(), &MF, Reg);
         } else if (MIOp == TargetOpcode::G_FCONSTANT) {
           Ty = MI.getOperand(1).getFPImm()->getType();
         } else {
diff --git a/llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-divergent-spv_assign_ptr_type.ll b/llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-divergent-spv_assign_ptr_type.ll
index f728eda079860..e4e0fb8a5e1a9 100644
--- a/llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-divergent-spv_assign_ptr_type.ll
+++ b/llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-divergent-spv_assign_ptr_type.ll
@@ -4,7 +4,7 @@
 
 define spir_kernel void @test_pointer_cast(ptr addrspace(1) %src) {
 ; CHECK-NOT: call void @llvm.spv.assign.ptr.type.p1(ptr addrspace(1) %src, metadata i8 undef, i32 1)
-; CHECK: call void @llvm.spv.assign.ptr.type.p1(ptr addrspace(1) %src, metadata i32 0, i32 1)
+; CHECK: call void @llvm.spv.assign.ptr.type.p1(ptr addrspace(1) %src, metadata i32 poison, i32 1)
   %b = bitcast ptr addrspace(1) %src to ptr addrspace(1)
   %g = getelementptr inbounds i32, ptr addrspace(1) %b, i64 52
   ret void
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-global-dup.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-global-dup.ll
new file mode 100644
index 0000000000000..e6130cedf6800
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-global-dup.ll
@@ -0,0 +1,27 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[#ArrayTy:]] = OpTypeArray %[[#Char]] %[[#]]
+; CHECK-SPIRV-DAG: %[[#CharPtrTy:]] = OpTypePointer CrossWorkgroup %[[#Char]]
+; CHECK-SPIRV-DAG: %[[#Const1:]] = OpConstantComposite %[[#]] %[[#]] %[[#]]
+; CHECK-SPIRV-DAG: %[[#CharPtrPtrTy:]] = OpTypePointer CrossWorkgroup %[[#CharPtrTy]]
+; CHECK-SPIRV-DAG: %[[#PtrArrayTy:]] = OpTypePointer CrossWorkgroup %[[#ArrayTy]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArrayTy]] CrossWorkgroup %[[#Const1]]
+; CHECK-SPIRV-DAG: OpVariable %[[#CharPtrPtrTy]] CrossWorkgroup %[[#]]
+
+ at a_var = addrspace(1) global [2 x i8] c"\01\01"
+ at p_var = addrspace(1) global ptr addrspace(1) getelementptr inbounds ([2 x i8], ptr addrspace(1) @a_var, i32 0, i64 1)
+
+define spir_func zeroext i8 @foo() {
+entry:
+  ret i8 1
+}
+
+define spir_func zeroext i8 @bar() {
+entry:
+  ret i8 1
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-simple-for.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-simple-for.ll
new file mode 100644
index 0000000000000..ab7f923797f30
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-simple-for.ll
@@ -0,0 +1,65 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - --translator-compatibility-mode | FileCheck %s --check-prefixes=CHECK-SPIRV,CHECK-COMPAT,CHECK-COMPAT64
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV,CHECK-DEFVER
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - --translator-compatibility-mode | FileCheck %s --check-prefixes=CHECK-SPIRV,CHECK-COMPAT,CHECK-COMPAT32
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[#IntTy:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#IntConst1:]] = OpConstant %[[#IntTy]] 1
+; CHECK-SPIRV-DAG: %[[#ArrTy:]] = OpTypeArray %[[#IntTy]] %[[#IntConst1]]
+; CHECK-SPIRV-DAG: %[[#BoolTy:]] = OpTypeBool
+; CHECK-SPIRV-DAG: %[[#LongTy:]] = OpTypeInt 64 0
+; CHECK-SPIRV-DAG: %[[#LongConst4:]] = OpConstant %[[#LongTy]] 4
+; CHECK-SPIRV-DAG: %[[#IntConst123:]] = OpConstant %[[#IntTy]] 123
+; CHECK-SPIRV-DAG: %[[#CharPtrTy:]] = OpTypePointer Function %[[#CharTy]]
+; CHECK-SPIRV-DAG: %[[#ArrPtrTy:]] = OpTypePointer Function %[[#ArrTy]]
+; CHECK-SPIRV-DAG: %[[#IntPtrTy:]] = OpTypePointer Function %[[#IntTy]]
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#LblEntry:]] = OpLabel
+; CHECK-SPIRV: %[[#Value:]] = OpVariable %[[#ArrPtrTy]] Function
+; CHECK-SPIRV: %[[#ValueAsCharPtr:]] = OpBitcast %[[#CharPtrTy]] %[[#Value]]
+; CHECK-SPIRV: %[[#Eof:]] = OpInBoundsPtrAccessChain %[[#CharPtrTy]] %[[#ValueAsCharPtr]] %[[#LongConst4]]
+; CHECK-SPIRV: %[[#EofAsArray:]] = OpBitcast %[[#ArrPtrTy]] %[[#Eof]]
+; CHECK-SPIRV: OpBranch %[[#LblCond:]]
+; CHECK-SPIRV: %[[#LblCond]] = OpLabel
+; CHECK-SPIRV: %[[#Iter:]] = OpPhi %[[#ArrPtrTy]] %[[#Value]] %[[#LblEntry]] %[[#CurrValue:]] %[[#LblBody:]]
+; CHECK-COMPAT64: %[[#IterInt:]] = OpConvertPtrToU %[[#LongTy]] %[[#Iter]]
+; CHECK-COMPAT64: %[[#EofInt:]] = OpConvertPtrToU %[[#LongTy]] %[[#EofAsArray]]
+; CHECK-COMPAT32: %[[#IterInt:]] = OpConvertPtrToU %[[#IntTy]] %[[#Iter]]
+; CHECK-COMPAT32: %[[#EofInt:]] = OpConvertPtrToU %[[#IntTy]] %[[#EofAsArray]]
+; CHECK-COMPAT: %[[#Is:]] = OpIEqual %[[#BoolTy]] %[[#IterInt]] %[[#EofInt]]
+; CHECK-DEFVER: %[[#Is:]] = OpPtrEqual %[[#BoolTy]] %[[#Iter]] %[[#EofAsArray]]
+; CHECK-SPIRV: OpBranchConditional %[[#Is]] %[[#LblExit:]] %[[#LblBody]]
+; CHECK-SPIRV: %[[#LblBody]] = OpLabel
+; CHECK-SPIRV: %[[#IterAsIntPtr:]] = OpBitcast %[[#IntPtrTy]] %[[#Iter]]
+; CHECK-SPIRV: OpStore %[[#IterAsIntPtr]] %[[#IntConst123]] Aligned 4
+; CHECK-SPIRV: %[[#IterAsCharPtr:]] = OpBitcast %[[#CharPtrTy]] %[[#Iter]]
+; CHECK-SPIRV: %[[#CurrValueAsCharPtr:]] = OpInBoundsPtrAccessChain %[[#CharPtrTy]] %[[#IterAsCharPtr]] %[[#LongConst4]]
+; CHECK-SPIRV: %[[#CurrValue]] = OpBitcast %[[#ArrPtrTy]] %[[#CurrValueAsCharPtr]]
+; CHECK-SPIRV: OpBranch %[[#LblCond]]
+; CHECK-SPIRV: %[[#LblExit]] = OpLabel
+; CHECK-SPIRV: OpFunctionEnd
+
+define spir_kernel void @foo() {
+entry:
+  %v = alloca [1 x i32], align 4
+  %eof = getelementptr inbounds i8, ptr %v, i64 4
+  br label %cond
+
+cond:
+  %iter = phi ptr [ %v, %entry ], [ %curr, %body ]
+  %is = icmp eq ptr %iter, %eof
+  br i1 %is, label %exit, label %body
+
+body:
+  store i32 123, ptr %iter, align 4
+  %curr = getelementptr inbounds i8, ptr %iter, i64 4
+  br label %cond
+
+exit:
+  ret void
+}



More information about the llvm-commits mailing list