[Mlir-commits] [mlir] 627fa0b - [mlir] MathApproximations: unroll virtual vectors into hardware vectors for ISA specific operation

Eugene Zhulenev llvmlistbot at llvm.org
Thu Oct 28 12:52:10 PDT 2021


Author: Eugene Zhulenev
Date: 2021-10-28T12:52:04-07:00
New Revision: 627fa0b9a897b8d73b481e475abc4518efe55db7

URL: https://github.com/llvm/llvm-project/commit/627fa0b9a897b8d73b481e475abc4518efe55db7
DIFF: https://github.com/llvm/llvm-project/commit/627fa0b9a897b8d73b481e475abc4518efe55db7.diff

LOG: [mlir] MathApproximations: unroll virtual vectors into hardware vectors for ISA specific operation

Reviewed By: cota

Differential Revision: https://reviews.llvm.org/D112736

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index df66c1582645..8c9ff7d5407d 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Math/Transforms/Approximation.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -92,6 +93,101 @@ static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
                                  : value;
 }
 
+//----------------------------------------------------------------------------//
+// Helper function to handle n-D vectors with 1-D operations.
+//----------------------------------------------------------------------------//
+
+// Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
+// and calls the compute function with 1-D vector operands. Stitches back all
+// results into the original n-D vector result.
+//
+// Examples: vectorWidth = 8
+//   - vector<4x8xf32> unrolled 4 times
+//   - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
+//   - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
+//
+// Some math approximations rely on ISA-specific operations that only accept
+// fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
+//
+// It is the caller's responsibility to verify that the inner dimension is
+// divisible by the vectorWidth, and that all operands have the same vector
+// shape.
+static Value
+handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
+                              ValueRange operands, int64_t vectorWidth,
+                              std::function<Value(ValueRange)> compute) {
+  assert(!operands.empty() && "operands must be not empty");
+  assert(vectorWidth > 0 && "vector width must be larger than 0");
+
+  VectorType inputType = operands[0].getType().cast<VectorType>();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+
+  // If input shape matches target vector width, we can just call the
+  // user-provided compute function with the operands.
+  if (inputShape == llvm::makeArrayRef(vectorWidth))
+    return compute(operands);
+
+  // Check if the inner dimension has to be expanded, or we can directly iterate
+  // over the outer dimensions of the vector.
+  int64_t innerDim = inputShape.back();
+  int64_t expansionDim = innerDim / vectorWidth;
+  assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
+
+  // Maybe expand operands to the higher rank vector shape that we'll use to
+  // iterate over and extract one dimensional vectors.
+  SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end());
+  SmallVector<Value> expandedOperands(operands);
+
+  if (expansionDim > 1) {
+    // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
+    expandedShape.insert(expandedShape.end() - 1, expansionDim);
+    expandedShape.back() = vectorWidth;
+
+    for (unsigned i = 0; i < operands.size(); ++i) {
+      auto operand = operands[i];
+      auto eltType = operand.getType().cast<VectorType>().getElementType();
+      auto expandedType = VectorType::get(expandedShape, eltType);
+      expandedOperands[i] =
+          builder.create<vector::ShapeCastOp>(expandedType, operand);
+    }
+  }
+
+  // Iterate over all outer dimensions of the compute shape vector type.
+  auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
+  int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims);
+
+  SmallVector<int64_t> ones(iterationDims.size(), 1);
+  auto strides = computeStrides(iterationDims, ones);
+
+  // Compute results for each one dimensional vector.
+  SmallVector<Value> results(maxLinearIndex);
+
+  for (int64_t i = 0; i < maxLinearIndex; ++i) {
+    auto offsets = delinearize(strides, i);
+
+    SmallVector<Value> extracted(expandedOperands.size());
+    for (auto tuple : llvm::enumerate(expandedOperands))
+      extracted[tuple.index()] =
+          builder.create<vector::ExtractOp>(tuple.value(), offsets);
+
+    results[i] = compute(extracted);
+  }
+
+  // Stitch results together into one large vector.
+  Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
+  Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
+  Value result = builder.create<ConstantOp>(
+      resultExpandedType, builder.getZeroAttr(resultExpandedType));
+
+  for (int64_t i = 0; i < maxLinearIndex; ++i)
+    result = builder.create<vector::InsertOp>(results[i], result,
+                                              delinearize(strides, i));
+
+  // Reshape back to the original vector shape.
+  return builder.create<vector::ShapeCastOp>(
+      VectorType::get(inputShape, resultEltType), result);
+}
+
 //----------------------------------------------------------------------------//
 // Helper functions to create constants.
 //----------------------------------------------------------------------------//
