[Mlir-commits] [mlir] 71441ed - [mlir][Vector] Add vector bitwidth target to xfer op flattening (#81966)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 21 09:22:52 PST 2024


Author: Diego Caballero
Date: 2024-02-21T09:22:48-08:00
New Revision: 71441ed1716e6ed3f053dea9c1ceb9cfe2822aea

URL: https://github.com/llvm/llvm-project/commit/71441ed1716e6ed3f053dea9c1ceb9cfe2822aea
DIFF: https://github.com/llvm/llvm-project/commit/71441ed1716e6ed3f053dea9c1ceb9cfe2822aea.diff

LOG: [mlir][Vector] Add vector bitwidth target to xfer op flattening (#81966)

This PR adds an optional bitwidth parameter to the vector xfer op
flattening transformation so that the flattening doesn't happen if the
trailing dimension of the read/writen vector is larger than this
bitwidth (i.e., we are already able to fill at least one vector register
with that size).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7c943f07066c70..46bb3ddec0baf6 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -330,8 +330,13 @@ void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
 /// to transform multiple small n-D transfers into a larger 1-D transfer where
 /// the memref contiguity properties allow it.
-void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
-                                           PatternBenefit benefit = 1);
+///
+/// Flattening is only applied if the bitwidth of the trailing vector dimension
+/// is smaller or equal to `targetVectorBitwidth`.
+void populateFlattenVectorTransferPatterns(
+    RewritePatternSet &patterns,
+    unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1);
 
 /// Collect a set of patterns that bubble up/down bitcast ops.
 ///

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b761d1ed888973..04e5a816dd91e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -19,7 +19,6 @@
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
@@ -535,9 +534,17 @@ namespace {
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
 class FlattenContiguousRowMajorTransferReadPattern
     : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+public:
+  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
+                                               unsigned vectorBitwidth,
+                                               PatternBenefit benefit)
+      : OpRewritePattern<vector::TransferReadOp>(context, benefit),
+        targetVectorBitwidth(vectorBitwidth) {}
 
   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
                                 PatternRewriter &rewriter) const override {
@@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern
     // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
       return failure();
+    if (!vectorType.getElementType().isSignlessIntOrFloat())
+      return failure();
+    unsigned trailingVectorDimBitwidth =
+        vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
+    if (trailingVectorDimBitwidth >= targetVectorBitwidth)
+      return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
@@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern
         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
     return success();
   }
+
+private:
+  // Minimum bitwidth that the trailing vector dimension should have after
+  // flattening.
+  unsigned targetVectorBitwidth;
 };
 
 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
@@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// already reduced i.e. without unit dims.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+public:
+  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
+                                                unsigned vectorBitwidth,
+                                                PatternBenefit benefit)
+      : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
+        targetVectorBitwidth(vectorBitwidth) {}
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
                                 PatternRewriter &rewriter) const override {
@@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
+    if (!vectorType.getElementType().isSignlessIntOrFloat())
+      return failure();
+    unsigned trailingVectorDimBitwidth =
+        vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
+    if (trailingVectorDimBitwidth >= targetVectorBitwidth)
+      return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
@@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern
     rewriter.eraseOp(transferWriteOp);
     return success();
   }
+
+private:
+  // Minimum bitwidth that the trailing vector dimension should have after
+  // flattening.
+  unsigned targetVectorBitwidth;
 };
 
 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
@@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
 }
 
 void mlir::vector::populateFlattenVectorTransferPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
