[Mlir-commits] [mlir] [Linalg] Add pattern to push down extract slice through linalg generic op (PR #154162)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Wed Aug 20 07:04:00 PDT 2025
    
    
  
================
@@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
   ControlPropagationFn controlFn;
 };
 
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+  OpFoldResult offset;
+  OpFoldResult sliceSize;
+  OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+  OpOperand *sliceOperand = nullptr;
+  unsigned operandIndex;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      continue;
+    sliceOperand = operand;
+    operandIndex = idx;
+    break;
+  }
+  if (!sliceOperand) {
+    return failure();
+  }
+  return std::make_tuple(sliceOperand, operandIndex);
+}
+
+// Return a map of dims that have non full slices on them so that other operands
+// can use this information. Also return a bool mentioning if a reduction dim
+// has a non full slice as that can be used to fold the original extract slice.
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+                       tensor::ExtractSliceOp producerSliceOp) {
+  llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+  bool hasNonZeroReductionDimSlice = false;
+  SmallVector<utils::IteratorType> iterators =
+      genericOp.getIteratorTypesArray();
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+      producerSliceOp.getSourceType().getShape(),
+      [&](int64_t sz) -> OpFoldResult {
+        return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+      });
+
+  for (auto [idx, expr] : llvm::enumerate(
+           genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+    if (isConstantIntValue(offsets[idx], 0) &&
+        isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+      continue;
+    }
+    if (!isa<AffineDimExpr>(expr)) {
+      return failure();
+    }
+    SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+    int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+    nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+    if (iterators[dimPos] == utils::IteratorType::reduction) {
+      hasNonZeroReductionDimSlice = true;
+    }
+  }
+  // Next check if the dims with non zero slice info are used as non
+  // AffineDimExpr and if they are then bail-out.
+  for (OpOperand &operand : genericOp->getOpOperands()) {
+    if (operand == *sliceOperand) {
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+    if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+          if (isa<AffineDimExpr>(expr)) {
+            return false;
+          }
+          WalkResult status = expr.walk([&](AffineExpr expr) {
+            if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+              if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+                return WalkResult::interrupt();
+              }
+            }
+            return WalkResult::advance();
+          });
+          if (status.wasInterrupted()) {
+            return true;
+          }
+          return false;
+        })) {
+      return failure();
+    }
+  }
+  return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+                                       GenericOp genericOp,
+                                       ControlPropagationFn controlFn) {
+  if (genericOp.getNumResults() != 1)
+    return failure();
+  if (hasGatherSemantics(genericOp))
+    return failure();
+  // Collect the unPacked operand, if present.
+  auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
+  if (failed(maybeSliceOperandAndIndex))
+    return failure();
+  OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
+  unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+
+  if (!controlFn(sliceOperand))
+    return failure();
+
+  tensor::ExtractSliceOp producerSliceOp =
+      sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  assert(producerSliceOp && "expect a valid UnPackOp");
+
+  if (producerSliceOp.getSource().getType().getRank() !=
+      producerSliceOp.getResult().getType().getRank()) {
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+  if (!areAllConstantIntValue(strides, 1))
+    return failure();
+
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  // check if we can support the propagation of this extractSlice
+  // through the generic op and if so return the dimensions that
+
+  auto maybeNonZeroSliceDimMap =
+      getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+
+  if (failed(maybeNonZeroSliceDimMap)) {
+    return failure();
+  }
+
+  auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
+  bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+
+  // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+  Location loc = genericOp->getLoc();
+  AffineExpr dim0, dim1;
+  bindDims(rewriter.getContext(), dim0, dim1);
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+                                                 {v1, v2});
+  };
+
+  MLIRContext *ctx = genericOp.getContext();
+  SmallVector<Value> paddedInputs;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+      paddedInputs.push_back(producerSliceOp.getSource());
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+    SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+                                             getAsIndexOpFoldResult(ctx, 0));
+    SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+                                              getAsIndexOpFoldResult(ctx, 0));
+    for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+      if (!isa<AffineDimExpr>(expr)) {
+        continue;
+      }
+      AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+      if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
----------------
Max191 wrote:
nit: Use early `continue` to save nesting
https://github.com/llvm/llvm-project/pull/154162
    
    
More information about the Mlir-commits
mailing list