[Mlir-commits] [mlir] 1469ebf - [mlir][vector] Allow unroll of contraction in arbitrary order

Christopher Bate llvmlistbot at llvm.org
Mon Jun 6 13:31:27 PDT 2022


Author: Christopher Bate
Date: 2022-06-06T14:31:04-06:00
New Revision: 1469ebf8382107e0344173f362b690d19e24029d

URL: https://github.com/llvm/llvm-project/commit/1469ebf8382107e0344173f362b690d19e24029d
DIFF: https://github.com/llvm/llvm-project/commit/1469ebf8382107e0344173f362b690d19e24029d.diff

LOG: [mlir][vector] Allow unroll of contraction in arbitrary order

Adds supprot for vector unroll transformations to unroll in different
orders. For example, the `vector.contract` can be unrolled into a
smaller set of contractions.  There is a choice of how to unroll the
decomposition  based on the traversal order of (dim0, dim1, dim2).
The choice of traversal order can now be specified by a callback which
given by the caller of the transform. For now, only the
`vector.contract`, `vector.transfer_read/transfer_write` operations
support the callback.

Differential Revision: https://reviews.llvm.org/D127004

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
    mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index e7226f4a6ac0f..e215be49b74ef 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -128,6 +128,19 @@ struct UnrollVectorOptions {
     };
     return *this;
   }
+
+  /// Function that returns the traversal order (in terms of "for loop order",
+  /// i.e. slowest varying dimension to fastest varying dimension) that shoudl
+  /// be used when unrolling the given operation into units of the native vector
+  /// size.
+  using UnrollTraversalOrderFnType =
+      std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
+  UnrollTraversalOrderFnType traversalOrderCallback = nullptr;
+  UnrollVectorOptions &
+  setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) {
+    traversalOrderCallback = std::move(traversalOrderFn);
+    return *this;
+  }
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 7f00788d888b2..b6c4114b391c1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -15,8 +15,11 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
+#include <numeric>
 
 #define DEBUG_TYPE "vector-unrolling"
 
@@ -36,20 +39,78 @@ static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
   return elementOffsets;
 }
 
+/// A functor that accomplishes the same thing as `getVectorOffset` but allows
+/// for reordering the traversal of the dimensions. The order of traversal is
+/// given in "for loop order" (outer to inner).
+namespace {
+class DecomposeShapeIterator {
+private:
+  SmallVector<int64_t, 4> vectorShape;
+  SmallVector<int64_t> loopOrder;
+  SmallVector<int64_t> sliceStrides;
+  int64_t maxIndexVal{1};
+
+public:
+  DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
+                         ArrayRef<int64_t> targetShape,
+                         ArrayRef<int64_t> loopOrder)
+      : vectorShape(targetShape.begin(), targetShape.end()),
+        loopOrder(loopOrder.begin(), loopOrder.end()),
+        sliceStrides(originalShape.size()) {
+    // Compute the count for each dimension.
+    SmallVector<int64_t> sliceDimCounts(originalShape.size());
+    for (unsigned r = 0; r < originalShape.size(); ++r) {
+      sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
+      maxIndexVal *= sliceDimCounts[r];
+    }
+
+    // Reversing "loop order" gives dimensions from fastest varying to slowest
+    // varying (smallest stride to largest stride).
+    int64_t accum = 1;
+    for (auto idx : llvm::reverse(loopOrder)) {
+      sliceStrides[idx] = accum;
+      accum *= sliceDimCounts[idx];
+    }
+  }
+
+  // Turn the linear index into a d-tuple based on units of vectors of size
+  // `vectorShape`. The linear index is assumed to represent traversal of the
+  // dimensions based on `order`.
+  SmallVector<int64_t> delinearize(int64_t index) const {
+    // Traverse in for loop order (largest stride to smallest stride).
+    SmallVector<int64_t, 4> vectorOffsets(sliceStrides.size());
+    for (auto idx : loopOrder) {
+      vectorOffsets[idx] = index / sliceStrides[idx];
+      index %= sliceStrides[idx];
+    }
+    return vectorOffsets;
+  }
+
+  int64_t maxIndex() const { return maxIndexVal; }
+
+  /// Return the offset within d-tuple based on the ordering given by
+  /// `loopOrder`.
+  SmallVector<int64_t> getVectorOffset(int64_t index) const {
+    SmallVector<int64_t> vectorOffsets = delinearize(index);
+    SmallVector<int64_t> elementOffsets =
+        computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
+    return elementOffsets;
+  }
+};
+} // namespace
+
 /// Compute the indices of the slice `index` for a tranfer op.
