[llvm] [SPIR-V] Type inference must realize that a <1 x Type> vector type is not a legal vector type in LLT (PR #124560)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 31 06:43:49 PST 2025


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/124560

>From 2fdc3501467e8c80005ca364dd20bb71966f8df3 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 27 Jan 2025 07:18:21 -0800
Subject: [PATCH 1/3] account for possible <1 x LLVM Type> input

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 42 ++++++++++---------
 llvm/lib/Target/SPIRV/SPIRVUtils.h            | 17 ++++++++
 2 files changed, 40 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 702206b8e0dc56e..96f67d6117e9733 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -427,7 +427,7 @@ Type *SPIRVEmitIntrinsics::reconstructType(Value *Op, bool UnknownElemTypeI8,
 
 void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
                                           Value *Arg) {
-  Value *OfType = PoisonValue::get(Ty);
+  Value *OfType = getNormalizedPoisonValue(Ty);
   CallInst *AssignCI = nullptr;
   if (Arg->getType()->isAggregateType() && Ty->isAggregateType() &&
       allowEmitFakeUse(Arg)) {
@@ -447,6 +447,7 @@ void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
 
 void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
                                          Value *Arg) {
+  ElemTy = normalizeType(ElemTy);
   Value *OfType = PoisonValue::get(ElemTy);
   CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg);
   if (AssignPtrTyCI == nullptr ||
@@ -470,7 +471,7 @@ void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg,
     return;
 
   // update association with the pointee type
-  Type *ElemTy = OfType->getType();
+  Type *ElemTy = normalizeType(OfType->getType());
   GR->addDeducedElementType(AssignCI, ElemTy);
   GR->addDeducedElementType(Arg, ElemTy);
 }
@@ -490,7 +491,7 @@ CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Function *F, Value *Op,
   }
   Type *OpTy = Op->getType();
   SmallVector<Type *, 2> Types = {OpTy, OpTy};
-  SmallVector<Value *, 2> Args = {Op, buildMD(PoisonValue::get(ElemTy)),
+  SmallVector<Value *, 2> Args = {Op, buildMD(getNormalizedPoisonValue(ElemTy)),
                                   B.getInt32(getPointerAddressSpace(OpTy))};
   CallInst *PtrCasted =
       B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
@@ -766,7 +767,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
   // remember the found relationship
   if (Ty && !IgnoreKnownType) {
     // specify nested types if needed, otherwise return unchanged
-    GR->addDeducedElementType(I, Ty);
+    GR->addDeducedElementType(I, normalizeType(Ty));
   }
 
   return Ty;
@@ -852,7 +853,7 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
       }
       if (Ty != OpTy) {
         Type *NewTy = VectorType::get(Ty, VecTy->getElementCount());
-        GR->addDeducedCompositeType(U, NewTy);
+        GR->addDeducedCompositeType(U, normalizeType(NewTy));
         return NewTy;
       }
     }
@@ -990,6 +991,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
   if (KnownElemTy)
     return false;
   if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+    OpElemTy = normalizeType(OpElemTy);
     GR->addDeducedElementType(F, OpElemTy);
     GR->addReturnType(
         F, TypedPointerType::get(OpElemTy,
@@ -1002,7 +1004,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
         continue;
       if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
         if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
-          updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
+          updateAssignType(AssignCI, CI, getNormalizedPoisonValue(OpElemTy));
           propagateElemType(CI, PrevElemTy, VisitedSubst);
         }
       }
@@ -1162,11 +1164,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
     Type *Ty = AskTy ? AskTy : GR->findDeducedElementType(Op);
     if (Ty == KnownElemTy)
       continue;
