[Mlir-commits] [mlir] 9f12215 - Recommit "[mlir][vector] Allow unroll of contraction in arbitrary order"

Christopher Bate llvmlistbot at llvm.org
Thu Jun 9 13:05:08 PDT 2022


Author: Christopher Bate
Date: 2022-06-09T14:01:19-06:00
New Revision: 9f1221521f4b0e4ea4b76d84b0d6ef12da2a4a8b

URL: https://github.com/llvm/llvm-project/commit/9f1221521f4b0e4ea4b76d84b0d6ef12da2a4a8b
DIFF: https://github.com/llvm/llvm-project/commit/9f1221521f4b0e4ea4b76d84b0d6ef12da2a4a8b.diff

LOG: Recommit "[mlir][vector] Allow unroll of contraction in arbitrary order"

Fixed issue with vector.contract default unroll permutation.

Adds support for vector unroll transformations to unroll in different
orders. For example, the vector.contract can be unrolled into a
smaller set of contractions. There is a choice of how to unroll the
decomposition based on the traversal order of (dim0, dim1, dim2).
The choice of traversal order can now be specified by a callback which
given by the caller of the transform. For now, only the
vector.contract, vector.transfer_read/transfer_write operations
support the callback.

Differential Revision: https://reviews.llvm.org/D127004

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
    mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index e7226f4a6ac0f..e215be49b74ef 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -128,6 +128,19 @@ struct UnrollVectorOptions {
     };
     return *this;
   }
+
+  /// Function that returns the traversal order (in terms of "for loop order",
+  /// i.e. slowest varying dimension to fastest varying dimension) that shoudl
+  /// be used when unrolling the given operation into units of the native vector
+  /// size.
+  using UnrollTraversalOrderFnType =
+      std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
+  UnrollTraversalOrderFnType traversalOrderCallback = nullptr;
+  UnrollVectorOptions &
+  setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) {
+    traversalOrderCallback = std::move(traversalOrderFn);
+    return *this;
+  }
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index 7f00788d888b2..350c0488a27b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -15,8 +15,10 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/MapVector.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include <numeric>
 
 #define DEBUG_TYPE "vector-unrolling"
 
@@ -36,20 +38,81 @@ static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
   return elementOffsets;
 }
 
+/// A functor that accomplishes the same thing as `getVectorOffset` but allows
+/// for reordering the traversal of the dimensions. The order of traversal is
+/// given in "for loop order" (outer to inner).
+namespace {
+class DecomposeShapeIterator {
+private:
+  SmallVector<int64_t, 4> vectorShape;
+  SmallVector<int64_t> loopOrder;
+  SmallVector<int64_t> sliceStrides;
+  int64_t maxIndexVal{1};
+
+public:
+  DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
+                         ArrayRef<int64_t> targetShape,
+                         ArrayRef<int64_t> loopOrder)
+      : vectorShape(targetShape.begin(), targetShape.end()),
+        loopOrder(loopOrder.begin(), loopOrder.end()),
+        sliceStrides(originalShape.size()) {
+    assert(originalShape.size() == targetShape.size());
+    assert(loopOrder.size() == targetShape.size());
+
+    // Compute the count for each dimension.
+    SmallVector<int64_t> sliceDimCounts(originalShape.size());
+    for (unsigned r = 0; r < originalShape.size(); ++r) {
+      sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
+      maxIndexVal *= sliceDimCounts[r];
+    }
+
+    // Reversing "loop order" gives dimensions from fastest varying to slowest
+    // varying (smallest stride to largest stride).
+    int64_t accum = 1;
+    for (auto idx : llvm::reverse(loopOrder)) {
+      sliceStrides[idx] = accum;
+      accum *= sliceDimCounts[idx];
+    }
+  }
+
+  // Turn the linear index into a d-tuple based on units of vectors of size
+  // `vectorShape`. The linear index is assumed to represent traversal of the
+  // dimensions based on `order`.
+  SmallVector<int64_t> delinearize(int64_t index) const {
+    // Traverse in for loop order (largest stride to smallest stride).
+    SmallVector<int64_t> vectorOffsets(sliceStrides.size());
+    for (auto idx : loopOrder) {
+      vectorOffsets[idx] = index / sliceStrides[idx];
+      index %= sliceStrides[idx];
+    }
+    return vectorOffsets;
+  }
+
+  int64_t maxIndex() const { return maxIndexVal; }
+
+  /// Return the offset within d-tuple based on the ordering given by
+  /// `loopOrder`.
+  SmallVector<int64_t> getVectorOffset(int64_t index) const {
+    SmallVector<int64_t> vectorOffsets = delinearize(index);
+    SmallVector<int64_t> elementOffsets =
+        computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
+    return elementOffsets;
+  }
+};
+} // namespace
+
 /// Compute the indices of the slice `index` for a tranfer op.
