[Mlir-commits] [mlir] c5ae550 - [mlir][scf] Refactor and improve ParallelLoopFusion (#179284)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 18 06:54:17 PST 2026


Author: fabrizio-indirli
Date: 2026-02-18T14:54:12Z
New Revision: c5ae550344d5da2d2f44523829a9af979d8af6f4

URL: https://github.com/llvm/llvm-project/commit/c5ae550344d5da2d2f44523829a9af979d8af6f4
DIFF: https://github.com/llvm/llvm-project/commit/c5ae550344d5da2d2f44523829a9af979d8af6f4.diff

LOG: [mlir][scf] Refactor and improve ParallelLoopFusion (#179284)

Refactor and extend the scf::ParalleLoopFusion pass:
- Refactor code, rename functions and add comments to improve
readability
- Make the dependency analysis safer by checking for read-after-write
dependencies also with vector.load/store & vector.transfer_read/write
ops, in addition to memref.load/store, and bail out when other
unsupported ops with memory effects are found.
- Extend the cases when the fusion is applied: allow fusing also when
one of the two loops reads/writes to memory through a full view/alias of
the buffer (read/written by the dual operation in the other loop) that
can be trivially resolved, including rank-reducing full subviews.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
    mlir/include/mlir/Dialect/SCF/Utils/Utils.h
    mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 5d2429bb476e6..9af0f301d763c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -157,6 +157,21 @@ void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
                                        ValueRange indices,
                                        SmallVectorImpl<Value> &sourceIndices);
 
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a rank-reducing full subview op, returns the indices w.r.t to the source
+/// memref of the memref.subview op. For example
+///
+///  %alias = memref.subview %src[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
+///                           to memref<2x2xf32>
+///  %val = memref.load %alias[%i, %j] : memref<2x2xf32>
+///
+/// could be folded into
+///
+///  %val = memref.load %src[0, %i, %j] : memref<1x2x2xf32>
+LogicalResult resolveSourceIndicesRankReducingSubview(
+    Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices);
+
 } // namespace memref
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index c85f3b02c4a44..a758032ef69b4 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -18,6 +18,7 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 #include <optional>
+#include <tuple>
 
 namespace mlir {
 class Location;
@@ -248,6 +249,12 @@ FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr,
     IRMapping *clonedToSrcOpsMap = nullptr);
 
+/// Get constant loop bounds and steps for each of the induction variables of
+/// the given loop operation, if all the loop's ranges are constant. Each entry
+/// in the returned vector is a tuple (lowerBound, upperBound, step).
+llvm::SmallVector<std::tuple<int64_t, int64_t, int64_t>>
+getConstLoopBounds(mlir::LoopLikeOpInterface loopOp);
+
 /// Get constant trip counts for each of the induction variables of the given
 /// loop operation. If any of the loop's trip counts is not constant, return an
 /// empty vector.

diff  --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 2d341dce665e5..cf126cd85ddce 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -286,5 +286,46 @@ void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
   }
 }
 
+LogicalResult resolveSourceIndicesRankReducingSubview(
+    Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices) {
+  if (!subViewOp.hasZeroOffset() || !subViewOp.hasUnitStride())
+    return failure();
+
+  MemRefType srcType = subViewOp.getSourceType();
+  MemRefType resType = subViewOp.getType();
+  unsigned srcRank = srcType.getRank();
+  unsigned resRank = resType.getRank();
+  if (srcRank <= resRank || indices.size() != resRank)
+    return failure();
+
+  auto droppedDims = subViewOp.getDroppedDims();
+  if (droppedDims.none() || droppedDims.count() != srcRank - resRank)
+    return failure();
+
+  auto mixedSizes = subViewOp.getMixedSizes();
+  if (mixedSizes.size() != srcRank)
+    return failure();
+
+  unsigned resultDim = 0;
+  for (unsigned sourceDim = 0; sourceDim < srcRank; ++sourceDim) {
+    if (droppedDims.test(sourceDim)) {
+      auto sizeCst = getConstantIntValue(mixedSizes[sourceDim]);
+      if (!sizeCst || *sizeCst != 1)
+        return failure();
+      sourceIndices.push_back(
+          getValueOrCreateConstantIndexOp(b, loc, b.getIndexAttr(0)));
+      continue;
+    }
+    if (resultDim >= indices.size())
+      return failure();
+    sourceIndices.push_back(indices[resultDim++]);
+  }
+  if (resultDim != indices.size())
+    return failure();
+
+  return success();
+}
+
 } // namespace memref
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index a9ffa9dc208a0..fb9aa4018d263 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   MLIRBufferizationTransforms
   MLIRDestinationStyleOpInterface
   MLIRDialectUtils
+  MLIRIndexDialect
   MLIRIR
   MLIRMemRefDialect
   MLIRPass

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 4ea832177c4f9..0b132e9109492 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -13,15 +13,31 @@
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 
 #include "mlir/Analysis/AliasAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+#include <tuple>
