[llvm] 1bada0a - [NVPTX] Add IR pass for FMA transformation in the llc pipeline (#154735)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 9 09:09:16 PST 2025


Author: Rajat Bajpai
Date: 2025-12-09T22:39:11+05:30
New Revision: 1bada0af22d878b2ec0cfefd655c09b801b44918

URL: https://github.com/llvm/llvm-project/commit/1bada0af22d878b2ec0cfefd655c09b801b44918
DIFF: https://github.com/llvm/llvm-project/commit/1bada0af22d878b2ec0cfefd655c09b801b44918.diff

LOG: [NVPTX] Add IR pass for FMA transformation in the llc pipeline (#154735)

This change introduces a new IR pass in the llc pipeline for NVPTX that
transforms sequences of FMUL followed by FADD or FSUB into a single FMA
instruction.

Currently, all FMA folding for NVPTX occurs at the DAGCombine stage,
which is too late for any IR-level passes that might want to optimize or
analyze FMAs. By moving this transformation earlier into the IR phase,
we enable more opportunities for FMA folding, including across basic
blocks.

Additionally, this new pass relies on the contract instruction level
fast-math flag to perform these transformations, rather than depending
on the -fp-contract=fast or -enable-unsafe-fp-math options passed to
llc.

Added: 
    llvm/lib/Target/NVPTX/NVPTXIRPeephole.cpp
    llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll

Modified: 
    llvm/lib/Target/NVPTX/CMakeLists.txt
    llvm/lib/Target/NVPTX/NVPTX.h
    llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
    llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt
index f9c24750c4836..6fe58c25c757d 100644
--- a/llvm/lib/Target/NVPTX/CMakeLists.txt
+++ b/llvm/lib/Target/NVPTX/CMakeLists.txt
@@ -18,6 +18,7 @@ set(NVPTXCodeGen_sources
   NVPTXAssignValidGlobalNames.cpp
   NVPTXAtomicLower.cpp
   NVPTXCtorDtorLowering.cpp
+  NVPTXIRPeephole.cpp
   NVPTXForwardParams.cpp
   NVPTXFrameLowering.cpp
   NVPTXGenericToNVVM.cpp

diff  --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 95fd05f2a926f..210624fbb235c 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerAllocaPass();
 FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
                                               bool NoTrapAfterNoreturn);
 FunctionPass *createNVPTXTagInvariantLoadsPass();
+FunctionPass *createNVPTXIRPeepholePass();
 MachineFunctionPass *createNVPTXPeephole();
 MachineFunctionPass *createNVPTXProxyRegErasurePass();
 MachineFunctionPass *createNVPTXForwardParamsPass();
@@ -75,12 +76,17 @@ void initializeNVPTXAAWrapperPassPass(PassRegistry &);
 void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
 void initializeNVPTXPeepholePass(PassRegistry &);
 void initializeNVPTXTagInvariantLoadLegacyPassPass(PassRegistry &);
+void initializeNVPTXIRPeepholePass(PassRegistry &);
 void initializeNVPTXPrologEpilogPassPass(PassRegistry &);
 
 struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
 };
 
+struct NVPTXIRPeepholePass : PassInfoMixin<NVPTXIRPeepholePass> {
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
 struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
   NVVMReflectPass() : SmVersion(0) {}
   NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}

diff  --git a/llvm/lib/Target/NVPTX/NVPTXIRPeephole.cpp b/llvm/lib/Target/NVPTX/NVPTXIRPeephole.cpp
new file mode 100644
index 0000000000000..bd16c7213b1e7
--- /dev/null
+++ b/llvm/lib/Target/NVPTX/NVPTXIRPeephole.cpp
@@ -0,0 +1,167 @@
+//===------ NVPTXIRPeephole.cpp - NVPTX IR Peephole --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements IR-level peephole optimizations. These transformations
+// run late in the NVPTX IR pass pipeline just before the instruction selection.
+//
+// Currently, it implements the following transformation(s):
+// 1. FMA folding (float/double types):
+//    Transforms FMUL+FADD/FSUB sequences into FMA intrinsics when the
+//    'contract' fast-math flag is present. Supported patterns:
+//    - fadd(fmul(a, b), c) => fma(a, b, c)
+//    - fadd(c, fmul(a, b)) => fma(a, b, c)
+//    - fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d))
+//    - fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+//    - fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
+//    - fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d)))
+//
+//===----------------------------------------------------------------------===//
+
+#include "NVPTXUtilities.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+
+#define DEBUG_TYPE "nvptx-ir-peephole"
+
+using namespace llvm;
+
+static bool tryFoldBinaryFMul(BinaryOperator *BI) {
+  Value *Op0 = BI->getOperand(0);
+  Value *Op1 = BI->getOperand(1);
+
+  auto *FMul0 = dyn_cast<BinaryOperator>(Op0);
+  auto *FMul1 = dyn_cast<BinaryOperator>(Op1);
+
+  BinaryOperator *FMul = nullptr;
+  Value *OtherOperand = nullptr;
+  bool IsFirstOperand = false;
+
+  // Either Op0 or Op1 should be a valid FMul
+  if (FMul0 && FMul0->getOpcode() == Instruction::FMul && FMul0->hasOneUse() &&
+      FMul0->hasAllowContract()) {
+    FMul = FMul0;
+    OtherOperand = Op1;
+    IsFirstOperand = true;
+  } else if (FMul1 && FMul1->getOpcode() == Instruction::FMul &&
+             FMul1->hasOneUse() && FMul1->hasAllowContract()) {
+    FMul = FMul1;
+    OtherOperand = Op0;
+    IsFirstOperand = false;
+  } else {
+    return false;
+  }
+
+  bool IsFSub = BI->getOpcode() == Instruction::FSub;
+  LLVM_DEBUG({
+    const char *OpName = IsFSub ? "FSub" : "FAdd";
+    dbgs() << "Found " << OpName << " with FMul (single use) as "
+           << (IsFirstOperand ? "first" : "second") << " operand: " << *BI
+           << "\n";
+  });
+
+  Value *MulOp0 = FMul->getOperand(0);
+  Value *MulOp1 = FMul->getOperand(1);
+  IRBuilder<> Builder(BI);
+  Value *FMA = nullptr;
+
+  if (!IsFSub) {
+    // fadd(fmul(a, b), c) => fma(a, b, c)
+    // fadd(c, fmul(a, b)) => fma(a, b, c)
+    FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
+                                  {MulOp0, MulOp1, OtherOperand});
+  } else {
+    if (IsFirstOperand) {
+      // fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+      Value *NegOtherOp =
+          Builder.CreateFNegFMF(OtherOperand, BI->getFastMathFlags());
+      FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
+                                    {MulOp0, MulOp1, NegOtherOp});
+    } else {
+      // fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
+      Value *NegMulOp0 =
+          Builder.CreateFNegFMF(MulOp0, FMul->getFastMathFlags());
+      FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
+                                    {NegMulOp0, MulOp1, OtherOperand});
+    }
+  }
+
+  // Combine fast-math flags from the original instructions
+  auto *FMAInst = cast<Instruction>(FMA);
+  FastMathFlags BinaryFMF = BI->getFastMathFlags();
+  FastMathFlags FMulFMF = FMul->getFastMathFlags();
+  FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
+                         FastMathFlags::unionValue(BinaryFMF, FMulFMF);
+  FMAInst->setFastMathFlags(NewFMF);
+
+  LLVM_DEBUG({
+    const char *OpName = IsFSub ? "FSub" : "FAdd";
+    dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n";
+  });
+  BI->replaceAllUsesWith(FMA);
+  BI->eraseFromParent();
+  FMul->eraseFromParent();
+  return true;
+}
+
+static bool foldFMA(Function &F) {
+  bool Changed = false;
+
+  // Iterate and process float/double FAdd/FSub instructions with allow-contract
+  for (auto &I : llvm::make_early_inc_range(instructions(F))) {
+    if (auto *BI = dyn_cast<BinaryOperator>(&I)) {
+      // Only FAdd and FSub are supported.
+      if (BI->getOpcode() != Instruction::FAdd &&
+          BI->getOpcode() != Instruction::FSub)
+        continue;
+
+      // At minimum, the instruction should have allow-contract.
+      if (!BI->hasAllowContract())
+        continue;
+
+      // Only float and double are supported.
+      if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy())
+        continue;
+
+      if (tryFoldBinaryFMul(BI))
+        Changed = true;
+    }
+  }
+  return Changed;
+}
+
+namespace {
+
+struct NVPTXIRPeephole : public FunctionPass {
+  static char ID;
+  NVPTXIRPeephole() : FunctionPass(ID) {}
+  bool runOnFunction(Function &F) override;
+};
+
+} // namespace
+
+char NVPTXIRPeephole::ID = 0;
+INITIALIZE_PASS(NVPTXIRPeephole, "nvptx-ir-peephole", "NVPTX IR Peephole",
+                false, false)
+
+bool NVPTXIRPeephole::runOnFunction(Function &F) { return foldFMA(F); }
+
+FunctionPass *llvm::createNVPTXIRPeepholePass() {
+  return new NVPTXIRPeephole();
+}
+
+PreservedAnalyses NVPTXIRPeepholePass::run(Function &F,
+                                           FunctionAnalysisManager &) {
+  if (!foldFMA(F))
+    return PreservedAnalyses::all();
+
+  PreservedAnalyses PA;
+  PA.preserveSet<CFGAnalyses>();
+  return PA;
+}