-static SmallVector<Value>
-sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
-                     ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
-                     AffineMap permutationMap, Location loc,
-                     OpBuilder &builder) {
+static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
+                                               ArrayRef<Value> indices,
+                                               AffineMap permutationMap,
+                                               Location loc,
+                                               OpBuilder &builder) {
   MLIRContext *ctx = builder.getContext();
   auto isBroadcast = [](AffineExpr expr) {
     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
       return constExpr.getValue() == 0;
     return false;
   };
-  SmallVector<int64_t, 4> elementOffsets =
-      getVectorOffset(originalShape, targetShape, index);
   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
@@ -99,6 +162,20 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
   return targetShape;
 }
 
+static SmallVector<int64_t>
+getUnrollOrder(unsigned numLoops, Operation *op,
+               const vector::UnrollVectorOptions &options) {
+  SmallVector<int64_t> loopOrder =
+      llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
+  if (options.traversalOrderCallback != nullptr) {
+    Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
+    if (order.hasValue()) {
+      loopOrder = std::move(*order);
+    }
+  }
+  return loopOrder;
+}
+
 namespace {
 
 struct UnrollTransferReadPattern
@@ -121,9 +198,7 @@ struct UnrollTransferReadPattern
     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
     Location loc = readOp.getLoc();
     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
+
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
@@ -131,17 +206,22 @@ struct UnrollTransferReadPattern
         VectorType::get(*targetShape, sourceVectorType.getElementType());
     SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
                                           readOp.getIndices().end());
-    for (int64_t i = 0; i < sliceCount; i++) {
+
+    SmallVector<int64_t> loopOrder =
+        getUnrollOrder(originalSize.size(), readOp, options);
+    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
+                                          loopOrder);
+    for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
+      SmallVector<int64_t, 4> elementOffsets =
+          indexToOffsets.getVectorOffset(i);
       SmallVector<Value, 4> indices =
-          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
+          sliceTransferIndices(elementOffsets, originalIndices,
                                readOp.getPermutationMap(), loc, rewriter);
       auto slicedRead = rewriter.create<vector::TransferReadOp>(
           loc, targetType, readOp.getSource(), indices,
           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
           readOp.getInBoundsAttr());
 
-      SmallVector<int64_t, 4> elementOffsets =
-          getVectorOffset(originalSize, *targetShape, i);
       result = rewriter.create<vector::InsertStridedSliceOp>(
           loc, slicedRead, result, elementOffsets, strides);
     }
@@ -174,20 +254,21 @@ struct UnrollTransferWritePattern
     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
     Location loc = writeOp.getLoc();
     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
     SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
                                           writeOp.getIndices().end());
+
+    SmallVector<int64_t> loopOrder =
+        getUnrollOrder(originalSize.size(), writeOp, options);
+    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
+                                          loopOrder);
     Value resultTensor;
-    for (int64_t i = 0; i < sliceCount; i++) {
+    for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
       SmallVector<int64_t, 4> elementOffsets =
-          getVectorOffset(originalSize, *targetShape, i);
+          indexToOffsets.getVectorOffset(i);
       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
-
       SmallVector<Value, 4> indices =
-          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
+          sliceTransferIndices(elementOffsets, originalIndices,
                                writeOp.getPermutationMap(), loc, rewriter);
       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
           loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
@@ -236,10 +317,7 @@ struct UnrollContractionPattern
       return failure();
     auto dstVecType = contractOp.getResultType().cast<VectorType>();
     SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
 
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
     Location loc = contractOp.getLoc();
     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
     AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
@@ -247,9 +325,14 @@ struct UnrollContractionPattern
         SmallVector<int64_t>, Value,
         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
         accCache;
+
+    SmallVector<int64_t> loopOrder = getUnrollOrder(
+        contractOp.getIteratorTypes().size(), contractOp, options);
+    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
+                                          loopOrder);
+    const int64_t sliceCount = indexToOffsets.maxIndex();
     for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> offsets =
-          getVectorOffset(originalSize, *targetShape, i);
+      SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
       SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
 
       // Helper to coompute the new shape of each operand and extract the slice.

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 3d0affd2a4be0..f6f218b6e39eb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns=reverse-unroll-order --split-input-file | FileCheck %s --check-prefix=ORDER
 
 // CHECK-LABEL: func @transfer_read_unroll
 //       CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
@@ -13,6 +14,19 @@
 //  CHECK-NEXT:   %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
 //  CHECK-NEXT:   return %[[VEC3]] : vector<4x4xf32>
 
+// ORDER-LABEL: func @transfer_read_unroll
+//       ORDER-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//       ORDER-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//       ORDER:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//  ORDER-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//  ORDER-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//  ORDER-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//  ORDER-NEXT:   return %[[VEC3]] : vector<4x4xf32>
+
 func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
