[Mlir-commits] [mlir] 96cee29 - [mlir] Allow polynomial approximations for N-d vectors.

Alexander Belyaev llvmlistbot at llvm.org
Tue Oct 26 11:52:43 PDT 2021


Author: Alexander Belyaev
Date: 2021-10-26T20:50:00+02:00
New Revision: 96cee29762f0aa638a96db46f27282b585622d0c

URL: https://github.com/llvm/llvm-project/commit/96cee29762f0aa638a96db46f27282b585622d0c
DIFF: https://github.com/llvm/llvm-project/commit/96cee29762f0aa638a96db46f27282b585622d0c.diff

LOG: [mlir] Allow polynomial approximations for N-d vectors.

Polynomial approximation can be extented to support N-d vectors.
N-dimensional vectors are useful when vectorizing operations on N-dimensional
tiles. Before lowering to LLVM these vectors are usually unrolled or flattened
to 1-dimensional vectors.

Differential Revision: https://reviews.llvm.org/D112566

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 306a1d117e1f7..dd90dd3461c09 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -32,27 +32,28 @@ using namespace mlir::vector;
 
 using TypePredicate = llvm::function_ref<bool(Type)>;
 
-// Returns vector width if the element type is matching the predicate (scalars
-// that do match the predicate have width equal to `1`).
-static Optional<int> vectorWidth(Type type, TypePredicate pred) {
-  // If the type matches the predicate then its width is `1`.
+// Returns vector shape if the element type is matching the predicate (scalars
+// that do match the predicate have shape equal to `{1}`).
+static Optional<SmallVector<int64_t, 2>> vectorShape(Type type,
+                                                     TypePredicate pred) {
+  // If the type matches the predicate then its shape is `{1}`.
   if (pred(type))
-    return 1;
+    return SmallVector<int64_t, 2>{1};
 
   // Otherwise check if the type is a vector type.
   auto vectorType = type.dyn_cast<VectorType>();
   if (vectorType && pred(vectorType.getElementType())) {
-    assert(vectorType.getRank() == 1 && "only 1d vectors are supported");
-    return vectorType.getDimSize(0);
+    return llvm::to_vector<2>(vectorType.getShape());
   }
 
   return llvm::None;
 }
 
-// Returns vector width of the type. If the type is a scalar returns `1`.
-static int vectorWidth(Type type) {
+// Returns vector shape of the type. If the type is a scalar returns `1`.
+static SmallVector<int64_t, 2> vectorShape(Type type) {
   auto vectorType = type.dyn_cast<VectorType>();
-  return vectorType ? vectorType.getDimSize(0) : 1;
+  return vectorType ? llvm::to_vector<2>(vectorType.getShape())
+                    : SmallVector<int64_t, 2>{1};
 }
 
 // Returns vector element type. If the type is a scalar returns the argument.
@@ -71,17 +72,24 @@ LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) {
 // Broadcast scalar types and values into vector types and values.
 //----------------------------------------------------------------------------//
 
-// Broadcasts scalar type into vector type (iff width is greater then 1).
-static Type broadcast(Type type, int width) {
+// Returns true if shape != {1}.
+static bool isNonScalarShape(ArrayRef<int64_t> shape) {
+  return shape.size() > 1 || shape[0] > 1;
+}
+
+// Broadcasts scalar type into vector type (iff shape is non-scalar).
+static Type broadcast(Type type, ArrayRef<int64_t> shape) {
   assert(!type.isa<VectorType>() && "must be scalar type");
-  return width > 1 ? VectorType::get({width}, type) : type;
+  return isNonScalarShape(shape) ? VectorType::get(shape, type) : type;
 }
 
-// Broadcasts scalar value into vector (iff width is greater then 1).
-static Value broadcast(ImplicitLocOpBuilder &builder, Value value, int width) {
+// Broadcasts scalar value into vector (iff shape is non-scalar).
+static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
+                       ArrayRef<int64_t> shape) {
   assert(!value.getType().isa<VectorType>() && "must be scalar value");
-  auto type = broadcast(value.getType(), width);
-  return width > 1 ? builder.create<BroadcastOp>(type, value) : value;
+  auto type = broadcast(value.getType(), shape);
+  return isNonScalarShape(shape) ? builder.create<BroadcastOp>(type, value)
+                                 : value;
 }
 
 //----------------------------------------------------------------------------//
@@ -126,15 +134,15 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
                                      bool is_positive = false) {
   assert(isF32(elementType(arg.getType())) && "argument must be f32 type");
 
-  int width = vectorWidth(arg.getType());
+  auto shape = vectorShape(arg.getType());
 
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, width);
+    return broadcast(builder, value, shape);
   };
 
   auto i32 = builder.getIntegerType(32);
-  auto i32Vec = broadcast(i32, width);
-  auto f32Vec = broadcast(builder.getF32Type(), width);
+  auto i32Vec = broadcast(i32, shape);
+  auto f32Vec = broadcast(builder.getF32Type(), shape);
 
   Value cst126f = f32Cst(builder, 126.0f);
   Value cstHalf = f32Cst(builder, 0.5f);
@@ -167,13 +175,13 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
   assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
 
-  int width = vectorWidth(arg.getType());
+  auto shape = vectorShape(arg.getType());
 
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, width);
+    return broadcast(builder, value, shape);
   };
 
-  auto f32Vec = broadcast(builder.getF32Type(), width);
+  auto f32Vec = broadcast(builder.getF32Type(), shape);
   // The exponent of f32 located at 23-bit.
   auto exponetBitLocation = bcast(i32Cst(builder, 23));
   // Set the exponent bias to zero.