diff  --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
index ee37c9826012c..7d645bff7110f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
+++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
@@ -40,4 +40,5 @@ FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
 FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
 FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this))
 FUNCTION_PASS("nvptx-tag-invariant-loads", NVPTXTagInvariantLoadsPass())
+FUNCTION_PASS("nvptx-ir-peephole", NVPTXIRPeepholePass())
 #undef FUNCTION_PASS

diff  --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index a6837a482608c..74bae28044e66 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -51,6 +51,13 @@ static cl::opt<bool>
                                cl::desc("Disable load/store vectorizer"),
                                cl::init(false), cl::Hidden);
 
+// NVPTX IR Peephole is a new pass; this option will lets us turn it off in case
+// we encounter some issues.
+static cl::opt<bool>
+    DisableNVPTXIRPeephole("disable-nvptx-ir-peephole",
+                           cl::desc("Disable NVPTX IR Peephole"),
+                           cl::init(false), cl::Hidden);
+
 // TODO: Remove this flag when we are confident with no regressions.
 static cl::opt<bool> DisableRequireStructuredCFG(
     "disable-nvptx-require-structured-cfg",
@@ -115,6 +122,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
   initializeNVPTXExternalAAWrapperPass(PR);
   initializeNVPTXPeepholePass(PR);
   initializeNVPTXTagInvariantLoadLegacyPassPass(PR);
+  initializeNVPTXIRPeepholePass(PR);
   initializeNVPTXPrologEpilogPassPass(PR);
 }
 
