[Mlir-commits] [mlir] ec32d54 - [mlir] MathApproximations: scalars shape must be 0-rank

Eugene Zhulenev llvmlistbot at llvm.org
Fri Oct 29 04:02:44 PDT 2021


Author: Eugene Zhulenev
Date: 2021-10-29T04:02:38-07:00
New Revision: ec32d540f808c98b7d7e27cfdb2baa6719501eff

URL: https://github.com/llvm/llvm-project/commit/ec32d540f808c98b7d7e27cfdb2baa6719501eff
DIFF: https://github.com/llvm/llvm-project/commit/ec32d540f808c98b7d7e27cfdb2baa6719501eff.diff

LOG: [mlir] MathApproximations: scalars shape must be 0-rank

Using [1] for representing shape of a scalar is incorrect, and will break with vectors of size 1.

- remove redundant helper functions
- fix couple of style warnings

Reviewed By: cota

Differential Revision: https://reviews.llvm.org/D112764

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 8c9ff7d5407d3..10d3c6a0b8acf 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -11,6 +11,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <climits>
+#include <cstddef>
+
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Approximation.h"
@@ -20,68 +23,35 @@
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/Bufferize.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/ArrayRef.h"
-#include <climits>
-#include <cstddef>
 
 using namespace mlir;
 using namespace mlir::math;
 using namespace mlir::vector;
 
