[Mlir-commits] [mlir] 96cee29 - [mlir] Allow polynomial approximations for N-d vectors.
Alexander Belyaev
llvmlistbot at llvm.org
Tue Oct 26 11:52:43 PDT 2021
Author: Alexander Belyaev
Date: 2021-10-26T20:50:00+02:00
New Revision: 96cee29762f0aa638a96db46f27282b585622d0c
URL: https://github.com/llvm/llvm-project/commit/96cee29762f0aa638a96db46f27282b585622d0c
DIFF: https://github.com/llvm/llvm-project/commit/96cee29762f0aa638a96db46f27282b585622d0c.diff
LOG: [mlir] Allow polynomial approximations for N-d vectors.
Polynomial approximation can be extented to support N-d vectors.
N-dimensional vectors are useful when vectorizing operations on N-dimensional
tiles. Before lowering to LLVM these vectors are usually unrolled or flattened
to 1-dimensional vectors.
Differential Revision: https://reviews.llvm.org/D112566
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 306a1d117e1f7..dd90dd3461c09 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -32,27 +32,28 @@ using namespace mlir::vector;
using TypePredicate = llvm::function_ref<bool(Type)>;
-// Returns vector width if the element type is matching the predicate (scalars
-// that do match the predicate have width equal to `1`).
-static Optional<int> vectorWidth(Type type, TypePredicate pred) {
- // If the type matches the predicate then its width is `1`.
+// Returns vector shape if the element type is matching the predicate (scalars
+// that do match the predicate have shape equal to `{1}`).
+static Optional<SmallVector<int64_t, 2>> vectorShape(Type type,
+ TypePredicate pred) {
+ // If the type matches the predicate then its shape is `{1}`.
if (pred(type))
- return 1;
+ return SmallVector<int64_t, 2>{1};
// Otherwise check if the type is a vector type.
auto vectorType = type.dyn_cast<VectorType>();
if (vectorType && pred(vectorType.getElementType())) {
- assert(vectorType.getRank() == 1 && "only 1d vectors are supported");
- return vectorType.getDimSize(0);
+ return llvm::to_vector<2>(vectorType.getShape());
}
return llvm::None;
}
-// Returns vector width of the type. If the type is a scalar returns `1`.
-static int vectorWidth(Type type) {
+// Returns vector shape of the type. If the type is a scalar returns `1`.
+static SmallVector<int64_t, 2> vectorShape(Type type) {
auto vectorType = type.dyn_cast<VectorType>();
- return vectorType ? vectorType.getDimSize(0) : 1;
+ return vectorType ? llvm::to_vector<2>(vectorType.getShape())
+ : SmallVector<int64_t, 2>{1};
}
// Returns vector element type. If the type is a scalar returns the argument.
@@ -71,17 +72,24 @@ LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) {
// Broadcast scalar types and values into vector types and values.
//----------------------------------------------------------------------------//
-// Broadcasts scalar type into vector type (iff width is greater then 1).
-static Type broadcast(Type type, int width) {
+// Returns true if shape != {1}.
+static bool isNonScalarShape(ArrayRef<int64_t> shape) {
+ return shape.size() > 1 || shape[0] > 1;
+}
+
+// Broadcasts scalar type into vector type (iff shape is non-scalar).
+static Type broadcast(Type type, ArrayRef<int64_t> shape) {
assert(!type.isa<VectorType>() && "must be scalar type");
- return width > 1 ? VectorType::get({width}, type) : type;
+ return isNonScalarShape(shape) ? VectorType::get(shape, type) : type;
}
-// Broadcasts scalar value into vector (iff width is greater then 1).
-static Value broadcast(ImplicitLocOpBuilder &builder, Value value, int width) {
+// Broadcasts scalar value into vector (iff shape is non-scalar).
+static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
+ ArrayRef<int64_t> shape) {
assert(!value.getType().isa<VectorType>() && "must be scalar value");
- auto type = broadcast(value.getType(), width);
- return width > 1 ? builder.create<BroadcastOp>(type, value) : value;
+ auto type = broadcast(value.getType(), shape);
+ return isNonScalarShape(shape) ? builder.create<BroadcastOp>(type, value)
+ : value;
}
//----------------------------------------------------------------------------//
@@ -126,15 +134,15 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
bool is_positive = false) {
assert(isF32(elementType(arg.getType())) && "argument must be f32 type");
- int width = vectorWidth(arg.getType());
+ auto shape = vectorShape(arg.getType());
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, width);
+ return broadcast(builder, value, shape);
};
auto i32 = builder.getIntegerType(32);
- auto i32Vec = broadcast(i32, width);
- auto f32Vec = broadcast(builder.getF32Type(), width);
+ auto i32Vec = broadcast(i32, shape);
+ auto f32Vec = broadcast(builder.getF32Type(), shape);
Value cst126f = f32Cst(builder, 126.0f);
Value cstHalf = f32Cst(builder, 0.5f);
@@ -167,13 +175,13 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
- int width = vectorWidth(arg.getType());
+ auto shape = vectorShape(arg.getType());
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, width);
+ return broadcast(builder, value, shape);
};
- auto f32Vec = broadcast(builder.getF32Type(), width);
+ auto f32Vec = broadcast(builder.getF32Type(), shape);
// The exponent of f32 located at 23-bit.
auto exponetBitLocation = bcast(i32Cst(builder, 23));
// Set the exponent bias to zero.
@@ -222,13 +230,13 @@ struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
LogicalResult
TanhApproximation::matchAndRewrite(math::TanhOp op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
// Clamp operand into [plusClamp, minusClamp] range.
@@ -309,13 +317,13 @@ template <typename Op>
LogicalResult
LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
bool base2) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
Value cstZero = bcast(f32Cst(builder, 0.0f));
@@ -455,13 +463,13 @@ struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
LogicalResult
Log1pApproximation::matchAndRewrite(math::Log1pOp op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
// Approximate log(1+x) using the following, due to W. Kahan:
@@ -624,15 +632,15 @@ struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
LogicalResult
ExpApproximation::matchAndRewrite(math::ExpOp op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
// TODO: Consider a common pattern rewriter with all methods below to
// write the approximations.
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
auto fmla = [&](Value a, Value b, Value c) {
return builder.create<math::FmaOp>(a, b, c);
@@ -675,7 +683,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
Value expY = fmla(q1, y2, q0);
expY = fmla(q2, y4, expY);
- auto i32Vec = broadcast(builder.getI32Type(), *width);
+ auto i32Vec = broadcast(builder.getI32Type(), *shape);
// exp2(k)
Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
@@ -744,13 +752,13 @@ struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
LogicalResult
ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
// expm1(x) = exp(x) - 1 = u - 1.
@@ -811,13 +819,13 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
static_assert(
llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
"SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
- auto width = vectorWidth(op.operand().getType(), isF32);
- if (!width.hasValue())
+ auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!shape.hasValue())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
auto mul = [&](Value a, Value b) -> Value {
return builder.create<arith::MulFOp>(a, b);
@@ -827,7 +835,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
};
auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
- auto i32Vec = broadcast(builder.getI32Type(), *width);
+ auto i32Vec = broadcast(builder.getI32Type(), *shape);
auto fPToSingedInteger = [&](Value a) -> Value {
return builder.create<arith::FPToSIOp>(a, i32Vec);
};
@@ -933,14 +941,14 @@ struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
LogicalResult
RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
PatternRewriter &rewriter) const {
- auto width = vectorWidth(op.operand().getType(), isF32);
+ auto shape = vectorShape(op.operand().getType(), isF32);
// Only support already-vectorized rsqrt's.
- if (!width.hasValue() || *width != 8)
+ if (!shape.hasValue() || (*shape)[0] != 8)
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *width);
+ return broadcast(builder, value, *shape);
};
Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index f7e7c215f7aaf..f3b2cdfc16efb 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -183,8 +183,8 @@ func @expm1_scalar(%arg0: f32) -> f32 {
}
// CHECK-LABEL: func @expm1_vector(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8xf32>
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32>
// CHECK-NOT: exp
// CHECK-COUNT-4: select
// CHECK-NOT: log
@@ -192,11 +192,11 @@ func @expm1_scalar(%arg0: f32) -> f32 {
// CHECK-NOT: expm1
// CHECK-COUNT-3: select
// CHECK: %[[VAL_115:.*]] = select
-// CHECK: return %[[VAL_115]] : vector<8xf32>
+// CHECK: return %[[VAL_115]] : vector<8x8xf32>
// CHECK: }
-func @expm1_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
- %0 = math.expm1 %arg0 : vector<8xf32>
- return %0 : vector<8xf32>
+func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
+ %0 = math.expm1 %arg0 : vector<8x8xf32>
+ return %0 : vector<8x8xf32>
}
// CHECK-LABEL: func @log_scalar(
More information about the Mlir-commits
mailing list