-static SmallVector<Value>
-sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
-                     ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
-                     AffineMap permutationMap, Location loc,
-                     OpBuilder &builder) {
+static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
+                                               ArrayRef<Value> indices,
+                                               AffineMap permutationMap,
+                                               Location loc,
+                                               OpBuilder &builder) {
   MLIRContext *ctx = builder.getContext();
   auto isBroadcast = [](AffineExpr expr) {
     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
       return constExpr.getValue() == 0;
     return false;
   };
-  SmallVector<int64_t, 4> elementOffsets =
-      getVectorOffset(originalShape, targetShape, index);
   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
@@ -99,6 +160,20 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
   return targetShape;
 }
 
+static SmallVector<int64_t>
+getUnrollOrder(unsigned numLoops, Operation *op,
+               const vector::UnrollVectorOptions &options) {
+  SmallVector<int64_t> loopOrder =
+      llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
+  if (options.traversalOrderCallback != nullptr) {
+    Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
+    if (order.hasValue()) {
+      loopOrder = std::move(*order);
+    }
+  }
+  return loopOrder;
+}
+
 namespace {
 
 struct UnrollTransferReadPattern
@@ -122,8 +197,7 @@ struct UnrollTransferReadPattern
     Location loc = readOp.getLoc();
     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
+
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
@@ -131,17 +205,22 @@ struct UnrollTransferReadPattern
         VectorType::get(*targetShape, sourceVectorType.getElementType());
     SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
                                           readOp.getIndices().end());
-    for (int64_t i = 0; i < sliceCount; i++) {
+
+    SmallVector<int64_t> loopOrder =
+        getUnrollOrder(ratio.size(), readOp, options);
+    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
+                                          loopOrder);
+    for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
+      SmallVector<int64_t, 4> elementOffsets =
+          indexToOffsets.getVectorOffset(i);
       SmallVector<Value, 4> indices =
-          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
+          sliceTransferIndices(elementOffsets, originalIndices,
                                readOp.getPermutationMap(), loc, rewriter);
       auto slicedRead = rewriter.create<vector::TransferReadOp>(
           loc, targetType, readOp.getSource(), indices,
           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
           readOp.getInBoundsAttr());
 
-      SmallVector<int64_t, 4> elementOffsets =
-          getVectorOffset(originalSize, *targetShape, i);
       result = rewriter.create<vector::InsertStridedSliceOp>(
           loc, slicedRead, result, elementOffsets, strides);
     }
@@ -174,20 +253,21 @@ struct UnrollTransferWritePattern
     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
     Location loc = writeOp.getLoc();
     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
     SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
                                           writeOp.getIndices().end());
+
+    SmallVector<int64_t> loopOrder =
+        getUnrollOrder(originalIndices.size(), writeOp, options);
+    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
+                                          loopOrder);
     Value resultTensor;
-    for (int64_t i = 0; i < sliceCount; i++) {
+    for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
       SmallVector<int64_t, 4> elementOffsets =
-          getVectorOffset(originalSize, *targetShape, i);
+          indexToOffsets.getVectorOffset(i);
       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
-
       SmallVector<Value, 4> indices =
-          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
+          sliceTransferIndices(elementOffsets, originalIndices,
                                writeOp.getPermutationMap(), loc, rewriter);
       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
           loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
@@ -238,8 +318,6 @@ struct UnrollContractionPattern
     SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
 
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
     Location loc = contractOp.getLoc();
     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
     AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
@@ -247,9 +325,14 @@ struct UnrollContractionPattern
         SmallVector<int64_t>, Value,
         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
         accCache;
+
+    SmallVector<int64_t> loopOrder = getUnrollOrder(
+        contractOp.getIndexingMaps().size(), contractOp, options);
+    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
+                                          loopOrder);
+    const int64_t sliceCount = indexToOffsets.maxIndex();
     for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> offsets =
