[Mlir-commits] [mlir] ee8b8d6 - [mlir][math] Uplift from arith to math.fma

Ivan Butygin llvmlistbot at llvm.org
Sun Jun 18 08:13:59 PDT 2023


Author: Ivan Butygin
Date: 2023-06-18T17:11:21+02:00
New Revision: ee8b8d6b58e1ff45d1bcbad38c6cf458f872b38d

URL: https://github.com/llvm/llvm-project/commit/ee8b8d6b58e1ff45d1bcbad38c6cf458f872b38d
DIFF: https://github.com/llvm/llvm-project/commit/ee8b8d6b58e1ff45d1bcbad38c6cf458f872b38d.diff

LOG: [mlir][math] Uplift from arith to math.fma

Add pass to uplift from arith mulf + addf ops to math.fma if fastmath flags allow it.

Differential Revision: https://reviews.llvm.org/D152633

Added: 
    mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/Math/Transforms/Passes.td
    mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp
    mlir/test/Dialect/Math/uplift-to-fma.mlir

Modified: 
    mlir/include/mlir/Dialect/Math/CMakeLists.txt
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Dialect/Math/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/CMakeLists.txt
index f33061b2d87cf..9f57627c321fb 100644
--- a/mlir/include/mlir/Dialect/Math/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Math/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..a37f069da46b0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Math)
+add_public_tablegen_target(MLIRMathTransformsIncGen)
+
+add_mlir_doc(Passes MathPasses ./ -gen-pass-doc)

diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 576ace34eac1c..817d6e1dae051 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -9,7 +9,17 @@
 #ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
 
+#include "mlir/Pass/Pass.h"
+
 namespace mlir {
+namespace math {
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+#define GEN_PASS_DECL_MATHUPLIFTTOFMA
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace math
 
 class RewritePatternSet;
 
@@ -34,6 +44,8 @@ void populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options = {});
 
+void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
new file mode 100644
index 0000000000000..d81a92b0371e3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -0,0 +1,22 @@
+//===-- Passes.td - Math pass definition file --------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES
+#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
+  let summary = "Uplift arith ops to math.fma.";
+  let description = [{
+    Uplift sequence of addf and mulf ops to math.fma if fastmath flags allows it.
+  }];
+  let dependentDialects = ["math::MathDialect"];
+}
+
+#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 57335fefd3f8c..c98b157c0cbdb 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
@@ -70,6 +71,7 @@ inline void registerAllPasses() {
   registerNVGPUPasses();
   registerSparseTensorPasses();
   LLVM::registerLLVMPasses();
+  math::registerMathPasses();
   memref::registerMemRefPasses();
   registerSCFPasses();
   registerShapePasses();

diff  --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index c23c2f1a9b650..2d446b453edc9 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,10 +2,14 @@ add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
   ExpandPatterns.cpp
   PolynomialApproximation.cpp
+  UpliftToFMA.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms
 
+  DEPENDS
+  MLIRMathTransformsIncGen
+
   LINK_LIBS PUBLIC
   MLIRArithDialect
   MLIRDialectUtils

diff  --git a/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp
new file mode 100644
index 0000000000000..6b0d0f5e7466f
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp
@@ -0,0 +1,79 @@
+//===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements uplifting from arith ops to math.fma.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHUPLIFTTOFMA
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+template <typename Op>
+static bool isValidForFMA(Op op) {
+  return static_cast<bool>(op.getFastmath() & arith::FastMathFlags::contract);
+}
+
+namespace {
+
+struct UpliftFma final : OpRewritePattern<arith::AddFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::AddFOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!isValidForFMA(op))
+      return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma");
+
+    Value c;
+    arith::MulFOp ab;
+    if ((ab = op.getLhs().getDefiningOp<arith::MulFOp>())) {
+      c = op.getRhs();
+    } else if ((ab = op.getRhs().getDefiningOp<arith::MulFOp>())) {
+      c = op.getLhs();
+    } else {
+      return rewriter.notifyMatchFailure(op, "no mulf op");
+    }
+
+    if (!isValidForFMA(ab))
+      return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma");
+
+    Value a = ab.getLhs();
+    Value b = ab.getRhs();
+    arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath();
+    rewriter.replaceOpWithNewOp<math::FmaOp>(op, a, b, c, fmf);
+    return success();
+  }
+};
+
+struct MathUpliftToFMA final
+    : math::impl::MathUpliftToFMABase<MathUpliftToFMA> {
+  using MathUpliftToFMABase::MathUpliftToFMABase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateUpliftToFMAPatterns(patterns);
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+void mlir::populateUpliftToFMAPatterns(RewritePatternSet &patterns) {
+  patterns.insert<UpliftFma>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Math/uplift-to-fma.mlir b/mlir/test/Dialect/Math/uplift-to-fma.mlir
new file mode 100644
index 0000000000000..071ddd05284f3
--- /dev/null
+++ b/mlir/test/Dialect/Math/uplift-to-fma.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s --split-input-file --math-uplift-to-fma | FileCheck %s
+
+// No uplifting without fastmath flags.
+// CHECK-LABEL: func @test
+//  CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+//       CHECK: %[[V1:.*]] = arith.mulf %[[ARG1]], %[[ARG2]]
+//       CHECK: %[[V2:.*]] = arith.addf %[[V1]], %[[ARG3]]
+//       CHECK: return %[[V2]]
+func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
+  %1 = arith.mulf %arg1, %arg2 : f32
+  %2 = arith.addf %1, %arg3 : f32
+  return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+//  CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+//       CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath<fast> : f32
+//       CHECK: return %[[RES]]
+func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
+  %1 = arith.mulf %arg1, %arg2 fastmath<fast> : f32
+  %2 = arith.addf %1, %arg3 fastmath<fast> : f32
+  return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+//  CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+//       CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath<contract> : f32
+//       CHECK: return %[[RES]]
+func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
+  %1 = arith.mulf %arg1, %arg2 fastmath<fast> : f32
+  %2 = arith.addf %arg3, %1 fastmath<contract> : f32
+  return %2 : f32
+}


        


More information about the Mlir-commits mailing list