[Mlir-commits] [mlir] [mlir][Vector] Add vector bitwidth target to xfer op flattening (PR #81966)

Diego Caballero llvmlistbot at llvm.org
Fri Feb 16 18:25:21 PST 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/81966

>From 9029a3828da29d68809668596b73b664e03b8b5c Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 15 Feb 2024 19:18:12 +0000
Subject: [PATCH 1/3] [mlir][Vector] Add vector bitwidth target to xfer op
 flattening

This PR adds an optional bitwidth parameter to the vector xfer op
flattening transformation so that the flattening doesn't happen if the
trailing dimension of the read/writen vector is larger than this
bitwidth (i.e., we are already able to fill at least one vector
register with that size).
---
 .../Vector/Transforms/VectorRewritePatterns.h |  9 +++-
 .../Transforms/VectorTransferOpTransforms.cpp | 45 ++++++++++++++++---
 .../Vector/vector-transfer-flatten.mlir       | 36 ++++++++++++++-
 .../Dialect/Vector/TestVectorTransforms.cpp   |  3 +-
 4 files changed, 83 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..cb3b3de8051d6f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -328,8 +328,13 @@ void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
 /// to transform multiple small n-D transfers into a larger 1-D transfer where
 /// the memref contiguity properties allow it.
-void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
-                                           PatternBenefit benefit = 1);
+///
+/// Flattening is only applied if the bitwidth of the trailing vector dimension
+/// is smaller or equal to `targetVectorBitwidth`.
+void populateFlattenVectorTransferPatterns(
+    RewritePatternSet &patterns,
+    unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1);
 
 /// Collect a set of patterns that bubble up/down bitcast ops.
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b761d1ed888973..04e5a816dd91e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -19,7 +19,6 @@
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
@@ -535,9 +534,17 @@ namespace {
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
 class FlattenContiguousRowMajorTransferReadPattern
     : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+public:
+  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
+                                               unsigned vectorBitwidth,
+                                               PatternBenefit benefit)
+      : OpRewritePattern<vector::TransferReadOp>(context, benefit),
+        targetVectorBitwidth(vectorBitwidth) {}
 
   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
                                 PatternRewriter &rewriter) const override {
@@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern
     // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
       return failure();
+    if (!vectorType.getElementType().isSignlessIntOrFloat())
+      return failure();
+    unsigned trailingVectorDimBitwidth =
+        vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
+    if (trailingVectorDimBitwidth >= targetVectorBitwidth)
+      return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
@@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern
         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
     return success();
   }
+
+private:
+  // Minimum bitwidth that the trailing vector dimension should have after
+  // flattening.
+  unsigned targetVectorBitwidth;
 };
 
 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
@@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// already reduced i.e. without unit dims.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+public:
+  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
+                                                unsigned vectorBitwidth,
+                                                PatternBenefit benefit)
+      : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
+        targetVectorBitwidth(vectorBitwidth) {}
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
                                 PatternRewriter &rewriter) const override {
@@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
+    if (!vectorType.getElementType().isSignlessIntOrFloat())
+      return failure();
+    unsigned trailingVectorDimBitwidth =
+        vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
+    if (trailingVectorDimBitwidth >= targetVectorBitwidth)
+      return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
@@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern
     rewriter.eraseOp(transferWriteOp);
     return success();
   }
+
+private:
+  // Minimum bitwidth that the trailing vector dimension should have after
+  // flattening.
+  unsigned targetVectorBitwidth;
 };
 
 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
@@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
 }
 
 void mlir::vector::populateFlattenVectorTransferPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
