[Mlir-commits] [mlir] 53fe155 - Revert "[mlir][vector] Allow unroll of contraction in arbitrary order"

Christopher Bate llvmlistbot at llvm.org
Tue Jun 7 13:57:16 PDT 2022


Author: Christopher Bate
Date: 2022-06-07T14:54:01-06:00
New Revision: 53fe155b3f4d455432041caaee90c8145eeb2a4b

URL: https://github.com/llvm/llvm-project/commit/53fe155b3f4d455432041caaee90c8145eeb2a4b
DIFF: https://github.com/llvm/llvm-project/commit/53fe155b3f4d455432041caaee90c8145eeb2a4b.diff

LOG: Revert "[mlir][vector] Allow unroll of contraction in arbitrary order"

Reverts commit 1469ebf8382107e0344173f362b690d19e24029d (original commit)
Reverts commit a392a39f75af586e3d3cd046a8361939277e067f (build fix for above commit)

The commit broke tests in out-of-tree projects, indicating that some logical
error was made in the previous change but not covered by current tests.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
    mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index e215be49b74ef..e7226f4a6ac0f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -128,19 +128,6 @@ struct UnrollVectorOptions {
     };
     return *this;
   }
-
-  /// Function that returns the traversal order (in terms of "for loop order",
-  /// i.e. slowest varying dimension to fastest varying dimension) that shoudl
-  /// be used when unrolling the given operation into units of the native vector
-  /// size.
-  using UnrollTraversalOrderFnType =
-      std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
-  UnrollTraversalOrderFnType traversalOrderCallback = nullptr;
-  UnrollVectorOptions &
-  setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) {
-    traversalOrderCallback = std::move(traversalOrderFn);
-    return *this;
-  }
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
index dbca588f87f81..7f00788d888b2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
@@ -15,11 +15,8 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
-#include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
-#include <numeric>
 
 #define DEBUG_TYPE "vector-unrolling"
 
@@ -39,78 +36,20 @@ static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
   return elementOffsets;
 }
 
-/// A functor that accomplishes the same thing as `getVectorOffset` but allows
-/// for reordering the traversal of the dimensions. The order of traversal is
-/// given in "for loop order" (outer to inner).
-namespace {
-class DecomposeShapeIterator {
-private:
-  SmallVector<int64_t, 4> vectorShape;
-  SmallVector<int64_t> loopOrder;
-  SmallVector<int64_t> sliceStrides;
-  int64_t maxIndexVal{1};
-
-public:
-  DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
-                         ArrayRef<int64_t> targetShape,
-                         ArrayRef<int64_t> loopOrder)
-      : vectorShape(targetShape.begin(), targetShape.end()),
-        loopOrder(loopOrder.begin(), loopOrder.end()),
-        sliceStrides(originalShape.size()) {
-    // Compute the count for each dimension.
-    SmallVector<int64_t> sliceDimCounts(originalShape.size());
-    for (unsigned r = 0; r < originalShape.size(); ++r) {
-      sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
-      maxIndexVal *= sliceDimCounts[r];
-    }
-
-    // Reversing "loop order" gives dimensions from fastest varying to slowest
-    // varying (smallest stride to largest stride).
-    int64_t accum = 1;
-    for (auto idx : llvm::reverse(loopOrder)) {
-      sliceStrides[idx] = accum;
-      accum *= sliceDimCounts[idx];
-    }
-  }
-
-  // Turn the linear index into a d-tuple based on units of vectors of size
-  // `vectorShape`. The linear index is assumed to represent traversal of the
-  // dimensions based on `order`.
-  SmallVector<int64_t> delinearize(int64_t index) const {
-    // Traverse in for loop order (largest stride to smallest stride).
-    SmallVector<int64_t> vectorOffsets(sliceStrides.size());
-    for (auto idx : loopOrder) {
-      vectorOffsets[idx] = index / sliceStrides[idx];
-      index %= sliceStrides[idx];
-    }
-    return vectorOffsets;
-  }
-
-  int64_t maxIndex() const { return maxIndexVal; }
-
-  /// Return the offset within d-tuple based on the ordering given by
-  /// `loopOrder`.
-  SmallVector<int64_t> getVectorOffset(int64_t index) const {
-    SmallVector<int64_t> vectorOffsets = delinearize(index);
-    SmallVector<int64_t> elementOffsets =
-        computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
-    return elementOffsets;
-  }
-};
-} // namespace
-
 /// Compute the indices of the slice `index` for a tranfer op.
