[Mlir-commits] [mlir] a7c2102 - [mlir][math]Expand Fused math.fmaf to a multiply-add
Robert Suderman
llvmlistbot at llvm.org
Fri Apr 7 15:16:21 PDT 2023
Author: Balaji V. Iyer
Date: 2023-04-07T22:14:56Z
New Revision: a7c2102d988b2ae2214f1483d2b4066955b4dc98
URL: https://github.com/llvm/llvm-project/commit/a7c2102d988b2ae2214f1483d2b4066955b4dc98
DIFF: https://github.com/llvm/llvm-project/commit/a7c2102d988b2ae2214f1483d2b4066955b4dc98.diff
LOG: [mlir][math]Expand Fused math.fmaf to a multiply-add
Fused multiply and add are being pushed directly to the libm. This is problematic
for situations where libm is not available. This patch will break down a fused multiply and
add into a multiply followed by an add.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D147811
Added:
Modified:
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index a1801dd995a6f..597618033cd0a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -16,6 +16,7 @@ class RewritePatternSet;
void populateExpandCtlzPattern(RewritePatternSet &patterns);
void populateExpandTanPattern(RewritePatternSet &patterns);
void populateExpandTanhPattern(RewritePatternSet &patterns);
+void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 91aef84348a96..f3e807e13102b 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -90,6 +90,18 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
return success();
}
+static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operandA = op.getOperand(0);
+ Value operandB = op.getOperand(1);
+ Value operandC = op.getOperand(2);
+ Type type = op.getType();
+ Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
+ Value add = b.create<arith::AddFOp>(type, mult, operandC);
+ rewriter.replaceOp(op, add);
+ return success();
+}
+
// Converts math.ctlz to scf and arith operations. This is done
// by performing a binary search on the bits.
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -145,3 +157,7 @@ void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
patterns.add(convertTanhOp);
}
+
+void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
+ patterns.add(convertFmaFOp);
+}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index a66ea082f1ef3..cc6c401d0c356 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -119,3 +119,15 @@ func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> {
// CHECK-LABEL: @ctlz_vector
// CHECK-NOT: math.ctlz
+
+// -----
+
+// CHECK-LABEL: func @fmaf_func
+// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64, [[ARG2:%.+]]: f64) -> f64
+func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
+ // CHECK-NEXT: [[MULF:%.+]] = arith.mulf [[ARG0]], [[ARG1]]
+ // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[MULF]], [[ARG2]]
+ // CHECK-NEXT: return [[ADDF]]
+ %ret = math.fma %a, %b, %c : f64
+ return %ret : f64
+}
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 29b862e410c0f..12bc3afe7ef75 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -39,6 +39,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandCtlzPattern(patterns);
populateExpandTanPattern(patterns);
populateExpandTanhPattern(patterns);
+ populateExpandFmaFPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
More information about the Mlir-commits
mailing list