[Mlir-commits] [mlir] 57e1943 - [mlir] Add support for non-f32 polynomial approximation
Robert Suderman
llvmlistbot at llvm.org
Mon Mar 27 14:58:52 PDT 2023
Author: Robert Suderman
Date: 2023-03-27T21:57:57Z
New Revision: 57e1943e8f8f1b28460678d1c25eaa36ba09b59b
URL: https://github.com/llvm/llvm-project/commit/57e1943e8f8f1b28460678d1c25eaa36ba09b59b
DIFF: https://github.com/llvm/llvm-project/commit/57e1943e8f8f1b28460678d1c25eaa36ba09b59b.diff
LOG: [mlir] Add support for non-f32 polynomial approximation
Polynomial approximations assume F32 values. We can convert all non-f32
cases to operate on f32s with intermediate casts.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D146677
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 6cbb50b0425b3..8f0f8dab77e22 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -331,7 +331,8 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
SmallVector<Value> operands;
for (auto operand : op->getOperands())
operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
- auto result = rewriter.create<math::Atan2Op>(loc, newType, operands);
+ auto result =
+ rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
return success();
}
@@ -1381,13 +1382,24 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {
+ // Patterns for leveraging existing f32 lowerings on other data types.
+ patterns
+ .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
+ ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
+ ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
+ ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
+ ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
+ ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
+ patterns.getContext());
+
patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
LogApproximation, Log2Approximation, Log1pApproximation,
ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- CbrtApproximation, ReuseF32Expansion<math::Atan2Op>,
- SinAndCosApproximation<true, math::SinOp>,
+ CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
SinAndCosApproximation<false, math::CosOp>>(
patterns.getContext());
- if (options.enableAvx2)
- patterns.add<RsqrtApproximation>(patterns.getContext());
+ if (options.enableAvx2) {
+ patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
+ patterns.getContext());
+ }
}
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 4b490e4ea990c..b87d4b5ecdbc6 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -642,4 +642,55 @@ func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
func.func @cbrt_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
%0 = "math.cbrt"(%arg0) : (vector<4xf32>) -> vector<4xf32>
func.return %0 : vector<4xf32>
-}
\ No newline at end of file
+}
+
+
+// CHECK-LABEL: @math_f16
+func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+
+ // CHECK-NOT: math.atan
+ %0 = "math.atan"(%arg0) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.atan2
+ %1 = "math.atan2"(%0, %arg0) : (vector<4xf16>, vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.tanh
+ %2 = "math.tanh"(%1) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.log
+ %3 = "math.log"(%2) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.log2
+ %4 = "math.log2"(%3) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.log1p
+ %5 = "math.log1p"(%4) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.erf
+ %6 = "math.erf"(%5) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.exp
+ %7 = "math.exp"(%6) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.expm1
+ %8 = "math.expm1"(%7) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.cbrt
+ %9 = "math.cbrt"(%8) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.sin
+ %10 = "math.sin"(%9) : (vector<4xf16>) -> vector<4xf16>
+
+ // CHECK-NOT: math.cos
+ %11 = "math.cos"(%10) : (vector<4xf16>) -> vector<4xf16>
+
+ return %11 : vector<4xf16>
+}
+
+
+// AVX2-LABEL: @rsqrt_f16
+func.func @rsqrt_f16(%arg0 : vector<2x8xf16>) -> vector<2x8xf16> {
+ // AVX2-NOT: math.rsqrt
+ %0 = "math.rsqrt"(%arg0) : (vector<2x8xf16>) -> vector<2x8xf16>
+ return %0 : vector<2x8xf16>
+}
More information about the Mlir-commits
mailing list