-static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
-                                               ArrayRef<Value> indices,
-                                               AffineMap permutationMap,
-                                               Location loc,
-                                               OpBuilder &builder) {
+static SmallVector<Value>
+sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
+                     ArrayRef<int64_t> targetShape, ArrayRef<Value> indices,
+                     AffineMap permutationMap, Location loc,
+                     OpBuilder &builder) {
   MLIRContext *ctx = builder.getContext();
   auto isBroadcast = [](AffineExpr expr) {
     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
       return constExpr.getValue() == 0;
     return false;
   };
+  SmallVector<int64_t, 4> elementOffsets =
+      getVectorOffset(originalShape, targetShape, index);
   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
   SmallVector<Value> slicedIndices(indices.begin(), indices.end());
   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
@@ -160,20 +99,6 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
   return targetShape;
 }
 
-static SmallVector<int64_t>
-getUnrollOrder(unsigned numLoops, Operation *op,
-               const vector::UnrollVectorOptions &options) {
-  SmallVector<int64_t> loopOrder =
-      llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
-  if (options.traversalOrderCallback != nullptr) {
-    Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
-    if (order.hasValue()) {
-      loopOrder = std::move(*order);
-    }
-  }
-  return loopOrder;
-}
-
 namespace {
 
 struct UnrollTransferReadPattern
@@ -197,7 +122,8 @@ struct UnrollTransferReadPattern
     Location loc = readOp.getLoc();
     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
-
+    // Compute shape ratio of 'shape' and 'sizes'.
+    int64_t sliceCount = computeMaxLinearIndex(ratio);
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
@@ -205,22 +131,17 @@ struct UnrollTransferReadPattern
         VectorType::get(*targetShape, sourceVectorType.getElementType());
     SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
                                           readOp.getIndices().end());
-
-    SmallVector<int64_t> loopOrder =
-        getUnrollOrder(ratio.size(), readOp, options);
-    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
-                                          loopOrder);
-    for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
-      SmallVector<int64_t, 4> elementOffsets =
-          indexToOffsets.getVectorOffset(i);
+    for (int64_t i = 0; i < sliceCount; i++) {
       SmallVector<Value, 4> indices =
-          sliceTransferIndices(elementOffsets, originalIndices,
+          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
                                readOp.getPermutationMap(), loc, rewriter);
       auto slicedRead = rewriter.create<vector::TransferReadOp>(
           loc, targetType, readOp.getSource(), indices,
           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
           readOp.getInBoundsAttr());
 
+      SmallVector<int64_t, 4> elementOffsets =
+          getVectorOffset(originalSize, *targetShape, i);
       result = rewriter.create<vector::InsertStridedSliceOp>(
           loc, slicedRead, result, elementOffsets, strides);
     }
@@ -253,21 +174,20 @@ struct UnrollTransferWritePattern
     SmallVector<int64_t, 4> strides(targetShape->size(), 1);
     Location loc = writeOp.getLoc();
     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+    // Compute shape ratio of 'shape' and 'sizes'.
+    int64_t sliceCount = computeMaxLinearIndex(ratio);
     SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
                                           writeOp.getIndices().end());
-
-    SmallVector<int64_t> loopOrder =
-        getUnrollOrder(originalIndices.size(), writeOp, options);
-    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
-                                          loopOrder);
     Value resultTensor;
