[llvm] PreISelIntrinsicLowering: Lower llvm.exp to a loop if scalable vec arg (PR #117568)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 25 07:20:49 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Stephen Long (steplong)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/117568.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp (+57) 
- (added) llvm/test/Transforms/PreISelIntrinsicLowering/expand-exp.ll (+23) 


``````````diff
diff --git a/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp b/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp
index 4a3d1673c2a7c1..74f54e43a8386f 100644
--- a/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp
+++ b/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp
@@ -335,6 +335,59 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses(Function &F) const {
   return Changed;
 }
 
+static bool lowerExpIntrinsicToLoop(Module &M, Function &F, CallInst *CI) {
+  ScalableVectorType *ScalableTy =
+      dyn_cast<ScalableVectorType>(F.getArg(0)->getType());
+  if (!ScalableTy) {
+    return false;
+  }
+
+  BasicBlock *PreLoopBB = CI->getParent();
+  BasicBlock *PostLoopBB = nullptr;
+  Function *ParentFunc = PreLoopBB->getParent();
+  LLVMContext &Ctx = PreLoopBB->getContext();
+
+  PostLoopBB = PreLoopBB->splitBasicBlock(CI);
+  BasicBlock *LoopBB = BasicBlock::Create(Ctx, "", ParentFunc, PostLoopBB);
+  PreLoopBB->getTerminator()->setSuccessor(0, LoopBB);
+
+  // loop preheader
+  IRBuilder<> PreLoopBuilder(PreLoopBB->getTerminator());
+  Value *VScale = PreLoopBuilder.CreateVScale(
+      ConstantInt::get(PreLoopBuilder.getInt64Ty(), 1));
+  Value *N = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
+                              ScalableTy->getMinNumElements());
+  Value *LoopEnd = PreLoopBuilder.CreateMul(VScale, N);
+
+  // loop body
+  IRBuilder<> LoopBuilder(LoopBB);
+  Type *Int64Ty = LoopBuilder.getInt64Ty();
+
+  PHINode *LoopIndex = LoopBuilder.CreatePHI(Int64Ty, 2);
+  LoopIndex->addIncoming(ConstantInt::get(Int64Ty, 0U), PreLoopBB);
+  PHINode *Vec = LoopBuilder.CreatePHI(ScalableTy, 2);
+  Vec->addIncoming(CI->getArgOperand(0), PreLoopBB);
+
+  Value *Elem = LoopBuilder.CreateExtractElement(Vec, LoopIndex);
+  Function *Exp = Intrinsic::getOrInsertDeclaration(
+      &M, Intrinsic::exp, ScalableTy->getElementType());
+  Value *Res = LoopBuilder.CreateCall(Exp, Elem);
+  Value *NewVec = LoopBuilder.CreateInsertElement(Vec, Res, LoopIndex);
+  Vec->addIncoming(NewVec, LoopBB);
+
+  Value *One = ConstantInt::get(Int64Ty, 1U);
+  Value *NextLoopIndex = LoopBuilder.CreateAdd(LoopIndex, One);
+  LoopIndex->addIncoming(NextLoopIndex, LoopBB);
+
+  Value *ExitCond =
+      LoopBuilder.CreateICmp(CmpInst::ICMP_EQ, NextLoopIndex, LoopEnd);
+  LoopBuilder.CreateCondBr(ExitCond, PostLoopBB, LoopBB);
+
+  CI->replaceAllUsesWith(NewVec);
+  CI->eraseFromParent();
+  return true;
+}
+
 bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const {
   bool Changed = false;
   for (Function &F : M) {
@@ -453,6 +506,10 @@ bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const {
     case Intrinsic::objc_sync_exit:
       Changed |= lowerObjCCall(F, "objc_sync_exit");
       break;
+    case Intrinsic::exp:
+      Changed |= forEachCall(
+          F, [&](CallInst *CI) { return lowerExpIntrinsicToLoop(M, F, CI); });
+      break;
     }
   }
   return Changed;
diff --git a/llvm/test/Transforms/PreISelIntrinsicLowering/expand-exp.ll b/llvm/test/Transforms/PreISelIntrinsicLowering/expand-exp.ll
new file mode 100644
index 00000000000000..6ad181033f233a
--- /dev/null
+++ b/llvm/test/Transforms/PreISelIntrinsicLowering/expand-exp.ll
@@ -0,0 +1,23 @@
+; RUN: opt -passes=pre-isel-intrinsic-lowering -S < %s | FileCheck %s
+
+define <vscale x 4 x float> @softmax_kernel() {
+; CHECK-LABEL: define <vscale x 4 x float> @softmax_kernel(
+; CHECK-NEXT:    [[VSCALE:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[LOOPEND:%.*]] = mul i64 [[VSCALE]], 4
+; CHECK-NEXT:    br label %[[LOOPBODY:.*]]
+; CHECK:       [[LOOPBODY]]:
+; CHECK-NEXT:    [[IDX:%.*]] = phi i64 [ 0, %0 ], [ [[NEW_IDX:%.*]], %[[LOOPBODY]] ]
+; CHECK-NEXT:    [[VEC:%.*]] = phi <vscale x 4 x float> [ zeroinitializer, %0 ], [ [[NEW_VEC:.*]], %[[LOOPBODY]] ]
+; CHECK-NEXT:    [[ELEM:%.*]] = extractelement <vscale x 4 x float> [[VEC]], i64 [[IDX]]
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.exp.f32(float [[ELEM]])
+; CHECK-NEXT:    [[NEW_VEC:%.*]] = insertelement <vscale x 4 x float> [[VEC]], float [[RES]], i64 [[IDX]]
+; CHECK-NEXT:    [[NEW_IDX]] = add i64 [[IDX]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i64 [[NEW_IDX]], [[LOOPEND]]
+; CHECK-NEXT:    br i1 [[CMP]], label %[[LOOPEXIT:.*]], label %[[LOOPBODY]]
+; CHECK:       [[LOOPEXIT]]:
+; CHECK-NEXT:    ret <vscale x 4 x float> [[NEW_VEC]]
+  %1 = call <vscale x 4 x float> @llvm.exp.nxv4f32(<vscale x 4 x float> zeroinitializer)
+  ret <vscale x 4 x float> %1
+}
+
+declare <vscale x 4 x float> @llvm.exp.nxv4f32(<vscale x 4 x float>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/117568


More information about the llvm-commits mailing list