[Mlir-commits] [mlir] [mlir][tensor] Move extract_slice reshaping into two functions (PR #153675)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 14 13:45:59 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Ian Wood (IanWood1)

<details>
<summary>Changes</summary>

Exposes the `tensor.extract_slice` reshaping logic in `BubbleUpExpandShapeThroughExtractSlice` and `BubbleUpCollapseShapeThroughExtractSlice` through two corresponding utility functions. These compute the offsets/sizes/strides of an extract slice after either collapsing or expanding.

This should also make it easier to implement the two other bubbling cases: (1) the `collapse_shape` is a consumer or (2) the `expand_shape` is a consumer.

---

Patch is 30.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153675.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (+28) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (+270-299) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 87deef9ca7466..2602252916388 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -142,6 +142,34 @@ FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::PadOp padOp,
 FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::EmptyOp emptyOp,
                                     ValueRange independencies);
 
+/// Computes the offsets, sizes, and strides needed to build a collapsed
+/// `sliceOp`. The dimensions to collapse are specified by `reassociation`.
+///
+/// This fails when the specified collapse cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getCollapsedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+                             ArrayRef<ReassociationIndices> reassociation,
+                             SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+                             SmallVectorImpl<OpFoldResult> &collapsedSizes,
+                             SmallVectorImpl<OpFoldResult> &collapsedStrides,
+                             OpBuilder &b);
+
+/// Computes the offsets, sizes, and strides needed to build an expanded
+/// `sliceOp`. The dimensions to expand are specified by `reassociation` and
+/// `expandedShape`.
+///
+/// This fails when the specified expansion cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getExpandedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+                            ArrayRef<ReassociationIndices> reassociation,
+                            ArrayRef<int64_t> expandedShape,
+                            SmallVectorImpl<OpFoldResult> &expandedOffsets,
+                            SmallVectorImpl<OpFoldResult> &expandedSizes,
+                            SmallVectorImpl<OpFoldResult> &expandedStrides,
+                            OpBuilder &b);
+
 } // namespace tensor
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 2ec23e1fb35ce..a93681b1fce92 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
                                 PatternRewriter &rewriter) const override {
     auto expandShapeOp =
         sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+    if (!expandShapeOp) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "tensor.extract_slice source not produced by expand_shape");
+    }
+    SmallVector<ReassociationIndices> reassociation =
+        expandShapeOp.getReassociationIndices();
 
-    if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
-                                                 rewriter)
-            .failed())
+    SmallVector<OpFoldResult> offsets, sizes, strides;
+    if (failed(getCollapsedExtractSliceInfo(sliceOp, reassociation, offsets,
+                                            sizes, strides, rewriter)))
       return failure();
 
-    // The tensor.extract_slice before applying the pattern works on the result
-    // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
-    // referring to the state before applying the pattern are named with the
-    // prefix "expanded", and ones referring to the state after applying the
-    // pattern are named with the prefix "collapsed".
-    SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> expandedShape =
-        getMixedValues(expandShapeOp.getStaticOutputShape(),
-                       expandShapeOp.getOutputShape(), rewriter);
-
-    // Helper variables and function for accumulating the size values.
-    Location loc = expandShapeOp->getLoc();
-    AffineExpr d0, d1, d2;
-    bindDims(rewriter.getContext(), d0, d1, d2);
-    // Multiply two integers.
-    auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
-      auto mulMap = AffineMap::get(2, 0, {d0 * d1});
-      return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
-                                                   {v1, v2});
-    };
-
-    // Compute new offsets, sizes, and strides for tensor.extract_slice.
-    // The new tensor.extract_slice will work on a tensor that has has a rank of
-    // ReassociationIndices.size(). In the loop a single offset, size, and
-    // stride value is computed per reassociation group.
-    SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
-        collapsedStrides;
-    for (const ReassociationIndices &indices :
-         expandShapeOp.getReassociationIndices()) {
-      // collapsedSize will hold the size of the single dim that represents the
-      // reassociation group in the non expanded tensor.
-      OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
-      // The reassocGroupSizes and reassocGroupOffsets are used to create an
-      // affine.linearize_index op to linearize the single offset value required
-      // for this reassociation group.
-      SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
-
-      for (long expandedDim : indices) {
-        // reassocGroupSizes and reassocGroupOffsets can be obtained directly
-        // from the expanded state, but the collapsed size requires calculation
-        // as it did not previously exist.
-        reassocGroupSizes.push_back(expandedShape[expandedDim]);
-        reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
-        collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
-      }
-
-      SmallVector<Value> offsetVals =
-          llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
-            return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
-          });
-      OpFoldResult collapsedOffset =
-          affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals,
-                                                 reassocGroupSizes,
-                                                 /*disjoint=*/true)
-              .getResult();
-      collapsedOffsets.push_back(collapsedOffset);
-      collapsedSizes.push_back(collapsedSize);
-
-      // Only unit stride is supported.
-      collapsedStrides.push_back(rewriter.getIndexAttr(1));
-    }
-
     // The shape of the result can be obtained from the sizes passed in.
