[Mlir-commits] [mlir] [Math] Fix 0-rank support for PolynomialApproximation (PR #114826)
Kunwar Grover
llvmlistbot at llvm.org
Mon Nov 4 08:29:16 PST 2024
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/114826
This patch disambiguates 0-rank vectors and scalars in PolynomialApproximation. This fixes a bug in PolynomialApproximation where 0-rank vectors would be treated as scalars and arguments would not be broadcasted properly.
>From 119a3a728baac479ff8c56d8ae3c5de73174a164 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 4 Nov 2024 16:27:29 +0000
Subject: [PATCH] Fix 0-dim support for PolynomialApproximation
---
.../Transforms/PolynomialApproximation.cpp | 15 ++++---
.../Math/polynomial-approximation.mlir | 41 +++++++++++++++++++
2 files changed, 48 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index f0503555bfe4bf..953ff6c12e16c2 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -43,8 +43,7 @@ using namespace mlir::vector;
struct VectorShape {
ArrayRef<int64_t> sizes;
ArrayRef<bool> scalableFlags;
-
- bool empty() const { return sizes.empty(); }
+ bool scalar = true;
};
// Returns vector shape if the type is a vector. Returns an empty shape if it is
@@ -52,7 +51,8 @@ struct VectorShape {
static VectorShape vectorShape(Type type) {
auto vectorType = dyn_cast<VectorType>(type);
return vectorType
- ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
+ ? VectorShape{vectorType.getShape(), vectorType.getScalableDims(),
+ /*scalar=*/false}
: VectorShape{};
}
@@ -67,9 +67,8 @@ static VectorShape vectorShape(Value value) {
// Broadcasts scalar type into vector type (iff shape is non-scalar).
static Type broadcast(Type type, VectorShape shape) {
assert(!isa<VectorType>(type) && "must be scalar type");
- return !shape.empty()
- ? VectorType::get(shape.sizes, type, shape.scalableFlags)
- : type;
+ return shape.scalar ? type
+ : VectorType::get(shape.sizes, type, shape.scalableFlags);
}
// Broadcasts scalar value into vector (iff shape is non-scalar).
@@ -77,7 +76,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
VectorShape shape) {
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
auto type = broadcast(value.getType(), shape);
- return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
+ return shape.scalar ? value : builder.create<BroadcastOp>(type, value);
}
//----------------------------------------------------------------------------//
@@ -1609,7 +1608,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
VectorShape shape = vectorShape(op.getOperand());
// Only support already-vectorized rsqrt's.
- if (shape.empty() || shape.sizes.back() % 8 != 0)
+ if (shape.sizes.empty() || shape.sizes.back() % 8 != 0)
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 93ecd67f14dd3d..81d071e6bbba36 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -894,6 +894,47 @@ func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> {
return %11 : vector<4xf16>
}
+// CHECK-LABEL: @math_zero_rank
+func.func @math_zero_rank(%arg0 : vector<f16>) -> vector<f16> {
+
+ // CHECK-NOT: math.atan
+ %0 = "math.atan"(%arg0) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.atan2
+ %1 = "math.atan2"(%0, %arg0) : (vector<f16>, vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.tanh
+ %2 = "math.tanh"(%1) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.log
+ %3 = "math.log"(%2) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.log2
+ %4 = "math.log2"(%3) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.log1p
+ %5 = "math.log1p"(%4) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.erf
+ %6 = "math.erf"(%5) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.exp
+ %7 = "math.exp"(%6) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.expm1
+ %8 = "math.expm1"(%7) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.cbrt
+ %9 = "math.cbrt"(%8) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.sin
+ %10 = "math.sin"(%9) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.cos
+ %11 = "math.cos"(%10) : (vector<f16>) -> vector<f16>
+
+ return %11 : vector<f16>
+}
// AVX2-LABEL: @rsqrt_f16
func.func @rsqrt_f16(%arg0 : vector<2x8xf16>) -> vector<2x8xf16> {
More information about the Mlir-commits
mailing list