[Mlir-commits] [mlir] [mlir][tensor] Move extract_slice reshaping into two functions (PR #153675)

Ian Wood llvmlistbot at llvm.org
Thu Aug 14 13:43:19 PDT 2025


https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/153675

Exposes the `tensor.extract_slice` reshaping logic in `BubbleUpExpandShapeThroughExtractSlice` and `BubbleUpCollapseShapeThroughExtractSlice` through two corresponding utility functions. These compute the offsets/sizes/strides of an extract slice after either collapsing or expanding.

This should also make it easier to implement the two other bubbling cases: (1) the `collapse_shape` is a consumer or (2) the `expand_shape` is a consumer.

>From bee87f64b5fef6c35e5c578c44fa7a90f2e87d57 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood at u.northwestern.edu>
Date: Thu, 14 Aug 2025 11:55:21 -0700
Subject: [PATCH] [mlir][tensor] Refactor extract_slice reshaping into util
 funcs

Signed-off-by: Ian Wood <ianwood at u.northwestern.edu>
---
 .../Dialect/Tensor/Transforms/Transforms.h    |  28 +
 .../Tensor/Transforms/ReshapePatterns.cpp     | 569 +++++++++---------
 2 files changed, 298 insertions(+), 299 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 87deef9ca7466..2602252916388 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -142,6 +142,34 @@ FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::PadOp padOp,
 FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::EmptyOp emptyOp,
                                     ValueRange independencies);
 
+/// Computes the offsets, sizes, and strides needed to build a collapsed
+/// `sliceOp`. The dimensions to collapse are specified by `reassociation`.
+///
+/// This fails when the specified collapse cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getCollapsedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+                             ArrayRef<ReassociationIndices> reassociation,
+                             SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+                             SmallVectorImpl<OpFoldResult> &collapsedSizes,
+                             SmallVectorImpl<OpFoldResult> &collapsedStrides,
+                             OpBuilder &b);
+
+/// Computes the offsets, sizes, and strides needed to build an expanded
+/// `sliceOp`. The dimensions to expand are specified by `reassociation` and
+/// `expandedShape`.
+///
+/// This fails when the specified expansion cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getExpandedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+                            ArrayRef<ReassociationIndices> reassociation,
+                            ArrayRef<int64_t> expandedShape,
+                            SmallVectorImpl<OpFoldResult> &expandedOffsets,
+                            SmallVectorImpl<OpFoldResult> &expandedSizes,
+                            SmallVectorImpl<OpFoldResult> &expandedStrides,
+                            OpBuilder &b);
+
 } // namespace tensor
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 2ec23e1fb35ce..a93681b1fce92 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
                                 PatternRewriter &rewriter) const override {
     auto expandShapeOp =
         sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+    if (!expandShapeOp) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "tensor.extract_slice source not produced by expand_shape");
+    }
+    SmallVector<ReassociationIndices> reassociation =
+        expandShapeOp.getReassociationIndices();
 
-    if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
-                                                 rewriter)
-            .failed())
+    SmallVector<OpFoldResult> offsets, sizes, strides;
+    if (failed(getCollapsedExtractSliceInfo(sliceOp, reassociation, offsets,
+                                            sizes, strides, rewriter)))
       return failure();
 
