[Mlir-commits] [mlir] 740e2e9 - [mlir][math] Math expansion for math.tan

Robert Suderman llvmlistbot at llvm.org
Tue Feb 28 17:22:45 PST 2023


Author: Robert Suderman
Date: 2023-03-01T01:13:54Z
New Revision: 740e2e908ca49118a6e1f27e380dbb3665a99cc8

URL: https://github.com/llvm/llvm-project/commit/740e2e908ca49118a6e1f27e380dbb3665a99cc8
DIFF: https://github.com/llvm/llvm-project/commit/740e2e908ca49118a6e1f27e380dbb3665a99cc8.diff

LOG: [mlir][math] Math expansion for math.tan

We can implement a polynomial approximation of math.tan by
decomposing to `math.sin` and `math.cos`. While it is not
technically a polynomial approximation it should be the most
straight forward approximation.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D144980

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 9dbead1768e8e..a1801dd995a6f 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -14,6 +14,7 @@ namespace mlir {
 class RewritePatternSet;
 
 void populateExpandCtlzPattern(RewritePatternSet &patterns);
+void populateExpandTanPattern(RewritePatternSet &patterns);
 void populateExpandTanhPattern(RewritePatternSet &patterns);
 
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 98d76f3771c8b..364dd05c093ba 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -54,6 +55,17 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type type = operand.getType();
+  Value sin = b.create<math::SinOp>(type, operand);
+  Value cos = b.create<math::CosOp>(type, operand);
+  Value div = b.create<arith::DivFOp>(type, sin, cos);
+  rewriter.replaceOp(op, div);
+  return success();
+}
+
 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
                                    PatternRewriter &rewriter) {
   auto operand = op.getOperand();
@@ -107,6 +119,10 @@ void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
   patterns.add(convertCtlzOp);
 }
 
+void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
+  patterns.add(convertTanOp);
+}
+
 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
   patterns.add(convertTanhOp);
 }

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index bfd33d5040457..49ac15fd97b7e 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -24,6 +24,19 @@ func.func @tanh(%arg: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @tan
+func.func @tan(%arg: f32) -> f32 {
+  %res = math.tan %arg : f32
+  return %res : f32
+}
+
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[SIN:.+]] = math.sin %[[ARG0]]
+// CHECK: %[[COS:.+]] = math.cos %[[ARG0]]
+// CEHCK: %[[DIV:.+]] = arith.div %[[SIN]] %[[COS]]
+
+// -----
+
 // CHECK-LABEL: func @ctlz
 func.func @ctlz(%arg: i32) -> i32 {
   // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32

diff  --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 4b13aa36e6f9f..28819518b2780 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -35,6 +35,7 @@ struct TestExpandMathPass
 void TestExpandMathPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   populateExpandCtlzPattern(patterns);
+  populateExpandTanPattern(patterns);
   populateExpandTanhPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }


        


More information about the Mlir-commits mailing list