[Mlir-commits] [mlir] ec32d54 - [mlir] MathApproximations: scalars shape must be 0-rank
Eugene Zhulenev
llvmlistbot at llvm.org
Fri Oct 29 04:02:44 PDT 2021
Author: Eugene Zhulenev
Date: 2021-10-29T04:02:38-07:00
New Revision: ec32d540f808c98b7d7e27cfdb2baa6719501eff
URL: https://github.com/llvm/llvm-project/commit/ec32d540f808c98b7d7e27cfdb2baa6719501eff
DIFF: https://github.com/llvm/llvm-project/commit/ec32d540f808c98b7d7e27cfdb2baa6719501eff.diff
LOG: [mlir] MathApproximations: scalars shape must be 0-rank
Using [1] for representing shape of a scalar is incorrect, and will break with vectors of size 1.
- remove redundant helper functions
- fix couple of style warnings
Reviewed By: cota
Differential Revision: https://reviews.llvm.org/D112764
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 8c9ff7d5407d3..10d3c6a0b8acf 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -11,6 +11,9 @@
//
//===----------------------------------------------------------------------===//
+#include <climits>
+#include <cstddef>
+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h"
@@ -20,68 +23,35 @@
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
-#include <climits>
-#include <cstddef>
using namespace mlir;
using namespace mlir::math;
using namespace mlir::vector;
-using TypePredicate = llvm::function_ref<bool(Type)>;
-
-// Returns vector shape if the element type is matching the predicate (scalars
-// that do match the predicate have shape equal to `{1}`).
-static Optional<SmallVector<int64_t, 2>> vectorShape(Type type,
- TypePredicate pred) {
- // If the type matches the predicate then its shape is `{1}`.
- if (pred(type))
- return SmallVector<int64_t, 2>{1};
-
- // Otherwise check if the type is a vector type.
- auto vectorType = type.dyn_cast<VectorType>();
- if (vectorType && pred(vectorType.getElementType())) {
- return llvm::to_vector<2>(vectorType.getShape());
- }
-
- return llvm::None;
-}
-
-// Returns vector shape of the type. If the type is a scalar returns `1`.
-static SmallVector<int64_t, 2> vectorShape(Type type) {
- auto vectorType = type.dyn_cast<VectorType>();
- return vectorType ? llvm::to_vector<2>(vectorType.getShape())
- : SmallVector<int64_t, 2>{1};
-}
-
-// Returns vector element type. If the type is a scalar returns the argument.
-LLVM_ATTRIBUTE_UNUSED static Type elementType(Type type) {
+// Returns vector shape if the type is a vector. Returns an empty shape if it is
+// not a vector.
+static ArrayRef<int64_t> vectorShape(Type type) {
auto vectorType = type.dyn_cast<VectorType>();
- return vectorType ? vectorType.getElementType() : type;
+ return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
}
-LLVM_ATTRIBUTE_UNUSED static bool isF32(Type type) { return type.isF32(); }
-
-LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) {
- return type.isInteger(32);
+static ArrayRef<int64_t> vectorShape(Value value) {
+ return vectorShape(value.getType());
}
//----------------------------------------------------------------------------//
// Broadcast scalar types and values into vector types and values.
//----------------------------------------------------------------------------//
-// Returns true if shape != {1}.
-static bool isNonScalarShape(ArrayRef<int64_t> shape) {
- return shape.size() > 1 || shape[0] > 1;
-}
-
// Broadcasts scalar type into vector type (iff shape is non-scalar).
static Type broadcast(Type type, ArrayRef<int64_t> shape) {
assert(!type.isa<VectorType>() && "must be scalar type");
- return isNonScalarShape(shape) ? VectorType::get(shape, type) : type;
+ return !shape.empty() ? VectorType::get(shape, type) : type;
}
// Broadcasts scalar value into vector (iff shape is non-scalar).
@@ -89,8 +59,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
ArrayRef<int64_t> shape) {
assert(!value.getType().isa<VectorType>() && "must be scalar value");
auto type = broadcast(value.getType(), shape);
- return isNonScalarShape(shape) ? builder.create<BroadcastOp>(type, value)
- : value;
+ return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
}
//----------------------------------------------------------------------------//
@@ -228,9 +197,8 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
// an integral power of two (see std::frexp). Returned values have float type.
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
bool is_positive = false) {
- assert(isF32(elementType(arg.getType())) && "argument must be f32 type");
-
- auto shape = vectorShape(arg.getType());
+ assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
+ ArrayRef<int64_t> shape = vectorShape(arg);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
@@ -269,9 +237,8 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
// Computes exp2 for an i32 argument.
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
- assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
-
- auto shape = vectorShape(arg.getType());
+ assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
+ ArrayRef<int64_t> shape = vectorShape(arg);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
@@ -294,12 +261,15 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
namespace {
Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
llvm::ArrayRef<Value> coeffs, Value x) {
- auto shape = vectorShape(x.getType(), isF32);
- if (coeffs.size() == 0) {
- return broadcast(builder, f32Cst(builder, 0.0f), *shape);
- } else if (coeffs.size() == 1) {
+ assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type");
+ ArrayRef<int64_t> shape = vectorShape(x);
+
+ if (coeffs.empty())
+ return broadcast(builder, f32Cst(builder, 0.0f), shape);
+
+ if (coeffs.size() == 1)
return coeffs[0];
- }
+
Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
coeffs[coeffs.size() - 2]);
for (auto i = ptr
diff _t(coeffs.size()) - 3; i >= 0; --i) {
@@ -326,13 +296,14 @@ struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
LogicalResult
TanhApproximation::matchAndRewrite(math::TanhOp op,
PatternRewriter &rewriter) const {
- auto shape = vectorShape(op.operand().getType(), isF32);
- if (!shape.hasValue())
+ if (!getElementTypeOrSelf(op.operand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
+ ArrayRef<int64_t> shape = vectorShape(op.operand());
+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *shape);
+ return broadcast(builder, value, shape);
};
// Clamp operand into [plusClamp, minusClamp] range.
@@ -413,13 +384,14 @@ template <typename Op>
LogicalResult
LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
bool base2) const {
- auto shape = vectorShape(op.operand().getType(), isF32);
- if (!shape.hasValue())
+ if (!getElementTypeOrSelf(op.operand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
+ ArrayRef<int64_t> shape = vectorShape(op.operand());
+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *shape);
+ return broadcast(builder, value, shape);
};
Value cstZero = bcast(f32Cst(builder, 0.0f));
@@ -559,13 +531,14 @@ struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
LogicalResult
Log1pApproximation::matchAndRewrite(math::Log1pOp op,
PatternRewriter &rewriter) const {
- auto shape = vectorShape(op.operand().getType(), isF32);
- if (!shape.hasValue())
+ if (!getElementTypeOrSelf(op.operand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
+ ArrayRef<int64_t> shape = vectorShape(op.operand());
+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *shape);
+ return broadcast(builder, value, shape);
};
// Approximate log(1+x) using the following, due to W. Kahan:
@@ -605,13 +578,14 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
LogicalResult
ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
PatternRewriter &rewriter) const {
- auto shape = vectorShape(op.operand().getType(), isF32);
- if (!shape.hasValue())
+ if (!getElementTypeOrSelf(op.operand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
+ ArrayRef<int64_t> shape = vectorShape(op.operand());
+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *shape);
+ return broadcast(builder, value, shape);
};
const int intervalsCount = 3;
@@ -728,15 +702,17 @@ struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
LogicalResult
ExpApproximation::matchAndRewrite(math::ExpOp op,
PatternRewriter &rewriter) const {
- auto shape = vectorShape(op.operand().getType(), isF32);
- if (!shape.hasValue())
+ if (!getElementTypeOrSelf(op.operand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ ArrayRef<int64_t> shape = vectorShape(op.operand());
+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
// TODO: Consider a common pattern rewriter with all methods below to
// write the approximations.
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *shape);
+ return broadcast(builder, value, shape);
};
auto fmla = [&](Value a, Value b, Value c) {
return builder.create<math::FmaOp>(a, b, c);
@@ -779,7 +755,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
Value expY = fmla(q1, y2, q0);
expY = fmla(q2, y4, expY);
- auto i32Vec = broadcast(builder.getI32Type(), *shape);
+ auto i32Vec = broadcast(builder.getI32Type(), shape);
// exp2(k)
Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
@@ -848,13 +824,14 @@ struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
LogicalResult
ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
PatternRewriter &rewriter) const {
- auto shape = vectorShape(op.operand().getType(), isF32);
- if (!shape.hasValue())
+ if (!getElementTypeOrSelf(op.operand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
+ ArrayRef<int64_t> shape = vectorShape(op.operand());
+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *shape);
+ return broadcast(builder, value, shape);
};
// expm1(x) = exp(x) - 1 = u - 1.
@@ -915,13 +892,15 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
static_assert(
llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
"SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
- auto shape = vectorShape(op.operand().getType(), isF32);
- if (!shape.hasValue())
+
+ if (!getElementTypeOrSelf(op.operand()).isF32())
return rewriter.notifyMatchFailure(op, "unsupported operand type");
+ ArrayRef<int64_t> shape = vectorShape(op.operand());
+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *shape);
+ return broadcast(builder, value, shape);
};
auto mul = [&](Value a, Value b) -> Value {
return builder.create<arith::MulFOp>(a, b);
@@ -931,7 +910,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
};
auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
- auto i32Vec = broadcast(builder.getI32Type(), *shape);
+ auto i32Vec = broadcast(builder.getI32Type(), shape);
auto fPToSingedInteger = [&](Value a) -> Value {
return builder.create<arith::FPToSIOp>(a, i32Vec);
};
@@ -1037,14 +1016,18 @@ struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
LogicalResult
RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
PatternRewriter &rewriter) const {
- auto shape = vectorShape(op.operand().getType(), isF32);
+ if (!getElementTypeOrSelf(op.operand()).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ ArrayRef<int64_t> shape = vectorShape(op.operand());
+
// Only support already-vectorized rsqrt's.
- if (!shape.hasValue() || shape->back() % 8 != 0)
+ if (shape.empty() || shape.back() % 8 != 0)
return rewriter.notifyMatchFailure(op, "unsupported operand type");
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
- return broadcast(builder, value, *shape);
+ return broadcast(builder, value, shape);
};
Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
More information about the Mlir-commits
mailing list