-    Value *OpTyVal = PoisonValue::get(KnownElemTy);
+    Value *OpTyVal = getNormalizedPoisonValue(KnownElemTy);
     Type *OpTy = Op->getType();
     if (!Ty || AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)) {
       Type *PrevElemTy = GR->findDeducedElementType(Op);
-      GR->addDeducedElementType(Op, KnownElemTy);
+      GR->addDeducedElementType(Op, normalizeType(KnownElemTy));
       // check if KnownElemTy is complete
       if (!Uncomplete)
         eraseTodoType(Op);
@@ -1492,7 +1494,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeTargetExt(
 
   // Our previous guess about the type seems to be wrong, let's update
   // inferred type according to a new, more precise type information.
-  updateAssignType(AssignCI, V, PoisonValue::get(AssignedType));
+  updateAssignType(AssignCI, V, getNormalizedPoisonValue(AssignedType));
 }
 
 void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
@@ -1507,7 +1509,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
     return;
 
   setInsertPointSkippingPhis(B, I);
-  Value *ExpectedElementVal = PoisonValue::get(ExpectedElementType);
+  Value *ExpectedElementVal = getNormalizedPoisonValue(ExpectedElementType);
   MetadataAsValue *VMD = buildMD(ExpectedElementVal);
   unsigned AddressSpace = getPointerAddressSpace(Pointer->getType());
   bool FirstPtrCastOrAssignPtrType = true;
@@ -1653,7 +1655,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
       if (!ElemTy) {
         ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx);
         if (ElemTy) {
-          GR->addDeducedElementType(CalledArg, ElemTy);
+          GR->addDeducedElementType(CalledArg, normalizeType(ElemTy));
         } else {
           for (User *U : CalledArg->users()) {
             if (Instruction *Inst = dyn_cast<Instruction>(U)) {
@@ -1984,8 +1986,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
           Type *ElemTy = GR->findDeducedElementType(Op);
           buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op, true), Op);
         } else {
-          CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
-                                               {OpTy}, Op, Op, {}, B);
+          CallInst *AssignCI =
+              buildIntrWithMD(Intrinsic::spv_assign_type, {OpTy},
+                              getNormalizedPoisonValue(OpTy), Op, {}, B);
           GR->addAssignPtrTypeInstr(Op, AssignCI);
         }
       }
@@ -2034,7 +2037,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
       Type *OpTy = Op->getType();
       Value *OpTyVal = Op;
       if (OpTy->isTargetExtTy())
-        OpTyVal = PoisonValue::get(OpTy);
+        OpTyVal = getNormalizedPoisonValue(OpTy);
       CallInst *NewOp =
           buildIntrWithMD(Intrinsic::spv_track_constant,
                           {OpTy, OpTyVal->getType()}, Op, OpTyVal, {}, B);
@@ -2045,7 +2048,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
         buildAssignPtr(B, IntegerType::getInt8Ty(I->getContext()), NewOp);
         SmallVector<Type *, 2> Types = {OpTy, OpTy};
         SmallVector<Value *, 2> Args = {
-            NewOp, buildMD(PoisonValue::get(OpElemTy)),
+            NewOp, buildMD(getNormalizedPoisonValue(OpElemTy)),
             B.getInt32(getPointerAddressSpace(OpTy))};
         CallInst *PtrCasted =
             B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
@@ -2178,7 +2181,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
     if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
       if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(Arg)) {
         DenseSet<std::pair<Value *, Value *>> VisitedSubst;
-        updateAssignType(AssignCI, Arg, PoisonValue::get(ElemTy));
+        updateAssignType(AssignCI, Arg, getNormalizedPoisonValue(ElemTy));
         propagateElemType(Arg, IntegerType::getInt8Ty(F->getContext()),
                           VisitedSubst);
       } else {
@@ -2232,7 +2235,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
           continue;
         if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type ||
             II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
-          updateAssignType(II, &F, PoisonValue::get(FPElemTy));
+          updateAssignType(II, &F, getNormalizedPoisonValue(FPElemTy));
           break;
         }
       }
@@ -2256,7 +2259,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
   for (Function *F : Worklist) {
     SmallVector<Value *> Args;
     for (const auto &Arg : F->args())
-      Args.push_back(PoisonValue::get(Arg.getType()));
+      Args.push_back(getNormalizedPoisonValue(Arg.getType()));
     IRB.CreateCall(F, Args);
   }
   IRB.CreateRetVoid();
