[Mlir-commits] [mlir] 57e1943 - [mlir] Add support for non-f32 polynomial approximation

Robert Suderman llvmlistbot at llvm.org
Mon Mar 27 14:58:52 PDT 2023


Author: Robert Suderman
Date: 2023-03-27T21:57:57Z
New Revision: 57e1943e8f8f1b28460678d1c25eaa36ba09b59b

URL: https://github.com/llvm/llvm-project/commit/57e1943e8f8f1b28460678d1c25eaa36ba09b59b
DIFF: https://github.com/llvm/llvm-project/commit/57e1943e8f8f1b28460678d1c25eaa36ba09b59b.diff

LOG: [mlir] Add support for non-f32 polynomial approximation

Polynomial approximations assume F32 values. We can convert all non-f32
cases to operate on f32s with intermediate casts.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D146677

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 6cbb50b0425b3..8f0f8dab77e22 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -331,7 +331,8 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
   SmallVector<Value> operands;
   for (auto operand : op->getOperands())
     operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
-  auto result = rewriter.create<math::Atan2Op>(loc, newType, operands);
+  auto result =
+      rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
   rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
   return success();
 }
@@ -1381,13 +1382,24 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
 void mlir::populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options) {
+  // Patterns for leveraging existing f32 lowerings on other data types.
+  patterns
+      .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
+           ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
+           ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
+           ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
+           ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
+           ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
+          patterns.getContext());
+
   patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
                LogApproximation, Log2Approximation, Log1pApproximation,
                ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-               CbrtApproximation, ReuseF32Expansion<math::Atan2Op>,
-               SinAndCosApproximation<true, math::SinOp>,
+               CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
                SinAndCosApproximation<false, math::CosOp>>(
       patterns.getContext());
-  if (options.enableAvx2)
-    patterns.add<RsqrtApproximation>(patterns.getContext());
+  if (options.enableAvx2) {
+    patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
+        patterns.getContext());
+  }
 }

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 4b490e4ea990c..b87d4b5ecdbc6 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -642,4 +642,55 @@ func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
 func.func @cbrt_vector(%arg0: vector<4xf32>) -> vector<4xf32> {
   %0 = "math.cbrt"(%arg0) : (vector<4xf32>) -> vector<4xf32>
   func.return %0 : vector<4xf32>
-}
\ No newline at end of file
+}
+
+
+// CHECK-LABEL: @math_f16
+func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+
+  // CHECK-NOT: math.atan
+  %0 = "math.atan"(%arg0) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.atan2
+  %1 = "math.atan2"(%0, %arg0) : (vector<4xf16>, vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.tanh
+  %2 = "math.tanh"(%1) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.log
+  %3 = "math.log"(%2) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.log2
+  %4 = "math.log2"(%3) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.log1p
+  %5 = "math.log1p"(%4) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.erf
+  %6 = "math.erf"(%5) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.exp
+  %7 = "math.exp"(%6) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.expm1
+  %8 = "math.expm1"(%7) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.cbrt
+  %9 = "math.cbrt"(%8) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.sin
+  %10 = "math.sin"(%9) : (vector<4xf16>) -> vector<4xf16>
+
+  // CHECK-NOT: math.cos
+  %11 = "math.cos"(%10) : (vector<4xf16>) -> vector<4xf16>
+
+  return %11 : vector<4xf16>
+}
+
+
+// AVX2-LABEL: @rsqrt_f16
+func.func @rsqrt_f16(%arg0 : vector<2x8xf16>) -> vector<2x8xf16> {
+  // AVX2-NOT: math.rsqrt
+  %0 = "math.rsqrt"(%arg0) : (vector<2x8xf16>) -> vector<2x8xf16>
+  return %0 : vector<2x8xf16>
+}


        


More information about the Mlir-commits mailing list