-using TypePredicate = llvm::function_ref<bool(Type)>;
-
-// Returns vector shape if the element type is matching the predicate (scalars
-// that do match the predicate have shape equal to `{1}`).
-static Optional<SmallVector<int64_t, 2>> vectorShape(Type type,
-                                                     TypePredicate pred) {
-  // If the type matches the predicate then its shape is `{1}`.
-  if (pred(type))
-    return SmallVector<int64_t, 2>{1};
-
-  // Otherwise check if the type is a vector type.
-  auto vectorType = type.dyn_cast<VectorType>();
-  if (vectorType && pred(vectorType.getElementType())) {
-    return llvm::to_vector<2>(vectorType.getShape());
-  }
-
-  return llvm::None;
-}
-
-// Returns vector shape of the type. If the type is a scalar returns `1`.
-static SmallVector<int64_t, 2> vectorShape(Type type) {
-  auto vectorType = type.dyn_cast<VectorType>();
-  return vectorType ? llvm::to_vector<2>(vectorType.getShape())
-                    : SmallVector<int64_t, 2>{1};
-}
-
-// Returns vector element type. If the type is a scalar returns the argument.
-LLVM_ATTRIBUTE_UNUSED static Type elementType(Type type) {
+// Returns vector shape if the type is a vector. Returns an empty shape if it is
+// not a vector.
+static ArrayRef<int64_t> vectorShape(Type type) {
   auto vectorType = type.dyn_cast<VectorType>();
-  return vectorType ? vectorType.getElementType() : type;
+  return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
 }
 
-LLVM_ATTRIBUTE_UNUSED static bool isF32(Type type) { return type.isF32(); }
-
-LLVM_ATTRIBUTE_UNUSED static bool isI32(Type type) {
-  return type.isInteger(32);
+static ArrayRef<int64_t> vectorShape(Value value) {
+  return vectorShape(value.getType());
 }
 
 //----------------------------------------------------------------------------//
 // Broadcast scalar types and values into vector types and values.
 //----------------------------------------------------------------------------//
 
-// Returns true if shape != {1}.
-static bool isNonScalarShape(ArrayRef<int64_t> shape) {
-  return shape.size() > 1 || shape[0] > 1;
-}
-
 // Broadcasts scalar type into vector type (iff shape is non-scalar).
 static Type broadcast(Type type, ArrayRef<int64_t> shape) {
   assert(!type.isa<VectorType>() && "must be scalar type");
-  return isNonScalarShape(shape) ? VectorType::get(shape, type) : type;
+  return !shape.empty() ? VectorType::get(shape, type) : type;
 }
 
 // Broadcasts scalar value into vector (iff shape is non-scalar).
@@ -89,8 +59,7 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
                        ArrayRef<int64_t> shape) {
   assert(!value.getType().isa<VectorType>() && "must be scalar value");
   auto type = broadcast(value.getType(), shape);
-  return isNonScalarShape(shape) ? builder.create<BroadcastOp>(type, value)
-                                 : value;
+  return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
 }
 
 //----------------------------------------------------------------------------//
@@ -228,9 +197,8 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
 // an integral power of two (see std::frexp). Returned values have float type.
 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
                                      bool is_positive = false) {
-  assert(isF32(elementType(arg.getType())) && "argument must be f32 type");
-
-  auto shape = vectorShape(arg.getType());
+  assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
+  ArrayRef<int64_t> shape = vectorShape(arg);
 
   auto bcast = [&](Value value) -> Value {
     return broadcast(builder, value, shape);
@@ -269,9 +237,8 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
 
 // Computes exp2 for an i32 argument.
 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
-  assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
-
-  auto shape = vectorShape(arg.getType());
+  assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
+  ArrayRef<int64_t> shape = vectorShape(arg);
 
   auto bcast = [&](Value value) -> Value {
     return broadcast(builder, value, shape);
@@ -294,12 +261,15 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
 namespace {
 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
                                 llvm::ArrayRef<Value> coeffs, Value x) {
-  auto shape = vectorShape(x.getType(), isF32);
-  if (coeffs.size() == 0) {
-    return broadcast(builder, f32Cst(builder, 0.0f), *shape);
-  } else if (coeffs.size() == 1) {
+  assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type");
+  ArrayRef<int64_t> shape = vectorShape(x);
+
+  if (coeffs.empty())
+    return broadcast(builder, f32Cst(builder, 0.0f), shape);
+
+  if (coeffs.size() == 1)
     return coeffs[0];
-  }
+
   Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
                                           coeffs[coeffs.size() - 2]);
   for (auto i = ptr
diff _t(coeffs.size()) - 3; i >= 0; --i) {
@@ -326,13 +296,14 @@ struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
 LogicalResult
 TanhApproximation::matchAndRewrite(math::TanhOp op,
                                    PatternRewriter &rewriter) const {
-  auto shape = vectorShape(op.operand().getType(), isF32);
-  if (!shape.hasValue())
+  if (!getElementTypeOrSelf(op.operand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
+  ArrayRef<int64_t> shape = vectorShape(op.operand());
+
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *shape);
+    return broadcast(builder, value, shape);
   };
 
   // Clamp operand into [plusClamp, minusClamp] range.
@@ -413,13 +384,14 @@ template <typename Op>
 LogicalResult
 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
                                              bool base2) const {
-  auto shape = vectorShape(op.operand().getType(), isF32);
-  if (!shape.hasValue())
+  if (!getElementTypeOrSelf(op.operand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
+  ArrayRef<int64_t> shape = vectorShape(op.operand());
+
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *shape);
+    return broadcast(builder, value, shape);
   };
 
   Value cstZero = bcast(f32Cst(builder, 0.0f));
@@ -559,13 +531,14 @@ struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
 LogicalResult
 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
                                     PatternRewriter &rewriter) const {
-  auto shape = vectorShape(op.operand().getType(), isF32);
-  if (!shape.hasValue())
+  if (!getElementTypeOrSelf(op.operand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
+  ArrayRef<int64_t> shape = vectorShape(op.operand());
+
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *shape);
+    return broadcast(builder, value, shape);
   };
 
   // Approximate log(1+x) using the following, due to W. Kahan:
@@ -605,13 +578,14 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
 LogicalResult
 ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
                                             PatternRewriter &rewriter) const {
-  auto shape = vectorShape(op.operand().getType(), isF32);
-  if (!shape.hasValue())
+  if (!getElementTypeOrSelf(op.operand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
+  ArrayRef<int64_t> shape = vectorShape(op.operand());
+
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *shape);
+    return broadcast(builder, value, shape);
   };
 
   const int intervalsCount = 3;
@@ -728,15 +702,17 @@ struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
 LogicalResult
 ExpApproximation::matchAndRewrite(math::ExpOp op,
                                   PatternRewriter &rewriter) const {
-  auto shape = vectorShape(op.operand().getType(), isF32);
-  if (!shape.hasValue())
+  if (!getElementTypeOrSelf(op.operand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  ArrayRef<int64_t> shape = vectorShape(op.operand());
+
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
 
   // TODO: Consider a common pattern rewriter with all methods below to
   // write the approximations.
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *shape);
+    return broadcast(builder, value, shape);
   };
   auto fmla = [&](Value a, Value b, Value c) {
     return builder.create<math::FmaOp>(a, b, c);
@@ -779,7 +755,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
   Value expY = fmla(q1, y2, q0);
   expY = fmla(q2, y4, expY);
 
-  auto i32Vec = broadcast(builder.getI32Type(), *shape);
+  auto i32Vec = broadcast(builder.getI32Type(), shape);
 
   // exp2(k)
   Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
@@ -848,13 +824,14 @@ struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
 LogicalResult
 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
                                     PatternRewriter &rewriter) const {
-  auto shape = vectorShape(op.operand().getType(), isF32);
-  if (!shape.hasValue())
+  if (!getElementTypeOrSelf(op.operand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
+  ArrayRef<int64_t> shape = vectorShape(op.operand());
+
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *shape);
+    return broadcast(builder, value, shape);
   };
 
   // expm1(x) = exp(x) - 1 = u - 1.
@@ -915,13 +892,15 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   static_assert(
       llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
       "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
-  auto shape = vectorShape(op.operand().getType(), isF32);
-  if (!shape.hasValue())
+
+  if (!getElementTypeOrSelf(op.operand()).isF32())
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
+  ArrayRef<int64_t> shape = vectorShape(op.operand());
+
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *shape);
+    return broadcast(builder, value, shape);
   };
   auto mul = [&](Value a, Value b) -> Value {
     return builder.create<arith::MulFOp>(a, b);
@@ -931,7 +910,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   };
   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
 
-  auto i32Vec = broadcast(builder.getI32Type(), *shape);
+  auto i32Vec = broadcast(builder.getI32Type(), shape);
   auto fPToSingedInteger = [&](Value a) -> Value {
     return builder.create<arith::FPToSIOp>(a, i32Vec);
   };
@@ -1037,14 +1016,18 @@ struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
 LogicalResult
 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
                                     PatternRewriter &rewriter) const {
-  auto shape = vectorShape(op.operand().getType(), isF32);
+  if (!getElementTypeOrSelf(op.operand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  ArrayRef<int64_t> shape = vectorShape(op.operand());
+
   // Only support already-vectorized rsqrt's.
-  if (!shape.hasValue() || shape->back() % 8 != 0)
+  if (shape.empty() || shape.back() % 8 != 0)
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
   auto bcast = [&](Value value) -> Value {
-    return broadcast(builder, value, *shape);
+    return broadcast(builder, value, shape);
   };
 
   Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));


        


More information about the Mlir-commits mailing list