@@ -222,13 +230,13 @@ struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
 LogicalResult
 TanhApproximation::matchAndRewrite(math::TanhOp op,
                                    PatternRewriter &rewriter) const {
-  auto width = vectorWidth(op.operand().getType(), isF32);
-  if (!width.hasValue())
+  auto shape = vectorShape(op.operand().getType(), isF32);
+  if (!shape.hasValue())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *width);
+    return broadcast(builder, value, *shape);
   };
 
   // Clamp operand into [plusClamp, minusClamp] range.
@@ -309,13 +317,13 @@ template <typename Op>
 LogicalResult
 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
                                              bool base2) const {
-  auto width = vectorWidth(op.operand().getType(), isF32);
-  if (!width.hasValue())
+  auto shape = vectorShape(op.operand().getType(), isF32);
+  if (!shape.hasValue())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *width);
+    return broadcast(builder, value, *shape);
   };
 
   Value cstZero = bcast(f32Cst(builder, 0.0f));
@@ -455,13 +463,13 @@ struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
 LogicalResult
 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
                                     PatternRewriter &rewriter) const {
-  auto width = vectorWidth(op.operand().getType(), isF32);
-  if (!width.hasValue())
+  auto shape = vectorShape(op.operand().getType(), isF32);
+  if (!shape.hasValue())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *width);
+    return broadcast(builder, value, *shape);
   };
 
   // Approximate log(1+x) using the following, due to W. Kahan:
@@ -624,15 +632,15 @@ struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
 LogicalResult
 ExpApproximation::matchAndRewrite(math::ExpOp op,
                                   PatternRewriter &rewriter) const {
-  auto width = vectorWidth(op.operand().getType(), isF32);
-  if (!width.hasValue())
+  auto shape = vectorShape(op.operand().getType(), isF32);
+  if (!shape.hasValue())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
 
   // TODO: Consider a common pattern rewriter with all methods below to
   // write the approximations.
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *width);
+    return broadcast(builder, value, *shape);
   };
   auto fmla = [&](Value a, Value b, Value c) {
     return builder.create<math::FmaOp>(a, b, c);
@@ -675,7 +683,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
   Value expY = fmla(q1, y2, q0);
   expY = fmla(q2, y4, expY);
 
-  auto i32Vec = broadcast(builder.getI32Type(), *width);
+  auto i32Vec = broadcast(builder.getI32Type(), *shape);
 
   // exp2(k)
   Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
@@ -744,13 +752,13 @@ struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
 LogicalResult
 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
                                     PatternRewriter &rewriter) const {
-  auto width = vectorWidth(op.operand().getType(), isF32);
-  if (!width.hasValue())
+  auto shape = vectorShape(op.operand().getType(), isF32);
+  if (!shape.hasValue())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *width);
+    return broadcast(builder, value, *shape);
   };
 
   // expm1(x) = exp(x) - 1 = u - 1.
@@ -811,13 +819,13 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   static_assert(
       llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
       "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
-  auto width = vectorWidth(op.operand().getType(), isF32);
-  if (!width.hasValue())
+  auto shape = vectorShape(op.operand().getType(), isF32);
+  if (!shape.hasValue())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *width);
+    return broadcast(builder, value, *shape);
   };
   auto mul = [&](Value a, Value b) -> Value {
     return builder.create<arith::MulFOp>(a, b);
@@ -827,7 +835,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   };
   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
 
-  auto i32Vec = broadcast(builder.getI32Type(), *width);
+  auto i32Vec = broadcast(builder.getI32Type(), *shape);
   auto fPToSingedInteger = [&](Value a) -> Value {
     return builder.create<arith::FPToSIOp>(a, i32Vec);
   };
@@ -933,14 +941,14 @@ struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
 LogicalResult
 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
                                     PatternRewriter &rewriter) const {
-  auto width = vectorWidth(op.operand().getType(), isF32);
+  auto shape = vectorShape(op.operand().getType(), isF32);
   // Only support already-vectorized rsqrt's.
-  if (!width.hasValue() || *width != 8)
+  if (!shape.hasValue() || (*shape)[0] != 8)
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *width);
+    return broadcast(builder, value, *shape);
   };
 
   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index f7e7c215f7aaf..f3b2cdfc16efb 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -183,8 +183,8 @@ func @expm1_scalar(%arg0: f32) -> f32 {
 }
 
 // CHECK-LABEL:   func @expm1_vector(
-// CHECK-SAME:                       %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8xf32>
+// CHECK-SAME:                       %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32>
 // CHECK-NOT:       exp
 // CHECK-COUNT-4:   select
 // CHECK-NOT:       log
@@ -192,11 +192,11 @@ func @expm1_scalar(%arg0: f32) -> f32 {
 // CHECK-NOT:       expm1
 // CHECK-COUNT-3:   select
 // CHECK:           %[[VAL_115:.*]] = select
-// CHECK:           return %[[VAL_115]] : vector<8xf32>
+// CHECK:           return %[[VAL_115]] : vector<8x8xf32>
 // CHECK:         }
-func @expm1_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
-  %0 = math.expm1 %arg0 : vector<8xf32>
-  return %0 : vector<8xf32>
+func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
+  %0 = math.expm1 %arg0 : vector<8x8xf32>
+  return %0 : vector<8x8xf32>
 }
 
 // CHECK-LABEL:   func @log_scalar(


        


More information about the Mlir-commits mailing list