@@ -33,6 +47,19 @@ func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
 //  CHECK-NEXT:   vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
 //  CHECK-NEXT:   return
 
+// ORDER-LABEL: func @transfer_write_unroll
+//       ORDER-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//       ORDER-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//       ORDER:   %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+//  ORDER-NEXT:   vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  ORDER-NEXT:   %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+//  ORDER-NEXT:   vector.transfer_write %[[S1]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  ORDER-NEXT:   %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+//  ORDER-NEXT:   vector.transfer_write %[[S2]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  ORDER-NEXT:   %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+//  ORDER-NEXT:   vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+//  ORDER-NEXT:   return
+
 func.func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
@@ -222,6 +249,25 @@ func.func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) ->
 //  CHECK-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
 //  CHECK-NEXT:   %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
 //  CHECK-NEXT:   return %[[VEC5]] : vector<6x4xf32>
+
+// ORDER-LABEL: func @transfer_read_unroll_
diff erent_rank
+//       ORDER-DAG:   %[[C4:.*]] = arith.constant 4 : index
+//       ORDER-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//       ORDER-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//       ORDER:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  ORDER-NEXT:   %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
+//  ORDER-NEXT:   return %[[VEC5]] : vector<6x4xf32>
+
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
 func.func @transfer_read_unroll_
diff erent_rank(%arg0 : memref<?x?x?xf32>) -> vector<6x4xf32> {
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 3b0aeb48665f4..d0d10887d6a2d 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -1,50 +1,157 @@
 // RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
+// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=2,0,1"  | FileCheck %s --check-prefix=ORDER
+// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=0,3,1,2" | FileCheck %s --check-prefix=BATCHED
 
-func.func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>,
+func.func @vector_contract_f32(%lhs : vector<8x4xf32>, %rhs : vector<8x4xf32>,
                           %init : vector<8x8xf32>) -> vector<8x8xf32> {
   %0 = vector.contract
          {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
                            affine_map<(i, j, k) -> (j, k)>,
                            affine_map<(i, j, k) -> (i, j)>],
           iterator_types = ["parallel", "parallel", "reduction"]}
-       %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
+       %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32>
   return %0 : vector<8x8xf32>
 }
 // CHECK-LABEL: func @vector_contract_f32
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
-//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+// CHECK-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [0, 2]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [0, 2]
+//       CHECK:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  CHECK-SAME:   offsets = [0, 4]
+//       CHECK:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [0, 2]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [4, 2]
+//       CHECK:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [0, 0]
+//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [4, 2]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [0, 2]
+//       CHECK:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum5]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [4, 0]
+//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  CHECK-SAME:   offsets = [4, 4]
+//       CHECK:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-//       CHECK:   vector.contract {
+
+//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  CHECK-SAME:   offsets = [4, 2]
+//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  CHECK-SAME:   offsets = [4, 2]
+//       CHECK:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum7]]
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
 //       CHECK:   return
 