@@ -379,6 +387,8 @@ void NVPTXPassConfig::addIRPasses() {
       addPass(createLoadStoreVectorizerPass());
     addPass(createSROAPass());
     addPass(createNVPTXTagInvariantLoadsPass());
+    if (!DisableNVPTXIRPeephole)
+      addPass(createNVPTXIRPeepholePass());
   }
 
   if (ST.hasPTXASUnreachableBug()) {

diff  --git a/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll b/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll
new file mode 100644
index 0000000000000..6d9ad8d3ad436
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll
@@ -0,0 +1,247 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=nvptx-ir-peephole -S | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+
+; fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+define float @test_fsub_fmul_c(float %a, float %b, float %c) {
+; CHECK-LABEL: define float @test_fsub_fmul_c(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg contract float [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[TMP1]])
+; CHECK-NEXT:    ret float [[TMP2]]
+;
+  %mul = fmul contract float %a, %b
+  %sub = fsub contract float %mul, %c
+  ret float %sub
+}
+
+
+; fsub(c, fmul(a, b)) => fma(-a, b, c)
+define float @test_fsub_c_fmul(float %a, float %b, float %c) {
+; CHECK-LABEL: define float @test_fsub_c_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg contract float [[A]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[TMP1]], float [[B]], float [[C]])
+; CHECK-NEXT:    ret float [[TMP2]]
+;
+  %mul = fmul contract float %a, %b
+  %sub = fsub contract float %c, %mul
+  ret float %sub
+}
+
+
+; fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d)))
+define float @test_fsub_fmul_fmul(float %a, float %b, float %c, float %d) {
+; CHECK-LABEL: define float @test_fsub_fmul_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], float [[D:%.*]]) {
+; CHECK-NEXT:    [[MUL2:%.*]] = fmul contract float [[C]], [[D]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg contract float [[MUL2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[TMP1]])
+; CHECK-NEXT:    ret float [[TMP2]]
+;
+  %mul1 = fmul contract float %a, %b
+  %mul2 = fmul contract float %c, %d
+  %sub = fsub contract float %mul1, %mul2
+  ret float %sub
+}
+
+
+; fsub(fmul(a, b), fmul(c, d)) => fma(fneg(c), d, fmul(a, b)))
+; fmul(a, b) has multiple uses.
+define float @test_fsub_fmul_fmul_multiple_use(float %a, float %b, float %c, float %d) {
+; CHECK-LABEL: define float @test_fsub_fmul_fmul_multiple_use(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], float [[D:%.*]]) {
+; CHECK-NEXT:    [[MUL1:%.*]] = fmul contract float [[A]], [[B]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg contract float [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[TMP1]], float [[D]], float [[MUL1]])
+; CHECK-NEXT:    [[ADD:%.*]] = fadd float [[TMP2]], [[MUL1]]
+; CHECK-NEXT:    ret float [[ADD]]
+;
+  %mul1 = fmul contract float %a, %b
+  %mul2 = fmul contract float %c, %d
+  %sub = fsub contract float %mul1, %mul2
+  %add = fadd float %sub, %mul1
+  ret float %add
+}
+
+
+; fsub(fmul(a, b), c) => fma(a, b, fneg(c)) where fsub and fmul are in 
diff erent BBs
+define float @test_fsub_fmul_
diff erent_BB(float %a, float %b, float %c, i32 %n) {
+; CHECK-LABEL: define float @test_fsub_fmul_
diff erent_BB(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], i32 [[N:%.*]]) {
+; CHECK-NEXT:  [[INIT:.*]]:
+; CHECK-NEXT:    [[CMP_ITER:%.*]] = icmp sgt i32 [[N]], 10
+; CHECK-NEXT:    br i1 [[CMP_ITER]], label %[[ITERATION:.*]], label %[[EXIT:.*]]
+; CHECK:       [[ITERATION]]:
+; CHECK-NEXT:    [[I:%.*]] = phi i32 [ 0, %[[INIT]] ], [ [[I_NEXT:%.*]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[ACC:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT:%.*]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[I_NEXT]] = add i32 [[I]], 1
+; CHECK-NEXT:    [[ACC_NEXT]] = fadd contract float [[ACC]], 1.000000e+00
+; CHECK-NEXT:    [[CMP_LOOP:%.*]] = icmp slt i32 [[I_NEXT]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP_LOOP]], label %[[ITERATION]], label %[[EXIT]]
+; CHECK:       [[EXIT]]:
+; CHECK-NEXT:    [[C_PHI:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = fneg contract float [[C_PHI]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[TMP0]])
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+init:
+  %mul = fmul contract float %a, %b
+  %cmp_iter = icmp sgt i32 %n, 10
+  br i1 %cmp_iter, label %iteration, label %exit
+
+iteration:
+  %i = phi i32 [ 0, %init ], [ %i_next, %iteration ]
+  %acc = phi float [ %c, %init ], [ %acc_next, %iteration ]
+  %i_next = add i32 %i, 1
+  %acc_next = fadd contract float %acc, 1.0
+  %cmp_loop = icmp slt i32 %i_next, %n
+  br i1 %cmp_loop, label %iteration, label %exit
+
+exit:
+  %c_phi = phi float [ %c, %init ], [ %acc_next, %iteration ]
+  %sub = fsub contract float %mul, %c_phi
+  ret float %sub
+}
+
+
+; fadd(fmul(a, b), c) => fma(a, b, c)
+define float @test_fadd_fmul_c(float %a, float %b, float %c) {
+; CHECK-LABEL: define float @test_fadd_fmul_c(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[C]])
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+  %mul = fmul contract float %a, %b
+  %add = fadd contract float %mul, %c
+  ret float %add
+}
+
+
+; fadd(c, fmul(a, b)) => fma(a, b, c)
+define float @test_fadd_c_fmul(float %a, float %b, float %c) {
+; CHECK-LABEL: define float @test_fadd_c_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[C]])
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+  %mul = fmul contract float %a, %b
+  %add = fadd contract float %c, %mul
+  ret float %add
+}
+
+
+; fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d))
+define float @test_fadd_fmul_fmul(float %a, float %b, float %c, float %d) {
+; CHECK-LABEL: define float @test_fadd_fmul_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], float [[D:%.*]]) {
+; CHECK-NEXT:    [[MUL2:%.*]] = fmul contract float [[C]], [[D]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[MUL2]])
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+  %mul1 = fmul contract float %a, %b
+  %mul2 = fmul contract float %c, %d
+  %add = fadd contract float %mul1, %mul2
+  ret float %add
+}
+
+
+; fadd(fmul(a, b), c) => fma(a, b, c) where fadd and fmul are in 
diff erent BBs
+define float @test_fadd_fmul_
diff erent_BB(float %a, float %b, float %c, i32 %n) {
+; CHECK-LABEL: define float @test_fadd_fmul_
diff erent_BB(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], i32 [[N:%.*]]) {
+; CHECK-NEXT:  [[INIT:.*]]:
+; CHECK-NEXT:    [[CMP_ITER:%.*]] = icmp sgt i32 [[N]], 10
+; CHECK-NEXT:    br i1 [[CMP_ITER]], label %[[ITERATION:.*]], label %[[EXIT:.*]]
+; CHECK:       [[ITERATION]]:
+; CHECK-NEXT:    [[I:%.*]] = phi i32 [ 0, %[[INIT]] ], [ [[I_NEXT:%.*]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[ACC:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT:%.*]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[I_NEXT]] = add i32 [[I]], 1
+; CHECK-NEXT:    [[ACC_NEXT]] = fadd contract float [[ACC]], 1.000000e+00
+; CHECK-NEXT:    [[CMP_LOOP:%.*]] = icmp slt i32 [[I_NEXT]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP_LOOP]], label %[[ITERATION]], label %[[EXIT]]
+; CHECK:       [[EXIT]]:
+; CHECK-NEXT:    [[C_PHI:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[C_PHI]])
+; CHECK-NEXT:    ret float [[TMP0]]
+;
+init:
+  %mul = fmul contract float %a, %b
+  %cmp_iter = icmp sgt i32 %n, 10
+  br i1 %cmp_iter, label %iteration, label %exit
+
+iteration:
+  %i = phi i32 [ 0, %init ], [ %i_next, %iteration ]
+  %acc = phi float [ %c, %init ], [ %acc_next, %iteration ]
+  %i_next = add i32 %i, 1
+  %acc_next = fadd contract float %acc, 1.0
+  %cmp_loop = icmp slt i32 %i_next, %n
+  br i1 %cmp_loop, label %iteration, label %exit
+
+exit:
+  %c_phi = phi float [ %c, %init ], [ %acc_next, %iteration ]
+  %add = fadd contract float %mul, %c_phi
+  ret float %add
+}
+
+
+; These scenarios shouldn't work.
+; fadd(fpext(fmul(a, b)), c) => fma(fpext(a), fpext(b), c)
+define double @test_fadd_fpext_fmul_c(float %a, float %b, double %c) {
+; CHECK-LABEL: define double @test_fadd_fpext_fmul_c(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], double [[C:%.*]]) {
+; CHECK-NEXT:    [[MUL:%.*]] = fmul contract float [[A]], [[B]]
+; CHECK-NEXT:    [[EXT:%.*]] = fpext float [[MUL]] to double
+; CHECK-NEXT:    [[ADD:%.*]] = fadd contract double [[EXT]], [[C]]
+; CHECK-NEXT:    ret double [[ADD]]
+;
+  %mul = fmul contract float %a, %b
+  %ext = fpext float %mul to double
+  %add = fadd contract double %ext, %c
+  ret double %add
+}
+
+
+; fadd(c, fpext(fmul(a, b))) => fma(fpext(a), fpext(b), c)
+define double @test_fadd_c_fpext_fmul(float %a, float %b, double %c) {
+; CHECK-LABEL: define double @test_fadd_c_fpext_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], double [[C:%.*]]) {
+; CHECK-NEXT:    [[MUL:%.*]] = fmul contract float [[A]], [[B]]
+; CHECK-NEXT:    [[EXT:%.*]] = fpext float [[MUL]] to double
+; CHECK-NEXT:    [[ADD:%.*]] = fadd contract double [[C]], [[EXT]]
+; CHECK-NEXT:    ret double [[ADD]]
+;
+  %mul = fmul contract float %a, %b
+  %ext = fpext float %mul to double
+  %add = fadd contract double %c, %ext
+  ret double %add
+}
+
+
+; Double precision tests
+; fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+define double @test_fsub_fmul_c_double(double %a, double %b, double %c) {
+; CHECK-LABEL: define double @test_fsub_fmul_c_double(
+; CHECK-SAME: double [[A:%.*]], double [[B:%.*]], double [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg contract double [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract double @llvm.fma.f64(double [[A]], double [[B]], double [[TMP1]])
+; CHECK-NEXT:    ret double [[TMP2]]
+;
+  %mul = fmul contract double %a, %b
+  %sub = fsub contract double %mul, %c
+  ret double %sub
+}
+
+
+; fadd(fmul(a, b), c) => fma(a, b, c)
+define double @test_fadd_fmul_c_double(double %a, double %b, double %c) {
+; CHECK-LABEL: define double @test_fadd_fmul_c_double(
+; CHECK-SAME: double [[A:%.*]], double [[B:%.*]], double [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract double @llvm.fma.f64(double [[A]], double [[B]], double [[C]])
+; CHECK-NEXT:    ret double [[TMP1]]
+;
+  %mul = fmul contract double %a, %b
+  %add = fadd contract double %mul, %c
+  ret double %add
+}


        


More information about the llvm-commits mailing list