-    SmallVector<Value> dynDims;
-    SmallVector<int64_t> shape;
-    dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
-    RankedTensorType resultType = RankedTensorType::get(
-        shape, expandShapeOp.getResultType().getElementType());
+    SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+    RankedTensorType resultType = sliceOp.getResultType();
 
     // Create a new ExtractSliceOp and ExpandShapeOp.
+    Location loc = sliceOp.getLoc();
     Value newSliceOp = tensor::ExtractSliceOp::create(
-        rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
-        collapsedStrides);
+        rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
         sliceOp, resultType, newSliceOp,
         expandShapeOp.getReassociationIndices(), expandedSizes);
     return success();
   }
-
-  // Helper function to check if all the required conditions for the
-  // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
-  // met.
-  LogicalResult
-  checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
-                                           tensor::ExpandShapeOp expandShapeOp,
-                                           PatternRewriter &rewriter) const {
-
-    if (!expandShapeOp) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "tensor.extract_slice source not produced by expand_shape");
-    }
-
-    if (!sliceOp.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
-                   "be supported in this transformation.");
-    }
-
-    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-
-    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
-        sizes.size()) {
-      return rewriter.notifyMatchFailure(sliceOp,
-                                         "unimplemented: rank reducing slice");
-    }
-
-    SmallVector<OpFoldResult> outputShape =
-        getMixedValues(expandShapeOp.getStaticOutputShape(),
-                       expandShapeOp.getOutputShape(), rewriter);
-
-    std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
-        isZeroOffsetAndFullSize =
-            [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
-              if (!isZeroInteger(offset))
-                return false;
-              FailureOr<bool> maybeEqual =
-                  ValueBoundsConstraintSet::areEqual(sliceSize, size);
-              return llvm::succeeded(maybeEqual) && maybeEqual.value();
-            };
-
-    // Check that the slice is contiguous within each reassociation group.
-    // The slice is contiguous only if after the first dimension where a non
-    // unit slice is taken, the slice size on all subsequent dimensions of the
-    // group is equal to the entire size of the dimension.
-    // Examples of contiguous slices:
-    //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
-    //   full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
-    // Examples of non contiguous slices:
-    //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
-    //   full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
-    for (const ReassociationIndices &indices :
-         expandShapeOp.getReassociationIndices()) {
-      int64_t i = 0;
-      int64_t e = indices.size();
-      // Find the first expanded dim after the first dim with non-unit extracted
-      // size.
-      for (; i < e; ++i) {
-        if (!isOneInteger(sizes[indices[i]])) {
-          // +1 to skip the first non-unit size dim.
-          i++;
-          break;
-        }
-      }
-
-      // Verify that all subsequent dimensions extract the full size of the
-      // source tensor.
-      for (; i < e; ++i) {
-        int64_t expandedDim = indices[i];
-        if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
-                                     outputShape[expandedDim])) {
-          return rewriter.notifyMatchFailure(
-              sliceOp, "Not a contiguous slice of the expanded tensor.");
-        }
-      }
-    }
-
-    return success();
-  }
 };
 
 /// Converts `tensor.extract_slice(tensor.collapse_shape)` to
@@ -582,170 +441,282 @@ struct BubbleUpCollapseShapeThroughExtractSlice
           "tensor.extract_slice source not produced by tensor.collapse_shape");
     }
 
-    if (!sliceOp.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
-                   "be supported in this transformation.");
-    }
+    SmallVector<OpFoldResult> offsets, sizes, strides;
+    if (failed(getExpandedExtractSliceInfo(
+            sliceOp, collapseShapeOp.getReassociationIndices(),
+            collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides,
+            rewriter)))
+      return failure();
 
-    // The tensor.extract_slice before applying the pattern works on the result
-    // of the tensor.collapse_shape, so variables (i.e. inputs for
-    // ExtractSliceOp) referring to the state before applying the pattern are
-    // named with the prefix "collapsed", and ones referring to the state after
-    // applying the pattern are named with the prefix "expanded".
-    SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
-
-    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
-        collapsedSizes.size()) {
-      return rewriter.notifyMatchFailure(sliceOp,
-                                         "unimplemented: rank reducing slice");
-    }
+    Value newSliceOp = tensor::ExtractSliceOp::create(
+        rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
+        sizes, strides);
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        sliceOp, sliceOp.getResultType(), newSliceOp,
+        collapseShapeOp.getReassociationIndices());
 
