[Mlir-commits] [mlir] [Math] Fix 0-rank support for PolynomialApproximation (PR #114826)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 4 08:29:51 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

This patch disambiguates 0-rank vectors and scalars in PolynomialApproximation. This fixes a bug in PolynomialApproximation where 0-rank vectors would be treated as scalars and arguments would not be broadcasted properly.

---
Full diff: https://github.com/llvm/llvm-project/pull/114826.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+7-8) 
- (modified) mlir/test/Dialect/Math/polynomial-approximation.mlir (+41) 


``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index f0503555bfe4bf..953ff6c12e16c2 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -43,8 +43,7 @@ using namespace mlir::vector;
 struct VectorShape {
   ArrayRef<int64_t> sizes;
   ArrayRef<bool> scalableFlags;
-
-  bool empty() const { return sizes.empty(); }
+  bool scalar = true;
 };
 
 // Returns vector shape if the type is a vector. Returns an empty shape if it is
@@ -52,7 +51,8 @@ struct VectorShape {
 static VectorShape vectorShape(Type type) {
   auto vectorType = dyn_cast<VectorType>(type);
   return vectorType
-             ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
+             ? VectorShape{vectorType.getShape(), vectorType.getScalableDims(),
+                           /*scalar=*/false}
              : VectorShape{};
 }
 
@@ -67,9 +67,8 @@ static VectorShape vectorShape(Value value) {
 // Broadcasts scalar type into vector type (iff shape is non-scalar).
 static Type broadcast(Type type, VectorShape shape) {
   assert(!isa<VectorType>(type) && "must be scalar type");
-  return !shape.empty()
-             ? VectorType::get(shape.sizes, type, shape.scalableFlags)
-             : type;
+  return shape.scalar ? type
+                      : VectorType::get(shape.sizes, type, shape.scalableFlags);
 }
 
 // Broadcasts scalar value into vector (iff shape is non-scalar).
@@ -77,7 +76,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
                        VectorShape shape) {
   assert(!isa<VectorType>(value.getType()) && "must be scalar value");
   auto type = broadcast(value.getType(), shape);
-  return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
+  return shape.scalar ? value : builder.create<BroadcastOp>(type, value);
 }
 
 //----------------------------------------------------------------------------//
@@ -1609,7 +1608,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
   VectorShape shape = vectorShape(op.getOperand());
 
   // Only support already-vectorized rsqrt's.
-  if (shape.empty() || shape.sizes.back() % 8 != 0)
+  if (shape.sizes.empty() || shape.sizes.back() % 8 != 0)
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 93ecd67f14dd3d..81d071e6bbba36 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -894,6 +894,47 @@ func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> {
   return %11 : vector<4xf16>
 }
 
+// CHECK-LABEL: @math_zero_rank
+func.func @math_zero_rank(%arg0 : vector<f16>) -> vector<f16> {
+
+  // CHECK-NOT: math.atan
+  %0 = "math.atan"(%arg0) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.atan2
+  %1 = "math.atan2"(%0, %arg0) : (vector<f16>, vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.tanh
+  %2 = "math.tanh"(%1) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.log
+  %3 = "math.log"(%2) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.log2
+  %4 = "math.log2"(%3) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.log1p
+  %5 = "math.log1p"(%4) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.erf
+  %6 = "math.erf"(%5) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.exp
+  %7 = "math.exp"(%6) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.expm1
+  %8 = "math.expm1"(%7) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.cbrt
+  %9 = "math.cbrt"(%8) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.sin
+  %10 = "math.sin"(%9) : (vector<f16>) -> vector<f16>
+
+  // CHECK-NOT: math.cos
+  %11 = "math.cos"(%10) : (vector<f16>) -> vector<f16>
+
+  return %11 : vector<f16>
+}
 
 // AVX2-LABEL: @rsqrt_f16
 func.func @rsqrt_f16(%arg0 : vector<2x8xf16>) -> vector<2x8xf16> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/114826


More information about the Mlir-commits mailing list