[Mlir-commits] [mlir] ee8b8d6 - [mlir][math] Uplift from arith to math.fma
Ivan Butygin
llvmlistbot at llvm.org
Sun Jun 18 08:13:59 PDT 2023
Author: Ivan Butygin
Date: 2023-06-18T17:11:21+02:00
New Revision: ee8b8d6b58e1ff45d1bcbad38c6cf458f872b38d
URL: https://github.com/llvm/llvm-project/commit/ee8b8d6b58e1ff45d1bcbad38c6cf458f872b38d
DIFF: https://github.com/llvm/llvm-project/commit/ee8b8d6b58e1ff45d1bcbad38c6cf458f872b38d.diff
LOG: [mlir][math] Uplift from arith to math.fma
Add pass to uplift from arith mulf + addf ops to math.fma if fastmath flags allow it.
Differential Revision: https://reviews.llvm.org/D152633
Added:
mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/Math/Transforms/Passes.td
mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp
mlir/test/Dialect/Math/uplift-to-fma.mlir
Modified:
mlir/include/mlir/Dialect/Math/CMakeLists.txt
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/CMakeLists.txt
index f33061b2d87cf..9f57627c321fb 100644
--- a/mlir/include/mlir/Dialect/Math/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Math/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..a37f069da46b0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Math/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Math)
+add_public_tablegen_target(MLIRMathTransformsIncGen)
+
+add_mlir_doc(Passes MathPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 576ace34eac1c..817d6e1dae051 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -9,7 +9,17 @@
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
+#include "mlir/Pass/Pass.h"
+
namespace mlir {
+namespace math {
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+#define GEN_PASS_DECL_MATHUPLIFTTOFMA
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace math
class RewritePatternSet;
@@ -34,6 +44,8 @@ void populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options = {});
+void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
+
} // namespace mlir
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
new file mode 100644
index 0000000000000..d81a92b0371e3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -0,0 +1,22 @@
+//===-- Passes.td - Math pass definition file --------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES
+#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
+ let summary = "Uplift arith ops to math.fma.";
+ let description = [{
+ Uplift sequence of addf and mulf ops to math.fma if fastmath flags allows it.
+ }];
+ let dependentDialects = ["math::MathDialect"];
+}
+
+#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 57335fefd3f8c..c98b157c0cbdb 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -25,6 +25,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
@@ -70,6 +71,7 @@ inline void registerAllPasses() {
registerNVGPUPasses();
registerSparseTensorPasses();
LLVM::registerLLVMPasses();
+ math::registerMathPasses();
memref::registerMemRefPasses();
registerSCFPasses();
registerShapePasses();
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index c23c2f1a9b650..2d446b453edc9 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -2,10 +2,14 @@ add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
PolynomialApproximation.cpp
+ UpliftToFMA.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math/Transforms
+ DEPENDS
+ MLIRMathTransformsIncGen
+
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRDialectUtils
diff --git a/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp
new file mode 100644
index 0000000000000..6b0d0f5e7466f
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp
@@ -0,0 +1,79 @@
+//===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements uplifting from arith ops to math.fma.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHUPLIFTTOFMA
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+template <typename Op>
+static bool isValidForFMA(Op op) {
+ return static_cast<bool>(op.getFastmath() & arith::FastMathFlags::contract);
+}
+
+namespace {
+
+struct UpliftFma final : OpRewritePattern<arith::AddFOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::AddFOp op,
+ PatternRewriter &rewriter) const override {
+ if (!isValidForFMA(op))
+ return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma");
+
+ Value c;
+ arith::MulFOp ab;
+ if ((ab = op.getLhs().getDefiningOp<arith::MulFOp>())) {
+ c = op.getRhs();
+ } else if ((ab = op.getRhs().getDefiningOp<arith::MulFOp>())) {
+ c = op.getLhs();
+ } else {
+ return rewriter.notifyMatchFailure(op, "no mulf op");
+ }
+
+ if (!isValidForFMA(ab))
+ return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma");
+
+ Value a = ab.getLhs();
+ Value b = ab.getRhs();
+ arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath();
+ rewriter.replaceOpWithNewOp<math::FmaOp>(op, a, b, c, fmf);
+ return success();
+ }
+};
+
+struct MathUpliftToFMA final
+ : math::impl::MathUpliftToFMABase<MathUpliftToFMA> {
+ using MathUpliftToFMABase::MathUpliftToFMABase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateUpliftToFMAPatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::populateUpliftToFMAPatterns(RewritePatternSet &patterns) {
+ patterns.insert<UpliftFma>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Math/uplift-to-fma.mlir b/mlir/test/Dialect/Math/uplift-to-fma.mlir
new file mode 100644
index 0000000000000..071ddd05284f3
--- /dev/null
+++ b/mlir/test/Dialect/Math/uplift-to-fma.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s --split-input-file --math-uplift-to-fma | FileCheck %s
+
+// No uplifting without fastmath flags.
+// CHECK-LABEL: func @test
+// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+// CHECK: %[[V1:.*]] = arith.mulf %[[ARG1]], %[[ARG2]]
+// CHECK: %[[V2:.*]] = arith.addf %[[V1]], %[[ARG3]]
+// CHECK: return %[[V2]]
+func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
+ %1 = arith.mulf %arg1, %arg2 : f32
+ %2 = arith.addf %1, %arg3 : f32
+ return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+// CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath<fast> : f32
+// CHECK: return %[[RES]]
+func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
+ %1 = arith.mulf %arg1, %arg2 fastmath<fast> : f32
+ %2 = arith.addf %1, %arg3 fastmath<fast> : f32
+ return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @test
+// CHECK-SAME: (%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+// CHECK: %[[RES:.*]] = math.fma %[[ARG1]], %[[ARG2]], %[[ARG3]] fastmath<contract> : f32
+// CHECK: return %[[RES]]
+func.func @test(%arg1: f32, %arg2: f32, %arg3: f32) -> f32 {
+ %1 = arith.mulf %arg1, %arg2 fastmath<fast> : f32
+ %2 = arith.addf %arg3, %1 fastmath<contract> : f32
+ return %2 : f32
+}
More information about the Mlir-commits
mailing list