-          getVectorOffset(originalSize, *targetShape, i);
+      SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
       SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
 
       // Helper to coompute the new shape of each operand and extract the slice.

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 3d0affd2a4be0..f6f218b6e39eb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns=reverse-unroll-order --split-input-file | FileCheck %s --check-prefix=ORDER
 
 // CHECK-LABEL: func @transfer_read_unroll
 //       CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
@@ -13,6 +14,19 @@
 //  CHECK-NEXT:   %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
 //  CHECK-NEXT:   return %[[VEC3]] : vector<4x4xf32>
 
+// ORDER-LABEL: func @transfer_read_unroll
+//       ORDER-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//       ORDER-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//       ORDER:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//  ORDER-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//  ORDER-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//  ORDER-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//  ORDER-NEXT:   return %[[VEC3]] : vector<4x4xf32>
+
 func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
@@ -33,6 +47,19 @@ func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
 //  CHECK-NEXT:   vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
 //  CHECK-NEXT:   return
 
+// ORDER-LABEL: func @transfer_write_unroll
+//       ORDER-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//       ORDER-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//       ORDER:   %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+//  ORDER-NEXT:   vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  ORDER-NEXT:   %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+//  ORDER-NEXT:   vector.transfer_write %[[S1]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  ORDER-NEXT:   %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+//  ORDER-NEXT:   vector.transfer_write %[[S2]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  ORDER-NEXT:   %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+//  ORDER-NEXT:   vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  ORDER-NEXT:   return
+
 func.func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
@@ -222,6 +249,25 @@ func.func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) ->
 //  CHECK-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
 //  CHECK-NEXT:   %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
 //  CHECK-NEXT:   return %[[VEC5]] : vector<6x4xf32>
+
+// ORDER-LABEL: func @transfer_read_unroll_
diff erent_rank
+//       ORDER-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//       ORDER-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//       ORDER-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//       ORDER:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   return %[[VEC5]] : vector<6x4xf32>
+
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
 func.func @transfer_read_unroll_
diff erent_rank(%arg0 : memref<?x?x?xf32>) -> vector<6x4xf32> {
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 3b0aeb48665f4..272c04a34e8ea 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -1,50 +1,156 @@
 // RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
+// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=2,0,1"  --split-input-file | FileCheck %s --check-prefix=ORDER
 
-func.func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>,
+func.func @vector_contract_f32(%lhs : vector<8x4xf32>, %rhs : vector<8x4xf32>,
                           %init : vector<8x8xf32>) -> vector<8x8xf32> {
   %0 = vector.contract
          {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
                            affine_map<(i, j, k) -> (j, k)>,
                            affine_map<(i, j, k) -> (i, j)>],
           iterator_types = ["parallel", "parallel", "reduction"]}
-       %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
+       %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32>
   return %0 : vector<8x8xf32>
 }
 // CHECK-LABEL: func @vector_contract_f32
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+// CHECK-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [0, 2]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [0, 2]
+//       CHECK:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  CHECK-SAME:   offsets = [0, 4]
+//       CHECK:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [0, 2]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [4, 2]
+//       CHECK:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [4, 2]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [0, 2]
+//       CHECK:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum5]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  CHECK-SAME:   offsets = [4, 4]
+//       CHECK:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [4, 2]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [4, 2]
+//       CHECK:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum7]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
 //       CHECK:   return
 