@@ -943,7 +1039,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
                                     PatternRewriter &rewriter) const {
   auto shape = vectorShape(op.operand().getType(), isF32);
   // Only support already-vectorized rsqrt's.
-  if (!shape.hasValue() || (*shape)[0] != 8)
+  if (!shape.hasValue() || shape->back() % 8 != 0)
     return rewriter.notifyMatchFailure(op, "unsupported operand type");
 
   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -967,7 +1063,10 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
   Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
 
   // Compute an approximate result.
-  Value yApprox = builder.create<x86vector::RsqrtOp>(op.operand());
+  Value yApprox = handleMultidimensionalVectors(
+      builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
+        return builder.create<x86vector::RsqrtOp>(operands);
+      });
 
   // Do a single step of Newton-Raphson iteration to improve the approximation.
   // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index f3b2cdfc16ef..f388b84d83c8 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -402,9 +402,9 @@ func @rsqrt_scalar(%arg0: f32) -> f32 {
   return %0 : f32
 }
 
-// CHECK-LABEL:   func @rsqrt_vector
+// CHECK-LABEL:   func @rsqrt_vector_8xf32
 // CHECK:           math.rsqrt
-// AVX2-LABEL:    func @rsqrt_vector(
+// AVX2-LABEL:    func @rsqrt_vector_8xf32(
 // AVX2-SAME:       %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
 // AVX2:   %[[VAL_1:.*]] = arith.constant dense<0x7F800000> : vector<8xf32>
 // AVX2:   %[[VAL_2:.*]] = arith.constant dense<1.500000e+00> : vector<8xf32>