-    // The tensor.extract_slice before applying the pattern works on the result
-    // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
-    // referring to the state before applying the pattern are named with the
-    // prefix "expanded", and ones referring to the state after applying the
-    // pattern are named with the prefix "collapsed".
-    SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> expandedShape =
-        getMixedValues(expandShapeOp.getStaticOutputShape(),
-                       expandShapeOp.getOutputShape(), rewriter);
-
-    // Helper variables and function for accumulating the size values.
-    Location loc = expandShapeOp->getLoc();
-    AffineExpr d0, d1, d2;
-    bindDims(rewriter.getContext(), d0, d1, d2);
-    // Multiply two integers.
-    auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
-      auto mulMap = AffineMap::get(2, 0, {d0 * d1});
-      return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
-                                                   {v1, v2});
-    };
-
-    // Compute new offsets, sizes, and strides for tensor.extract_slice.
-    // The new tensor.extract_slice will work on a tensor that has has a rank of
-    // ReassociationIndices.size(). In the loop a single offset, size, and
-    // stride value is computed per reassociation group.
-    SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
-        collapsedStrides;
-    for (const ReassociationIndices &indices :
-         expandShapeOp.getReassociationIndices()) {
-      // collapsedSize will hold the size of the single dim that represents the
-      // reassociation group in the non expanded tensor.
-      OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
-      // The reassocGroupSizes and reassocGroupOffsets are used to create an
-      // affine.linearize_index op to linearize the single offset value required
-      // for this reassociation group.
-      SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
-
-      for (long expandedDim : indices) {
-        // reassocGroupSizes and reassocGroupOffsets can be obtained directly
-        // from the expanded state, but the collapsed size requires calculation
-        // as it did not previously exist.
-        reassocGroupSizes.push_back(expandedShape[expandedDim]);
-        reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
-        collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
-      }
-
-      SmallVector<Value> offsetVals =
-          llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
-            return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
-          });
-      OpFoldResult collapsedOffset =
-          affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals,
-                                                 reassocGroupSizes,
-                                                 /*disjoint=*/true)
-              .getResult();
-      collapsedOffsets.push_back(collapsedOffset);
-      collapsedSizes.push_back(collapsedSize);
-
-      // Only unit stride is supported.
-      collapsedStrides.push_back(rewriter.getIndexAttr(1));
-    }
-
     // The shape of the result can be obtained from the sizes passed in.
-    SmallVector<Value> dynDims;
-    SmallVector<int64_t> shape;
-    dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
-    RankedTensorType resultType = RankedTensorType::get(
-        shape, expandShapeOp.getResultType().getElementType());
+    SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+    RankedTensorType resultType = sliceOp.getResultType();
 
     // Create a new ExtractSliceOp and ExpandShapeOp.
+    Location loc = sliceOp.getLoc();
     Value newSliceOp = tensor::ExtractSliceOp::create(
-        rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
-        collapsedStrides);
+        rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
         sliceOp, resultType, newSliceOp,
         expandShapeOp.getReassociationIndices(), expandedSizes);
     return success();
   }
-
-  // Helper function to check if all the required conditions for the
-  // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
-  // met.
-  LogicalResult
-  checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
-                                           tensor::ExpandShapeOp expandShapeOp,
-                                           PatternRewriter &rewriter) const {
-
-    if (!expandShapeOp) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "tensor.extract_slice source not produced by expand_shape");
-    }
-
-    if (!sliceOp.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
-                   "be supported in this transformation.");
-    }
-
-    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-
-    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
-        sizes.size()) {
-      return rewriter.notifyMatchFailure(sliceOp,
-                                         "unimplemented: rank reducing slice");
-    }
-
-    SmallVector<OpFoldResult> outputShape =
-        getMixedValues(expandShapeOp.getStaticOutputShape(),
-                       expandShapeOp.getOutputShape(), rewriter);
-
-    std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
-        isZeroOffsetAndFullSize =
-            [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
-              if (!isZeroInteger(offset))
-                return false;
-              FailureOr<bool> maybeEqual =
-                  ValueBoundsConstraintSet::areEqual(sliceSize, size);
-              return llvm::succeeded(maybeEqual) && maybeEqual.value();
-            };
-
-    // Check that the slice is contiguous within each reassociation group.
-    // The slice is contiguous only if after the first dimension where a non
-    // unit slice is taken, the slice size on all subsequent dimensions of the
-    // group is equal to the entire size of the dimension.
-    // Examples of contiguous slices:
-    //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
-    //   full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
-    // Examples of non contiguous slices:
-    //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
-    //   full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
-    for (const ReassociationIndices &indices :
-         expandShapeOp.getReassociationIndices()) {
-      int64_t i = 0;
-      int64_t e = indices.size();
-      // Find the first expanded dim after the first dim with non-unit extracted
-      // size.
-      for (; i < e; ++i) {
-        if (!isOneInteger(sizes[indices[i]])) {
-          // +1 to skip the first non-unit size dim.
-          i++;
-          break;
-        }
-      }
-
-      // Verify that all subsequent dimensions extract the full size of the
-      // source tensor.
-      for (; i < e; ++i) {
-        int64_t expandedDim = indices[i];
-        if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
-                                     outputShape[expandedDim])) {
-          return rewriter.notifyMatchFailure(
-              sliceOp, "Not a contiguous slice of the expanded tensor.");
-        }
-      }
-    }
-
-    return success();
-  }
 };
 
 /// Converts `tensor.extract_slice(tensor.collapse_shape)` to