+    RewritePatternSet &patterns, unsigned targetVectorBitwidth,
+    PatternBenefit benefit) {
   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
                FlattenContiguousRowMajorTransferWritePattern>(
-      patterns.getContext(), benefit);
+      patterns.getContext(), targetVectorBitwidth, benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
   populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9976048a3320b6..1775b5fa4a346a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
 
 func.func @transfer_read_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
@@ -16,6 +17,9 @@ func.func @transfer_read_dims_match_contiguous(
 // CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
 // CHECK:         return %[[VEC2D]]
 
+// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_dims_match_contiguous_empty_stride(
@@ -27,13 +31,16 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
     return %v : vector<5x4x3x2xi8>
 }
 
-// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride
+// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
 // CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
 // CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
 // CHECK:         return %[[VEC2D]]
 
+// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 // The shape of the memref and the vector don't match, but the vector is a
@@ -57,6 +64,9 @@ func.func @transfer_read_dims_mismatch_contiguous(
 // CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
 // CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_dims_mismatch_non_zero_indices(
@@ -66,7 +76,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -87,6 +97,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK:           %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
 // CHECK:           vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 // The input memref has a dynamic trailing shape and hence is not flattened.
@@ -99,7 +112,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -115,6 +128,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
 // CHECK:           vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_dims_mismatch_non_contiguous(
@@ -130,6 +146,9 @@ func.func @transfer_read_dims_mismatch_non_contiguous(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
@@ -141,10 +160,13 @@ func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
     return %v : vector<2x1x2x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_dims_match_contiguous(
@@ -155,13 +177,16 @@ func.func @transfer_write_dims_match_contiguous(
     return
 }
 
-// CHECK-LABEL: func @transfer_write_dims_match_contiguous
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous(
 // CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
 // CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
 // CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
 
+// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_dims_mismatch_contiguous(
@@ -182,6 +207,9 @@ func.func @transfer_write_dims_mismatch_contiguous(
 // CHECK:           return
 // CHECK:         }
 
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_dims_mismatch_non_contiguous(
@@ -196,6 +224,9 @@ func.func @transfer_write_dims_mismatch_non_contiguous(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
@@ -207,6 +238,10 @@ func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_write_0d(
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
 // -----
 
 func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
@@ -219,6 +254,10 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_read_0d(
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
 // -----
 
 func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
@@ -241,6 +280,9 @@ func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memre
 // CHECK:       %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
 // CHECK:       return %[[VEC2D]] : vector<8x4xi8>
 
+// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) {
@@ -260,6 +302,9 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
 // CHECK-SAME:    {in_bounds = [true]}
 // CHECK-SAME:    : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
 
+// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_flattenable_negative(
@@ -274,6 +319,9 @@ func.func @transfer_read_flattenable_negative(
 // CHECK-LABEL: func @transfer_read_flattenable_negative
 //       CHECK:   vector.transfer_read {{.*}} vector<2x2x2x2xi8>
 
+// CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_flattenable_negative2(
@@ -288,6 +336,9 @@ func.func @transfer_read_flattenable_negative2(
 // CHECK-LABEL: func @transfer_read_flattenable_negative2
 //       CHECK:   vector.transfer_read {{.*}} vector<5x4x3x2xi8>
 
+// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
@@ -302,6 +353,9 @@ func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
 // CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32>
 // CHECK:           return %[[VAL_4]] : vector<1x8xi32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_add_basic(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) -> vector<1x8x1xi32> {
@@ -316,6 +370,9 @@ func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) ->
 // CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32>
 // CHECK:           return %[[VAL_4]] : vector<1x8x1xi32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_add_leading_and_trailing(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
@@ -334,6 +391,9 @@ func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_add(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
@@ -352,6 +412,9 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
 // CHECK:           %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
 // CHECK:           return %[[VAL_4]] : vector<8x[2]xf32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_mulf(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
@@ -367,6 +430,9 @@ func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32>
 // CHECK:           %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
 // CHECK:           return %[[VAL_2]] : vector<8x[2]xf32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_sitofp(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 // All shape casts are folded away
@@ -389,3 +455,7 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// CHECK-128B-LABEL: func @fold_unit_dims_entirely(
+//   CHECK-128B-NOT:   memref.collapse_shape
+

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index acd38980514a56..178a58e796b246 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -466,21 +466,35 @@ struct TestFlattenVectorTransferPatterns
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
       TestFlattenVectorTransferPatterns)
 
+  TestFlattenVectorTransferPatterns() = default;
+  TestFlattenVectorTransferPatterns(
+      const TestFlattenVectorTransferPatterns &pass)
+      : PassWrapper(pass) {}
+
   StringRef getArgument() const final {
     return "test-vector-transfer-flatten-patterns";
   }
+
   StringRef getDescription() const final {
     return "Test patterns to rewrite contiguous row-major N-dimensional "
            "vector.transfer_{read,write} ops into 1D transfers";
   }
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<memref::MemRefDialect>();
     registry.insert<affine::AffineDialect>();
     registry.insert<vector::VectorDialect>();
   }
+
+  Option<unsigned> targetVectorBitwidth{
+      *this, "target-vector-bitwidth",
+      llvm::cl::desc(
+          "Minimum vector bitwidth to enable the flattening transformation"),
+      llvm::cl::init(std::numeric_limits<unsigned>::max())};
+
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateFlattenVectorTransferPatterns(patterns);
+    populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list