[Mlir-commits] [mlir] [mlir][Math] Fix 0-rank support for PolynomialApproximation (PR #114826)

Kunwar Grover llvmlistbot at llvm.org
Tue Nov 5 01:06:12 PST 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/114826

>From 119a3a728baac479ff8c56d8ae3c5de73174a164 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 4 Nov 2024 16:27:29 +0000
Subject: [PATCH 1/3] Fix 0-dim support for PolynomialApproximation

---
 .../Transforms/PolynomialApproximation.cpp    | 15 ++++---
 .../Math/polynomial-approximation.mlir        | 41 +++++++++++++++++++
 2 files changed, 48 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index f0503555bfe4bf..953ff6c12e16c2 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -43,8 +43,7 @@ using namespace mlir::vector;
 struct VectorShape {
   ArrayRef<int64_t> sizes;
   ArrayRef<bool> scalableFlags;
-
-  bool empty() const { return sizes.empty(); }
+  bool scalar = true;
 };
 
 // Returns vector shape if the type is a vector. Returns an empty shape if it is
@@ -52,7 +51,8 @@ struct VectorShape {
 static VectorShape vectorShape(Type type) {
   auto vectorType = dyn_cast<VectorType>(type);
   return vectorType
-             ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
+             ? VectorShape{vectorType.getShape(), vectorType.getScalableDims(),
+                           /*scalar=*/false}
              : VectorShape{};
 }
 
@@ -67,9 +67,8 @@ static VectorShape vectorShape(Value value) {
 // Broadcasts scalar type into vector type (iff shape is non-scalar).
 static Type broadcast(Type type, VectorShape shape) {
   assert(!isa<VectorType>(type) && "must be scalar type");
-  return !shape.empty()
-             ? VectorType::get(shape.sizes, type, shape.scalableFlags)
-             : type;
+  return shape.scalar ? type
+                      : VectorType::get(shape.sizes, type, shape.scalableFlags);
 }
 
 // Broadcasts scalar value into vector (iff shape is non-scalar).
@@ -77,7 +76,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
                        VectorShape shape) {
   assert(!isa<VectorType>(value.getType()) && "must be scalar value");
   auto type = broadcast(value.getType(), shape);
-  return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
+  return shape.scalar ? value : builder.create<BroadcastOp>(type, value);
 }
 
 //----------------------------------------------------------------------------//
@@ -1609,7 +1608,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
   VectorShape shape = vectorShape(op.getOperand());
 
   // Only support already-vectorized rsqrt's.
-  if (shape.empty() || shape.sizes.back() % 8 != 0)
+  if (shape.sizes.empty() || shape.sizes.back() % 8 != 0)
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 93ecd67f14dd3d..81d071e6bbba36 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -894,6 +894,47 @@ func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> {
   return %11 : vector<4xf16>
 }
 
+// CHECK-LABEL: @math_zero_rank
+func.func @math_zero_rank(%arg0 : vector<f16>) -> vector<f16> {
+
+  // CHECK-NOT: math.atan
+  %0 = "math.atan"(%arg0) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.atan2
+  %1 = "math.atan2"(%0, %arg0) : (vector<f16>, vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.tanh
+  %2 = "math.tanh"(%1) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.log
+  %3 = "math.log"(%2) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.log2
+  %4 = "math.log2"(%3) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.log1p
+  %5 = "math.log1p"(%4) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.erf
+  %6 = "math.erf"(%5) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.exp
+  %7 = "math.exp"(%6) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.expm1
+  %8 = "math.expm1"(%7) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.cbrt
+  %9 = "math.cbrt"(%8) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.sin
+  %10 = "math.sin"(%9) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.cos
+  %11 = "math.cos"(%10) : (vector<f16>) -> vector<f16>
+
+  return %11 : vector<f16>
+}
 
 // AVX2-LABEL: @rsqrt_f16
 func.func @rsqrt_f16(%arg0 : vector<2x8xf16>) -> vector<2x8xf16> {

>From 796128c12270f4935394524f50da064509e5f694 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 4 Nov 2024 17:21:59 +0000
Subject: [PATCH 2/3] scalar -> std::optional

---
 .../Transforms/PolynomialApproximation.cpp    | 59 +++++++++----------
 1 file changed, 29 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 953ff6c12e16c2..c7e01817b04143 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -43,20 +43,18 @@ using namespace mlir::vector;
 struct VectorShape {
   ArrayRef<int64_t> sizes;
   ArrayRef<bool> scalableFlags;
-  bool scalar = true;
 };
 
 // Returns vector shape if the type is a vector. Returns an empty shape if it is
 // not a vector.
-static VectorShape vectorShape(Type type) {
+static std::optional<VectorShape> vectorShape(Type type) {
   auto vectorType = dyn_cast<VectorType>(type);
-  return vectorType
-             ? VectorShape{vectorType.getShape(), vectorType.getScalableDims(),
-                           /*scalar=*/false}
-             : VectorShape{};
+  return vectorType ? std::optional(VectorShape{vectorType.getShape(),
+                                                vectorType.getScalableDims()})
+                    : std::nullopt;
 }
 
-static VectorShape vectorShape(Value value) {
+static std::optional<VectorShape> vectorShape(Value value) {
   return vectorShape(value.getType());
 }
 
@@ -65,18 +63,18 @@ static VectorShape vectorShape(Value value) {
 //----------------------------------------------------------------------------//
 
 // Broadcasts scalar type into vector type (iff shape is non-scalar).
-static Type broadcast(Type type, VectorShape shape) {
+static Type broadcast(Type type, std::optional<VectorShape> shape) {
   assert(!isa<VectorType>(type) && "must be scalar type");
-  return shape.scalar ? type
-                      : VectorType::get(shape.sizes, type, shape.scalableFlags);
+  return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
+               : type;
 }
 
 // Broadcasts scalar value into vector (iff shape is non-scalar).
 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
-                       VectorShape shape) {
+                       std::optional<VectorShape> shape) {
   assert(!isa<VectorType>(value.getType()) && "must be scalar value");
   auto type = broadcast(value.getType(), shape);
-  return shape.scalar ? value : builder.create<BroadcastOp>(type, value);
+  return shape ? builder.create<BroadcastOp>(type, value) : value;
 }
 
 //----------------------------------------------------------------------------//
@@ -226,7 +224,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
                                      bool isPositive = false) {
   assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
-  VectorShape shape = vectorShape(arg);
+  std::optional<VectorShape> shape = vectorShape(arg);
 
   auto bcast = [&](Value value) -> Value {
     return broadcast(builder, value, shape);
@@ -266,7 +264,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
 // Computes exp2 for an i32 argument.
 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
   assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
-  VectorShape shape = vectorShape(arg);
+  std::optional<VectorShape> shape = vectorShape(arg);
 
   auto bcast = [&](Value value) -> Value {
     return broadcast(builder, value, shape);
@@ -292,7 +290,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
   Type elementType = getElementTypeOrSelf(x);
   assert((elementType.isF32() || elementType.isF16()) &&
          "x must be f32 or f16 type");
-  VectorShape shape = vectorShape(x);
+  std::optional<VectorShape> shape = vectorShape(x);
 
   if (coeffs.empty())
     return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -390,7 +388,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
   if (!getElementTypeOrSelf(operand).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  VectorShape shape = vectorShape(op.getOperand());
+  std::optional<VectorShape> shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   Value abs = builder.create<math::AbsFOp>(operand);
@@ -489,7 +487,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
-  VectorShape shape = vectorShape(op.getResult());
+  std::optional<VectorShape> shape = vectorShape(op.getResult());
 
   // Compute atan in the valid range.
   auto div = builder.create<arith::DivFOp>(y, x);
@@ -555,7 +553,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  VectorShape shape = vectorShape(op.getOperand());
+  std::optional<VectorShape> shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -643,7 +641,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  VectorShape shape = vectorShape(op.getOperand());
+  std::optional<VectorShape> shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -790,7 +788,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  VectorShape shape = vectorShape(op.getOperand());
+  std::optional<VectorShape> shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -845,7 +843,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
   if (!(elementType.isF32() || elementType.isF16()))
     return rewriter.notifyMatchFailure(op,
                                        "only f32 and f16 type is supported.");
-  VectorShape shape = vectorShape(operand);
+  std::optional<VectorShape> shape = vectorShape(operand);
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -940,7 +938,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
   if (!(elementType.isF32() || elementType.isF16()))
     return rewriter.notifyMatchFailure(op,
                                        "only f32 and f16 type is supported.");
-  VectorShape shape = vectorShape(operand);
+  std::optional<VectorShape> shape = vectorShape(operand);
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -1018,7 +1016,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
   if (!(elementType.isF32() || elementType.isF16()))
     return rewriter.notifyMatchFailure(op,
                                        "only f32 and f16 type is supported.");
-  VectorShape shape = vectorShape(operand);
+  std::optional<VectorShape> shape = vectorShape(operand);
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -1127,8 +1125,9 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
 
 namespace {
 
-Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
-                       Value value, float lowerBound, float upperBound) {
+Value clampWithNormals(ImplicitLocOpBuilder &builder,
+                       const std::optional<VectorShape> shape, Value value,
+                       float lowerBound, float upperBound) {
   assert(!std::isnan(lowerBound));
   assert(!std::isnan(upperBound));
 
@@ -1319,7 +1318,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  VectorShape shape = vectorShape(op.getOperand());
+  std::optional<VectorShape> shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -1389,7 +1388,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  VectorShape shape = vectorShape(op.getOperand());
+  std::optional<VectorShape> shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -1516,7 +1515,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
-  VectorShape shape = vectorShape(operand);
+  std::optional<VectorShape> shape = vectorShape(operand);
 
   Type floatTy = getElementTypeOrSelf(operand.getType());
   Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1605,10 +1604,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  VectorShape shape = vectorShape(op.getOperand());
+  std::optional<VectorShape> shape = vectorShape(op.getOperand());
 
   // Only support already-vectorized rsqrt's.
-  if (shape.sizes.empty() || shape.sizes.back() % 8 != 0)
+  if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);

>From 643d6ca735ba9cbf6b3f9f0263992bcff8283312 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 5 Nov 2024 08:39:58 +0000
Subject: [PATCH 3/3] address comments

---
 .../Math/Transforms/PolynomialApproximation.cpp       | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index c7e01817b04143..24c892f68b5031 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -45,13 +45,12 @@ struct VectorShape {
   ArrayRef<bool> scalableFlags;
 };
 
-// Returns vector shape if the type is a vector. Returns an empty shape if it is
-// not a vector.
+// Returns vector shape if the type is a vector, otherwise return nullopt.
 static std::optional<VectorShape> vectorShape(Type type) {
-  auto vectorType = dyn_cast<VectorType>(type);
-  return vectorType ? std::optional(VectorShape{vectorType.getShape(),
-                                                vectorType.getScalableDims()})
-                    : std::nullopt;
+  if (auto vectorType = dyn_cast<VectorType>(type)) {
+    return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
+  }
+  return std::nullopt;
 }
 
 static std::optional<VectorShape> vectorShape(Value value) {



More information about the Mlir-commits mailing list