+
 namespace mlir {
 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -55,114 +71,670 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
          matchOperands(firstPloop.getStep(), secondPloop.getStep());
 }
 
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
+/// Check if both operations are the same type of memory write op and
+/// write to the same memory location (same buffer and same indices).
+static bool opsWriteSameMemLocation(Operation *op1, Operation *op2) {
+  if (!op1 || !op2 || op1->getName() != op2->getName())
+    return false;
+  if (op1 == op2)
+    return true;
+  // support only these memory-writing ops for now
+  if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
+    return false;
+  bool opsAreIdentical =
+      llvm::TypeSwitch<Operation *, bool>(op1)
+          .Case([&](memref::StoreOp storeOp1) {
+            auto storeOp2 = cast<memref::StoreOp>(op2);
+            return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
+                   (storeOp1.getIndices() == storeOp2.getIndices());
+          })
+          .Case([&](vector::TransferWriteOp writeOp1) {
+            auto writeOp2 = cast<vector::TransferWriteOp>(op2);
+            return (writeOp1.getBase() == writeOp2.getBase()) &&
+                   (writeOp1.getIndices() == writeOp2.getIndices()) &&
+                   (writeOp1.getMask() == writeOp2.getMask()) &&
+                   (writeOp1.getValueToStore().getType() ==
+                    writeOp2.getValueToStore().getType()) &&
+                   (writeOp1.getInBounds() == writeOp2.getInBounds());
+          })
+          .Case([&](vector::StoreOp vecStoreOp1) {
+            auto vecStoreOp2 = cast<vector::StoreOp>(op2);
+            return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
+                   (vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
+                   (vecStoreOp1.getValueToStore().getType() ==
+                    vecStoreOp2.getValueToStore().getType()) &&
+                   (vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
+                   (vecStoreOp1.getNontemporal() ==
+                    vecStoreOp2.getNontemporal());
+          })
+          .Default([](Operation *) { return false; });
+  return opsAreIdentical;
+}
+
+/// Check if val1 (from the first parallel loop) and val2 (from the
+/// second) are equivalent, considering the mapping of induction variables from
+/// the first to the second parallel loop.
+static bool valsAreEquivalent(Value val1, Value val2,
+                              const IRMapping &loopsIVsMap) {
+  if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
+      loopsIVsMap.lookupOrDefault(val2) == val1)
+    return true;
+  Operation *val1DefOp = val1.getDefiningOp();
+  Operation *val2DefOp = val2.getDefiningOp();
+  if (!val1DefOp || !val2DefOp)
+    return false;
+  if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
+    return false;
+  return OperationEquivalence::isEquivalentTo(
+      val1DefOp, val2DefOp,
+      [&](Value v1, Value v2) {
+        return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
+                       loopsIVsMap.lookupOrDefault(v2) == v1);
+      },
+      /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
+}
+
+/// If the `expr` value is the result of an integer addition of `base` and a
+/// constant, return the constant.
+static std::optional<int64_t> getAddConstant(Value expr, Value base,
+                                             const IRMapping &loopsIVsMap) {
+  if (auto addOp = expr.getDefiningOp<arith::AddIOp>()) {
+    if (auto constOp = getConstantIntValue(addOp.getLhs());
+        constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
+      return constOp.value();
+    if (auto constOp = getConstantIntValue(addOp.getRhs());
+        constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
+      return constOp.value();
+    return std::nullopt;
+  }
+
+  if (auto addOp = expr.getDefiningOp<index::AddOp>()) {
+    if (auto constOp = getConstantIntValue(addOp.getLhs());
+        constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
+      return constOp.value();
+    if (auto constOp = getConstantIntValue(addOp.getRhs());
+        constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
+      return constOp.value();
+    return std::nullopt;
+  }
+
+  if (auto applyOp = expr.getDefiningOp<affine::AffineApplyOp>()) {
+    AffineMap map = applyOp.getAffineMap();
+    if (map.getNumResults() != 1 || map.getNumDims() != 1 ||
+        map.getNumSymbols() != 0)
+      return std::nullopt;
+    if (!valsAreEquivalent(applyOp.getOperand(0), base, loopsIVsMap))
+      return std::nullopt;
+    AffineExpr result = map.getResult(0);
+    auto bin = dyn_cast<AffineBinaryOpExpr>(result);
+    if (!bin || bin.getKind() != AffineExprKind::Add)
+      return std::nullopt;
+    auto lhsDim = dyn_cast<AffineDimExpr>(bin.getLHS());
+    auto rhsDim = dyn_cast<AffineDimExpr>(bin.getRHS());
+    auto lhsConst = dyn_cast<AffineConstantExpr>(bin.getLHS());
+    auto rhsConst = dyn_cast<AffineConstantExpr>(bin.getRHS());
+    if (lhsConst && rhsDim)
+      return lhsConst.getValue();
+    if (rhsConst && lhsDim)
+      return rhsConst.getValue();
+  }
+  return std::nullopt;
+}
+
+// Return true if the scalar load index may hit any element covered by a
+// vector.store/transfer_write along a single memref dimension. Supported cases:
+//
+// 1) Direct index match (with optional offset):
+//    vector.transfer_write %v, %A[%i] : vector<4xf32>, memref<...>
+//    %x = memref.load %A[%i] : memref<...>
+//
+// 2) Loop IV range intersects the write range:
+//    vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
+//    scf.for %k = %c0 to %c4 step %c1 { %x = memref.load %A[%k] }
+//
+// 3) Constant index (or IV + constant) within the write range:
+//    vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
+//    %x = memref.load %A[%c2] : memref<...>
+//    %y = memref.load %A[%i + %c1] : memref<...>
+//
+// Args:
+// - loadIndex: index used by the scalar load for this dimension.
+// - offset: subview offset for the base memref dimension (if any).
+// - writeIndex: index used by the transfer_write for this dimension. Can be
+// null if the dim was dropped by a rank reducing subview, whose result is
+// written by the vector.write.
+// - extent: vector size along this dimension (number of elements written).
+// - loopsIVsMap: IV equivalence map between fused loops.
+static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset,
+                                      Value writeIndex, int64_t extent,
+                                      const IRMapping &loopsIVsMap) {
+  if (extent <= 0)
+    return false;
+
+  // Extract constant loop bounds for loop IVs (e.g. from scf.for).
+  auto getConstLoopBoundsForIV =
+      [](Value index) -> std::optional<std::tuple<int64_t, int64_t, int64_t>> {
+    auto blockArg = dyn_cast<BlockArgument>(index);
+    if (!blockArg)
+      return std::nullopt;
+    auto *parentOp = blockArg.getOwner()->getParentOp();
+    auto loopLike = dyn_cast<LoopLikeOpInterface>(parentOp);
+    if (!loopLike)
+      return std::nullopt;
+    auto ranges = getConstLoopBounds(loopLike);
+    if (ranges.empty())
+      return std::nullopt;
+
+    auto ivs = loopLike.getLoopInductionVars();
+    if (!ivs)
+      return std::nullopt;
+    auto it = llvm::find(*ivs, blockArg);
+    if (it == ivs->end())
+      return std::nullopt;
+    unsigned pos = std::distance(ivs->begin(), it);
+    if (pos >= ranges.size())
+      return std::nullopt;
+    auto [lb, ub, step] = ranges[pos];
+    return std::make_tuple(lb, ub, step);
+  };
+
+  std::optional<int64_t> offsetConst = getConstantIntValue(offset);
+  std::optional<int64_t> writeConst =
+      writeIndex ? getConstantIntValue(writeIndex) : std::optional<int64_t>(0);
+  if (!writeConst && writeIndex) {
+    // Treat single-iteration IVs as constants for matching.
+    if (auto bounds = getConstLoopBoundsForIV(writeIndex)) {
+      auto [lb, ub, step] = *bounds;
+      if (step > 0 && ub == lb + step)
+        writeConst = lb;
+    }
+  }
+
+  // Check whether a loop IV is fully contained in a constant write range.
+  auto loopIVWithinRange = [](int64_t lb, int64_t ub, int64_t step,
+                              int64_t rangeStart, int64_t rangeExtent) -> bool {
+    if (rangeExtent <= 0 || step <= 0)
+      return false;
+    if (ub <= lb)
+      return false;
+    int64_t rangeEnd = rangeStart + rangeExtent;
+    return lb >= rangeStart && ub <= rangeEnd;
+  };
+
+  if (offsetConst && writeConst) {
+    // Constant start of the write range; check constant load or loop IV range.
+    int64_t start = *offsetConst + *writeConst;
+    if (auto loadConst = getConstantIntValue(loadIndex))
+      return (*loadConst >= start && *loadConst < start + extent);
+    if (auto bounds = getConstLoopBoundsForIV(loadIndex)) {
+      auto [lb, ub, step] = *bounds;
+      return loopIVWithinRange(lb, ub, step, start, extent);
+    }
+  }
+
+  if (writeIndex) {
+    // Direct IV match (or IV + constant) against the write index.
+    if (offsetConst && *offsetConst == 0 &&
+        valsAreEquivalent(loadIndex, writeIndex, loopsIVsMap))
+      return true;
+    if (auto addConst = getAddConstant(loadIndex, writeIndex, loopsIVsMap)) {
+      // Match load index of the form writeIndex + C within the write extent.
+      if (offsetConst) {
+        int64_t start = *offsetConst;
+        return (*addConst >= start && *addConst < start + extent);
+      }
+    }
+    return false;
+  }
+
+  if (auto offsetVal = dyn_cast<Value>(offset)) {
+    // Exact match when extent is 1 and the load hits the offset value.
+    if (extent == 1 && valsAreEquivalent(loadIndex, offsetVal, loopsIVsMap))
+      return true;
+  }
+
+  return false;
+}
+
+/// Return the base memref value used by the given memory op.
+static Value getBaseMemref(Operation *op) {
+  // TODO: use the common interface for memory ops once available.
+  return llvm::TypeSwitch<Operation *, Value>(op)
+      .Case([&](memref::LoadOp load) { return load.getMemRef(); })
+      .Case([&](memref::StoreOp store) { return store.getMemRef(); })
+      .Case([&](vector::TransferReadOp read) { return read.getBase(); })
+      .Case([&](vector::TransferWriteOp write) { return write.getBase(); })
+      .Case([&](vector::LoadOp load) { return load.getBase(); })
+      .Case([&](vector::StoreOp store) { return store.getBase(); })
+      .Default([](Operation *) { return Value(); });
+}
+
+/// Recognize scalar memref.load of an element produced by a vector write
+/// (vector.transfer_write or vector.store, optionally through a rank-reducing
+/// unit-stride subview) of the same buffer. This covers the pattern where a
+/// vector write stores a full lane pack and a subsequent scalar load reads an
+/// element from that lane pack. EXAMPLE:
+///  vector.transfer_write %V, %arg[%x, %y, ..., 0] {in_bounds = [true]} :
+///             vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
+///  scf.for %iter = %c0 to %c4 step %c1 iter_args(...) -> (f32) {
+///    %0 = memref.load %arg[%x, %y, ..., %iter] : memref<1x128x16x4xf32>
+///    ...
+///  }
+///
+static bool isLoadOnWrittenVector(memref::LoadOp loadOp, Value writeBase,
+                                  ValueRange writeIndices, VectorType vecTy,
+                                  ArrayRef<int64_t> vectorDimForWriteDim,
+                                  const IRMapping &ivsMap) {
+  if (!vecTy)
+    return false;
+
+  Value base = writeBase;
+  // The write base if there is no subview, or the subview source otherwise.
+  MemrefValue baseMemref = nullptr;
+  SmallVector<OpFoldResult> offsets;
+  llvm::SmallBitVector droppedDims;
+  bool hasSubview = false;
+  auto *ctx = loadOp.getContext();
+  if (auto subView = base.getDefiningOp<memref::SubViewOp>()) {
+    if (!subView.hasUnitStride())
+      return false;
+    baseMemref = cast<MemrefValue>(subView.getSource());
+    offsets = llvm::to_vector(subView.getMixedOffsets());
+    droppedDims = subView.getDroppedDims();
+    hasSubview = true;
+  } else {
+    baseMemref = dyn_cast<MemrefValue>(base);
+    if (!baseMemref)
+      return false;
+  }
+
+  auto loadIndices = loadOp.getIndices();
+  unsigned baseRank = baseMemref.getType().getRank();
+  if ((loadOp.getMemref() != baseMemref) || (loadIndices.size() != baseRank))
+    return false;
+
+  unsigned writeRank = writeIndices.size();
+  if ((!hasSubview && writeRank != baseRank) ||
+      (hasSubview && offsets.size() != baseRank) ||
+      (vectorDimForWriteDim.size() != writeRank))
+    return false;
+
+  auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
+  unsigned writeMemrefDim = 0;
+  for (unsigned baseDim : llvm::seq(baseRank)) {
+    bool wasDropped = (hasSubview && droppedDims.test(baseDim));
+    int64_t vectorDim = !wasDropped ? vectorDimForWriteDim[writeMemrefDim] : -1;
+    int64_t extent = 1;
+    if (vectorDim >= 0) {
+      int64_t dimSize = vecTy.getDimSize(vectorDim);
+      if (dimSize == ShapedType::kDynamic)
+        return false;
+      extent = dimSize;
+    }
+    Value writeIndex = !wasDropped ? writeIndices[writeMemrefDim] : Value();
+    OpFoldResult offset =
+        hasSubview ? offsets[baseDim] : OpFoldResult(zeroAttr);
+    if (!loadIndexWithinWriteRange(loadIndices[baseDim], offset, writeIndex,
+                                   extent, ivsMap))
+      return false;
+    if (!wasDropped)
+      ++writeMemrefDim;
+  }
+
+  return true;
+}
+
+/// Recognize scalar memref.load of an element produced by a
+/// vector.transfer_write
+static bool loadMatchesVectorWrite(memref::LoadOp loadOp,
+                                   vector::TransferWriteOp writeOp,
+                                   const IRMapping &ivsMap) {
+  auto vecTy = dyn_cast<VectorType>(writeOp.getVector().getType());
+  if (!vecTy)
+    return false;
+
+  unsigned writeRank = writeOp.getIndices().size();
+  AffineMap permutationMap = writeOp.getPermutationMap();
+  if (!permutationMap.isProjectedPermutation() ||
+      permutationMap.getNumResults() != vecTy.getRank() ||
+      permutationMap.getNumDims() != writeRank)
+    return false;
+
+  SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
+  for (unsigned vecDim = 0; vecDim < permutationMap.getNumResults(); ++vecDim) {
+    auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(vecDim));
+    if (!dimExpr)
+      return false;
+    unsigned writeDim = dimExpr.getPosition();
+    if (writeDim >= writeRank || vectorDimForWriteDim[writeDim] != -1)
+      return false;
+    vectorDimForWriteDim[writeDim] = vecDim;
+  }
+
+  return isLoadOnWrittenVector(loadOp, writeOp.getBase(), writeOp.getIndices(),
+                               vecTy, vectorDimForWriteDim, ivsMap);
+}
+
+/// Recognize scalar memref.load of an element produced by a vector.store
+static bool loadMatchesVectorStore(memref::LoadOp loadOp,
+                                   vector::StoreOp storeOp,
+                                   const IRMapping &ivsMap) {
+  auto vecTy = dyn_cast<VectorType>(storeOp.getValueToStore().getType());
+  if (!vecTy)
+    return false;
+
+  unsigned writeRank = storeOp.getIndices().size();
+  if (vecTy.getRank() > writeRank)
+    return false;
+
+  SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
+  unsigned vecRank = vecTy.getRank();
+  for (unsigned i = 0; i < vecRank; ++i) {
+    unsigned writeDim = writeRank - vecRank + i;
+    vectorDimForWriteDim[writeDim] = i;
+  }
+
+  return isLoadOnWrittenVector(loadOp, storeOp.getBase(), storeOp.getIndices(),
+                               vecTy, vectorDimForWriteDim, ivsMap);
+}
+
+/// Check if both operations access the same positions of the same
+/// buffer, but one of the two does it through a rank-reducing full subview of
+/// the buffer (the other's base). EXAMPLE:
+///  memref.store %a, %buf[%c0, %i, %j] : memref<1x2x2xf32>
+///  %alias = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
+///                           to memref<2x2xf32>
+///  %val = memref.load %alias[%i, %j] : memref<2x2xf32>
+template <typename OpTy1, typename OpTy2>
+static bool opsAccessSameIndicesViaRankReducingSubview(
+    OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap,
+    OpBuilder &b) {
+  auto base1 = cast<MemrefValue>(getBaseMemref(op1));
+  auto base2 = cast<MemrefValue>(getBaseMemref(op2));
+  if (!base1 || !base2)
+    return false;
+
+  auto accessThroughTrivialSubviewIsSame =
+      [&b](memref::SubViewOp subView, ValueRange subViewAccess,
+           ValueRange sourceAccess, const IRMapping &ivsMap) -> bool {
+    SmallVector<Value> resolvedSubviewAccess;
+    LogicalResult resolved = resolveSourceIndicesRankReducingSubview(
+        subView.getLoc(), b, subView, subViewAccess, resolvedSubviewAccess);
+    if (failed(resolved) ||
+        (resolvedSubviewAccess.size() != sourceAccess.size()))
+      return false;
+    for (auto [dimIdx, resolvedIndex] :
+         llvm::enumerate(resolvedSubviewAccess)) {
+      if (!matchPattern(resolvedIndex, m_Zero()) &&
+          !valsAreEquivalent(resolvedIndex, sourceAccess[dimIdx], ivsMap))
+        return false;
+    }
+    return true;
+  };
+
+  // Case 1: op1 uses a subview of op2's base.
+  if (auto subView = base1.template getDefiningOp<memref::SubViewOp>();
+      subView &&
+      memref::isSameViewOrTrivialAlias(
+          base2, cast<MemrefValue>(subView.getSource())) &&
+      accessThroughTrivialSubviewIsSame(subView, op1.getIndices(),
+                                        op2.getIndices(),
+                                        firstToSecondPloopIVsMap))
+    return true;
+
+  // Case 2: op2 uses a subview of op1's base.
+  if (auto subView = base2.template getDefiningOp<memref::SubViewOp>();
+      subView &&
+      memref::isSameViewOrTrivialAlias(
+          base1, cast<MemrefValue>(subView.getSource())) &&
+      accessThroughTrivialSubviewIsSame(subView, op2.getIndices(),
+                                        op1.getIndices(),
+                                        firstToSecondPloopIVsMap))
+    return true;
+
+  return false;
+}
+
+/// Check if both memory read/write operations access the same indices
+/// (considering also the mapping of induction variables from the first to the
+/// second parallel loop).
+template <typename OpTy1, typename OpTy2>
+static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2,
+                                 const IRMapping &loopsIVsMap, OpBuilder &b) {
+  auto indices1 = op1.getIndices();
+  auto indices2 = op2.getIndices();
+  if (indices1.size() != indices2.size())
+    return opsAccessSameIndicesViaRankReducingSubview(op1, op2, loopsIVsMap, b);
+  for (auto [idx1, idx2] : llvm::zip(indices1, indices2)) {
+    if (!valsAreEquivalent(idx1, idx2, loopsIVsMap))
+      return false;
+  }
+  return true;
+}
+
+/// Check if the loadOp reads from the same memory location (same buffer,
+/// same indices and same properties) as written by the storeOp.
+static bool
+loadsFromSameMemoryLocationWrittenBy(Operation *loadOp, Operation *storeOp,
+                                     const IRMapping &firstToSecondPloopIVsMap,
+                                     OpBuilder &b) {
+  if (!loadOp || !storeOp)
+    return false;
+  // Support only these memory-reading ops for now
+  if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp))
+    return false;
+  bool accessSameMemory =
+      llvm::TypeSwitch<Operation *, bool>(loadOp)
+          .Case([&](memref::LoadOp memLoadOp) {
+            if (auto memStoreOp = dyn_cast<memref::StoreOp>(storeOp))
+              return opsAccessSameIndices(memLoadOp, memStoreOp,
+                                          firstToSecondPloopIVsMap, b);
+            if (auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp))
+              return loadMatchesVectorWrite(memLoadOp, vecWriteOp,
+                                            firstToSecondPloopIVsMap);
+            if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp))
+              return loadMatchesVectorStore(memLoadOp, vecStoreOp,
+                                            firstToSecondPloopIVsMap);
+            return false;
+          })
+          .Case([&](vector::TransferReadOp vecReadOp) {
+            auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
+            if (!vecWriteOp)
+              return false;
+            return opsAccessSameIndices(vecReadOp, vecWriteOp,
+                                        firstToSecondPloopIVsMap, b) &&
+                   (vecReadOp.getMask() == vecWriteOp.getMask()) &&
+                   (vecReadOp.getInBounds() == vecWriteOp.getInBounds());
+          })
+          .Case([&](vector::LoadOp vecLoadOp) {
+            auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
+            if (!vecStoreOp)
+              return false;
+            return opsAccessSameIndices(vecLoadOp, vecStoreOp,
+                                        firstToSecondPloopIVsMap, b) &&
+                   (vecLoadOp.getAlignment() == vecStoreOp.getAlignment());
+          })
+          .Default([](Operation *) { return false; });
+  return accessSameMemory;
+}
+
+static Value getStoreOpTargetBuffer(Operation *op) {
+  return llvm::TypeSwitch<Operation *, Value>(op)
+      .Case([&](memref::StoreOp storeOp) { return storeOp.getMemRef(); })
+      .Case([&](vector::TransferWriteOp writeOp) { return writeOp.getBase(); })
+      .Case([&](vector::StoreOp vecStoreOp) { return vecStoreOp.getBase(); })
+      .Default([](Operation *) { return Value(); });
+}
+
+/// To be called when `mayAlias(val1, val2)` is true. Check if the potential
+/// aliasing between the loadOp and storeOp can be resolved by analyzing their
+/// access patterns.
+static bool canResolveAlias(Operation *loadOp, Operation *storeOp,
+                            const IRMapping &loopsIVsMap) {
+  if (auto transfWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
+      transfWriteOp && isa<memref::LoadOp>(loadOp))
+    return loadMatchesVectorWrite(cast<memref::LoadOp>(loadOp), transfWriteOp,
+                                  loopsIVsMap);
+  if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
+      vecStoreOp && isa<memref::LoadOp>(loadOp))
+    return loadMatchesVectorStore(cast<memref::LoadOp>(loadOp), vecStoreOp,
+                                  loopsIVsMap);
+  return false;
+}
+
+/// Check that the parallel loops have no mixed access to the same buffers.
+/// Return `true` if the second parallel loop does not read or write the buffers
+/// written by the first loop using 
diff erent indices.
+static bool haveNoDataDependenciesExceptSameIndex(
     ParallelOp firstPloop, ParallelOp secondPloop,
     const IRMapping &firstToSecondPloopIndices,
-    llvm::function_ref<bool(Value, Value)> mayAlias) {
-  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
-  SmallVector<Value> bufferStoresVec;
-  firstPloop.getBody()->walk([&](memref::StoreOp store) {
-    bufferStores[store.getMemRef()].push_back(store.getIndices());
-    bufferStoresVec.emplace_back(store.getMemRef());
-  });
-  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
-    Value loadMem = load.getMemRef();
-    // Stop if the memref is defined in secondPloop body. Careful alias analysis
-    // is needed.
-    auto *memrefDef = loadMem.getDefiningOp();
-    if (memrefDef && memrefDef->getBlock() == load->getBlock())
+    llvm::function_ref<bool(Value, Value)> mayAlias, OpBuilder &b) {
+  // Map buffers to their store/write ops in the firstPloop
+  DenseMap<Value, SmallVector<Operation *>> bufferStoresInFirstPloop;
+  // Record all the memory buffers used in store/write ops found in firstPloop
+  llvm::SmallSetVector<Value, 4> buffersWrittenInFirstPloop;
+
+  auto collectStoreOpsInWalk = [&](Operation *op) {
+    auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(op);
+    // Ignore ops that don't write to memory
+    if (!memOpInterf || (!memOpInterf.hasEffect<MemoryEffects::Write>() &&
+                         !memOpInterf.hasEffect<MemoryEffects::Free>()))
+      return WalkResult::advance();
+
+    // Only these memory-writing ops are supported for now:
+    // memref.store, vector.transfer_write, vector.store
+    Value storeOpBase = getStoreOpTargetBuffer(op);
+    if (!storeOpBase)
       return WalkResult::interrupt();
 
-    for (Value store : bufferStoresVec)
-      if (store != loadMem && mayAlias(store, loadMem))
-        return WalkResult::interrupt();
+    // Expect the base operand to be a Memref
+    MemrefValue storeOpBaseMemref = dyn_cast<MemrefValue>(storeOpBase);
+    if (!storeOpBaseMemref)
+      return WalkResult::interrupt();
+    // Get the original memref buffer, skipping full view-like ops
+    Value buffer = memref::skipFullyAliasingOperations(storeOpBaseMemref);
+    bufferStoresInFirstPloop[buffer].push_back(op);
+    buffersWrittenInFirstPloop.insert(buffer);
+    return WalkResult::advance();
+  };
 
-    auto write = bufferStores.find(loadMem);
-    if (write == bufferStores.end())
-      return WalkResult::advance();
+  // Walk the first parallel loop to collect all store/write ops and their
+  // target buffers
+  if (firstPloop.getBody()->walk(collectStoreOpsInWalk).wasInterrupted())
+    return false;
 
-    // Check that at last one store was retrieved
-    if (write->second.empty())
+  // Check that this load/read op encountered while walking the second parallel
+  // loop does not have incompatible data dependencies with the store/write ops
+  // collected from the first parallel loop: the loops can be fused only if in
+  // the 2nd loop there are no loads/stores from/to the buffers written in the
+  // 1st loop, except when on the same exact memory location (same indices) as
+  // written in the 1st loop.
+  auto checkLoadInWalkHasNoIncompatibleDataDeps = [&](Operation *loadOp) {
+    auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(loadOp);
+    // To be conservative, we should stop on ops that don't advertise their
+    // memory effects. However, many ops don't implement MemoryEffectOpInterface
+    // yet, so for now we just skip them.
+    // TODO: once more ops add MemoryEffectOpInterface, interrupt the walk here.
+    if (!memOpInterf &&
+        !loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>())
+      return WalkResult::advance();
+    // Ignore ops that don't read from memory, and wrapping ops that have nested
+    // memory effects (e.g. loops, conditionals) as they will be analyzed when
+    // visiting their nested ops.
+    if ((!memOpInterf &&
+         loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) ||
+        (memOpInterf && !memOpInterf.hasEffect<MemoryEffects::Read>()))
+      return WalkResult::advance();
+    // Support only these memory-reading ops for now
+    if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp) ||
+        !isa<MemrefValue>(loadOp->getOperand(0)))
       return WalkResult::interrupt();
 
