[Mlir-commits] [mlir] a7c2102 - [mlir][math]Expand Fused math.fmaf to a multiply-add

Robert Suderman llvmlistbot at llvm.org
Fri Apr 7 15:16:21 PDT 2023


Author: Balaji V. Iyer
Date: 2023-04-07T22:14:56Z
New Revision: a7c2102d988b2ae2214f1483d2b4066955b4dc98

URL: https://github.com/llvm/llvm-project/commit/a7c2102d988b2ae2214f1483d2b4066955b4dc98
DIFF: https://github.com/llvm/llvm-project/commit/a7c2102d988b2ae2214f1483d2b4066955b4dc98.diff

LOG: [mlir][math]Expand Fused math.fmaf to a multiply-add

Fused multiply and add are being pushed directly to the libm. This is problematic
for situations where libm is not available. This patch will break down a fused multiply and
add into a multiply followed by an add.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D147811

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index a1801dd995a6f..597618033cd0a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -16,6 +16,7 @@ class RewritePatternSet;
 void populateExpandCtlzPattern(RewritePatternSet &patterns);
 void populateExpandTanPattern(RewritePatternSet &patterns);
 void populateExpandTanhPattern(RewritePatternSet &patterns);
+void populateExpandFmaFPattern(RewritePatternSet &patterns);
 
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 91aef84348a96..f3e807e13102b 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -90,6 +90,18 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  Value operandC = op.getOperand(2);
+  Type type = op.getType();
+  Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
+  Value add = b.create<arith::AddFOp>(type, mult, operandC);
+  rewriter.replaceOp(op, add);
+  return success();
+}
+
 // Converts math.ctlz to scf and arith operations. This is done
 // by performing a binary search on the bits.
 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -145,3 +157,7 @@ void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
   patterns.add(convertTanhOp);
 }
+
+void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertFmaFOp);
+}

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index a66ea082f1ef3..cc6c401d0c356 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -119,3 +119,15 @@ func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> {
 
 // CHECK-LABEL: @ctlz_vector
 // CHECK-NOT: math.ctlz
+
+// -----
+
+// CHECK-LABEL:    func @fmaf_func
+// CHECK-SAME:     ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64, [[ARG2:%.+]]: f64) -> f64
+func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
+  // CHECK-NEXT:     [[MULF:%.+]] = arith.mulf [[ARG0]], [[ARG1]]
+  // CHECK-NEXT:     [[ADDF:%.+]] = arith.addf [[MULF]], [[ARG2]]
+  // CHECK-NEXT:     return [[ADDF]]
+  %ret = math.fma %a, %b, %c : f64
+  return %ret : f64
+}

diff  --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 29b862e410c0f..12bc3afe7ef75 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -39,6 +39,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandCtlzPattern(patterns);
   populateExpandTanPattern(patterns);
   populateExpandTanhPattern(patterns);
+  populateExpandFmaFPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 


        


More information about the Mlir-commits mailing list