@@ -582,170 +441,282 @@ struct BubbleUpCollapseShapeThroughExtractSlice
           "tensor.extract_slice source not produced by tensor.collapse_shape");
     }
 
-    if (!sliceOp.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
-                   "be supported in this transformation.");
-    }
+    SmallVector<OpFoldResult> offsets, sizes, strides;
+    if (failed(getExpandedExtractSliceInfo(
+            sliceOp, collapseShapeOp.getReassociationIndices(),
+            collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides,
+            rewriter)))
+      return failure();
 
-    // The tensor.extract_slice before applying the pattern works on the result
-    // of the tensor.collapse_shape, so variables (i.e. inputs for
-    // ExtractSliceOp) referring to the state before applying the pattern are
-    // named with the prefix "collapsed", and ones referring to the state after
-    // applying the pattern are named with the prefix "expanded".
-    SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
-
-    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
-        collapsedSizes.size()) {
-      return rewriter.notifyMatchFailure(sliceOp,
-                                         "unimplemented: rank reducing slice");
-    }
+    Value newSliceOp = tensor::ExtractSliceOp::create(
+        rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
+        sizes, strides);
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        sliceOp, sliceOp.getResultType(), newSliceOp,
+        collapseShapeOp.getReassociationIndices());
 
-    ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
-    SmallVector<ReassociationIndices, 4> reassociationIndices =
-        collapseShapeOp.getReassociationIndices();
-
-    // Compute new offsets, sizes, and strides for tensor.extract_slice.
-    // The new tensor.extract_slice will work on a tensor that has has a rank
-    // equal to the rank of the src of the collapse_shape. In each iteration of
-    // the loop, the offsets and sizes will be computed per reassociation group.
-    SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
-    SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
-                                              rewriter.getIndexAttr(1));
-
-    for (auto [collapsedSize, collapsedOffset, reassocIndices] :
-         llvm::zip_equal(collapsedSizes, collapsedOffsets,
-                         collapseShapeOp.getReassociationIndices())) {
-      // CASE #1 - size and/or offset are dynamic.
-      // In this case, the slice can be represented as a contiguous slice only
-      // if there is a single dimension in the reassociation group that has a
-      // size not equal to 1.
-      if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
-        int nonUnitSizeCount = 0;
-        for (int64_t expandedShapeIdx : reassocIndices) {
-          if (srcShape[expandedShapeIdx] != 1) {
-            nonUnitSizeCount++;
-            expandedSizes.push_back(collapsedSize);
-            expandedOffsets.push_back(collapsedOffset);
-            continue;
-          }
-
-          expandedSizes.push_back(rewriter.getIndexAttr(1));
-          expandedOffsets.push_back(rewriter.getIndexAttr(0));
-        }
+    return success();
+  }
+};
 
-        if (nonUnitSizeCount != 1) {
-          return rewriter.notifyMatchFailure(
-              sliceOp,
-              "unsupported: slice cannot be verified to be contiguous");
-        }
-        continue;
-      }
+} // namespace
 