-    for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
+    for (int64_t i = 0; i < sliceCount; i++) {
       SmallVector<int64_t, 4> elementOffsets =
-          indexToOffsets.getVectorOffset(i);
+          getVectorOffset(originalSize, *targetShape, i);
       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
+
       SmallVector<Value, 4> indices =
-          sliceTransferIndices(elementOffsets, originalIndices,
+          sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
                                writeOp.getPermutationMap(), loc, rewriter);
       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
           loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
@@ -318,6 +238,8 @@ struct UnrollContractionPattern
     SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
     SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
 
+    // Compute shape ratio of 'shape' and 'sizes'.
+    int64_t sliceCount = computeMaxLinearIndex(ratio);
     Location loc = contractOp.getLoc();
     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
     AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
@@ -325,14 +247,9 @@ struct UnrollContractionPattern
         SmallVector<int64_t>, Value,
         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
         accCache;
-
-    SmallVector<int64_t> loopOrder = getUnrollOrder(
-        contractOp.getIndexingMaps().size(), contractOp, options);
-    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
-                                          loopOrder);
-    const int64_t sliceCount = indexToOffsets.maxIndex();
     for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
+      SmallVector<int64_t, 4> offsets =
+          getVectorOffset(originalSize, *targetShape, i);
       SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
 
       // Helper to coompute the new shape of each operand and extract the slice.

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index f6f218b6e39eb..3d0affd2a4be0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -1,5 +1,4 @@
 // RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns=reverse-unroll-order --split-input-file | FileCheck %s --check-prefix=ORDER
 
 // CHECK-LABEL: func @transfer_read_unroll
 //       CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
@@ -14,19 +13,6 @@
 //  CHECK-NEXT:   %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
 //  CHECK-NEXT:   return %[[VEC3]] : vector<4x4xf32>
 
-// ORDER-LABEL: func @transfer_read_unroll
-//       ORDER-DAG:   %[[C2:.*]] = arith.constant 2 : index
-//       ORDER-DAG:   %[[C0:.*]] = arith.constant 0 : index
-//       ORDER:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-//  ORDER-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-//  ORDER-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-//  ORDER-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
-//  ORDER-NEXT:   return %[[VEC3]] : vector<4x4xf32>
-
 func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
@@ -47,19 +33,6 @@ func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
 //  CHECK-NEXT:   vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
 //  CHECK-NEXT:   return
 
-// ORDER-LABEL: func @transfer_write_unroll
-//       ORDER-DAG:   %[[C2:.*]] = arith.constant 2 : index
-//       ORDER-DAG:   %[[C0:.*]] = arith.constant 0 : index
-//       ORDER:   %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-//  ORDER-NEXT:   vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
-//  ORDER-NEXT:   %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-//  ORDER-NEXT:   vector.transfer_write %[[S1]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
-//  ORDER-NEXT:   %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-//  ORDER-NEXT:   vector.transfer_write %[[S2]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
-//  ORDER-NEXT:   %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
-//  ORDER-NEXT:   vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
-//  ORDER-NEXT:   return
-
 func.func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
@@ -249,25 +222,6 @@ func.func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) ->
 //  CHECK-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
 //  CHECK-NEXT:   %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
 //  CHECK-NEXT:   return %[[VEC5]] : vector<6x4xf32>
-
-// ORDER-LABEL: func @transfer_read_unroll_
diff erent_rank
-//       ORDER-DAG:   %[[C4:.*]] = arith.constant 4 : index
-//       ORDER-DAG:   %[[C2:.*]] = arith.constant 2 : index
-//       ORDER-DAG:   %[[C0:.*]] = arith.constant 0 : index
-//       ORDER:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
-//  ORDER-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
-//  ORDER-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
-//  ORDER-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
-//  ORDER-NEXT:   %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
-//  ORDER-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
-//  ORDER-NEXT:   %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32>
-//  ORDER-NEXT:   return %[[VEC5]] : vector<6x4xf32>
-
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
 func.func @transfer_read_unroll_