@@ -2286,7 +2289,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
             buildAssignPtr(B, ElemTy, Arg);
           }
         } else if (isa<Instruction>(Param)) {
-          GR->addDeducedElementType(Param, ElemTy);
+          GR->addDeducedElementType(Param, normalizeType(ElemTy));
           // insertAssignTypeIntrs() will complete buildAssignPtr()
         } else {
           B.SetInsertPoint(CI->getParent()
@@ -2302,6 +2305,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
         if (!RefF || !isPointerTy(RefF->getReturnType()) ||
             GR->findDeducedElementType(RefF))
           continue;
+        ElemTy = normalizeType(ElemTy);
         GR->addDeducedElementType(RefF, ElemTy);
         GR->addReturnType(
             RefF, TypedPointerType::get(
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index fd48098257065aa..ed7b2ef1becd952 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -383,6 +383,23 @@ inline const Type *unifyPtrType(const Type *Ty) {
   return toTypedPointer(const_cast<Type *>(Ty));
 }
 
+// Modify an LLVM type to conform with future transformations in IRTranslator.
+// At the moment use cases comprise only a <1 x Type> vector. To extend when/if
+// needed.
+inline Type *normalizeType(Type *Ty) {
+  auto *FVTy = dyn_cast<FixedVectorType>(Ty);
+  if (!FVTy || FVTy->getNumElements() != 1)
+    return Ty;
+  // If it's a <1 x Type> vector type, replace it by the element type, because
+  // it's not a legal vector type in LLT and IRTranslator will represent it as
+  // the scalar eventually.
+  return normalizeType(FVTy->getElementType());
+}
+
+inline PoisonValue *getNormalizedPoisonValue(Type *Ty) {
+  return PoisonValue::get(normalizeType(Ty));
+}
+
 MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
 
 #define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun"

>From 692565e5d28071cdc95abd19d8f64464aa9d28c3 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 31 Jan 2025 05:02:07 -0800
Subject: [PATCH 2/3] fix insert/extract; add a test case

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  10 +
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |   5 +
 .../validate/triton-tut-softmax-kernel.ll     | 221 ++++++++++++++++++
 3 files changed, 236 insertions(+)
 create mode 100644 llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 96f67d6117e9733..52614d378c465a1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1706,6 +1706,11 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
 }
 
 Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) {
+  // If it's a <1 x Type> vector type, don't modify it. It's not a legal vector
+  // type in LLT and IRTranslator will replace it by the scalar.
+  if (isVector1(I.getType()))
+    return &I;
+
   SmallVector<Type *, 4> Types = {I.getType(), I.getOperand(0)->getType(),
                                   I.getOperand(1)->getType(),
                                   I.getOperand(2)->getType()};
@@ -1719,6 +1724,11 @@ Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) {
 
 Instruction *
 SPIRVEmitIntrinsics::visitExtractElementInst(ExtractElementInst &I) {
+  // If it's a <1 x Type> vector type, don't modify it. It's not a legal vector
+  // type in LLT and IRTranslator will replace it by the scalar.
+  if (isVector1(I.getVectorOperandType()))
+    return &I;
+
   IRBuilder<> B(I.getParent());
   B.SetInsertPoint(&I);
   SmallVector<Type *, 3> Types = {I.getType(), I.getVectorOperandType(),
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index ed7b2ef1becd952..552adf2df7d1796 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -383,6 +383,11 @@ inline const Type *unifyPtrType(const Type *Ty) {
   return toTypedPointer(const_cast<Type *>(Ty));
 }
 
+inline bool isVector1(Type *Ty) {
+  auto *FVTy = dyn_cast<FixedVectorType>(Ty);
+  return FVTy && FVTy->getNumElements() == 1;
+}
+
 // Modify an LLVM type to conform with future transformations in IRTranslator.
 // At the moment use cases comprise only a <1 x Type> vector. To extend when/if
 // needed.
diff --git a/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll b/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll
new file mode 100644
index 000000000000000..48771cc07e95984
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll
@@ -0,0 +1,221 @@
+; This is an excerpt from the tutorial of the Triton language converted into
+; LLVM IR via the Triton XPU backend and cleaned of irrelevant details.
+; The only pass criterion is that spirv-val considers output valid.
+
+; Ths particular case is related to translation of <1 x Ty> vectors.
+
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.4 %}
+
+define spir_kernel void @softmax_kernel(ptr addrspace(1) nocapture writeonly %0, ptr addrspace(1) nocapture readonly %1, i32 %2, i32 %3, i32 %4, i32 %5, ptr addrspace(3) nocapture %6) {
+  %8 = tail call spir_func i64 @_Z12get_group_idj(i32 0)
+  %9 = trunc i64 %8 to i32
+  %10 = tail call spir_func i64 @_Z14get_num_groupsj(i32 0)
+  %11 = trunc i64 %10 to i32
+  %12 = tail call spir_func i64 @_Z12get_local_idj(i32 0)
+  %13 = trunc i64 %12 to i32
+  %14 = and i32 %13, 255
+  %15 = or disjoint i32 %14, 256
+  %16 = or disjoint i32 %14, 512
+  %17 = or disjoint i32 %14, 768
+  %18 = icmp slt i32 %14, %5
+  %19 = icmp slt i32 %15, %5
+  %20 = icmp slt i32 %16, %5
+  %21 = icmp slt i32 %17, %5
+  %22 = icmp sgt i32 %4, %9
+  br i1 %22, label %.lr.ph, label %._crit_edge
+
+.lr.ph:                                           ; preds = %7
+  %23 = lshr i64 %12, 5
+  %24 = and i32 %13, 31
+  %25 = zext nneg i32 %15 to i64
+  %26 = zext nneg i32 %16 to i64
+  %27 = zext nneg i32 %17 to i64
+  %28 = and i64 %12, 255
+  %29 = and i64 %23, 7
+  %30 = icmp eq i32 %24, 0
+  %31 = getelementptr float, ptr addrspace(3) %6, i64 %29
+  %32 = icmp slt i32 %13, 8
+  %sext = shl i64 %12, 32
+  %33 = ashr exact i64 %sext, 30
+  %34 = getelementptr i8, ptr addrspace(3) %6, i64 %33
+  %35 = and i32 %13, 7
+  %36 = icmp eq i32 %35, 0
+  %37 = and i1 %32, %36
+  br label %38
+
+38:                                               ; preds = %.lr.ph, %123
+  %39 = phi i32 [ %9, %.lr.ph ], [ %124, %123 ]
+  %40 = mul i32 %39, %2
+  %41 = sext i32 %40 to i64
+  %42 = getelementptr float, ptr addrspace(1) %1, i64 %41
+  %43 = getelementptr float, ptr addrspace(1) %42, i64 %25
+  %44 = getelementptr float, ptr addrspace(1) %42, i64 %26
+  %45 = getelementptr float, ptr addrspace(1) %42, i64 %27
+  br i1 %18, label %46, label %49
+
+46:                                               ; preds = %38
+  %47 = getelementptr float, ptr addrspace(1) %42, i64 %28
+  %48 = load <1 x float>, ptr addrspace(1) %47, align 4
+  br label %49
+
+49:                                               ; preds = %46, %38
+  %50 = phi <1 x float> [ %48, %46 ], [ splat (float 0xFFF0000000000000), %38 ]
+  %51 = extractelement <1 x float> %50, i64 0
+  br i1 %19, label %52, label %54
+
+52:                                               ; preds = %49
+  %53 = load <1 x float>, ptr addrspace(1) %43, align 4
+  br label %54
+
+54:                                               ; preds = %52, %49
+  %55 = phi <1 x float> [ %53, %52 ], [ splat (float 0xFFF0000000000000), %49 ]
+  %56 = extractelement <1 x float> %55, i64 0
+  br i1 %20, label %57, label %59
+
+57:                                               ; preds = %54
+  %58 = load <1 x float>, ptr addrspace(1) %44, align 4
+  br label %59
+
+59:                                               ; preds = %57, %54
+  %60 = phi <1 x float> [ %58, %57 ], [ splat (float 0xFFF0000000000000), %54 ]
+  %61 = extractelement <1 x float> %60, i64 0
+  br i1 %21, label %62, label %64
+
+62:                                               ; preds = %59
+  %63 = load <1 x float>, ptr addrspace(1) %45, align 4
+  br label %64
+
+64:                                               ; preds = %62, %59
+  %65 = phi <1 x float> [ %63, %62 ], [ splat (float 0xFFF0000000000000), %59 ]
+  %66 = extractelement <1 x float> %65, i64 0
+  tail call spir_func void @_Z7barrierj(i32 1)
+  %67 = tail call float @llvm.maxnum.f32(float %51, float %56)
+  %68 = tail call float @llvm.maxnum.f32(float %67, float %61)
+  %69 = tail call float @llvm.maxnum.f32(float %68, float %66)
+  %70 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32 3, i32 0, float %69)
+  br i1 %30, label %71, label %72
+
+71:                                               ; preds = %64
+  store float %70, ptr addrspace(3) %31, align 4
+  br label %72
+
+72:                                               ; preds = %71, %64
+  tail call spir_func void @_Z7barrierj(i32 1)
+  br i1 %32, label %74, label %.thread1
+
+.thread1:                                         ; preds = %72
+  %73 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float undef, i32 8)
+  br label %78
+
+74:                                               ; preds = %72
+  %75 = load float, ptr addrspace(3) %34, align 4
+  %76 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float %75, i32 8)
+  br i1 %37, label %77, label %78
+
+77:                                               ; preds = %74
+  store float %76, ptr addrspace(3) %34, align 4
+  br label %78
+
+78:                                               ; preds = %.thread1, %77, %74
+  tail call spir_func void @_Z7barrierj(i32 1)
+  %79 = load float, ptr addrspace(3) %6, align 4
+  %80 = fsub float %51, %79
+  %81 = fsub float %56, %79
+  %82 = fsub float %61, %79
+  %83 = fsub float %66, %79
+  %84 = fmul float %80, 0x3FF7154760000000
+  %85 = tail call float @llvm.exp2.f32(float %84)
+  %86 = fmul float %81, 0x3FF7154760000000
+  %87 = tail call float @llvm.exp2.f32(float %86)
+  %88 = fmul float %82, 0x3FF7154760000000
+  %89 = tail call float @llvm.exp2.f32(float %88)
+  %90 = fmul float %83, 0x3FF7154760000000
+  %91 = tail call float @llvm.exp2.f32(float %90)
+  tail call spir_func void @_Z7barrierj(i32 1)
+  %92 = fadd float %85, %87
+  %93 = fadd float %89, %92
+  %94 = fadd float %91, %93
+  %95 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32 3, i32 0, float %94)
+  br i1 %30, label %96, label %97
+
+96:                                               ; preds = %78
+  store float %95, ptr addrspace(3) %31, align 4
+  br label %97
+
+97:                                               ; preds = %96, %78
+  tail call spir_func void @_Z7barrierj(i32 1)
+  br i1 %32, label %99, label %.thread
+
+.thread:                                          ; preds = %97
+  %98 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float undef, i32 8)
+  br label %103
+
+99:                                               ; preds = %97
+  %100 = load float, ptr addrspace(3) %34, align 4
+  %101 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float %100, i32 8)
+  br i1 %37, label %102, label %103
+
+102:                                              ; preds = %99
+  store float %101, ptr addrspace(3) %34, align 4
+  br label %103
+
+103:                                              ; preds = %.thread, %102, %99
+  tail call spir_func void @_Z7barrierj(i32 1)
+  %104 = load float, ptr addrspace(3) %6, align 4
+  %105 = fdiv float %87, %104
+  %106 = fdiv float %89, %104
+  %107 = fdiv float %91, %104
+  %108 = mul i32 %39, %3
+  %109 = sext i32 %108 to i64
+  %110 = getelementptr float, ptr addrspace(1) %0, i64 %109
+  %111 = getelementptr float, ptr addrspace(1) %110, i64 %25
+  %112 = getelementptr float, ptr addrspace(1) %110, i64 %26
+  %113 = getelementptr float, ptr addrspace(1) %110, i64 %27
+  br i1 %18, label %114, label %117
+
+114:                                              ; preds = %103
+  %115 = fdiv float %85, %104
+  %116 = getelementptr float, ptr addrspace(1) %110, i64 %28
+  store float %115, ptr addrspace(1) %116, align 4
+  br label %117
+
+117:                                              ; preds = %114, %103
+  br i1 %19, label %118, label %119
+
+118:                                              ; preds = %117
+  store float %105, ptr addrspace(1) %111, align 4
+  br label %119
+
+119:                                              ; preds = %118, %117
+  br i1 %20, label %120, label %121
+
+120:                                              ; preds = %119
+  store float %106, ptr addrspace(1) %112, align 4
+  br label %121
+
+121:                                              ; preds = %120, %119
+  br i1 %21, label %122, label %123
+
+122:                                              ; preds = %121
+  store float %107, ptr addrspace(1) %113, align 4
+  br label %123
+
+123:                                              ; preds = %122, %121
+  %124 = add i32 %39, %11
+  %125 = icmp slt i32 %124, %4
+  br i1 %125, label %38, label %._crit_edge
+
+._crit_edge:                                      ; preds = %123, %7
+  ret void
+}
+
+declare float @llvm.maxnum.f32(float, float)
+declare spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32, i32, float, i32)
+declare spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32, i32, float)
+declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32, i32, float, i32)
+declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32, i32, float)
+declare spir_func void @_Z7barrierj(i32)
+declare spir_func i64 @_Z12get_local_idj(i32)
+declare spir_func i64 @_Z14get_num_groupsj(i32)
+declare spir_func i64 @_Z12get_group_idj(i32)
+declare float @llvm.exp2.f32(float)

>From 8b8f3f601c10579d0b121559a642cd5f41870f45 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 31 Jan 2025 06:42:54 -0800
Subject: [PATCH 3/3] undef => poison

---
 llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll b/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll
index 48771cc07e95984..d8a6c85b3d40739 100644
--- a/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll
+++ b/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll
@@ -104,7 +104,7 @@ define spir_kernel void @softmax_kernel(ptr addrspace(1) nocapture writeonly %0,
   br i1 %32, label %74, label %.thread1
 
 .thread1:                                         ; preds = %72
-  %73 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float undef, i32 8)
+  %73 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float poison, i32 8)
   br label %78
 
 74:                                               ; preds = %72
@@ -147,7 +147,7 @@ define spir_kernel void @softmax_kernel(ptr addrspace(1) nocapture writeonly %0,
   br i1 %32, label %99, label %.thread
 
 .thread:                                          ; preds = %97
-  %98 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float undef, i32 8)
+  %98 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float poison, i32 8)
   br label %103
 
 99:                                               ; preds = %97



More information about the llvm-commits mailing list