-      // CASE #2 = size and offset are static.
-      // Verify that the slice can be represented as a contiguous slice of the
-      // src of the collapse_shape.
-      // Checking this is done on order of most internal dimensions first,
-      // so traversal is done in reverse order of the reassociation group.
-      // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
-      // ...,An] then we first find the size and offset for n...k+1 then for k
-      // and then for k-1...0.
-
-      // currentCollapsedsize and currentCollapsedOffset are initialized with
-      // the original collapsed size and offset and divided by the expanded
-      // shape size in each dimension as we go along the reassociation group.
-      // In essence we are spreading the original collapsed size and offset over
-      // the various expanded slice dimensions.
-      // The variables are used both to check the validity of the slice and to
-      // compute the expanded sizes and offsets.
-      int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
-      int64_t currentCollapsedOffset =
-          getConstantIntValue(collapsedOffset).value();
-
-      SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
-
-      ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
-                                                  reassocIndices.rend());
-      int64_t idx = 0;
-      int64_t reassocGroupSize = reassocIndices.size();
-
-      // First handle the trailing dimensions where the slice size should be
-      // equal to the tensor shape and the offset should be 0 (n...k+1).
-      for (; idx < reassocGroupSize; ++idx) {
-        int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-
-        if (currentCollapsedsize < expandedShapeSize)
-          break;
-
-        // We need to make sure that the slice size can be set to the shape size
-        // and the offset to 0.
-        if ((currentCollapsedsize % expandedShapeSize) != 0 ||
-            (currentCollapsedOffset % expandedShapeSize) != 0) {
-          return rewriter.notifyMatchFailure(
-              sliceOp, "unsupported: cannot be extracted as a contiguous slice "
-                       "of the src of the collapse_shape");
-        }
+LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
+    tensor::ExtractSliceOp sliceOp,
+    ArrayRef<ReassociationIndices> reassociation,
+    SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+    SmallVectorImpl<OpFoldResult> &collapsedSizes,
+    SmallVectorImpl<OpFoldResult> &collapsedStrides, OpBuilder &b) {
+  if (!sliceOp.hasUnitStride()) {
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
 
-        groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
-        groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+  if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+    return failure();
+  }
 
-        currentCollapsedsize /= expandedShapeSize;
-        currentCollapsedOffset /= expandedShapeSize;
+  auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
+                                     OpFoldResult sliceSize, int64_t inputDim) {
+    if (!isZeroInteger(offset))
+      return false;
+    ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
+    FailureOr<bool> maybeEqual =
+        ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
+    return llvm::succeeded(maybeEqual) && maybeEqual.value();
+  };
+
+  // Check that the slice is contiguous within each reassociation group.
+  // The slice is contiguous only if after the first dimension where a non
+  // unit slice is taken, the slice size on all subsequent dimensions of the
+  // group is equal to the entire size of the dimension.
+  // Examples of contiguous slices:
+  //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
+  //   full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
+  // Examples of non contiguous slices:
+  //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
+  //   full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
+  for (const ReassociationIndices &indices : reassociation) {
+    int64_t i = 0;
+    int64_t e = indices.size();
+    // Find the first expanded dim after the first dim with non-unit extracted
+    // size.
+    for (; i < e; ++i) {
+      if (!isOneInteger(sizes[indices[i]])) {
+        // +1 to skip the first non-unit size dim.
+        i++;
+        break;
       }
+    }
+
+    // Verify that all subsequent dimensions extract the full size of the
+    // source tensor.
+    for (; i < e; ++i) {
+      int64_t expandedDim = indices[i];
+      if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+                                   expandedDim)) {
+        return failure();
+      }
+    }
+  }
+
+  // The tensor.extract_slice before applying the pattern works on the result
+  // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+  // referring to the state before applying the pattern are named with the
+  // prefix "expanded", and ones referring to the state after applying the
+  // pattern are named with the prefix "collapsed".
+  Location loc = sliceOp.getLoc();
+  SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> expandedShape =
+      getMixedSizes(b, loc, sliceOp.getSource());
+
+  // Helper variables and function for accumulating the size values.
+  AffineExpr d0, d1, d2;
+  bindDims(b.getContext(), d0, d1, d2);
+  // Multiply two integers.
+  auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+    auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+    return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2});
+  };
+
+  // Compute new offsets, sizes, and strides for tensor.extract_slice.
+  // The new tensor.extract_slice will work on a tensor that has has a rank of
+  // ReassociationIndices.size(). In the loop a single offset, size, and
+  // stride value is computed per reassociation group.
+  for (const ReassociationIndices &indices : reassociation) {
+    // collapsedSize will hold the size of the single dim that represents the
+    // reassociation group in the non expanded tensor.
+    OpFoldResult collapsedSize = b.getIndexAttr(1);
+    // The reassocGroupSizes and reassocGroupOffsets are used to create an
+    // affine.linearize_index op to linearize the single offset value required
+    // for this reassociation group.
+    SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
+
+    for (long expandedDim : indices) {
+      // reassocGroupSizes and reassocGroupOffsets can be obtained directly
+      // from the expanded state, but the collapsed size requires calculation
+      // as it did not previously exist.
+      reassocGroupSizes.push_back(expandedShape[expandedDim]);
+      reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
+      collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
+    }
+
+    SmallVector<Value> offsetVals =
+        llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
+          return getValueOrCreateConstantIndexOp(b, loc, ofr);
+        });
+    OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
+                                       b, loc, offsetVals, reassocGroupSizes,
+                                       /*disjoint=*/true)
+                                       .getResult();
+    collapsedOffsets.push_back(collapsedOffset);
+    collapsedSizes.push_back(collapsedSize);
+
+    // Only unit stride is supported.
+    collapsedStrides.push_back(b.getIndexAttr(1));
+  }
+  return success();
+}
+
+LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
+    tensor::ExtractSliceOp sliceOp,
+    ArrayRef<ReassociationIndices> reassociation,
+    ArrayRef<int64_t> expandedShape,
+    SmallVectorImpl<OpFoldResult> &expandedOffsets,
+    SmallVectorImpl<OpFoldResult> &expandedSizes,
+    SmallVectorImpl<OpFoldResult> &expandedStrides, OpBuilder &b) {
+  if (!sliceOp.hasUnitStride()) {
+    return failure();
+  }
+
+  // The tensor.extract_slice before applying the pattern works on the result
+  // of the tensor.collapse_shape, so variables (i.e. inputs for
+  // ExtractSliceOp) referring to the state before applying the pattern are
+  // named with the prefix "collapsed", and ones referring to the state after
+  // applying the pattern are named with the prefix "expanded".
+  SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
+  if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+      collapsedSizes.size()) {
+    return failure();
+  }
 