+// ORDER-LABEL: func @vector_contract_f32
+// ORDER-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  ORDER-SAME:   offsets = [0, 4]
+//       ORDER:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  ORDER-SAME:   offsets = [4, 4]
+//       ORDER:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [0, 2]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [0, 2]
+//       ORDER:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [0, 2]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [4, 2]
+//       ORDER:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum2]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [4, 2]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [0, 2]
+//       ORDER:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [4, 2]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [4, 2]
+//       ORDER:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum4]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   return
+
+
+
 func.func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
                           %init : vector<8x8xf16>) -> vector<8x8xf16> {
   %0 = vector.contract
@@ -158,3 +264,4 @@ func.func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
 //       CHECK:   %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
 //       CHECK:   %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
 //       CHECK:   return %[[V7]] : vector<2x3x8x4xf32>
+

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a81aa536df4ad..cbec267734eba 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
@@ -322,12 +323,18 @@ struct TestVectorUnrollingPatterns
         }
         return nativeShape;
       };
-      populateVectorUnrollPatterns(patterns,
-                                   UnrollVectorOptions()
-                                       .setNativeShapeFn(nativeShapeFn)
-                                       .setFilterConstraint([](Operation *op) {
-                                         return success(isa<ContractionOp>(op));
-                                       }));
+
+      UnrollVectorOptions opts;
+      opts.setNativeShapeFn(nativeShapeFn)
+          .setFilterConstraint(
+              [](Operation *op) { return success(isa<ContractionOp>(op)); });
+      if (!unrollOrder.empty()) {
+        opts.setUnrollTraversalOrderFn([this](Operation *op)
+                                           -> Optional<SmallVector<int64_t>> {
+          return SmallVector<int64_t>{unrollOrder.begin(), unrollOrder.end()};
+        });
+      }
+      populateVectorUnrollPatterns(patterns, opts);
     } else {
       populateVectorUnrollPatterns(
           patterns, UnrollVectorOptions()
@@ -340,6 +347,10 @@ struct TestVectorUnrollingPatterns
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
+  ListOption<int64_t> unrollOrder{*this, "unroll-order",
+                                  llvm::cl::desc("set the unroll order"),
+                                  llvm::cl::ZeroOrMore};
+
   Option<bool> unrollBasedOnType{
       *this, "unroll-based-on-type",
       llvm::cl::desc("Set the unroll factor based on type of the operation"),
@@ -472,6 +483,11 @@ struct TestVectorTransferUnrollingPatterns
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
       TestVectorTransferUnrollingPatterns)
 
+  TestVectorTransferUnrollingPatterns() = default;
+  TestVectorTransferUnrollingPatterns(
+      const TestVectorTransferUnrollingPatterns &pass)
+      : PassWrapper(pass) {}
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<AffineDialect>();
   }
@@ -485,17 +501,36 @@ struct TestVectorTransferUnrollingPatterns
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateVectorUnrollPatterns(
-        patterns,
-        UnrollVectorOptions()
-            .setNativeShape(ArrayRef<int64_t>{2, 2})
-            .setFilterConstraint([](Operation *op) {
-              return success(
-                  isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
-            }));
+    UnrollVectorOptions opts;
+    opts.setNativeShape(ArrayRef<int64_t>{2, 2})
+        .setFilterConstraint([](Operation *op) {
+          return success(
+              isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
+        });
+    if (reverseUnrollOrder.getValue()) {
+      opts.setUnrollTraversalOrderFn(
+          [](Operation *op) -> Optional<SmallVector<int64_t>> {
+            int64_t numLoops = 0;
+            if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
+              numLoops = readOp.getVectorType().getRank();
+            else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
+              numLoops = writeOp.getVectorType().getRank();
+            else
+              return None;
+            auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
+            return llvm::to_vector(order);
+          });
+    }
+    populateVectorUnrollPatterns(patterns, opts);
     populateVectorToVectorCanonicalizationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
+
+  Option<bool> reverseUnrollOrder{
+      *this, "reverse-unroll-order",
+      llvm::cl::desc(
+          "reverse the order of unrolling of vector transfer operations"),
+      llvm::cl::init(false)};
 };
 
 struct TestVectorTransferFullPartialSplitPatterns


        


More information about the Mlir-commits mailing list