[Mlir-commits] [mlir] [mlir][Math] Fix 0-rank support for PolynomialApproximation (PR #114826)
Kunwar Grover
llvmlistbot at llvm.org
Mon Nov 4 09:22:19 PST 2024
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/114826
>From 119a3a728baac479ff8c56d8ae3c5de73174a164 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 4 Nov 2024 16:27:29 +0000
Subject: [PATCH 1/2] Fix 0-dim support for PolynomialApproximation
---
.../Transforms/PolynomialApproximation.cpp | 15 ++++---
.../Math/polynomial-approximation.mlir | 41 +++++++++++++++++++
2 files changed, 48 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index f0503555bfe4bf..953ff6c12e16c2 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -43,8 +43,7 @@ using namespace mlir::vector;
struct VectorShape {
ArrayRef<int64_t> sizes;
ArrayRef<bool> scalableFlags;
-
- bool empty() const { return sizes.empty(); }
+ bool scalar = true;
};
// Returns vector shape if the type is a vector. Returns an empty shape if it is
@@ -52,7 +51,8 @@ struct VectorShape {
static VectorShape vectorShape(Type type) {
auto vectorType = dyn_cast<VectorType>(type);
return vectorType
- ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
+ ? VectorShape{vectorType.getShape(), vectorType.getScalableDims(),
+ /*scalar=*/false}
: VectorShape{};
}
@@ -67,9 +67,8 @@ static VectorShape vectorShape(Value value) {
// Broadcasts scalar type into vector type (iff shape is non-scalar).
static Type broadcast(Type type, VectorShape shape) {
assert(!isa<VectorType>(type) && "must be scalar type");
- return !shape.empty()
- ? VectorType::get(shape.sizes, type, shape.scalableFlags)
- : type;
+ return shape.scalar ? type
+ : VectorType::get(shape.sizes, type, shape.scalableFlags);
}
// Broadcasts scalar value into vector (iff shape is non-scalar).
@@ -77,7 +76,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
VectorShape shape) {
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
auto type = broadcast(value.getType(), shape);
- return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
+ return shape.scalar ? value : builder.create<BroadcastOp>(type, value);
}
//----------------------------------------------------------------------------//
@@ -1609,7 +1608,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
VectorShape shape = vectorShape(op.getOperand());
// Only support already-vectorized rsqrt's.
- if (shape.empty() || shape.sizes.back() % 8 != 0)
+ if (shape.sizes.empty() || shape.sizes.back() % 8 != 0)
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 93ecd67f14dd3d..81d071e6bbba36 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -894,6 +894,47 @@ func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> {
return %11 : vector<4xf16>
}
+// CHECK-LABEL: @math_zero_rank
+func.func @math_zero_rank(%arg0 : vector<f16>) -> vector<f16> {
+
+ // CHECK-NOT: math.atan
+ %0 = "math.atan"(%arg0) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.atan2
+ %1 = "math.atan2"(%0, %arg0) : (vector<f16>, vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.tanh
+ %2 = "math.tanh"(%1) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.log
+ %3 = "math.log"(%2) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.log2
+ %4 = "math.log2"(%3) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.log1p
+ %5 = "math.log1p"(%4) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.erf
+ %6 = "math.erf"(%5) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.exp
+ %7 = "math.exp"(%6) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.expm1
+ %8 = "math.expm1"(%7) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.cbrt
+ %9 = "math.cbrt"(%8) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.sin
+ %10 = "math.sin"(%9) : (vector<f16>) -> vector<f16>
+
+ // CHECK-NOT: math.cos
+ %11 = "math.cos"(%10) : (vector<f16>) -> vector<f16>
+
+ return %11 : vector<f16>
+}
// AVX2-LABEL: @rsqrt_f16
func.func @rsqrt_f16(%arg0 : vector<2x8xf16>) -> vector<2x8xf16> {
>From 796128c12270f4935394524f50da064509e5f694 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 4 Nov 2024 17:21:59 +0000
Subject: [PATCH 2/2] scalar -> std::optional
---
.../Transforms/PolynomialApproximation.cpp | 59 +++++++++----------
1 file changed, 29 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 953ff6c12e16c2..c7e01817b04143 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -43,20 +43,18 @@ using namespace mlir::vector;
struct VectorShape {
ArrayRef<int64_t> sizes;
ArrayRef<bool> scalableFlags;
- bool scalar = true;
};
// Returns vector shape if the type is a vector. Returns an empty shape if it is
// not a vector.
-static VectorShape vectorShape(Type type) {
+static std::optional<VectorShape> vectorShape(Type type) {
auto vectorType = dyn_cast<VectorType>(type);
- return vectorType
- ? VectorShape{vectorType.getShape(), vectorType.getScalableDims(),
- /*scalar=*/false}
- : VectorShape{};
+ return vectorType ? std::optional(VectorShape{vectorType.getShape(),
+ vectorType.getScalableDims()})
+ : std::nullopt;
}
-static VectorShape vectorShape(Value value) {
+static std::optional<VectorShape> vectorShape(Value value) {
return vectorShape(value.getType());
}
@@ -65,18 +63,18 @@ static VectorShape vectorShape(Value value) {
//----------------------------------------------------------------------------//
// Broadcasts scalar type into vector type (iff shape is non-scalar).
-static Type broadcast(Type type, VectorShape shape) {
+static Type broadcast(Type type, std::optional<VectorShape> shape) {
assert(!isa<VectorType>(type) && "must be scalar type");
- return shape.scalar ? type
- : VectorType::get(shape.sizes, type, shape.scalableFlags);
+ return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
+ : type;
}
// Broadcasts scalar value into vector (iff shape is non-scalar).
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
- VectorShape shape) {
+ std::optional<VectorShape> shape) {
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
auto type = broadcast(value.getType(), shape);
- return shape.scalar ? value : builder.create<BroadcastOp>(type, value);
+ return shape ? builder.create<BroadcastOp>(type, value) : value;
}
//----------------------------------------------------------------------------//
@@ -226,7 +224,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
bool isPositive = false) {
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
- VectorShape shape = vectorShape(arg);
+ std::optional<VectorShape> shape = vectorShape(arg);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
@@ -266,7 +264,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
// Computes exp2 for an i32 argument.
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
- VectorShape shape = vectorShape(arg);
+ std::optional<VectorShape> shape = vectorShape(arg);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
@@ -292,7 +290,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
Type elementType = getElementTypeOrSelf(x);
assert((elementType.isF32() || elementType.isF16()) &&
"x must be f32 or f16 type");
- VectorShape shape = vectorShape(x);
+ std::optional<VectorShape> shape = vectorShape(x);
if (coeffs.empty())
return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -390,7 +388,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
if (!getElementTypeOrSelf(operand).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ std::optional<VectorShape> shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
Value abs = builder.create<math::AbsFOp>(operand);
@@ -489,7 +487,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
- VectorShape shape = vectorShape(op.getResult());
+ std::optional<VectorShape> shape = vectorShape(op.getResult());
// Compute atan in the valid range.
auto div = builder.create<arith::DivFOp>(y, x);
@@ -555,7 +553,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ std::optional<VectorShape> shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -643,7 +641,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ std::optional<VectorShape> shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -790,7 +788,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ std::optional<VectorShape> shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -845,7 +843,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
if (!(elementType.isF32() || elementType.isF16()))
return rewriter.notifyMatchFailure(op,
"only f32 and f16 type is supported.");
- VectorShape shape = vectorShape(operand);
+ std::optional<VectorShape> shape = vectorShape(operand);
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -940,7 +938,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
if (!(elementType.isF32() || elementType.isF16()))
return rewriter.notifyMatchFailure(op,
"only f32 and f16 type is supported.");
- VectorShape shape = vectorShape(operand);
+ std::optional<VectorShape> shape = vectorShape(operand);
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -1018,7 +1016,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
if (!(elementType.isF32() || elementType.isF16()))
return rewriter.notifyMatchFailure(op,
"only f32 and f16 type is supported.");
- VectorShape shape = vectorShape(operand);
+ std::optional<VectorShape> shape = vectorShape(operand);
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -1127,8 +1125,9 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
namespace {
-Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
- Value value, float lowerBound, float upperBound) {
+Value clampWithNormals(ImplicitLocOpBuilder &builder,
+ const std::optional<VectorShape> shape, Value value,
+ float lowerBound, float upperBound) {
assert(!std::isnan(lowerBound));
assert(!std::isnan(upperBound));
@@ -1319,7 +1318,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ std::optional<VectorShape> shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -1389,7 +1388,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ std::optional<VectorShape> shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -1516,7 +1515,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
- VectorShape shape = vectorShape(operand);
+ std::optional<VectorShape> shape = vectorShape(operand);
Type floatTy = getElementTypeOrSelf(operand.getType());
Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1605,10 +1604,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- VectorShape shape = vectorShape(op.getOperand());
+ std::optional<VectorShape> shape = vectorShape(op.getOperand());
// Only support already-vectorized rsqrt's.
- if (shape.sizes.empty() || shape.sizes.back() % 8 != 0)
+ if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
More information about the Mlir-commits
mailing list