[Mlir-commits] [mlir] 740e2e9 - [mlir][math] Math expansion for math.tan
Robert Suderman
llvmlistbot at llvm.org
Tue Feb 28 17:22:45 PST 2023
Author: Robert Suderman
Date: 2023-03-01T01:13:54Z
New Revision: 740e2e908ca49118a6e1f27e380dbb3665a99cc8
URL: https://github.com/llvm/llvm-project/commit/740e2e908ca49118a6e1f27e380dbb3665a99cc8
DIFF: https://github.com/llvm/llvm-project/commit/740e2e908ca49118a6e1f27e380dbb3665a99cc8.diff
LOG: [mlir][math] Math expansion for math.tan
We can implement a polynomial approximation of math.tan by
decomposing to `math.sin` and `math.cos`. While it is not
technically a polynomial approximation it should be the most
straight forward approximation.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D144980
Added:
Modified:
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 9dbead1768e8e..a1801dd995a6f 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -14,6 +14,7 @@ namespace mlir {
class RewritePatternSet;
void populateExpandCtlzPattern(RewritePatternSet &patterns);
+void populateExpandTanPattern(RewritePatternSet &patterns);
void populateExpandTanhPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 98d76f3771c8b..364dd05c093ba 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -54,6 +55,17 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
return success();
}
+static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type type = operand.getType();
+ Value sin = b.create<math::SinOp>(type, operand);
+ Value cos = b.create<math::CosOp>(type, operand);
+ Value div = b.create<arith::DivFOp>(type, sin, cos);
+ rewriter.replaceOp(op, div);
+ return success();
+}
+
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
PatternRewriter &rewriter) {
auto operand = op.getOperand();
@@ -107,6 +119,10 @@ void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
patterns.add(convertCtlzOp);
}
+void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
+ patterns.add(convertTanOp);
+}
+
void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
patterns.add(convertTanhOp);
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index bfd33d5040457..49ac15fd97b7e 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -24,6 +24,19 @@ func.func @tanh(%arg: f32) -> f32 {
// -----
+// CHECK-LABEL: func @tan
+func.func @tan(%arg: f32) -> f32 {
+ %res = math.tan %arg : f32
+ return %res : f32
+}
+
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[SIN:.+]] = math.sin %[[ARG0]]
+// CHECK: %[[COS:.+]] = math.cos %[[ARG0]]
+// CEHCK: %[[DIV:.+]] = arith.div %[[SIN]] %[[COS]]
+
+// -----
+
// CHECK-LABEL: func @ctlz
func.func @ctlz(%arg: i32) -> i32 {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 4b13aa36e6f9f..28819518b2780 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -35,6 +35,7 @@ struct TestExpandMathPass
void TestExpandMathPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateExpandCtlzPattern(patterns);
+ populateExpandTanPattern(patterns);
populateExpandTanhPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
More information about the Mlir-commits
mailing list