[Mlir-commits] [mlir] [Math] Fix 0-rank support for PolynomialApproximation (PR #114826)

Kunwar Grover llvmlistbot at llvm.org
Mon Nov 4 08:29:16 PST 2024


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/114826

This patch disambiguates 0-rank vectors and scalars in PolynomialApproximation. This fixes a bug in PolynomialApproximation where 0-rank vectors would be treated as scalars and arguments would not be broadcasted properly.

>From 119a3a728baac479ff8c56d8ae3c5de73174a164 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 4 Nov 2024 16:27:29 +0000
Subject: [PATCH] Fix 0-dim support for PolynomialApproximation

---
 .../Transforms/PolynomialApproximation.cpp    | 15 ++++---
 .../Math/polynomial-approximation.mlir        | 41 +++++++++++++++++++
 2 files changed, 48 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index f0503555bfe4bf..953ff6c12e16c2 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -43,8 +43,7 @@ using namespace mlir::vector;
 struct VectorShape {
   ArrayRef<int64_t> sizes;
   ArrayRef<bool> scalableFlags;
-
-  bool empty() const { return sizes.empty(); }
+  bool scalar = true;
 };
 
 // Returns vector shape if the type is a vector. Returns an empty shape if it is
@@ -52,7 +51,8 @@ struct VectorShape {
 static VectorShape vectorShape(Type type) {
   auto vectorType = dyn_cast<VectorType>(type);
   return vectorType
-             ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
+             ? VectorShape{vectorType.getShape(), vectorType.getScalableDims(),
+                           /*scalar=*/false}
              : VectorShape{};
 }
 
@@ -67,9 +67,8 @@ static VectorShape vectorShape(Value value) {
 // Broadcasts scalar type into vector type (iff shape is non-scalar).
 static Type broadcast(Type type, VectorShape shape) {
   assert(!isa<VectorType>(type) && "must be scalar type");
-  return !shape.empty()
-             ? VectorType::get(shape.sizes, type, shape.scalableFlags)
-             : type;
+  return shape.scalar ? type
+                      : VectorType::get(shape.sizes, type, shape.scalableFlags);
 }
 
 // Broadcasts scalar value into vector (iff shape is non-scalar).
@@ -77,7 +76,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
                        VectorShape shape) {
   assert(!isa<VectorType>(value.getType()) && "must be scalar value");
   auto type = broadcast(value.getType(), shape);
-  return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
+  return shape.scalar ? value : builder.create<BroadcastOp>(type, value);
 }
 
 //----------------------------------------------------------------------------//
@@ -1609,7 +1608,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
   VectorShape shape = vectorShape(op.getOperand());
 
   // Only support already-vectorized rsqrt's.
-  if (shape.empty() || shape.sizes.back() % 8 != 0)
+  if (shape.sizes.empty() || shape.sizes.back() % 8 != 0)
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 93ecd67f14dd3d..81d071e6bbba36 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -894,6 +894,47 @@ func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> {
   return %11 : vector<4xf16>
 }
 
+// CHECK-LABEL: @math_zero_rank
+func.func @math_zero_rank(%arg0 : vector<f16>) -> vector<f16> {
+
+  // CHECK-NOT: math.atan
+  %0 = "math.atan"(%arg0) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.atan2
+  %1 = "math.atan2"(%0, %arg0) : (vector<f16>, vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.tanh
+  %2 = "math.tanh"(%1) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.log
+  %3 = "math.log"(%2) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.log2
+  %4 = "math.log2"(%3) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.log1p
+  %5 = "math.log1p"(%4) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.erf
+  %6 = "math.erf"(%5) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.exp
+  %7 = "math.exp"(%6) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.expm1
+  %8 = "math.expm1"(%7) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.cbrt
+  %9 = "math.cbrt"(%8) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.sin
+  %10 = "math.sin"(%9) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.cos
+  %11 = "math.cos"(%10) : (vector<f16>) -> vector<f16>
+
+  return %11 : vector<f16>
+}
 
 // AVX2-LABEL: @rsqrt_f16
 func.func @rsqrt_f16(%arg0 : vector<2x8xf16>) -> vector<2x8xf16> {



More information about the Mlir-commits mailing list