[Mlir-commits] [mlir] [mlir][Vector] Add vector bitwidth target to xfer op flattening (PR #81966)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 15 19:58:44 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

This PR adds an optional bitwidth parameter to the vector xfer op flattening transformation so that the flattening doesn't happen if the trailing dimension of the read/writen vector is larger than this bitwidth (i.e., we are already able to fill at least one vector register with that size).

---
Full diff: https://github.com/llvm/llvm-project/pull/81966.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+7-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+40-5) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+34-2) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..cb3b3de8051d6f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -328,8 +328,13 @@ void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
 /// to transform multiple small n-D transfers into a larger 1-D transfer where
 /// the memref contiguity properties allow it.
-void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
-                                           PatternBenefit benefit = 1);
+///
+/// Flattening is only applied if the bitwidth of the trailing vector dimension
+/// is smaller or equal to `targetVectorBitwidth`.
+void populateFlattenVectorTransferPatterns(
+    RewritePatternSet &patterns,
+    unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1);
 
 /// Collect a set of patterns that bubble up/down bitcast ops.
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b761d1ed888973..04e5a816dd91e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -19,7 +19,6 @@
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
@@ -535,9 +534,17 @@ namespace {
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
 class FlattenContiguousRowMajorTransferReadPattern
     : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+public:
+  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
+                                               unsigned vectorBitwidth,
+                                               PatternBenefit benefit)
+      : OpRewritePattern<vector::TransferReadOp>(context, benefit),
+        targetVectorBitwidth(vectorBitwidth) {}
 
   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
                                 PatternRewriter &rewriter) const override {
@@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern
     // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
       return failure();
+    if (!vectorType.getElementType().isSignlessIntOrFloat())
+      return failure();
+    unsigned trailingVectorDimBitwidth =
+        vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
+    if (trailingVectorDimBitwidth >= targetVectorBitwidth)
+      return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
@@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern
         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
     return success();
   }
+
+private:
+  // Minimum bitwidth that the trailing vector dimension should have after
+  // flattening.
+  unsigned targetVectorBitwidth;
 };
 
 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
@@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// already reduced i.e. without unit dims.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+public:
+  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
+                                                unsigned vectorBitwidth,
+                                                PatternBenefit benefit)
+      : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
+        targetVectorBitwidth(vectorBitwidth) {}
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
                                 PatternRewriter &rewriter) const override {
@@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
+    if (!vectorType.getElementType().isSignlessIntOrFloat())
+      return failure();
+    unsigned trailingVectorDimBitwidth =
+        vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
+    if (trailingVectorDimBitwidth >= targetVectorBitwidth)
+      return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
@@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern
     rewriter.eraseOp(transferWriteOp);
     return success();
   }
+
+private:
+  // Minimum bitwidth that the trailing vector dimension should have after
+  // flattening.
+  unsigned targetVectorBitwidth;
 };
 
 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
@@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
 }
 
 void mlir::vector::populateFlattenVectorTransferPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
+    RewritePatternSet &patterns, unsigned targetVectorBitwidth,
+    PatternBenefit benefit) {
   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
                FlattenContiguousRowMajorTransferWritePattern>(
-      patterns.getContext(), benefit);
+      patterns.getContext(), targetVectorBitwidth, benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
   populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9976048a3320b6..5ba3ac824770ce 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -66,7 +66,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -99,7 +99,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -389,3 +389,35 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
+      %arg : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>) -> vector<5x4x3x20xi32> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i32
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>, vector<5x4x3x20xi32>
+    return %v : vector<5x4x3x20xi32>
+}
+
+// CHECK-LABEL:  func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
+//   CHECK-NOT:    tensor.collapse_shape
+
+// -----
+
+func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
+      %arg0 : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>,
+      %arg1 : vector<5x4x3x20xi32>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] :
+      vector<5x4x3x20xi32>, memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>
+    return
+}
+
+// CHECK-LABEL:  func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
+//   CHECK-NOT:    tensor.collapse_shape
+
+
+
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 126d65b1b8487f..57d104e80d7243 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -480,7 +480,8 @@ struct TestFlattenVectorTransferPatterns
   }
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateFlattenVectorTransferPatterns(patterns);
+    constexpr unsigned targetVectorBitwidth = 512;
+    populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };

``````````

</details>


https://github.com/llvm/llvm-project/pull/81966


More information about the Mlir-commits mailing list