diff erent_rank(%arg0 : memref<?x?x?xf32>) -> vector<6x4xf32> {
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 272c04a34e8ea..3b0aeb48665f4 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -1,156 +1,50 @@
 // RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
-// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=2,0,1"  --split-input-file | FileCheck %s --check-prefix=ORDER
 
-func.func @vector_contract_f32(%lhs : vector<8x4xf32>, %rhs : vector<8x4xf32>,
+func.func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>,
                           %init : vector<8x8xf32>) -> vector<8x8xf32> {
   %0 = vector.contract
          {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
                            affine_map<(i, j, k) -> (j, k)>,
                            affine_map<(i, j, k) -> (i, j)>],
           iterator_types = ["parallel", "parallel", "reduction"]}
-       %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32>
+       %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
   return %0 : vector<8x8xf32>
 }
 // CHECK-LABEL: func @vector_contract_f32
-// CHECK-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
-
-//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  CHECK-SAME:   offsets = [0, 0]
-//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  CHECK-SAME:   offsets = [0, 0]
-//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
-//  CHECK-SAME:   offsets = [0, 0]
-//       CHECK:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//       CHECK:   vector.contract {
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  CHECK-SAME:   offsets = [0, 2]
-//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  CHECK-SAME:   offsets = [0, 2]
-//       CHECK:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
+//       CHECK:   vector.contract {
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  CHECK-SAME:   offsets = [0, 0]
-//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  CHECK-SAME:   offsets = [4, 0]
-//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
-//  CHECK-SAME:   offsets = [0, 4]
-//       CHECK:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//       CHECK:   vector.contract {
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  CHECK-SAME:   offsets = [0, 2]
-//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  CHECK-SAME:   offsets = [4, 2]
-//       CHECK:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
+//       CHECK:   vector.contract {
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  CHECK-SAME:   offsets = [4, 0]
-//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  CHECK-SAME:   offsets = [0, 0]
-//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
-//  CHECK-SAME:   offsets = [4, 0]
-//       CHECK:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//       CHECK:   vector.contract {
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  CHECK-SAME:   offsets = [4, 2]
-//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  CHECK-SAME:   offsets = [0, 2]
-//       CHECK:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum5]]
+//       CHECK:   vector.contract {
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  CHECK-SAME:   offsets = [4, 0]
-//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  CHECK-SAME:   offsets = [4, 0]
-//       CHECK:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
-//  CHECK-SAME:   offsets = [4, 4]
-//       CHECK:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
+//       CHECK:   vector.contract {
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       CHECK:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  CHECK-SAME:   offsets = [4, 2]
-//       CHECK:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  CHECK-SAME:   offsets = [4, 2]
-//       CHECK:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum7]]
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
 //  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
 //       CHECK:   return
 
-// ORDER-LABEL: func @vector_contract_f32
-// ORDER-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32>
-
-//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  ORDER-SAME:   offsets = [0, 0]
-//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  ORDER-SAME:   offsets = [0, 0]
-//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
-//  ORDER-SAME:   offsets = [0, 0]
-//       ORDER:   [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
-//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  ORDER-SAME:   offsets = [0, 0]
-//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  ORDER-SAME:   offsets = [4, 0]
-//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
-//  ORDER-SAME:   offsets = [0, 4]
-//       ORDER:   [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
-//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  ORDER-SAME:   offsets = [4, 0]
-//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  ORDER-SAME:   offsets = [0, 0]
-//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
-//  ORDER-SAME:   offsets = [4, 0]
-//       ORDER:   [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
-//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  ORDER-SAME:   offsets = [4, 0]
-//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  ORDER-SAME:   offsets = [4, 0]
-//       ORDER:   [[c:%.+]] = vector.extract_strided_slice [[arg2]] 
-//  ORDER-SAME:   offsets = [4, 4]
-//       ORDER:   [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]]
-//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  ORDER-SAME:   offsets = [0, 2]
-//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  ORDER-SAME:   offsets = [0, 2]
-//       ORDER:   [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]]
-//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  ORDER-SAME:   offsets = [0, 2]
-//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  ORDER-SAME:   offsets = [4, 2]
-//       ORDER:   [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum2]]
-//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  ORDER-SAME:   offsets = [4, 2]
-//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  ORDER-SAME:   offsets = [0, 2]
-//       ORDER:   [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]]
-//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       ORDER:   [[a:%.+]] = vector.extract_strided_slice [[arg0]] 
-//  ORDER-SAME:   offsets = [4, 2]
-//       ORDER:   [[b:%.+]] = vector.extract_strided_slice [[arg1]] 
-//  ORDER-SAME:   offsets = [4, 2]
-//       ORDER:   [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum4]]
-//  ORDER-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
-
-//       ORDER:   return
-
-
-
 func.func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
                           %init : vector<8x8xf16>) -> vector<8x8xf16> {
   %0 = vector.contract
@@ -264,4 +158,3 @@ func.func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
 //       CHECK:   %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
 //       CHECK:   %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
 //       CHECK:   return %[[V7]] : vector<2x3x8x4xf32>
-

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index cbec267734eba..a81aa536df4ad 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -18,7 +18,6 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
@@ -323,18 +322,12 @@ struct TestVectorUnrollingPatterns
         }
         return nativeShape;
       };