-    auto storeIndices = write->second.front();
+    MemrefValue loadOpBase = cast<MemrefValue>(loadOp->getOperand(0));
+    MemrefValue loadedOrigBuf = memref::skipFullyAliasingOperations(loadOpBase);
 
-    // Multiple writes to the same memref are allowed only on the same indices
-    for (const auto &othStoreIndices : write->second) {
-      if (othStoreIndices != storeIndices)
+    for (Value storedMem : buffersWrittenInFirstPloop)
+      if ((storedMem != loadedOrigBuf) && mayAlias(storedMem, loadedOrigBuf) &&
+          !llvm::all_of(bufferStoresInFirstPloop[storedMem],
+                        [&](Operation *storeOp) {
+                          return canResolveAlias(loadOp, storeOp,
+                                                 firstToSecondPloopIndices);
+                        })) {
         return WalkResult::interrupt();
+      }
+
+    auto writeOpsIt = bufferStoresInFirstPloop.find(loadedOrigBuf);
+    if (writeOpsIt == bufferStoresInFirstPloop.end())
+      return WalkResult::advance();
+    // Store/write ops to this buffer in the firstPloop
+    SmallVector<mlir::Operation *> &writeOps = writeOpsIt->second;
+
+    // If the first loop has no writes to this buffer, continue
+    if (writeOps.empty())
+      return WalkResult::advance();
+
+    Operation *writeOp = writeOps.front();
+
+    // In the first parallel loop, multiple writes to the same memref are
+    // allowed only on the same memory location
+    if (!llvm::all_of(writeOps, [&](Operation *otherWriteOp) {
+          return opsWriteSameMemLocation(writeOp, otherWriteOp);
+        })) {
+      return WalkResult::interrupt();
     }
 
-    // Check that the load indices of secondPloop coincide with store indices of
-    // firstPloop for the same memrefs.
-    auto loadIndices = load.getIndices();
-    if (storeIndices.size() != loadIndices.size())
+    // Check that the load in secondPloop reads from the same memory location as
+    // written by the corresponding store in firstPloop
+    if (!loadsFromSameMemoryLocationWrittenBy(loadOp, writeOp,
+                                              firstToSecondPloopIndices, b)) {
       return WalkResult::interrupt();
-    for (int i = 0, e = storeIndices.size(); i < e; ++i) {
-      if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
-          loadIndices[i]) {
-        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
-        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
-        if (storeIndexDefOp && loadIndexDefOp) {
-          if (!isMemoryEffectFree(storeIndexDefOp))
-            return WalkResult::interrupt();
-          if (!isMemoryEffectFree(loadIndexDefOp))
-            return WalkResult::interrupt();
-          if (!OperationEquivalence::isEquivalentTo(
-                  storeIndexDefOp, loadIndexDefOp,
-                  [&](Value storeIndex, Value loadIndex) {
-                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
-                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
-                      return failure();
-                    else
-                      return success();
-                  },
-                  /*markEquivalent=*/nullptr,
-                  OperationEquivalence::Flags::IgnoreLocations)) {
-            return WalkResult::interrupt();
-          }
-        } else {
-          return WalkResult::interrupt();
-        }
-      }
     }
+
     return WalkResult::advance();
-  });
-  return !walkResult.wasInterrupted();
+  };
+
+  // Walk the second parallel loop to check load/read ops against the stores
+  // collected from the first parallel loop.
+  return !secondPloop.getBody()
+              ->walk(checkLoadInWalkHasNoIncompatibleDataDeps)
+              .wasInterrupted();
 }
 
