[Mlir-commits] [mlir] [mlir][math] Propagate scalability in polynomial approximation (PR #84949)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Mar 13 09:34:49 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/84949
>From d8d370837431eacb665e357e216acd6a1586f9a8 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 12 Mar 2024 16:52:04 +0000
Subject: [PATCH 1/2] [mlir][math] Propagate scalability in polynomial
approximation
This simply updates the rewrites to propagate the scalable flags (which
as they do not alter the vector shape, is pretty simple).
The added tests are simply scalable versions of the existing vector
tests.
---
.../Transforms/PolynomialApproximation.cpp | 57 +++++++-----
.../Math/polynomial-approximation.mlir | 89 +++++++++++++++++++
2 files changed, 123 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 962cb28b7c2ab9..428c1c37c4e8b5 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -39,14 +39,24 @@ using namespace mlir;
using namespace mlir::math;
using namespace mlir::vector;
+// Helper to encapsulate a vector's shape (including scalable dims).
+struct VectorShape {
+ ArrayRef<int64_t> sizes;
+ ArrayRef<bool> scalableFlags;
+
+ bool empty() const { return sizes.empty(); }
+};
+
// Returns vector shape if the type is a vector. Returns an empty shape if it is
// not a vector.
-static ArrayRef<int64_t> vectorShape(Type type) {
+static VectorShape vectorShape(Type type) {
auto vectorType = dyn_cast<VectorType>(type);
- return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
+ return vectorType
+ ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
+ : VectorShape{};
}
-static ArrayRef<int64_t> vectorShape(Value value) {
+static VectorShape vectorShape(Value value) {
return vectorShape(value.getType());
}
@@ -55,14 +65,16 @@ static ArrayRef<int64_t> vectorShape(Value value) {
//----------------------------------------------------------------------------//
// Broadcasts scalar type into vector type (iff shape is non-scalar).
-static Type broadcast(Type type, ArrayRef<int64_t> shape) {
+static Type broadcast(Type type, VectorShape shape) {
assert(!isa<VectorType>(type) && "must be scalar type");
- return !shape.empty() ? VectorType::get(shape, type) : type;
+ return !shape.empty()
+ ? VectorType::get(shape.sizes, type, shape.scalableFlags)
+ : type;
}
// Broadcasts scalar value into vector (iff shape is non-scalar).
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
- ArrayRef<int64_t> shape) {
+ VectorShape shape) {
assert(!isa<VectorType>(value.getType()) && "must be scalar value");
auto type = broadcast(value.getType(), shape);
return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
@@ -215,7 +227,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
bool isPositive = false) {
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
- ArrayRef<int64_t> shape = vectorShape(arg);
+ VectorShape shape = vectorShape(arg);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
@@ -255,7 +267,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
// Computes exp2 for an i32 argument.
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
- ArrayRef<int64_t> shape = vectorShape(arg);
+ VectorShape shape = vectorShape(arg);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
@@ -281,7 +293,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
Type elementType = getElementTypeOrSelf(x);
assert((elementType.isF32() || elementType.isF16()) &&
"x must be f32 or f16 type");
- ArrayRef<int64_t> shape = vectorShape(x);
+ VectorShape shape = vectorShape(x);
if (coeffs.empty())
return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
@@ -379,7 +391,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
if (!getElementTypeOrSelf(operand).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+ VectorShape shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
Value abs = builder.create<math::AbsFOp>(operand);
@@ -478,7 +490,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
- ArrayRef<int64_t> shape = vectorShape(op.getResult());
+ VectorShape shape = vectorShape(op.getResult());
// Compute atan in the valid range.
auto div = builder.create<arith::DivFOp>(y, x);
@@ -544,7 +556,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+ VectorShape shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -632,7 +644,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+ VectorShape shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -779,7 +791,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+ VectorShape shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -829,7 +841,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
if (!(elementType.isF32() || elementType.isF16()))
return rewriter.notifyMatchFailure(op,
"only f32 and f16 type is supported.");
- ArrayRef<int64_t> shape = vectorShape(operand);
+ VectorShape shape = vectorShape(operand);
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -938,9 +950,8 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
namespace {
-Value clampWithNormals(ImplicitLocOpBuilder &builder,
- const llvm::ArrayRef<int64_t> shape, Value value,
- float lowerBound, float upperBound) {
+Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
+ Value value, float lowerBound, float upperBound) {
assert(!std::isnan(lowerBound));
assert(!std::isnan(upperBound));
@@ -1131,7 +1142,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+ VectorShape shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -1201,7 +1212,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+ VectorShape shape = vectorShape(op.getOperand());
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
@@ -1328,7 +1339,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
- ArrayRef<int64_t> shape = vectorShape(operand);
+ VectorShape shape = vectorShape(operand);
Type floatTy = getElementTypeOrSelf(operand.getType());
Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
@@ -1417,10 +1428,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
if (!getElementTypeOrSelf(op.getOperand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
- ArrayRef<int64_t> shape = vectorShape(op.getOperand());
+ VectorShape shape = vectorShape(op.getOperand());
// Only support already-vectorized rsqrt's.
- if (shape.empty() || shape.back() % 8 != 0)
+ if (shape.empty() || shape.sizes.back() % 8 != 0)
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 834a7dc0af66d6..82b2646bea4a86 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -94,6 +94,20 @@ func.func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @erf_scalable_vector(
+// CHECK-SAME: %[[arg0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<[8]xf32>
+// CHECK-NOT: erf
+// CHECK-NOT: vector<8xf32>
+// CHECK-COUNT-20: select
+// CHECK: %[[res:.*]] = arith.select
+// CHECK: return %[[res]] : vector<[8]xf32>
+// CHECK: }
+func.func @erf_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+ %0 = math.erf %arg0 : vector<[8]xf32>
+ return %0 : vector<[8]xf32>
+}
+
// CHECK-LABEL: func @exp_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 5.000000e-01 : f32
@@ -151,6 +165,17 @@ func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @exp_scalable_vector
+// CHECK-NOT: math.exp
+// CHECK-NOT: vector<8xf32>
+// CHECK: vector<[8]xf32>
+// CHECK-NOT: vector<8xf32>
+// CHECK-NOT: math.exp
+func.func @exp_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+ %0 = math.exp %arg0 : vector<[8]xf32>
+ return %0 : vector<[8]xf32>
+}
+
// CHECK-LABEL: func @expm1_scalar(
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
@@ -277,6 +302,22 @@ func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
return %0 : vector<8x8xf32>
}
+// CHECK-LABEL: func @expm1_scalable_vector(
+// CHECK-SAME: %{{.*}}: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
+// CHECK-NOT: vector<8x8xf32>
+// CHECK-NOT: exp
+// CHECK-NOT: log
+// CHECK-NOT: expm1
+// CHECK: vector<8x[8]xf32>
+// CHECK-NOT: vector<8x8xf32>
+// CHECK-NOT: exp
+// CHECK-NOT: log
+// CHECK-NOT: expm1
+func.func @expm1_scalable_vector(%arg0: vector<8x[8]xf32>) -> vector<8x[8]xf32> {
+ %0 = math.expm1 %arg0 : vector<8x[8]xf32>
+ return %0 : vector<8x[8]xf32>
+}
+
// CHECK-LABEL: func @log_scalar(
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
@@ -357,6 +398,18 @@ func.func @log_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @log_scalable_vector(
+// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK: %[[CST_LN2:.*]] = arith.constant dense<0.693147182> : vector<[8]xf32>
+// CHECK-COUNT-5: select
+// CHECK: %[[VAL_71:.*]] = arith.select
+// CHECK: return %[[VAL_71]] : vector<[8]xf32>
+// CHECK: }
+func.func @log_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+ %0 = math.log %arg0 : vector<[8]xf32>
+ return %0 : vector<[8]xf32>
+}
+
// CHECK-LABEL: func @log2_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
// CHECK: %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32
@@ -381,6 +434,18 @@ func.func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @log2_scalable_vector(
+// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK: %[[CST_LOG2E:.*]] = arith.constant dense<1.44269502> : vector<[8]xf32>
+// CHECK-COUNT-5: select
+// CHECK: %[[VAL_71:.*]] = arith.select
+// CHECK: return %[[VAL_71]] : vector<[8]xf32>
+// CHECK: }
+func.func @log2_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+ %0 = math.log2 %arg0 : vector<[8]xf32>
+ return %0 : vector<[8]xf32>
+}
+
// CHECK-LABEL: func @log1p_scalar(
// CHECK-SAME: %[[X:.*]]: f32) -> f32 {
// CHECK: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
@@ -414,6 +479,17 @@ func.func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @log1p_scalable_vector(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK: %[[CST_ONE:.*]] = arith.constant dense<1.000000e+00> : vector<[8]xf32>
+// CHECK-COUNT-6: select
+// CHECK: %[[VAL_79:.*]] = arith.select
+// CHECK: return %[[VAL_79]] : vector<[8]xf32>
+// CHECK: }
+func.func @log1p_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+ %0 = math.log1p %arg0 : vector<[8]xf32>
+ return %0 : vector<[8]xf32>
+}
// CHECK-LABEL: func @tanh_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
@@ -470,6 +546,19 @@ func.func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @tanh_scalable_vector(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<-7.99881172> : vector<[8]xf32>
+// CHECK-NOT: tanh
+// CHECK-COUNT-2: select
+// CHECK: %[[VAL_33:.*]] = arith.select
+// CHECK: return %[[VAL_33]] : vector<[8]xf32>
+// CHECK: }
+func.func @tanh_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
+ %0 = math.tanh %arg0 : vector<[8]xf32>
+ return %0 : vector<[8]xf32>
+}
+
// We only approximate rsqrt for vectors and when the AVX2 option is enabled.
// CHECK-LABEL: func @rsqrt_scalar
// AVX2-LABEL: func @rsqrt_scalar
>From 57e8db5515bbf8889ffb61a156c140ccc0ac768c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 13 Mar 2024 16:32:51 +0000
Subject: [PATCH 2/2] Use CHECK-COUNT in tests
---
.../Math/polynomial-approximation.mlir | 32 +++++++++----------
1 file changed, 16 insertions(+), 16 deletions(-)
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 82b2646bea4a86..93ecd67f14dd3d 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -99,8 +99,8 @@ func.func @erf_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<[8]xf32>
// CHECK-NOT: erf
// CHECK-NOT: vector<8xf32>
-// CHECK-COUNT-20: select
-// CHECK: %[[res:.*]] = arith.select
+// CHECK-COUNT-20: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
+// CHECK: %[[res:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[res]] : vector<[8]xf32>
// CHECK: }
func.func @erf_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
@@ -166,11 +166,11 @@ func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
}
// CHECK-LABEL: func @exp_scalable_vector
-// CHECK-NOT: math.exp
-// CHECK-NOT: vector<8xf32>
-// CHECK: vector<[8]xf32>
-// CHECK-NOT: vector<8xf32>
-// CHECK-NOT: math.exp
+// CHECK-NOT: math.exp
+// CHECK-NOT: vector<8xf32>
+// CHECK-COUNT-46: vector<[8]x{{(i32)|(f32)}}>
+// CHECK-NOT: vector<8xf32>
+// CHECK-NOT: math.exp
func.func @exp_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
%0 = math.exp %arg0 : vector<[8]xf32>
return %0 : vector<[8]xf32>
@@ -308,7 +308,7 @@ func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
// CHECK-NOT: exp
// CHECK-NOT: log
// CHECK-NOT: expm1
-// CHECK: vector<8x[8]xf32>
+// CHECK-COUNT-127: vector<8x[8]x{{(i32)|(f32)|(i1)}}>
// CHECK-NOT: vector<8x8xf32>
// CHECK-NOT: exp
// CHECK-NOT: log
@@ -401,8 +401,8 @@ func.func @log_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-LABEL: func @log_scalable_vector(
// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
// CHECK: %[[CST_LN2:.*]] = arith.constant dense<0.693147182> : vector<[8]xf32>
-// CHECK-COUNT-5: select
-// CHECK: %[[VAL_71:.*]] = arith.select
+// CHECK-COUNT-5: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
+// CHECK: %[[VAL_71:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[VAL_71]] : vector<[8]xf32>
// CHECK: }
func.func @log_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
@@ -437,8 +437,8 @@ func.func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-LABEL: func @log2_scalable_vector(
// CHECK-SAME: %{{.*}}: vector<[8]xf32>) -> vector<[8]xf32> {
// CHECK: %[[CST_LOG2E:.*]] = arith.constant dense<1.44269502> : vector<[8]xf32>
-// CHECK-COUNT-5: select
-// CHECK: %[[VAL_71:.*]] = arith.select
+// CHECK-COUNT-5: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
+// CHECK: %[[VAL_71:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[VAL_71]] : vector<[8]xf32>
// CHECK: }
func.func @log2_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
@@ -482,8 +482,8 @@ func.func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-LABEL: func @log1p_scalable_vector(
// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
// CHECK: %[[CST_ONE:.*]] = arith.constant dense<1.000000e+00> : vector<[8]xf32>
-// CHECK-COUNT-6: select
-// CHECK: %[[VAL_79:.*]] = arith.select
+// CHECK-COUNT-6: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
+// CHECK: %[[VAL_79:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[VAL_79]] : vector<[8]xf32>
// CHECK: }
func.func @log1p_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
@@ -550,8 +550,8 @@ func.func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
// CHECK-SAME: %[[VAL_0:.*]]: vector<[8]xf32>) -> vector<[8]xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<-7.99881172> : vector<[8]xf32>
// CHECK-NOT: tanh
-// CHECK-COUNT-2: select
-// CHECK: %[[VAL_33:.*]] = arith.select
+// CHECK-COUNT-2: select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
+// CHECK: %[[VAL_33:.*]] = arith.select {{.*}} : vector<[8]xi1>, vector<[8]xf32>
// CHECK: return %[[VAL_33]] : vector<[8]xf32>
// CHECK: }
func.func @tanh_scalable_vector(%arg0: vector<[8]xf32>) -> vector<[8]xf32> {
More information about the Mlir-commits
mailing list