@@ -421,7 +421,89 @@ func @rsqrt_scalar(%arg0: f32) -> f32 {
 // AVX2:   %[[VAL_13:.*]] = select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32>
 // AVX2:   return %[[VAL_13]] : vector<8xf32>
 // AVX2: }
-func @rsqrt_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
+func @rsqrt_vector_8xf32(%arg0: vector<8xf32>) -> vector<8xf32> {
   %0 = math.rsqrt %arg0 : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// Virtual vector width is not a multiple of an AVX2 vector width.
+//
+// CHECK-LABEL:  func @rsqrt_vector_5xf32
+// CHECK:          math.rsqrt
+// AVX2-LABEL:   func @rsqrt_vector_5xf32
+// AVX2:           math.rsqrt
+func @rsqrt_vector_5xf32(%arg0: vector<5xf32>) -> vector<5xf32> {
+  %0 = math.rsqrt %arg0 : vector<5xf32>
+  return %0 : vector<5xf32>
+}
+
+// One dimensional virtual vector expanded and unrolled into multiple AVX2-sized
+// vectors.
+//
+// CHECK-LABEL: func @rsqrt_vector_16xf32
+// CHECK:         math.rsqrt
+// AVX2-LABEL:  func @rsqrt_vector_16xf32(
+// AVX2-SAME:     %[[ARG:.*]]: vector<16xf32>
+// AVX2-SAME:   ) -> vector<16xf32>
+// AVX2:          %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
+// AVX2:          %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<16xf32> to vector<2x8xf32>
+// AVX2:          %[[VEC0:.*]] = vector.extract %[[EXPAND]][0]
+// AVX2:          %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]]
+// AVX2:          %[[VEC1:.*]] = vector.extract %[[EXPAND]][1]
+// AVX2:          %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]]
+// AVX2:          %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
+// AVX2:          %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
+// AVX2:          %[[RSQRT:.*]] = vector.shape_cast %[[RESULT1]] : vector<2x8xf32> to vector<16xf32>
+func @rsqrt_vector_16xf32(%arg0: vector<16xf32>) -> vector<16xf32> {
+  %0 = math.rsqrt %arg0 : vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// Two dimensional virtual vector unrolled into multiple AVX2-sized vectors.
+//
+// CHECK-LABEL: func @rsqrt_vector_2x8xf32
+// CHECK:         math.rsqrt
+// AVX2-LABEL:  func @rsqrt_vector_2x8xf32(
+// AVX2-SAME:     %[[ARG:.*]]: vector<2x8xf32>
+// AVX2-SAME:   ) -> vector<2x8xf32>
+// AVX2:          %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32>
+// AVX2-NOT:      vector.shape_cast
+// AVX2:          %[[VEC0:.*]] = vector.extract %[[ARG]][0]
+// AVX2:          %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]]
+// AVX2:          %[[VEC1:.*]] = vector.extract %[[ARG]][1]
+// AVX2:          %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]]
+// AVX2:          %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0]
+// AVX2:          %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1]
+// AVX2-NOT:      vector.shape_cast
+func @rsqrt_vector_2x8xf32(%arg0: vector<2x8xf32>) -> vector<2x8xf32> {
+  %0 = math.rsqrt %arg0 : vector<2x8xf32>
+  return %0 : vector<2x8xf32>
+}
+
+// Two dimensional virtual vector expanded and unrolled into multiple AVX2-sized
+// vectors.
+//
+// CHECK-LABEL: func @rsqrt_vector_2x16xf32
+// CHECK:         math.rsqrt
+// AVX2-LABEL:  func @rsqrt_vector_2x16xf32(
+// AVX2-SAME:     %[[ARG:.*]]: vector<2x16xf32>
+// AVX2-SAME:   ) -> vector<2x16xf32>
+// AVX2:          %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf32>
+// AVX2:          %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<2x16xf32> to vector<2x2x8xf32>
+// AVX2:          %[[VEC00:.*]] = vector.extract %[[EXPAND]][0, 0]
+// AVX2:          %[[RSQRT00:.*]] = x86vector.avx.rsqrt %[[VEC00]]
+// AVX2:          %[[VEC01:.*]] = vector.extract %[[EXPAND]][0, 1]
+// AVX2:          %[[RSQRT01:.*]] = x86vector.avx.rsqrt %[[VEC01]]
+// AVX2:          %[[VEC10:.*]] = vector.extract %[[EXPAND]][1, 0]
+// AVX2:          %[[RSQRT10:.*]] = x86vector.avx.rsqrt %[[VEC10]]
+// AVX2:          %[[VEC11:.*]] = vector.extract %[[EXPAND]][1, 1]
+// AVX2:          %[[RSQRT11:.*]] = x86vector.avx.rsqrt %[[VEC11]]
+// AVX2:          %[[RESULT0:.*]] = vector.insert %[[RSQRT00]], %[[INIT]] [0, 0]
+// AVX2:          %[[RESULT1:.*]] = vector.insert %[[RSQRT01]], %[[RESULT0]] [0, 1]
+// AVX2:          %[[RESULT2:.*]] = vector.insert %[[RSQRT10]], %[[RESULT1]] [1, 0]
+// AVX2:          %[[RESULT3:.*]] = vector.insert %[[RSQRT11]], %[[RESULT2]] [1, 1]
+// AVX2:          %[[RSQRT:.*]] = vector.shape_cast %[[RESULT3]] : vector<2x2x8xf32> to vector<2x16xf32>
+func @rsqrt_vector_2x16xf32(%arg0: vector<2x16xf32>) -> vector<2x16xf32> {
+  %0 = math.rsqrt %arg0 : vector<2x16xf32>
+  return %0 : vector<2x16xf32>
+}


        


More information about the Mlir-commits mailing list