[Mlir-commits] [mlir] [mlir][scf] Refactor and improve ParallelLoopFusion (PR #179284)

Ivan Butygin llvmlistbot at llvm.org
Mon Feb 16 11:35:43 PST 2026


================
@@ -55,124 +71,690 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
          matchOperands(firstPloop.getStep(), secondPloop.getStep());
 }
 
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
+/// Check if both operations are the same type of memory write op and
+/// write to the same memory location (same buffer and same indices).
+static bool opsWriteSameMemLocation(Operation *op1, Operation *op2) {
+  if (!op1 || !op2 || op1->getName() != op2->getName())
+    return false;
+  if (op1 == op2)
+    return true;
+  // support only these memory-writing ops for now
+  if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
+    return false;
+  bool opsAreIdentical =
+      llvm::TypeSwitch<Operation *, bool>(op1)
+          .Case([&](memref::StoreOp storeOp1) {
+            auto storeOp2 = cast<memref::StoreOp>(op2);
+            return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
+                   (storeOp1.getIndices() == storeOp2.getIndices());
+          })
+          .Case([&](vector::TransferWriteOp writeOp1) {
+            auto writeOp2 = cast<vector::TransferWriteOp>(op2);
+            return (writeOp1.getBase() == writeOp2.getBase()) &&
+                   (writeOp1.getIndices() == writeOp2.getIndices()) &&
+                   (writeOp1.getMask() == writeOp2.getMask()) &&
+                   (writeOp1.getValueToStore().getType() ==
+                    writeOp2.getValueToStore().getType()) &&
+                   (writeOp1.getInBounds() == writeOp2.getInBounds());
+          })
+          .Case([&](vector::StoreOp vecStoreOp1) {
+            auto vecStoreOp2 = cast<vector::StoreOp>(op2);
+            return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
+                   (vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
+                   (vecStoreOp1.getValueToStore().getType() ==
+                    vecStoreOp2.getValueToStore().getType()) &&
+                   (vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
+                   (vecStoreOp1.getNontemporal() ==
+                    vecStoreOp2.getNontemporal());
+          })
+          .Default([](Operation *) { return false; });
+  return opsAreIdentical;
+}
+
+/// Check if val1 (from the first parallel loop) and val2 (from the
+/// second) are equivalent, considering the mapping of induction variables from
+/// the first to the second parallel loop.
+static bool valsAreEquivalent(Value val1, Value val2,
+                              const IRMapping &loopsIVsMap) {
+  if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
+      loopsIVsMap.lookupOrDefault(val2) == val1)
+    return true;
+  Operation *val1DefOp = val1.getDefiningOp();
+  Operation *val2DefOp = val2.getDefiningOp();
+  if (!val1DefOp || !val2DefOp)
+    return false;
+  if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
+    return false;
+  return OperationEquivalence::isEquivalentTo(
+      val1DefOp, val2DefOp,
+      [&](Value v1, Value v2) {
+        return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
+                       loopsIVsMap.lookupOrDefault(v2) == v1);
+      },
+      /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
+}
+
+static std::optional<int64_t> getConstantIndex(Value value) {
+  if (auto constOp = value.getDefiningOp<arith::ConstantIndexOp>())
+    return constOp.value();
+  return std::nullopt;
+}
+
+/// If the `expr` value is the result of an integer addition of `base` and a
+/// constant, return the constant.
+static std::optional<int64_t> getAddConstant(Value expr, Value base,
+                                             const IRMapping &loopsIVsMap) {
+  if (auto addOp = expr.getDefiningOp<arith::AddIOp>()) {
+    if (auto constOp = getConstantIntValue(addOp.getLhs());
+        constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
+      return constOp.value();
+    if (auto constOp = getConstantIntValue(addOp.getRhs());
+        constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
+      return constOp.value();
+    return std::nullopt;
+  }
+
+  if (auto addOp = expr.getDefiningOp<index::AddOp>()) {
+    if (auto constOp = getConstantIntValue(addOp.getLhs());
+        constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
+      return constOp.value();
+    if (auto constOp = getConstantIntValue(addOp.getRhs());
+        constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
+      return constOp.value();
+    return std::nullopt;
+  }
+
+  if (auto applyOp = expr.getDefiningOp<affine::AffineApplyOp>()) {
+    AffineMap map = applyOp.getAffineMap();
+    if (map.getNumResults() != 1 || map.getNumDims() != 1 ||
+        map.getNumSymbols() != 0)
+      return std::nullopt;
+    if (!valsAreEquivalent(applyOp.getOperand(0), base, loopsIVsMap))
+      return std::nullopt;
+    AffineExpr result = map.getResult(0);
+    auto bin = dyn_cast<AffineBinaryOpExpr>(result);
+    if (!bin || bin.getKind() != AffineExprKind::Add)
+      return std::nullopt;
+    auto lhsDim = dyn_cast<AffineDimExpr>(bin.getLHS());
+    auto rhsDim = dyn_cast<AffineDimExpr>(bin.getRHS());
+    auto lhsConst = dyn_cast<AffineConstantExpr>(bin.getLHS());
+    auto rhsConst = dyn_cast<AffineConstantExpr>(bin.getRHS());
+    if (lhsConst && rhsDim)
+      return lhsConst.getValue();
+    if (rhsConst && lhsDim)
+      return rhsConst.getValue();
+  }
+  return std::nullopt;
+}
+
+// Return true if the scalar load index may hit any element covered by a
+// vector.store/transfer_write along a single memref dimension. Supported cases:
+//
+// 1) Direct index match (with optional offset):
+//    vector.transfer_write %v, %A[%i] : vector<4xf32>, memref<...>
+//    %x = memref.load %A[%i] : memref<...>
+//
+// 2) Loop IV range intersects the write range:
+//    vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
+//    scf.for %k = %c0 to %c4 step %c1 { %x = memref.load %A[%k] }
+//
+// 3) Constant index (or IV + constant) within the write range:
+//    vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
+//    %x = memref.load %A[%c2] : memref<...>
+//    %y = memref.load %A[%i + %c1] : memref<...>
+//
+// Args:
+// - loadIndex: index used by the scalar load for this dimension.
+// - offset: subview offset for the base memref dimension (if any).
+// - writeIndex: index used by the transfer_write for this dimension.
+// - extent: vector size along this dimension (number of elements written).
+// - loopsIVsMap: IV equivalence map between fused loops.
+static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset,
+                                      Value writeIndex, int64_t extent,
+                                      const IRMapping &loopsIVsMap) {
+  if (extent <= 0)
+    return false;
+
+  // Extract constant loop bounds for IVs (scf.for / scf.parallel).
+  auto getConstLoopBoundsForIV =
+      [](Value index) -> std::optional<std::tuple<int64_t, int64_t, int64_t>> {
+    auto blockArg = dyn_cast<BlockArgument>(index);
+    if (!blockArg)
+      return std::nullopt;
+    auto *parentOp = blockArg.getOwner()->getParentOp();
+    auto loopLike = dyn_cast<LoopLikeOpInterface>(parentOp);
+    if (!loopLike)
+      return std::nullopt;
+    auto ranges = getConstLoopBounds(loopLike);
+    if (ranges.empty())
+      return std::nullopt;
+
+    auto ivs = loopLike.getLoopInductionVars();
+    if (!isa<scf::ForOp, scf::ParallelOp>(parentOp) || !ivs)
+      return std::nullopt;
+    auto it = llvm::find(*ivs, blockArg);
+    if (it == ivs->end())
+      return std::nullopt;
+    unsigned pos = std::distance(ivs->begin(), it);
+    if (pos >= ranges.size())
+      return std::nullopt;
+    auto [lb, ub, step] = ranges[pos];
+    return std::make_tuple(lb, ub, step);
+  };
+
+  std::optional<int64_t> offsetConst = getConstantIntValue(offset);
+  std::optional<int64_t> writeConst =
+      writeIndex ? getConstantIndex(writeIndex) : std::optional<int64_t>(0);
+  if (!writeConst && writeIndex) {
+    // Treat single-iteration IVs as constants for matching.
+    if (auto bounds = getConstLoopBoundsForIV(writeIndex)) {
+      auto [lb, ub, step] = *bounds;
+      if (step > 0 && ub == lb + step)
+        writeConst = lb;
+    }
+  }
+
+  // Check whether a loop IV is fully contained in a constant write range.
+  auto loopIVWithinRange = [](int64_t lb, int64_t ub, int64_t step,
+                              int64_t rangeStart, int64_t rangeExtent) -> bool {
+    if (rangeExtent <= 0 || step <= 0)
+      return false;
+    if (ub <= lb)
+      return false;
+    int64_t rangeEnd = rangeStart + rangeExtent;
+    return lb >= rangeStart && ub <= rangeEnd;
+  };
+
+  if (offsetConst && writeConst) {
+    // Constant start of the write range; check constant load or loop IV range.
+    int64_t start = *offsetConst + *writeConst;
+    if (auto loadConst = getConstantIndex(loadIndex))
+      return (*loadConst >= start && *loadConst < start + extent);
+    if (auto bounds = getConstLoopBoundsForIV(loadIndex)) {
+      auto [lb, ub, step] = *bounds;
+      return loopIVWithinRange(lb, ub, step, start, extent);
+    }
+  }
+
+  if (writeIndex) {
+    // Direct IV match (or IV + constant) against the write index.
+    if (offsetConst && *offsetConst == 0 &&
+        valsAreEquivalent(loadIndex, writeIndex, loopsIVsMap))
+      return true;
+    if (auto addConst = getAddConstant(loadIndex, writeIndex, loopsIVsMap)) {
+      // Match load index of the form writeIndex + C within the write extent.
+      if (offsetConst) {
+        int64_t start = *offsetConst;
+        return (*addConst >= start && *addConst < start + extent);
+      }
+    }
+    return false;
+  }
+
+  if (auto offsetVal = offset.dyn_cast<Value>()) {
----------------
Hardcode84 wrote:

`x.dyn_cast<>` is deprecated in favor of `dyn_cast<>(x)`

https://github.com/llvm/llvm-project/pull/179284


More information about the Mlir-commits mailing list