-      // Now handle the first dim where slicing occurs on (k).
-      if (idx < reassocGroupSize) {
-        int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-        int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
-        // We need to make sure that the slice size in this dim + offset will
-        // not exceed the shape size.
-        if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
-          return rewriter.notifyMatchFailure(
-              sliceOp, "unsupported: slice cannot be extracted as a contiguous "
-                       "slice of the src of the collapse_shape");
+  // Compute new offsets, sizes, and strides for tensor.extract_slice.
+  // The new tensor.extract_slice will work on a tensor that has has a rank
+  // equal to the rank of the src of the collapse_shape. In each iteration of
+  // the loop, the offsets and sizes will be computed per reassociation group.
+  expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
+  for (auto [collapsedSize, collapsedOffset, reassocIndices] :
+       llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
+    // CASE #1 - size and/or offset are dynamic.
+    // In this case, the slice can be represented as a contiguous slice only
+    // if there is a single dimension in the reassociation group that has a
+    // size not equal to 1.
+    if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+      int nonUnitSizeCount = 0;
+      for (int64_t expandedShapeIdx : reassocIndices) {
+        if (expandedShape[expandedShapeIdx] != 1) {
+          nonUnitSizeCount++;
+          expandedSizes.push_back(collapsedSize);
+          expandedOffsets.push_back(collapsedOffset);
+          continue;
         }
 
-        groupExpandedSizes.push_back(
-            rewriter.getIndexAttr(currentCollapsedsize));
-        groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
+        expandedSizes.push_back(b.getIndexAttr(1));
+        expandedOffsets.push_back(b.getIndexAttr(0));
+      }
 
-        currentCollapsedOffset /= expandedShapeSize;
+      if (nonUnitSizeCount != 1) {
+        return failure();
       }
+      continue;
+    }
 
-      // Now handle the leading dimensions where the slice size is equal to 1
-      // (k-1...0).
-      // The size for these dimensions must be 1 because of how we constructed
-      // the slice size of the expanded shape. We spread the original collapsed
-      // size over the expanded shape sizes until we reached dimension k where
-      // the remaining size was smaller than the expanded shape size, and spread
-      // the remaining size on it. So, now we are left with only 1s.
-      for (idx++; idx < reassocGroupSize; ++idx) {
-        int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-        int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
-        groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
-        groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
-        currentCollapsedOffset /= expandedShapeSize;
+    // CASE #2 = size and offset are static.
+    // Verify that the slice can be represented as a contiguous slice of the
+    // src of the collapse_shape.
+    // Checking this is done on order of most internal dimensions first,
+    // so traversal is done in reverse order of the reassociation group.
+    // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+    // ...,An] then we first find the size and offset for n...k+1 then for k
+    // and then for k-1...0.
+
+    // currentCollapsedsize and currentCollapsedOffset are initialized with
+    // the original collapsed size and offset and divided by the expanded
+    // shape size in each dimension as we go along the reassociation group.
+    // In essence we are spreading the original collapsed size and offset over
+    // the various expanded slice dimensions.
+    // The variables are used both to check the validity of the slice and to
+    // compute the expanded sizes and offsets.
+    int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
+    int64_t currentCollapsedOffset =
+        getConstantIntValue(collapsedOffset).value();
+    SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
+    ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+                                                reassocIndices.rend());
+    int64_t idx = 0;
+    int64_t reassocGroupSize = reassocIndices.size();
+
+    // First handle the trailing dimensions where the slice size should be
+    // equal to the tensor shape and the offset should be 0 (n...k+1).
+    for (; idx < reassocGroupSize; ++idx) {
+      int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+
+      if (currentCollapsedsize < expandedShapeSize)
+        break;
+
+      // We need to make sure that the slice size can be set to the shape size
+      // and the offset to 0.
+      if ((currentCollapsedsize % expandedShapeSize) != 0 ||
+          (currentCollapsedOffset % expandedShapeSize) != 0) {
+        return failure();
       }
 
-      expandedSizes.append(groupExpandedSizes.rbegin(),
-                           groupExpandedSizes.rend());
-      expandedOffsets.append(groupExpandedOffsets.rbegin(),
-                             groupExpandedOffsets.rend());
+      groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize));
+      groupExpandedOffsets.push_back(b.getIndexAttr(0));
+
+      currentCollapsedsize /= expandedShapeSize;
+      currentCollapsedOffset /= expandedShapeSize;
     }
 