+    RewritePatternSet &patterns, unsigned targetVectorBitwidth,
+    PatternBenefit benefit) {
   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
                FlattenContiguousRowMajorTransferWritePattern>(
-      patterns.getContext(), benefit);
+      patterns.getContext(), targetVectorBitwidth, benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
   populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9976048a3320b6..5ba3ac824770ce 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -66,7 +66,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -99,7 +99,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -389,3 +389,35 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
+      %arg : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>) -> vector<5x4x3x20xi32> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i32
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>, vector<5x4x3x20xi32>
+    return %v : vector<5x4x3x20xi32>
+}
+
+// CHECK-LABEL:  func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
+//   CHECK-NOT:    tensor.collapse_shape
+
+// -----
+
+func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
+      %arg0 : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>,
+      %arg1 : vector<5x4x3x20xi32>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] :
+      vector<5x4x3x20xi32>, memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>
+    return
+}
+
+// CHECK-LABEL:  func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
+//   CHECK-NOT:    tensor.collapse_shape
+
+
+
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 126d65b1b8487f..57d104e80d7243 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -480,7 +480,8 @@ struct TestFlattenVectorTransferPatterns
   }
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateFlattenVectorTransferPatterns(patterns);
+    constexpr unsigned targetVectorBitwidth = 512;
+    populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };

>From 283794a8ed2fe7e000cff42472241fdfacdd5d92 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 16 Feb 2024 23:06:38 +0000
Subject: [PATCH 2/3] Flag

---
 .../Dialect/Vector/vector-transfer-flatten.mlir   |  5 +----
 .../lib/Dialect/Vector/TestVectorTransforms.cpp   | 15 ++++++++++++++-
 2 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 5ba3ac824770ce..5c9338c87ffe2e 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=512 -split-input-file | FileCheck %s
 
 func.func @transfer_read_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
@@ -418,6 +418,3 @@ func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
 // CHECK-LABEL:  func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
 //   CHECK-NOT:    tensor.collapse_shape
 
-
-
-
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 57d104e80d7243..db0f3550763ead 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -466,21 +466,34 @@ struct TestFlattenVectorTransferPatterns
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
       TestFlattenVectorTransferPatterns)
 
+  TestFlattenVectorTransferPatterns() = default;
+  TestFlattenVectorTransferPatterns(
+      const TestFlattenVectorTransferPatterns &pass)
+      : PassWrapper(pass) {}
+
   StringRef getArgument() const final {
     return "test-vector-transfer-flatten-patterns";
   }
+
   StringRef getDescription() const final {
     return "Test patterns to rewrite contiguous row-major N-dimensional "
            "vector.transfer_{read,write} ops into 1D transfers";
   }
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<memref::MemRefDialect>();
     registry.insert<affine::AffineDialect>();
     registry.insert<vector::VectorDialect>();
   }
+
+  Option<unsigned> targetVectorBitwidth{
+      *this, "target-vector-bitwidth",
+      llvm::cl::desc(
+          "Minimum vector bitwidth to enable the flattening transformation"),
+      llvm::cl::init(512)};
+
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    constexpr unsigned targetVectorBitwidth = 512;
     populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }

>From 3b0ad690a5d14d0407eed48f5f2017b478dd222b Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Sat, 17 Feb 2024 02:21:35 +0000
Subject: [PATCH 3/3] Improve testing

---
 .../Vector/vector-transfer-flatten.mlir       | 103 ++++++++++++------
 .../Dialect/Vector/TestVectorTransforms.cpp   |   2 +-
 2 files changed, 73 insertions(+), 32 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 5c9338c87ffe2e..1775b5fa4a346a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=512 -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
 
 func.func @transfer_read_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
@@ -16,6 +17,9 @@ func.func @transfer_read_dims_match_contiguous(
 // CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
 // CHECK:         return %[[VEC2D]]
 
+// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_dims_match_contiguous_empty_stride(
@@ -27,13 +31,16 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
     return %v : vector<5x4x3x2xi8>
 }
 
-// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride
+// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
 // CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
 // CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
 // CHECK:         return %[[VEC2D]]
 
+// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 // The shape of the memref and the vector don't match, but the vector is a
@@ -57,6 +64,9 @@ func.func @transfer_read_dims_mismatch_contiguous(
 // CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
 // CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_dims_mismatch_non_zero_indices(
@@ -87,6 +97,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK:           %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
 // CHECK:           vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 // The input memref has a dynamic trailing shape and hence is not flattened.
@@ -115,6 +128,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
 // CHECK:           vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_dims_mismatch_non_contiguous(
@@ -130,6 +146,9 @@ func.func @transfer_read_dims_mismatch_non_contiguous(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
@@ -141,10 +160,13 @@ func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
     return %v : vector<2x1x2x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_dims_match_contiguous(
@@ -155,13 +177,16 @@ func.func @transfer_write_dims_match_contiguous(
     return
 }
 
-// CHECK-LABEL: func @transfer_write_dims_match_contiguous
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous(
 // CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
 // CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
 // CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
 
+// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_dims_mismatch_contiguous(
@@ -182,6 +207,9 @@ func.func @transfer_write_dims_mismatch_contiguous(
 // CHECK:           return
 // CHECK:         }
 
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_dims_mismatch_non_contiguous(
@@ -196,6 +224,9 @@ func.func @transfer_write_dims_mismatch_non_contiguous(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
@@ -207,6 +238,10 @@ func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_write_0d(
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
 // -----
 
 func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
@@ -219,6 +254,10 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
+// CHECK-128B-LABEL: func @transfer_read_0d(
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
 // -----
 
 func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
@@ -241,6 +280,9 @@ func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memre
 // CHECK:       %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
 // CHECK:       return %[[VEC2D]] : vector<8x4xi8>
 
+// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) {
@@ -260,6 +302,9 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
 // CHECK-SAME:    {in_bounds = [true]}
 // CHECK-SAME:    : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
 
+// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
+//       CHECK-128B:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_flattenable_negative(
@@ -274,6 +319,9 @@ func.func @transfer_read_flattenable_negative(
 // CHECK-LABEL: func @transfer_read_flattenable_negative
 //       CHECK:   vector.transfer_read {{.*}} vector<2x2x2x2xi8>
 
+// CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @transfer_read_flattenable_negative2(
@@ -288,6 +336,9 @@ func.func @transfer_read_flattenable_negative2(
 // CHECK-LABEL: func @transfer_read_flattenable_negative2
 //       CHECK:   vector.transfer_read {{.*}} vector<5x4x3x2xi8>
 
+// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
@@ -302,6 +353,9 @@ func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
 // CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32>
 // CHECK:           return %[[VAL_4]] : vector<1x8xi32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_add_basic(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) -> vector<1x8x1xi32> {
@@ -316,6 +370,9 @@ func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) ->
 // CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32>
 // CHECK:           return %[[VAL_4]] : vector<1x8x1xi32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_add_leading_and_trailing(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
@@ -334,6 +391,9 @@ func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_add(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
@@ -352,6 +412,9 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
 // CHECK:           %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
 // CHECK:           return %[[VAL_4]] : vector<8x[2]xf32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_mulf(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
@@ -367,6 +430,9 @@ func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32>
 // CHECK:           %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
 // CHECK:           return %[[VAL_2]] : vector<8x[2]xf32>
 
+// CHECK-128B-LABEL: func @fold_unit_dim_sitofp(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
 // -----
 
 // All shape casts are folded away
@@ -390,31 +456,6 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
 
-// -----
-
-func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
-      %arg : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>) -> vector<5x4x3x20xi32> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i32
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>, vector<5x4x3x20xi32>
-    return %v : vector<5x4x3x20xi32>
-}
-
-// CHECK-LABEL:  func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
-//   CHECK-NOT:    tensor.collapse_shape
-
-// -----
-
-func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
-      %arg0 : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>,
-      %arg1 : vector<5x4x3x20xi32>) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] :
-      vector<5x4x3x20xi32>, memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>
-    return
-}
-
-// CHECK-LABEL:  func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
-//   CHECK-NOT:    tensor.collapse_shape
+// CHECK-128B-LABEL: func @fold_unit_dims_entirely(
+//   CHECK-128B-NOT:   memref.collapse_shape
 
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index db0f3550763ead..e27156856a15c2 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -490,7 +490,7 @@ struct TestFlattenVectorTransferPatterns
       *this, "target-vector-bitwidth",
       llvm::cl::desc(
           "Minimum vector bitwidth to enable the flattening transformation"),
-      llvm::cl::init(512)};
+      llvm::cl::init(std::numeric_limits<unsigned>::max())};
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());



More information about the Mlir-commits mailing list