-/// Analyzes dependencies in the most primitive way by checking simple read and
-/// write patterns.
-static LogicalResult
-verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
-                   const IRMapping &firstToSecondPloopIndices,
-                   llvm::function_ref<bool(Value, Value)> mayAlias) {
-  if (!haveNoReadsAfterWriteExceptSameIndex(
-          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
-    return failure();
+/// Check that in each loop there are no read ops on the buffers written
+/// by the other loop, except when reading from the same exact memory location
+/// (same indices) as written in the other loop.
+static bool
+noIncompatibleDataDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
+                               const IRMapping &firstToSecondPloopIndices,
+                               llvm::function_ref<bool(Value, Value)> mayAlias,
+                               OpBuilder &b) {
+  if (!haveNoDataDependenciesExceptSameIndex(
+          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias, b))
+    return false;
 
   IRMapping secondToFirstPloopIndices;
   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
                                 firstPloop.getBody()->getArguments());
-  return success(haveNoReadsAfterWriteExceptSameIndex(
-      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+  return haveNoDataDependenciesExceptSameIndex(
+      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias, b);
 }
 
+/// Check if fusion of the two parallel loops is legal:
+/// i.e. no nested parallel loops, equal iteration spaces,
+/// and no incompatible data dependencies between the loops.
 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
                           const IRMapping &firstToSecondPloopIndices,
-                          llvm::function_ref<bool(Value, Value)> mayAlias) {
+                          llvm::function_ref<bool(Value, Value)> mayAlias,
+                          OpBuilder &b) {
   return !hasNestedParallelOp(firstPloop) &&
          !hasNestedParallelOp(secondPloop) &&
          equalIterationSpaces(firstPloop, secondPloop) &&
-         succeeded(verifyDependencies(firstPloop, secondPloop,
-                                      firstToSecondPloopIndices, mayAlias));
+         noIncompatibleDataDependencies(firstPloop, secondPloop,
+                                        firstToSecondPloopIndices, mayAlias, b);
 }
 
-/// Prepends operations of firstPloop's body into secondPloop's body.
-/// Updates secondPloop with new loop.
+/// Prepend operations of firstPloop's body into secondPloop's body.
+/// Update secondPloop with new loop.
 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
                         OpBuilder builder,
                         llvm::function_ref<bool(Value, Value)> mayAlias) {
@@ -172,7 +744,7 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
   firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
 
   if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
-                     mayAlias))
+                     mayAlias, builder))
     return;
 
   DominanceInfo dom;
