[Mlir-commits] [mlir] bbddd19 - [mlir][math] Expand coverage of atan2 expansion

Jacques Pienaar llvmlistbot at llvm.org
Tue Feb 8 15:00:50 PST 2022


Author: Jacques Pienaar
Date: 2022-02-08T15:00:39-08:00
New Revision: bbddd19ec723a15ea1558cce5e47cb2460fa8e24

URL: https://github.com/llvm/llvm-project/commit/bbddd19ec723a15ea1558cce5e47cb2460fa8e24
DIFF: https://github.com/llvm/llvm-project/commit/bbddd19ec723a15ea1558cce5e47cb2460fa8e24.diff

LOG: [mlir][math] Expand coverage of atan2 expansion

Reuse the higher precision F32 approximation for the F16 one (by expanding and
truncating). This is partly RFC as I'm not sure what the expectations are here
(e.g., these are only for F32 and should not be expanded, that reusing
higher-precision ones for lower precision is undesirable due to increased
compute cost and only approximations per exact type is preferred, or this is
appropriate [at least as fallback] but we need to see how to make it more
generic across all the patterns here).

Differential Revision: https://reviews.llvm.org/D118968

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index ed9170cdf55ac..9c8e413b0e55f 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -23,11 +23,15 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
 using namespace mlir::math;
@@ -279,6 +283,65 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
 }
 } // namespace
 
+//----------------------------------------------------------------------------//
+// Helper function/pattern to insert casts for reusing F32 bit expansion.
+//----------------------------------------------------------------------------//
+
+template <typename T>
+LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
+  // Conservatively only allow where the operand and result types are exactly 1.
+  Type origType = op->getResultTypes().front();
+  for (Type t : llvm::drop_begin(op->getResultTypes()))
+    if (origType != t)
+      return rewriter.notifyMatchFailure(op, "required all types to match");
+  for (Type t : op->getOperandTypes())
+    if (origType != t)
+      return rewriter.notifyMatchFailure(op, "required all types to match");
+
+  // Skip if already F32  or larger than 32 bits.
+  if (getElementTypeOrSelf(origType).isF32() ||
+      getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
+    return failure();
+
+  // Create F32 equivalent type.
+  Type newType;
+  if (auto shaped = origType.dyn_cast<ShapedType>()) {
+    newType = shaped.clone(rewriter.getF32Type());
+  } else if (origType.isa<FloatType>()) {
+    newType = rewriter.getF32Type();
+  } else {
+    return rewriter.notifyMatchFailure(op,
+                                       "unable to find F32 equivalent type");
+  }
+
+  Location loc = op->getLoc();
+  SmallVector<Value> operands;
+  for (auto operand : op->getOperands())
+    operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
+  auto result = rewriter.create<math::Atan2Op>(loc, newType, operands);
+  rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
+  return success();
+}
+
+namespace {
+// Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
+// op.
+// TODO: Consider revising to avoid adding multiple casts for a subgraph that is
+// all in lower precision. Currently this is only fallback support and performs
+// simplistic casting.
+template <typename T>
+struct ReuseF32Expansion : public OpRewritePattern<T> {
+public:
+  using OpRewritePattern<T>::OpRewritePattern;
+  LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
+    static_assert(
+        T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
+        "requires same operands and result types");
+    return insertCasts<T>(op, rewriter);
+  }
+};
+} // namespace
+
 //----------------------------------------------------------------------------//
 // AtanOp approximation.
 //----------------------------------------------------------------------------//
@@ -1209,6 +1272,7 @@ void mlir::populateMathPolynomialApproximationPatterns(
   patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
                LogApproximation, Log2Approximation, Log1pApproximation,
                ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
+               ReuseF32Expansion<math::Atan2Op>,
                SinAndCosApproximation<true, math::SinOp>,
                SinAndCosApproximation<false, math::CosOp>>(
       patterns.getContext());

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 8655c1cf61b87..e8c09ba2ca0ac 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -542,7 +542,9 @@ func @atan_scalar(%arg0: f32) -> f32 {
 // CHECK-DAG:  %[[N3:.+]] = arith.constant -0.0106783099
 // CHECK-DAG:  %[[N4:.+]] = arith.constant 1.00209987
 // CHECK-DAG:  %[[HALF_PI:.+]] = arith.constant 1.57079637
-// CHECK-DAG:  %[[RATIO:.+]] = arith.divf %arg0, %arg1
+// CHECK-DAG:  %[[ARG0:.+]] = arith.extf %arg0 : f16 to f32
+// CHECK-DAG:  %[[ARG1:.+]] = arith.extf %arg1 : f16 to f32
+// CHECK-DAG:  %[[RATIO:.+]] = arith.divf %[[ARG0]], %[[ARG1]]
 // CHECK-DAG:  %[[ABS:.+]] = math.abs %[[RATIO]]
 // CHECK-DAG:  %[[DIV:.+]] = arith.divf %cst, %[[ABS]]
 // CHECK-DAG:  %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]]
@@ -562,30 +564,31 @@ func @atan_scalar(%arg0: f32) -> f32 {
 // CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]]
 // CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]]
 // CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]]
-// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %arg1, %[[ZERO]]
+// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ZERO]]
 // CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]]
 
 // Handle PI / 2 edge case:
-// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %arg1, %[[ZERO]]
-// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %arg0, %[[ZERO]]
+// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %[[ARG1]], %[[ZERO]]
+// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %[[ARG0]], %[[ZERO]]
 // CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]]
 // CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]]
 
 // Handle -PI / 2 edge case:
 // CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637
-// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %arg0, %[[ZERO]]
+// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %[[ARG0]], %[[ZERO]]
 // CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]]
 // CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]]
 
 // Handle Nan edgecase:
-// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %arg0, %[[ZERO]]
+// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %[[ARG0]], %[[ZERO]]
 // CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]]
 // CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000
 // CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]]
-// CHECK: return %[[EDGE3]]
+// CHECK: %[[RET:.+]] = arith.truncf %[[EDGE3]]
+// CHECK: return %[[RET]]
 
-func @atan2_scalar(%arg0: f32, %arg1: f32) -> f32 {
-  %0 = math.atan2 %arg0, %arg1 : f32
-  return %0 : f32
+func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 {
+  %0 = math.atan2 %arg0, %arg1 : f16
+  return %0 : f16
 }
 


        


More information about the Mlir-commits mailing list