[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) (PR #73523)

Han-Chung Wang llvmlistbot at llvm.org
Fri Dec 1 11:25:56 PST 2023


================
@@ -562,45 +613,93 @@ class FlattenContiguousRowMajorTransferReadPattern
     VectorType vectorType = cast<VectorType>(vector.getType());
     Value source = transferReadOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+    // 0. Check pre-conditions
     // Contiguity check is valid on tensors only.
     if (!sourceType)
       return failure();
+    // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
-      // Already 0D/1D, nothing to do.
       return failure();
-    if (!hasMatchingInnerContigousShape(
-            sourceType,
-            vectorType.getShape().take_back(vectorType.getRank() - 1)))
+    if (!isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstContiguousInnerDim =
-        sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferReadOp.hasOutOfBoundsDim())
       return failure();
     if (!transferReadOp.getPermutationMap().isMinorIdentity())
       return failure();
     if (transferReadOp.getMask())
       return failure();
+
     SmallVector<Value> collapsedIndices;
-    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
-                                                firstContiguousInnerDim,
-                                                collapsedIndices)))
-      return failure();
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+
+    // 1. Collapse the source memref
     Value collapsedSource =
-        collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+        collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
         dyn_cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
-    assert(collapsedRank == firstContiguousInnerDim + 1);
+    assert(collapsedRank == firstDimToCollapse + 1);
+
+    // 2. Generate input args for a new vector.transfer_read that will read
+    // from the collapsed memref.
+    // 2.1. New dim exprs + affine map
     SmallVector<AffineExpr, 1> dimExprs{
-        getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+        getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+
+    // 2.2 New indices
+    // If all the collapsed indices are zero then no extra logic is needed.
+    // Otherwise, a new offset/index has to be computed.
+    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
+                                                firstDimToCollapse,
+                                                collapsedIndices))) {
+      // Copy all the leading indices
+      collapsedIndices = transferReadOp.getIndices();
+      collapsedIndices.resize(firstDimToCollapse);
+
+      // Compute the remaining trailing index/offset required for reading from
+      // the collapsed memref:
+      //
+      //    offset = 0
+      //    for (i = firstDimToCollapse; i < outputRank; ++i)
+      //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
+      //
+      // For this example:
+      //   %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
+      //   memref<1x43x2xi32>, vector<1x2xi32>
+      // which would be collapsed to:
+      //   %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
+      //   memref<1x86xi32>, vector<2xi32>
+      // one would get the following offset:
+      //    %offset = %arg0 * 43
+      int64_t outputRank = transferReadOp.getIndices().size();
+      Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+      for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
+        Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        auto sourceDimSize =
+            rewriter.create<memref::DimOp>(loc, source, dimIdx);
+
+        offset = rewriter.create<arith::AddIOp>(
+            loc,
+            rewriter.create<arith::MulIOp>(loc, transferReadOp.getIndices()[i],
+                                           sourceDimSize),
+            offset);
+      }
----------------
hanhanW wrote:

We already have better helpers that use OpFoldResult and affine_apply ops to make this more idiomatic and more concise IR. It gives us constant values if they are all static. Can you try using `makeComposedFoldedAffineApply`?

https://github.com/llvm/llvm-project/pull/73523


More information about the Mlir-commits mailing list