[Mlir-commits] [mlir] 2eb9e33 - [mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) (#73523)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 5 00:36:03 PST 2023


Author: Andrzej WarzyƄski
Date: 2023-12-05T08:35:58Z
New Revision: 2eb9e33cc57d5acc2232d468a99f0e35c8f583dc

URL: https://github.com/llvm/llvm-project/commit/2eb9e33cc57d5acc2232d468a99f0e35c8f583dc
DIFF: https://github.com/llvm/llvm-project/commit/2eb9e33cc57d5acc2232d468a99f0e35c8f583dc.diff

LOG: [mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) (#73523)

Updates patterns for flattening `vector.transfer_read` by relaxing the
requirement that the "collapsed" indices are all zero. This enables
collapsing cases like this one:

```mlir
  %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>
```

Previously only the following case would be consider for collapsing
(all indices are 0):

```mlir
  %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>
```

Also adds some new comments and renames the `firstContiguousInnerDim`
parameter as `firstDimToCollapse` (the latter better matches the actual
meaning).

Similar updates for `vector.transfer_write` will be implemented in a
follow-up patch.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index aab7075006031..ed42e6508b431 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -511,6 +511,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
 /// Checks that the indices corresponding to dimensions starting at
 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
+/// TODO: Extract the logic that writes to outIndices so that this method
+/// simply checks one pre-condition.
 static LogicalResult
 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
                                  SmallVector<Value> &outIndices) {
@@ -542,18 +544,18 @@ class FlattenContiguousRowMajorTransferReadPattern
     auto loc = transferReadOp.getLoc();
     Value vector = transferReadOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
-    Value source = transferReadOp.getSource();
+    auto source = transferReadOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+    // 0. Check pre-conditions
     // Contiguity check is valid on tensors only.
     if (!sourceType)
       return failure();
+    // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
-      // Already 0D/1D, nothing to do.
       return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstContiguousInnerDim =
-        sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferReadOp.hasOutOfBoundsDim())
       return failure();
@@ -561,26 +563,81 @@ class FlattenContiguousRowMajorTransferReadPattern
       return failure();
     if (transferReadOp.getMask())
       return failure();
+
     SmallVector<Value> collapsedIndices;
-    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
-                                                firstContiguousInnerDim,
-                                                collapsedIndices)))
-      return failure();
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+
+    // 1. Collapse the source memref
     Value collapsedSource =
-        collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+        collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
         dyn_cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
-    assert(collapsedRank == firstContiguousInnerDim + 1);
+    assert(collapsedRank == firstDimToCollapse + 1);
+
+    // 2. Generate input args for a new vector.transfer_read that will read
+    // from the collapsed memref.
+    // 2.1. New dim exprs + affine map
     SmallVector<AffineExpr, 1> dimExprs{
-        getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+        getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+
+    // 2.2 New indices
+    // If all the collapsed indices are zero then no extra logic is needed.
+    // Otherwise, a new offset/index has to be computed.
+    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
+                                                firstDimToCollapse,
+                                                collapsedIndices))) {
+      // Copy all the leading indices
+      collapsedIndices = transferReadOp.getIndices();
+      collapsedIndices.resize(firstDimToCollapse);
+
+      // Compute the remaining trailing index/offset required for reading from
+      // the collapsed memref:
+      //
+      //    offset = 0
+      //    for (i = firstDimToCollapse; i < outputRank; ++i)
+      //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
+      //
+      // For this example:
+      //   %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
+      //      memref<1x43x2xi32>, vector<1x2xi32>
+      // which would be collapsed to:
+      //   %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
+      //      memref<1x86xi32>, vector<2xi32>
+      // one would get the following offset:
+      //    %offset = %arg0 * 43
+      AffineExpr offsetExpr, idxExpr;
+      bindSymbols(rewriter.getContext(), offsetExpr, idxExpr);
+
+      int64_t outputRank = transferReadOp.getIndices().size();
+      OpFoldResult offset =
+          rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
+
+      for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
+        int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
+        offset = affine::makeComposedFoldedAffineApply(
+            rewriter, loc, offsetExpr + dim * idxExpr,
+            {offset, transferReadOp.getIndices()[i]});
+      }
+      if (offset.is<Value>()) {
+        collapsedIndices.push_back(offset.get<Value>());
+      } else {
+        collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
+            loc, *getConstantIntValue(offset)));
+      }
+    }
+
+    // 3. Create new vector.transfer_read that reads from the collapsed memref
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+    // 4. Replace the old transfer_read with the new one reading from the
+    // collapsed shape
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
     return success();

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index ac0fe64c70cd6..2ad992af989c9 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -265,6 +265,11 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
     return false;
   auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
 
+  // TODO: Add support for memref with trailing dynamic shapes. Memrefs
+  // with leading dynamic dimensions are already supported.
+  if (ShapedType::isDynamicShape(memrefShape))
+    return false;
+
   // Cond 1: A contiguous memref will always have a unit trailing stride.
   if (strides.back() != 1)
     return false;

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 2ffe85bf3bfa6..603792e537a10 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -41,6 +41,61 @@ func.func @transfer_read_dims_mismatch_contiguous(
 
 // -----
 
+func.func @transfer_read_dims_mismatch_non_zero_indices(
+                     %idx_1: index,
+                     %idx_2: index,
+                     %m_in: memref<1x43x4x6xi32>,
+                     %m_out: memref<1x2x6xi32>) {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+    memref<1x43x4x6xi32>, vector<1x2x6xi32>
+  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x6xi32>, memref<1x2x6xi32>
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
+
+// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
+// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
+// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>,
+// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
+// CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]]
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK:           %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
+// CHECK:           vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+                     %idx_1: index,
+                     %idx_2: index,
+                     %m_in: memref<1x?x4x6xi32>,
+                     %m_out: memref<1x2x6xi32>) {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+    memref<1x?x4x6xi32>, vector<1x2x6xi32>
+  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x6xi32>, memref<1x2x6xi32>
+  return
+}
+
+// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
+// CHECK-SAME:      %[[M_IN:.*]]: memref<1x?x4x6xi32>,
+// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK:           vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
+
+// -----
+
 func.func @transfer_read_dims_mismatch_non_contiguous(
     %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
     %c0 = arith.constant 0 : index

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index feb716cdbf404..86b8d5f9b0995 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -454,6 +454,7 @@ struct TestFlattenVectorTransferPatterns
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<memref::MemRefDialect>();
+    registry.insert<affine::AffineDialect>();
   }
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());


        


More information about the Mlir-commits mailing list