@@ -272,6 +844,18 @@ struct ParallelLoopFusion
     auto &aa = getAnalysis<AliasAnalysis>();
 
     auto mayAlias = [&](Value val1, Value val2) -> bool {
+      // If the memref is defined in one of the parallel loops body, careful
+      // alias analysis is needed.
+      // TODO: check if this is still needed as a separate check.
+      auto val1Def = val1.getDefiningOp();
+      auto val2Def = val2.getDefiningOp();
+      auto val1Loop =
+          val1Def ? val1Def->getParentOfType<ParallelOp>() : nullptr;
+      auto val2Loop =
+          val2Def ? val2Def->getParentOfType<ParallelOp>() : nullptr;
+      if (val1Loop != val2Loop)
+        return true;
+
       return !aa.alias(val1, val2).isNo();
     };
 

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index f8a4f057c9f0d..e795f3f0b019b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1560,6 +1560,25 @@ bool mlir::isPerfectlyNestedForLoops(
   return true;
 }
 
+llvm::SmallVector<std::tuple<int64_t, int64_t, int64_t>>
+mlir::getConstLoopBounds(mlir::LoopLikeOpInterface loopOp) {
+  std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
+  std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
+  std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
+  if (!loBnds || !upBnds || !steps)
+    return {};
+  llvm::SmallVector<std::tuple<int64_t, int64_t, int64_t>> loopRanges;
+  for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
+    auto lbCst = getConstantIntValue(lb);
+    auto ubCst = getConstantIntValue(ub);
+    auto stepCst = getConstantIntValue(step);
+    if (!lbCst || !ubCst || !stepCst)
+      return {};
+    loopRanges.emplace_back(*lbCst, *ubCst, *stepCst);
+  }
+  return loopRanges;
+}
+
 llvm::SmallVector<llvm::APInt>
 mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
   std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();

