[Mlir-commits] [mlir] [mlir][math] Propagate scalability in polynomial approximation (PR #84949)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 12 10:01:28 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-math

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This simply updates the rewrites to propagate the scalable flags (which as they do not alter the vector shape, is pretty simple).

The added tests are simply scalable versions of the existing vector tests.

---
Full diff: https://github.com/llvm/llvm-project/pull/84949.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+34-23) 
- (modified) mlir/test/Dialect/Math/polynomial-approximation.mlir (+89) 


``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 962cb28b7c2ab9..428c1c37c4e8b5 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -39,14 +39,24 @@ using namespace mlir;
 using namespace mlir::math;
 using namespace mlir::vector;
 
+// Helper to encapsulate a vector's shape (including scalable dims).
+struct VectorShape {
+  ArrayRef<int64_t> sizes;
+  ArrayRef<bool> scalableFlags;
+
+  bool empty() const { return sizes.empty(); }
+};
+
 // Returns vector shape if the type is a vector. Returns an empty shape if it is
 // not a vector.
-static ArrayRef<int64_t> vectorShape(Type type) {
+static VectorShape vectorShape(Type type) {
   auto vectorType = dyn_cast<VectorType>(type);
-  return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
+  return vectorType
+             ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
+             : VectorShape{};
 }
 
-static ArrayRef<int64_t> vectorShape(Value value) {
+static VectorShape vectorShape(Value value) {
   return vectorShape(value.getType());
 }
 
@@ -55,14 +65,16 @@ static ArrayRef<int64_t> vectorShape(Value value) {
 //----------------------------------------------------------------------------//
 
 // Broadcasts scalar type into vector type (iff shape is non-scalar).
-static Type broadcast(Type type, ArrayRef<int64_t> shape) {
+static Type broadcast(Type type, VectorShape shape) {
   assert(!isa<VectorType>(type) && "must be scalar type");
-  return !shape.empty() ? VectorType::get(shape, type) : type;
+  return !shape.empty()
+             ? VectorType::get(shape.sizes, type, shape.scalableFlags)
+             : type;
 }
 
 // Broadcasts scalar value into vector (iff shape is non-scalar).
 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
-                       ArrayRef<int64_t> shape) {
+                       VectorShape shape) {
   assert(!isa<VectorType>(value.getType()) && "must be scalar value");
   auto type = broadcast(value.getType(), shape);
   return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
@@ -215,7 +227,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
                                      bool isPositive = false) {
   assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
-  ArrayRef<int64_t> shape = vectorShape(arg);
+  VectorShape shape = vectorShape(arg);
 
   auto bcast = [&](Value value) -> Value {
     return broadcast(builder, value, shape);
@@ -255,7 +267,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
 // Computes exp2 for an i32 argument.
 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
   assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
-  ArrayRef<int64_t> shape = vectorShape(arg);
+  VectorShape shape = vectorShape(arg);
 
   auto bcast = [&](Value value) -> Value {
     return broadcast(builder, value, shape);
@@ -281,7 +293,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
   Type elementType = getElementTypeOrSelf(x);
   assert((elementType.isF32() || elementType.isF16()) &&
          "x must be f32 or f16 type");
-  ArrayRef<int64_t> shape = vectorShape(x);
+  VectorShape shape = vectorShape(x);
 
   if (coeffs.empty())
     return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -379,7 +391,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
   if (!getElementTypeOrSelf(operand).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   Value abs = builder.create<math::AbsFOp>(operand);
@@ -478,7 +490,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
-  ArrayRef<int64_t> shape = vectorShape(op.getResult());
+  VectorShape shape = vectorShape(op.getResult());
 
   // Compute atan in the valid range.
   auto div = builder.create<arith::DivFOp>(y, x);
@@ -544,7 +556,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -632,7 +644,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -779,7 +791,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -829,7 +841,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
   if (!(elementType.isF32() || elementType.isF16()))
     return rewriter.notifyMatchFailure(op,
                                        "only f32 and f16 type is supported.");
-  ArrayRef<int64_t> shape = vectorShape(operand);
+  VectorShape shape = vectorShape(operand);
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -938,9 +950,8 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
 
 namespace {
 
-Value clampWithNormals(ImplicitLocOpBuilder &builder,
-                       const llvm::ArrayRef<int64_t> shape, Value value,
-                       float lowerBound, float upperBound) {
+Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
+                       Value value, float lowerBound, float upperBound) {
   assert(!std::isnan(lowerBound));
   assert(!std::isnan(upperBound));
 
@@ -1131,7 +1142,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -1201,7 +1212,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
@@ -1328,7 +1339,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
-  ArrayRef<int64_t> shape = vectorShape(operand);
+  VectorShape shape = vectorShape(operand);
 
   Type floatTy = getElementTypeOrSelf(operand.getType());
   Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1417,10 +1428,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
   if (!getElementTypeOrSelf(op.getOperand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
-  ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+  VectorShape shape = vectorShape(op.getOperand());
 
   // Only support already-vectorized rsqrt's.
-  if (shape.empty() || shape.back() % 8 != 0)
+  if (shape.empty() || shape.sizes.back() % 8 != 0)
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 834a7dc0af66d6..82b2646bea4a86 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -94,6 +94,20 @@ func.func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @erf_scalable_vector(
+// CHECK-SAME:                     %[[arg0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<[8]xf32>
+// CHECK-NOT:       erf
+// CHECK-NOT:       vector<8xf32>
+// CHECK-COUNT-20:  select
+// CHECK:           %[[res:.*]] = arith.select
+// CHECK:           return %[[res]] : vector<[8]xf32>
+// CHECK:         }
+func.func @erf_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.erf %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // CHECK-LABEL:   func @exp_scalar(
 // CHECK-SAME:                     %[[VAL_0:.*]]: f32) -> f32 {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 5.000000e-01 : f32
@@ -151,6 +165,17 @@ func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @exp_scalable_vector
+// CHECK-NOT:  math.exp
+// CHECK-NOT:  vector<8xf32>
+// CHECK:      vector<[8]xf32>
+// CHECK-NOT:  vector<8xf32>
+// CHECK-NOT:  math.exp
+func.func @exp_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.exp %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // CHECK-LABEL:   func @expm1_scalar(
 // CHECK-SAME:                       %[[X:.*]]: f32) -> f32 {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
@@ -277,6 +302,22 @@ func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
   return %0 : vector<8x8xf32>
 }
 
+// CHECK-LABEL:   func @expm1_scalable_vector(
+// CHECK-SAME:                       %{{.*}}: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
+// CHECK-NOT:       vector<8x8xf32>
+// CHECK-NOT:       exp
+// CHECK-NOT:       log
+// CHECK-NOT:       expm1
+// CHECK:           vector<8x[8]xf32>
+// CHECK-NOT:       vector<8x8xf32>
+// CHECK-NOT:       exp
+// CHECK-NOT:       log
+// CHECK-NOT:       expm1
+func.func @expm1_scalable_vector(%arg0: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
+  %0 = math.expm1 %arg0 : vector<8x[8]xf32>
+  return %0 : vector<8x[8]xf32>
+}
+
 // CHECK-LABEL:   func @log_scalar(
 // CHECK-SAME:                             %[[X:.*]]: f32) -> f32 {
 // CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
@@ -357,6 +398,18 @@ func.func @log_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @log_scalable_vector(
+// CHECK-SAME:                     %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[CST_LN2:.*]] = arith.constant dense<0.693147182> : vector<[8]xf32>
+// CHECK-COUNT-5:   select
+// CHECK:           %[[VAL_71:.*]] = arith.select
+// CHECK:           return %[[VAL_71]] : vector<[8]xf32>
+// CHECK:         }
+func.func @log_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.log %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // CHECK-LABEL:   func @log2_scalar(
 // CHECK-SAME:                      %[[VAL_0:.*]]: f32) -> f32 {
 // CHECK:           %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32
@@ -381,6 +434,18 @@ func.func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @log2_scalable_vector(
+// CHECK-SAME:                      %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[CST_LOG2E:.*]] = arith.constant dense<1.44269502> : vector<[8]xf32>
+// CHECK-COUNT-5:   select
+// CHECK:           %[[VAL_71:.*]] = arith.select
+// CHECK:           return %[[VAL_71]] : vector<[8]xf32>
+// CHECK:         }
+func.func @log2_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.log2 %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // CHECK-LABEL:   func @log1p_scalar(
 // CHECK-SAME:                       %[[X:.*]]: f32) -> f32 {
 // CHECK:           %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
@@ -414,6 +479,17 @@ func.func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @log1p_scalable_vector(
+// CHECK-SAME:                       %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[CST_ONE:.*]] = arith.constant dense<1.000000e+00> : vector<[8]xf32>
+// CHECK-COUNT-6:   select
+// CHECK:           %[[VAL_79:.*]] = arith.select
+// CHECK:           return %[[VAL_79]] : vector<[8]xf32>
+// CHECK:         }
+func.func @log1p_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.log1p %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
 
 // CHECK-LABEL:   func @tanh_scalar(
 // CHECK-SAME:                      %[[VAL_0:.*]]: f32) -> f32 {
@@ -470,6 +546,19 @@ func.func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL:   func @tanh_scalable_vector(
+// CHECK-SAME:                      %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<-7.99881172> : vector<[8]xf32>
+// CHECK-NOT:       tanh
+// CHECK-COUNT-2:   select
+// CHECK:           %[[VAL_33:.*]] = arith.select
+// CHECK:           return %[[VAL_33]] : vector<[8]xf32>
+// CHECK:         }
+func.func @tanh_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+  %0 = math.tanh %arg0 : vector<[8]xf32>
+  return %0 : vector<[8]xf32>
+}
+
 // We only approximate rsqrt for vectors and when the AVX2 option is enabled.
 // CHECK-LABEL:   func @rsqrt_scalar
 // AVX2-LABEL:    func @rsqrt_scalar

``````````

</details>


https://github.com/llvm/llvm-project/pull/84949


More information about the Mlir-commits mailing list