-
-      UnrollVectorOptions opts;
-      opts.setNativeShapeFn(nativeShapeFn)
-          .setFilterConstraint(
-              [](Operation *op) { return success(isa<ContractionOp>(op)); });
-      if (!unrollOrder.empty()) {
-        opts.setUnrollTraversalOrderFn([this](Operation *op)
-                                           -> Optional<SmallVector<int64_t>> {
-          return SmallVector<int64_t>{unrollOrder.begin(), unrollOrder.end()};
-        });
-      }
-      populateVectorUnrollPatterns(patterns, opts);
+      populateVectorUnrollPatterns(patterns,
+                                   UnrollVectorOptions()
+                                       .setNativeShapeFn(nativeShapeFn)
+                                       .setFilterConstraint([](Operation *op) {
+                                         return success(isa<ContractionOp>(op));
+                                       }));
     } else {
       populateVectorUnrollPatterns(
           patterns, UnrollVectorOptions()
@@ -347,10 +340,6 @@ struct TestVectorUnrollingPatterns
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
-  ListOption<int64_t> unrollOrder{*this, "unroll-order",
-                                  llvm::cl::desc("set the unroll order"),
-                                  llvm::cl::ZeroOrMore};
-
   Option<bool> unrollBasedOnType{
       *this, "unroll-based-on-type",
       llvm::cl::desc("Set the unroll factor based on type of the operation"),
@@ -483,11 +472,6 @@ struct TestVectorTransferUnrollingPatterns
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
       TestVectorTransferUnrollingPatterns)
 
-  TestVectorTransferUnrollingPatterns() = default;
-  TestVectorTransferUnrollingPatterns(
-      const TestVectorTransferUnrollingPatterns &pass)
-      : PassWrapper(pass) {}
-
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<AffineDialect>();
   }
@@ -501,36 +485,17 @@ struct TestVectorTransferUnrollingPatterns
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    UnrollVectorOptions opts;
-    opts.setNativeShape(ArrayRef<int64_t>{2, 2})
-        .setFilterConstraint([](Operation *op) {
-          return success(
-              isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
-        });
-    if (reverseUnrollOrder.getValue()) {
-      opts.setUnrollTraversalOrderFn(
-          [](Operation *op) -> Optional<SmallVector<int64_t>> {
-            int64_t numLoops = 0;
-            if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
-              numLoops = readOp.getVectorType().getRank();
-            else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
-              numLoops = writeOp.getVectorType().getRank();
-            else
-              return None;
-            auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
-            return llvm::to_vector(order);
-          });
-    }
-    populateVectorUnrollPatterns(patterns, opts);
+    populateVectorUnrollPatterns(
+        patterns,
+        UnrollVectorOptions()
+            .setNativeShape(ArrayRef<int64_t>{2, 2})
+            .setFilterConstraint([](Operation *op) {
+              return success(
+                  isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
+            }));
     populateVectorToVectorCanonicalizationPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
-
-  Option<bool> reverseUnrollOrder{
-      *this, "reverse-unroll-order",
-      llvm::cl::desc(
-          "reverse the order of unrolling of vector transfer operations"),
-      llvm::cl::init(false)};
 };
 
 struct TestVectorTransferFullPartialSplitPatterns


        


More information about the Mlir-commits mailing list