diff  --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 0d4ea6f20e8d9..d876062b704f2 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -314,23 +314,24 @@ func.func @do_not_fuse_unmatching_read_write_patterns(
 
 // -----
 
-func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
+func.func @do_not_fuse_loops_with_nonfull_alias_defined_in_loop_bodies() {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
   %buffer  = memref.alloc() : memref<2x2xf32>
-  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c1) step (%c1, %c1) {
+    memref.store %c1fp, %buffer[%i, %j] : memref<2x2xf32>
     scf.reduce
   }
-  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
-    %A = memref.subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
-      : memref<2x2xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-    %A_elem = memref.load %A[%i, %j] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c1) step (%c1, %c1) {
+    %A = memref.subview %buffer[%i, %c0][2, 1][1, 1] : memref<2x2xf32> to memref<2x1xf32, strided<[2, 1], offset: ?>>
+    %A_elem = memref.load %A[%i, %j] : memref<2x1xf32, strided<[2, 1], offset: ?>>
     scf.reduce
   }
   return
 }
-// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
+// CHECK-LABEL: func @do_not_fuse_loops_with_nonfull_alias_defined_in_loop_bodies
 // CHECK:        scf.parallel
 // CHECK:        scf.parallel
 
@@ -604,6 +605,415 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var(
 
 // -----
 
+func.func @fuse_trivial_rank_reducing_subview() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c1fp = arith.constant 1.0 : f32
+  %buf = memref.alloc() : memref<1x2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c1fp, %buf[%c0, %i, %j] : memref<1x2x2xf32>
+    scf.reduce
+  }
+  %sub = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]
+      : memref<1x2x2xf32> to memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %v = memref.load %sub[%i, %j] : memref<2x2xf32>
+    memref.store %v, %buf[%c0, %i, %j] : memref<1x2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %buf : memref<1x2x2xf32>
+  return
+}
+// CHECK-LABEL: func @fuse_trivial_rank_reducing_subview
+// CHECK:       %[[BUF:.*]] = memref.alloc() : memref<1x2x2xf32>
+// CHECK:       %[[SUB:.*]] = memref.subview %[[BUF]]
+// CHECK:       scf.parallel
+// CHECK:         memref.store {{.*}}, %[[BUF]]
+// CHECK:         %[[L:.*]] = memref.load %[[SUB]]
+// CHECK:         memref.store %[[L]], %[[BUF]]
+// CHECK-NOT:   scf.parallel
+// CHECK:       memref.dealloc %[[BUF]] : memref<1x2x2xf32>
+
+// -----
+
+func.func @do_not_fuse_nontrivial_subview_offset() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c1fp = arith.constant 1.0 : f32
+  %buf = memref.alloc() : memref<2x2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c1fp, %buf[%c0, %i, %j] : memref<2x2x2xf32>
+    scf.reduce
+  }
+  %sub = memref.subview %buf[1, 0, 0][1, 2, 2][1, 1, 1]
+      : memref<2x2x2xf32> to memref<2x2xf32, strided<[2, 1], offset: 4>>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %v = memref.load %sub[%i, %j]
+        : memref<2x2xf32, strided<[2, 1], offset: 4>>
+    memref.store %v, %buf[%c0, %i, %j] : memref<2x2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %buf : memref<2x2x2xf32>
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_nontrivial_subview_offset
+// CHECK:       scf.parallel
+// CHECK:       scf.parallel
+
+// -----
+
+func.func @fuse_vector_load_store(%A: memref<4x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %vec0 = arith.constant dense<0.0> : vector<4xf32>
+  scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+    vector.store %vec0, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+    %v = vector.load %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+    vector.store %v, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @fuse_vector_load_store
+// CHECK:       scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) {
+// CHECK:         vector.store
+// CHECK:         %[[V:.*]] = vector.load
+// CHECK:         vector.store %[[V]]
+// CHECK-NOT:   scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_
diff erent_indices(%A: memref<4x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %vec0 = arith.constant dense<0.0> : vector<4xf32>
+  scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+    vector.store %vec0, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+    %j = affine.apply affine_map<(d0) -> (d0 + 1)>(%i)
+    %v = vector.load %A[%j, %c0] : memref<4x4xf32>, vector<4xf32>
+    vector.store %v, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_
diff erent_indices
+// CHECK:       scf.parallel
+// CHECK:       scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_same_indices(%A: memref<4x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %zero = arith.constant 0.0 : f32
+  scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+    %v = vector.transfer_read %A[%i, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
+    vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+    %v = vector.transfer_read %A[%i, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
+    vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_same_indices
+// CHECK:       scf.parallel
+// CHECK:         vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK:         vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK:         vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK:         vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK-NOT:   scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_transfer_
diff erent_indices(%A: memref<4x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %zero = arith.constant 0.0 : f32
+  scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+    %v = vector.transfer_read %A[%i, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
+    vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+    %j = affine.apply affine_map<(d0) -> (d0 + 1)>(%i)
+    %v = vector.transfer_read %A[%j, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
+    vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_transfer_
diff erent_indices
+// CHECK:       scf.parallel
+// CHECK:       scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_with_subview(%A: memref<1x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %zero = arith.constant 0.0 : f32
+  %vec = arith.constant dense<1.0> : vector<4xf32>
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %sub = memref.subview %A[0, 0][1, 4][1, 1] : memref<1x4xf32> to memref<4xf32>
+    vector.transfer_write %vec, %sub[%c0] {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : vector<4xf32>, memref<4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+      %v = memref.load %A[%c0, %k] : memref<1x4xf32>
+      %n = arith.addf %v, %acc : f32
+      scf.yield %n : f32
+    }
+    memref.store %sum, %A[%c0, %c0] : memref<1x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_with_subview
+// CHECK:       scf.parallel
+// CHECK:         vector.transfer_write
+// CHECK:         scf.for
+// CHECK-NOT:   scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_transfer_nontrivial_subview(%A: memref<2x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %zero = arith.constant 0.0 : f32
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %v = vector.transfer_read %A[%c0, %i], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<2x4xf32>, vector<1xf32>
+    vector.transfer_write %v, %A[%c0, %i] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<1xf32>, memref<2x4xf32>
+    scf.reduce
+  }
+    %sub = memref.subview %A[1, 0][1, 4][1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: 4>>
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %v = vector.transfer_read %sub[%i], %zero {in_bounds = [true]} : memref<4xf32, strided<[1], offset: 4>>, vector<1xf32>
+    vector.transfer_write %v, %sub[%i] {in_bounds = [true]} : vector<1xf32>, memref<4xf32, strided<[1], offset: 4>>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_transfer_nontrivial_subview
+// CHECK:       scf.parallel
+// CHECK:       scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_transfer_
diff erent_masks(%A: memref<1x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %zero = arith.constant 0.0 : f32
+  %mask_true = vector.create_mask %c1 : vector<1xi1>
+  %mask_false = vector.create_mask %c0 : vector<1xi1>
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %v = vector.transfer_read %A[%c0, %i], %zero, %mask_true {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<1x4xf32>, vector<1xf32>
+    vector.transfer_write %v, %A[%c0, %i], %mask_true {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<1xf32>, memref<1x4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %v = vector.transfer_read %A[%c0, %i], %zero, %mask_false {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<1x4xf32>, vector<1xf32>
+    vector.transfer_write %v, %A[%c0, %i], %mask_false {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<1xf32>, memref<1x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_transfer_
diff erent_masks
+// CHECK:       scf.parallel
+// CHECK:       scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_subview_rank_reducing(%A: memref<1x4xf32>, %B: memref<1x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %zero = arith.constant 0.0 : f32
+  %vec = arith.constant dense<1.0> : vector<4xf32>
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %sub = memref.subview %A[%i, %c0][1, 4][1, 1] : memref<1x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+    vector.transfer_write %vec, %sub[%c0] {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+      %v = memref.load %A[%i, %k] : memref<1x4xf32>
+      %n = arith.addf %v, %acc : f32
+      scf.yield %n : f32
+    }
+    memref.store %sum, %B[%i, %c0] : memref<1x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_subview_rank_reducing
+// CHECK:       scf.parallel
+// CHECK:         vector.transfer_write
+// CHECK:         scf.for
+// CHECK-NOT:   scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_transfer_subview_offset(%A: memref<1x4xf32>, %B: memref<1x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %zero = arith.constant 0.0 : f32
+  %vec = arith.constant dense<1.0> : vector<4xf32>
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %sub = memref.subview %A[%i, %c0][1, 4][1, 1] : memref<1x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+    vector.transfer_write %vec, %sub[%c0] {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+      %v = memref.load %A[%i, %k] : memref<1x4xf32>
+      %n = arith.addf %v, %acc : f32
+      scf.yield %n : f32
+    }
+    // Read from an offset alias to prevent fusion.
+    %off = memref.subview %A[%i, %c1][1, 3][1, 1] : memref<1x4xf32> to memref<3xf32, strided<[1], offset: ?>>
+    %v0 = memref.load %off[%c0] : memref<3xf32, strided<[1], offset: ?>>
+    %res = arith.addf %sum, %v0 : f32
+    memref.store %res, %B[%i, %c0] : memref<1x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_transfer_subview_offset
+// CHECK:       scf.parallel
+// CHECK:       scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_no_subview(%A: memref<1x4xf32>, %B: memref<1x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %zero = arith.constant 0.0 : f32
+  %vec = arith.constant dense<2.0> : vector<4xf32>
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    vector.transfer_write %vec, %A[%c0, %i] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<1x4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+      %v = memref.load %A[%c0, %k] : memref<1x4xf32>
+      %n = arith.addf %v, %acc : f32
+      scf.yield %n : f32
+    }
+    memref.store %sum, %B[%c0, %c0] : memref<1x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_no_subview
+// CHECK:       vector.transfer_write
+// CHECK:       scf.for
+// CHECK-NOT:   scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_scalar_load_rank2(%A: memref<2x4xf32>, %B: memref<2x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %vec = arith.constant dense<1.0> : vector<2x4xf32>
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    vector.transfer_write %vec, %A[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : vector<2x4xf32>, memref<2x4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %v0 = memref.load %A[%c0, %c1] : memref<2x4xf32>
+    %v1 = memref.load %A[%c1, %c2] : memref<2x4xf32>
+    %sum = arith.addf %v0, %v1 : f32
+    memref.store %sum, %B[%c0, %c0] : memref<2x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_scalar_load_rank2
+// CHECK:       scf.parallel
+// CHECK:         vector.transfer_write
+// CHECK:         memref.load
+// CHECK:         memref.load
+// CHECK-NOT:   scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_scalar_load_loop_rank2(%A: memref<2x4xf32>, %B: memref<2x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %zero = arith.constant 0.0 : f32
+  %vec = arith.constant dense<2.0> : vector<2x4xf32>
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    vector.transfer_write %vec, %A[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : vector<2x4xf32>, memref<2x4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+      %v = memref.load %A[%c1, %k] : memref<2x4xf32>
+      %n = arith.addf %v, %acc : f32
+      scf.yield %n : f32
+    }
+    memref.store %sum, %B[%c0, %c0] : memref<2x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_scalar_load_loop_rank2
+// CHECK:       scf.parallel
+// CHECK:         vector.transfer_write
+// CHECK:         scf.for
+// CHECK-NOT:   scf.parallel
+
+// -----
+
+func.func @fuse_vector_store_scalar_load_rank2(%A: memref<2x4xf32>, %B: memref<2x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %vec = arith.constant dense<3.0> : vector<2x4xf32>
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    vector.store %vec, %A[%c0, %c0] : memref<2x4xf32>, vector<2x4xf32>
+    scf.reduce
+  }
+  scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+    %v0 = memref.load %A[%c1, %c2] : memref<2x4xf32>
+    %v1 = memref.load %A[%c0, %c3] : memref<2x4xf32>
+    %sum = arith.addf %v0, %v1 : f32
+    memref.store %sum, %B[%c0, %c0] : memref<2x4xf32>
+    scf.reduce
+  }
+  return
+}
+// CHECK-LABEL: func @fuse_vector_store_scalar_load_rank2
+// CHECK:       scf.parallel
+// CHECK:         vector.store
+// CHECK:         memref.load
+// CHECK:         memref.load
+// CHECK-NOT:   scf.parallel
+
+// -----
+
 func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list