[Mlir-commits] [mlir] [mlir][vector] Improve flattening vector.transfer_write ops. (PR #94051)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 31 14:44:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>

We can flatten the transfer ops even when the collapsed indices are not zeros. We can compute it. It is already supported in vector.transfer_read cases. The revision refactors the logic and reuse it in transfer_write cases.

---
Full diff: https://github.com/llvm/llvm-project/pull/94051.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+65-71) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+10-8) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 997b56a1ce142..ae47d7fc28811 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -510,20 +510,59 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
 /// TODO: Extract the logic that writes to outIndices so that this method
 /// simply checks one pre-condition.
-static LogicalResult
-checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
-                                 SmallVector<Value> &outIndices) {
-  int64_t rank = indices.size();
-  if (firstDimToCollapse >= rank)
-    return failure();
-  for (int64_t i = firstDimToCollapse; i < rank; ++i) {
-    std::optional<int64_t> cst = getConstantIntValue(indices[i]);
-    if (!cst || cst.value() != 0)
-      return failure();
+static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
+                                              Location loc,
+                                              ArrayRef<int64_t> shape,
+                                              ValueRange indices,
+                                              int64_t firstDimToCollapse) {
+  assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
+
+  // If all the collapsed indices are zero then no extra logic is needed.
+  // Otherwise, a new offset/index has to be computed.
+  SmallVector<Value> collapsedIndices(indices.begin(),
+                                      indices.begin() + firstDimToCollapse);
+  SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
+                                       indices.end());
+  if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
+    collapsedIndices.push_back(indicesToCollapse[0]);
+    return collapsedIndices;
+  }
+
+  // Compute the remaining trailing index/offset required for reading from
+  // the collapsed memref:
+  //
+  //    offset = 0
+  //    for (i = firstDimToCollapse; i < outputRank; ++i)
+  //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
+  //
+  // For this example:
+  //   %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
+  //      memref<1x43x2xi32>, vector<1x2xi32>
+  // which would be collapsed to:
+  //   %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
+  //      memref<1x86xi32>, vector<2xi32>
+  // one would get the following offset:
+  //    %offset = %arg0 * 43
+  OpFoldResult collapsedOffset =
+      rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
+
+  auto collapsedStrides = computeSuffixProduct(
+      ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
+
+  // Compute the collapsed offset.
+  auto &&[collapsedExpr, collapsedVals] =
+      computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
+  collapsedOffset = affine::makeComposedFoldedAffineApply(
+      rewriter, loc, collapsedExpr, collapsedVals);
+
+  if (collapsedOffset.is<Value>()) {
+    collapsedIndices.push_back(collapsedOffset.get<Value>());
+  } else {
+    collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
+        loc, *getConstantIntValue(collapsedOffset)));
   }
-  outIndices = indices;
-  outIndices.resize(firstDimToCollapse + 1);
-  return success();
+
+  return collapsedIndices;
 }
 
 namespace {
@@ -594,54 +633,9 @@ class FlattenContiguousRowMajorTransferReadPattern
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
 
     // 2.2 New indices
-    // If all the collapsed indices are zero then no extra logic is needed.
-    // Otherwise, a new offset/index has to be computed.
-    SmallVector<Value> collapsedIndices;
-    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
-                                                firstDimToCollapse,
-                                                collapsedIndices))) {
-      // Copy all the leading indices.
-      SmallVector<Value> indices = transferReadOp.getIndices();
-      collapsedIndices.append(indices.begin(),
-                              indices.begin() + firstDimToCollapse);
-
-      // Compute the remaining trailing index/offset required for reading from
-      // the collapsed memref:
-      //
-      //    offset = 0
-      //    for (i = firstDimToCollapse; i < outputRank; ++i)
-      //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
-      //
-      // For this example:
-      //   %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
-      //      memref<1x43x2xi32>, vector<1x2xi32>
-      // which would be collapsed to:
-      //   %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
-      //      memref<1x86xi32>, vector<2xi32>
-      // one would get the following offset:
-      //    %offset = %arg0 * 43
-      OpFoldResult collapsedOffset =
-          rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
-
-      auto sourceShape = sourceType.getShape();
-      auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
-          sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
-
-      // Compute the collapsed offset.
-      ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
-                                        indices.end());
-      auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
-          collapsedOffset, collapsedStrides, indicesToCollapse);
-      collapsedOffset = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, collapsedExpr, collapsedVals);
-
-      if (collapsedOffset.is<Value>()) {
-        collapsedIndices.push_back(collapsedOffset.get<Value>());
-      } else {
-        collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
-            loc, *getConstantIntValue(collapsedOffset)));
-      }
-    }
+    SmallVector<Value> collapsedIndices =
+        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+                            transferReadOp.getIndices(), firstDimToCollapse);
 
     // 3. Create new vector.transfer_read that reads from the collapsed memref
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
@@ -697,8 +691,7 @@ class FlattenContiguousRowMajorTransferWritePattern
       return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstContiguousInnerDim =
-        sourceType.getRank() - vectorType.getRank();
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
       return failure();
@@ -706,22 +699,23 @@ class FlattenContiguousRowMajorTransferWritePattern
       return failure();
     if (transferWriteOp.getMask())
       return failure();
-    SmallVector<Value> collapsedIndices;
-    if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
-                                                firstContiguousInnerDim,
-                                                collapsedIndices)))
-      return failure();
+
+    SmallVector<Value> collapsedIndices =
+        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+                            transferWriteOp.getIndices(), firstDimToCollapse);
 
     Value collapsedSource =
-        collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+        collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
         cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
-    assert(collapsedRank == firstContiguousInnerDim + 1);
+    assert(collapsedRank == firstDimToCollapse + 1);
+
     SmallVector<AffineExpr, 1> dimExprs{
-        getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+        getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
     Value flatVector =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..b5f29b2ac958b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -471,16 +471,16 @@ func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, str
 }
 
 //       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK-LABEL:    func.func @regression_non_contiguous_dim_read(
-//       CHECK:      %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-//       CHECK:     %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK-LABEL:  func.func @regression_non_contiguous_dim_read(
+//       CHECK:    %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+//       CHECK:    %[[APPLY:.*]] = affine.apply #[[$MAP]]()
 
 // CHECK-128B-LABEL: func @regression_non_contiguous_dim_read(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
 
-func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
+func.func @regression_non_contiguous_dim_write(%value : vector<2x2xf32>,
                                                 %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
                                                 %idx0 : index, %idx1 : index) {
   %c0 = arith.constant 0 : index
@@ -488,8 +488,10 @@ func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
   return
 }
 
-// CHECK-LABEL:  func.func @unsupported_non_contiguous_dim_write(
-//   CHECK-NOT:    memref.collapse_shape
+//       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL:  func.func @regression_non_contiguous_dim_write(
+//       CHECK:    %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+//       CHECK:    %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
 
-// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write(
-//   CHECK-128B-NOT:   memref.collapse_shape
+// CHECK-128B-LABEL: func @regression_non_contiguous_dim_write(
+//       CHECK-128B:   memref.collapse_shape

``````````

</details>


https://github.com/llvm/llvm-project/pull/94051


More information about the Mlir-commits mailing list