[Mlir-commits] [mlir] [Linalg] Add pattern to push down extract slice through linalg generic op (PR #154162)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 20 07:03:57 PDT 2025


================
@@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
   ControlPropagationFn controlFn;
 };
 
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+  OpFoldResult offset;
+  OpFoldResult sliceSize;
+  OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+  OpOperand *sliceOperand = nullptr;
+  unsigned operandIndex;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      continue;
+    sliceOperand = operand;
+    operandIndex = idx;
+    break;
+  }
+  if (!sliceOperand) {
+    return failure();
+  }
+  return std::make_tuple(sliceOperand, operandIndex);
+}
+
+// Return a map of dims that have non full slices on them so that other operands
+// can use this information. Also return a bool mentioning if a reduction dim
+// has a non full slice as that can be used to fold the original extract slice.
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+                       tensor::ExtractSliceOp producerSliceOp) {
+  llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+  bool hasNonZeroReductionDimSlice = false;
+  SmallVector<utils::IteratorType> iterators =
+      genericOp.getIteratorTypesArray();
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+      producerSliceOp.getSourceType().getShape(),
+      [&](int64_t sz) -> OpFoldResult {
+        return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+      });
----------------
Max191 wrote:

 You can pass an ArrayRef to `getAsIndexOpFoldResult`, so you don't need to map_to_vector.

https://github.com/llvm/llvm-project/pull/154162


More information about the Mlir-commits mailing list