-    Value newSliceOp = tensor::ExtractSliceOp::create(
-        rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(),
-        expandedOffsets, expandedSizes, expandedStrides);
-    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
-        sliceOp, sliceOp.getResultType(), newSliceOp,
-        collapseShapeOp.getReassociationIndices());
+    // Now handle the first dim where slicing occurs on (k).
+    if (idx < reassocGroupSize) {
+      int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+      int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+      // We need to make sure that the slice size in this dim + offset will
+      // not exceed the shape size.
+      if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
+        return failure();
+      }
+      groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize));
+      groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+      currentCollapsedOffset /= expandedShapeSize;
+    }
 
-    return success();
+    // Now handle the leading dimensions where the slice size is equal to 1
+    // (k-1...0).
+    // The size for these dimensions must be 1 because of how we constructed
+    // the slice size of the expanded shape. We spread the original collapsed
+    // size over the expanded shape sizes until we reached dimension k where
+    // the remaining size was smaller than the expanded shape size, and spread
+    // the remaining size on it. So, now we are left with only 1s.
+    for (idx++; idx < reassocGroupSize; ++idx) {
+      int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+      int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+      groupExpandedSizes.push_back(b.getIndexAttr(1));
+      groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+      currentCollapsedOffset /= expandedShapeSize;
+    }
+    expandedSizes.append(groupExpandedSizes.rbegin(),
+                         groupExpandedSizes.rend());
+    expandedOffsets.append(groupExpandedOffsets.rbegin(),
+                           groupExpandedOffsets.rend());
   }
-};
-
-} // namespace
+  return success();
+}
 
 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
     RewritePatternSet &patterns) {



More information about the Mlir-commits mailing list