-    ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
-    SmallVector<ReassociationIndices, 4> reassociationIndices =
-        collapseShapeOp.getReassociationIndices();
-
-    // Compute new offsets, sizes, and strides for tensor.extract_slice.
-    // The new tensor.extract_slice will work on a tensor that has has a rank
-    // equal to the rank of the src of the collapse_shape. In each iteration of
-    // the loop, the offsets and sizes will be computed per reassociation group.
-    SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
-    SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
-                                              rewriter.getIndexAttr(1));
-
-    for (auto [collapsedSize, collapsedOffset, reassocIndices] :
-         llvm::zip_equal(collapsedSizes, collapsedOffsets,
-                         collapseShapeOp.getReassociationIndices())) {
-      // CASE #1 - size and/or offset are dynamic.
-      // In this case, the slice can be represented as a contiguous slice only
-      // if there is a single dimension in the reassociation group that has a
-      // size not equal to 1.
-      if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
-        int nonUnitSizeCount = 0;
-        for (int64_t expandedShapeIdx : reassocIndices) {
-          if (srcShape[expandedShapeIdx] != 1) {
-            nonUnitSizeCount++;
-            expandedSizes.push_back(collapsedSize);
-            expandedOffsets.push_back(collapsedOffset);
-            continue;
-          }
-
-          expandedSizes.push_back(rewriter.getIndexAttr(1));
-          expandedOffsets.push_back(rewriter.getIndexAttr(0));
-        }
+    return success();
+  }
+};
 
-        if (nonUnitSizeCount != 1) {
-          return rewriter.notifyMatchFailure(
-              sliceOp,
-              "unsupported: slice cannot be verified to be contiguous");
-        }
-        continue;
-      }
+} // namespace
 
-      // CASE #2 = size and offset are static.
-      // Verify that the slice can be represented as a contiguous slice of the
-      // src of the collapse_shape.
-      // Checking this is done on order of most internal dimensions first,
-      // so traversal is done in reverse order of the reassociation group.
-      // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
-      // ...,An] then we first find the size and offset for n...k+1 then for k
-      // and then for k-1...0.
-
-      // currentCollapsedsize and currentCollapsedOffset are initialized with
-      // the original collapsed size and offset and divided by the expanded
-      // shape size in each dimension as we go along the reassociation group.
-      // In essence we are spreading the original collapsed size and offset over
-      // the various expanded slice dimensions.
-      // The variables are used both to check the validity of the slice and to
-      // compute the expanded sizes and offsets.
-      int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
-      int64_t currentCollapsedOffset =
-          getConstantIntValue(collapsedOffset).value();
-
-      SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
-
-      ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
-                                                  reassocIndices.rend());
-      int64_t idx = 0;
-      int64_t reassocGroupSize = reassocIndices.size();
-
-      // First handle the trailing dimensions where the slice size should be
-      // equal to the tensor shape and the offset should be 0 (n...k+1).
-      for (; idx < reassocGroupSize; ++idx) {
-        int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-
-        if (currentCollapsedsize < expandedShapeSize)
-          break;
-
-        // We need to make sure that the slice size can be set to the shape size
-        // and the offset to 0.
-        if ((currentCollapsedsize % expandedShapeSize) != 0 ||
-            (currentCollapsedOffset % expandedShapeSize) != 0) {
-          return rewriter.notifyMatchFailure(
-              sliceOp, "unsupported: cannot be extracted as a contiguous slice "
-                       "of the src of the collapse_shape");
-        }
+LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
+    tensor::ExtractSliceOp sliceOp,
+    ArrayRef<ReassociationIndices> reassociation,
+    SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+    SmallVectorImpl<OpFoldResult> &collapsedSizes,
+    SmallVectorImpl<OpFoldResult> &collapsedStrides, OpBuilder &b) {
+  if (!sliceOp.hasUnitStride()) {
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
 
-        groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
-        groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+  if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+    return failure();
+  }
 
-        currentCollapsedsize /= expandedShapeSize;
-        currentCollapsedOffset /= expandedShapeSize;
+  auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
+                                     OpFoldResult sliceSize, int64_t inputDim) {
+    if (!isZeroInteger(offset))
+      return false;
+    ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
+    FailureOr<bool> maybeEqual =
+        ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
+    return llvm::succeeded(maybeEqual) && maybeEqual.value();
+  };
+
+  // Check that the slice is contiguous within each reassociation group.
+  // The slice is contiguous only if after the first dimension where a non
+  // unit slice is taken, the slice size on all subsequent dimensions of the
+  // group is equal to the entire size of the dimension.
+  // Examples of contiguous slices:
+  //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
+  //   full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
+  // Examples of non contiguous slices:
+  //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
+  //   full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
+  for (const ReassociationIndices &indices : reassociation) {
+    int64_t i = 0;
+    int64_t e = indices.size();
+    // Find the first expanded dim after the first dim with non-unit extracted
+    // size.
+    for (; i < e; ++i) {
+      if (!isOneInteger(sizes[indices[i]])) {
+        // +1 to skip the first non-unit size dim.
+        i++;
+        break;
       }
+    }
+
+    // Verify that all subsequent dimensions extract the full size of the
+    // source tensor.
+    for (; i < e; ++i) {
+      int64_t expandedDim = indices[i];
+      if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+                                   expandedDim)) {
+        return failure();
+      }
+    }
+  }
+
+  // The tensor.extract_slice before applying the pattern works on the result
+  // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+  // referring to the state before applying the pattern are named with the
+  // prefix "expanded", and ones referring to the state after applying the
+  // pattern are named with the prefix "collapsed".
+  Location loc = sliceOp.getLoc();
+  SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> expandedShape =
+      getMixedSizes(b, loc, sliceOp.getSource());
+
+  // Helper variables and function for accumulat...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/153675


More information about the Mlir-commits mailing list