+// ORDER-LABEL: func @vector_contract_f32
+// ORDER-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  ORDER-SAME:   offsets = [0, 4]
+//       ORDER:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [0, 0]
+//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [4, 0]
+//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
+//  ORDER-SAME:   offsets = [4, 4]
+//       ORDER:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [0, 2]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [0, 2]
+//       ORDER:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [0, 2]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [4, 2]
+//       ORDER:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum2]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [4, 2]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [0, 2]
+//       ORDER:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
+//  ORDER-SAME:   offsets = [4, 2]
+//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
+//  ORDER-SAME:   offsets = [4, 2]
+//       ORDER:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum4]]
+//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+
+//       ORDER:   return
+
+
+
 func.func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
                           %init : vector<8x8xf16>) -> vector<8x8xf16> {
   %0 = vector.contract
@@ -158,3 +265,32 @@ func.func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
 //       CHECK:   %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
 //       CHECK:   %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
 //       CHECK:   return %[[V7]] : vector<2x3x8x4xf32>
+
+// -----
+
+func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf32>, %init: vector<8x8x8xf32>) -> vector<8x8x8xf32> {    
+  %0 = vector.contract
+         {indexing_maps = [affine_map<(d0,d1,d2,c0) -> (d0,d1,c0)>,
+                           affine_map<(d0,d1,d2,c0) -> (d0,d2,c0)>,
+                           affine_map<(d0,d1,d2,c0) -> (d0,d1,d2)>],
+          iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+       %lhs, %rhs, %init : vector<8x8x4xf32>, vector<8x8x4xf32> into vector<8x8x8xf32>
+  return %0 : vector<8x8x8xf32>
+}
+
+
+//    CHECK-LABEL: vector_contract_batched
+// CHECK-COUNT-16: vector.contract
+//      CHECK-NOT: vector.contract
+//          CHECK: return 
+
+//    UNROLL-LABEL: vector_contract_batched
+//  UNROLL-COUNT-1: vector.contract
+//      UNROLL-NOT: vector.contract
+//          UNROLL: return
+
+
+//    BATCHED-LABEL: vector_contract_batched
+// BATCHED-COUNT-16: vector.contract
+//      BATCHED-NOT: vector.contract
+//          BATCHED: return

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a81aa536df4ad..6901e51c22e88 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -18,10 +18,12 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
@@ -312,34 +314,51 @@ struct TestVectorUnrollingPatterns
       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
           [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
-        SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
-        if (auto floatType = contractOp.getLhsType()
-                                 .getElementType()
-                                 .dyn_cast<FloatType>()) {
-          if (floatType.getWidth() == 16) {
-            nativeShape[2] = 4;
-          }
-        }
+        SmallVector<int64_t, 4> nativeShape(
+            contractOp.getIteratorTypes().size(), 4);
+        Type lhsType = contractOp.getLhsType().getElementType();
+        nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
         return nativeShape;
       };
+
+      UnrollVectorOptions opts;
+      opts.setNativeShapeFn(nativeShapeFn)
+          .setFilterConstraint(
+              [](Operation *op) { return success(isa<ContractionOp>(op)); });
+
+      if (!unrollOrder.empty()) {
+        opts.setUnrollTraversalOrderFn([this](Operation *op)
+                                           -> Optional<SmallVector<int64_t>> {
+          vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
+          if (contractOp.getIteratorTypes().size() == unrollOrder.size())
+            return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
+          return None;
+        });
+      }
+      populateVectorUnrollPatterns(patterns, opts);
+    } else {
+      auto nativeShapeFn =
+          [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
+        auto contractOp = dyn_cast<ContractionOp>(op);
+        if (!contractOp)
+          return None;
+        return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
+      };
       populateVectorUnrollPatterns(patterns,
                                    UnrollVectorOptions()
                                        .setNativeShapeFn(nativeShapeFn)
                                        .setFilterConstraint([](Operation *op) {
                                          return success(isa<ContractionOp>(op));
                                        }));
-    } else {
-      populateVectorUnrollPatterns(
-          patterns, UnrollVectorOptions()
-                        .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
-                        .setFilterConstraint([](Operation *op) {
-                          return success(isa<ContractionOp>(op));
-                        }));
     }
     populateVectorToVectorCanonicalizationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
+  ListOption<int64_t> unrollOrder{*this, "unroll-order",
+                                  llvm::cl::desc("set the unroll order"),
+                                  llvm::cl::ZeroOrMore};
+
   Option<bool> unrollBasedOnType{
       *this, "unroll-based-on-type",
       llvm::cl::desc("Set the unroll factor based on type of the operation"),
@@ -472,6 +491,11 @@ struct TestVectorTransferUnrollingPatterns
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
       TestVectorTransferUnrollingPatterns)
 
+  TestVectorTransferUnrollingPatterns() = default;
+  TestVectorTransferUnrollingPatterns(
+      const TestVectorTransferUnrollingPatterns &pass)
+      : PassWrapper(pass) {}
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<AffineDialect>();
   }
@@ -485,17 +509,36 @@ struct TestVectorTransferUnrollingPatterns
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateVectorUnrollPatterns(
-        patterns,
-        UnrollVectorOptions()
-            .setNativeShape(ArrayRef<int64_t>{2, 2})
-            .setFilterConstraint([](Operation *op) {
-              return success(
-                  isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
-            }));
+    UnrollVectorOptions opts;
+    opts.setNativeShape(ArrayRef<int64_t>{2, 2})
+        .setFilterConstraint([](Operation *op) {
+          return success(
+              isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
+        });
+    if (reverseUnrollOrder.getValue()) {
+      opts.setUnrollTraversalOrderFn(
+          [](Operation *op) -> Optional<SmallVector<int64_t>> {
+            int64_t numLoops = 0;
+            if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
+              numLoops = readOp.getVectorType().getRank();
+            else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
+              numLoops = writeOp.getVectorType().getRank();
+            else
+              return None;
+            auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
+            return llvm::to_vector(order);
+          });
+    }
+    populateVectorUnrollPatterns(patterns, opts);
     populateVectorToVectorCanonicalizationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
+
+  Option<bool> reverseUnrollOrder{
+      *this, "reverse-unroll-order",
+      llvm::cl::desc(
+          "reverse the order of unrolling of vector transfer operations"),
+      llvm::cl::init(false)};
 };
 
 struct TestVectorTransferFullPartialSplitPatterns


        


More information about the Mlir-commits mailing list