[Mlir-commits] [mlir] dc53715 - [MLIR][Affine] Add utility to check if the slice is valid

Uday Bondhugula llvmlistbot at llvm.org
Thu Apr 1 02:22:49 PDT 2021


Author: Vinayaka Bandishti
Date: 2021-04-01T14:52:22+05:30
New Revision: dc537158d5372894b539b7cf90ace3cfe911a520

URL: https://github.com/llvm/llvm-project/commit/dc537158d5372894b539b7cf90ace3cfe911a520
DIFF: https://github.com/llvm/llvm-project/commit/dc537158d5372894b539b7cf90ace3cfe911a520.diff

LOG: [MLIR][Affine] Add utility to check if the slice is valid

Fixes a bug in affine fusion pipeline where an incorrect slice is computed.
After the slice computation is done, original domain of the the source is
compared with the new domain that will result if the fusion succeeds. If the
new domain must be a subset of the original domain for the slice to be
valid. If the slice computed is incorrect, fusion based on such a slice is
avoided.

Relevant test cases are added/edited.

Fixes https://bugs.llvm.org/show_bug.cgi?id=49203

Differential Revision: https://reviews.llvm.org/D98239

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Utils.h
    mlir/include/mlir/Transforms/LoopFusionUtils.h
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
    mlir/test/Transforms/loop-fusion-slice-computation.mlir
    mlir/test/Transforms/loop-fusion.mlir
    mlir/test/lib/Transforms/TestLoopFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index ee6f8095f25ef..ccedd17e0e6c5 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -54,6 +54,18 @@ unsigned getNestingDepth(Operation *op);
 void getSequentialLoops(AffineForOp forOp,
                         llvm::SmallDenseSet<Value, 8> *sequentialLoops);
 
+/// Enumerates 
diff erent result statuses of slice computation by
+/// `computeSliceUnion`
+// TODO: Identify and add 
diff erent kinds of failures during slice computation.
+struct SliceComputationResult {
+  enum ResultEnum {
+    Success,
+    IncorrectSliceFailure, // Slice is computed, but it is incorrect.
+    GenericFailure,        // Unable to compute src loop computation slice.
+  } value;
+  SliceComputationResult(ResultEnum v) : value(v) {}
+};
+
 /// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their
 /// associated operands for a set of loops within a loop nest (typically the
 /// set of loops surrounding a store operation). Loop bound AffineMaps which
@@ -80,6 +92,12 @@ struct ComputationSliceState {
   // Returns failure if we cannot add loop bounds because of unsupported cases.
   LogicalResult getAsConstraints(FlatAffineConstraints *cst);
 
+  /// Adds to 'cst' constraints which represent the original loop bounds on
+  /// 'ivs' in 'this'. This corresponds to the original domain of the loop nest
+  /// from which the slice is being computed. Returns failure if we cannot add
+  /// loop bounds because of unsupported cases.
+  LogicalResult getSourceAsConstraints(FlatAffineConstraints &cst);
+
   // Clears all bounds and operands in slice state.
   void clearBounds();
 
@@ -93,6 +111,22 @@ struct ComputationSliceState {
   // information hasn't changed.
   Optional<bool> isMaximal() const;
 
+  /// Checks the validity of the slice computed. This is done using the
+  /// following steps:
+  /// 1. Get the new domain of the slice that would be created if fusion
+  /// succeeds. This domain gets constructed with source loop IVS and
+  /// destination loop IVS as dimensions.
+  /// 2. Project out the dimensions of the destination loop from the domain
+  /// above calculated in step(1) to express it purely in terms of the source
+  /// loop IVs.
+  /// 3. Calculate a set 
diff erence between the iterations of the new domain and
+  /// the original domain of the source loop.
+  /// If this 
diff erence is empty, the slice is declared to be valid. Otherwise,
+  /// return false as it implies that the effective fusion results in at least
+  /// one iteration of the slice that was not originally in the source's domain.
+  /// If the validity cannot be determined, returns llvm:None.
+  Optional<bool> isSliceValid();
+
   void dump() const;
 
 private:
@@ -151,21 +185,21 @@ void getComputationSliceState(Operation *depSourceOp, Operation *depSinkOp,
                               ComputationSliceState *sliceState);
 
 /// Computes in 'sliceUnion' the union of all slice bounds computed at
-/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
-/// The parameter 'numCommonLoops' is the number of loops common to the
-/// operations in 'opsA' and 'opsB'.
-/// If 'isBackwardSlice' is true, computes slice bounds for loop nest
-/// surrounding ops in 'opsA', as a function of IVs and symbols of loop nest
-/// surrounding ops in 'opsB' at 'loopDepth'.
-/// If 'isBackwardSlice' is false, computes slice bounds for loop nest
-/// surrounding ops in 'opsB', as a function of IVs and symbols of loop nest
-/// surrounding ops in 'opsA' at 'loopDepth'.
-/// Returns 'success' if union was computed, 'failure' otherwise.
+/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
+/// then verifies if it is valid. The parameter 'numCommonLoops' is the number
+/// of loops common to the operations in 'opsA' and 'opsB'. If 'isBackwardSlice'
+/// is true, computes slice bounds for loop nest surrounding ops in 'opsA', as a
+/// function of IVs and symbols of loop nest surrounding ops in 'opsB' at
+/// 'loopDepth'. If 'isBackwardSlice' is false, computes slice bounds for loop
+/// nest surrounding ops in 'opsB', as a function of IVs and symbols of loop
+/// nest surrounding ops in 'opsA' at 'loopDepth'. Returns
+/// 'SliceComputationResult::Success' if union was computed correctly, an
+/// appropriate 'failure' otherwise.
 // TODO: Change this API to take 'forOpA'/'forOpB'.
-LogicalResult computeSliceUnion(ArrayRef<Operation *> opsA,
-                                ArrayRef<Operation *> opsB, unsigned loopDepth,
-                                unsigned numCommonLoops, bool isBackwardSlice,
-                                ComputationSliceState *sliceUnion);
+SliceComputationResult
+computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
+                  unsigned loopDepth, unsigned numCommonLoops,
+                  bool isBackwardSlice, ComputationSliceState *sliceUnion);
 
 /// Creates a clone of the computation contained in the loop nest surrounding
 /// 'srcOpInst', slices the iteration space of src loop based on slice bounds

diff  --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h
index 10d6b83d022ff..b66d38fae3b84 100644
--- a/mlir/include/mlir/Transforms/LoopFusionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -35,6 +35,7 @@ struct FusionResult {
     FailBlockDependence,  // Fusion would violate another dependence in block.
     FailFusionDependence, // Fusion would reverse dependences between loops.
     FailComputationSlice, // Unable to compute src loop computation slice.
+    FailIncorrectSlice,   // Slice is computed, but it is incorrect.
   } value;
   FusionResult(ResultEnum v) : value(v) {}
 };

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index e81d92ae4a009..30b00bd9b6b61 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -2128,13 +2128,22 @@ LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
       continue;
     }
 
-    if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
-                                             /*lower=*/true)))
-      return failure();
-
-    if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
-                                             /*lower=*/false)))
-      return failure();
+    // If lower or upper bound maps are null or provide no results, it implies
+    // that the source loop was not at all sliced, and the entire loop will be a
+    // part of the slice.
+    if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
+        ubMap.getNumResults() != 0) {
+      if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
+                                      /*lower=*/true)))
+        return failure();
+      if (failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
+                                      /*lower=*/false)))
+        return failure();
+    } else {
+      auto loop = getForInductionVarOwner(values[i]);
+      if (failed(this->addAffineForOpDomain(loop)))
+        return failure();
+    }
   }
   return success();
 }

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index c3c63f3a7f915..a4b8ccfc7ad14 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -61,6 +61,21 @@ void mlir::getEnclosingAffineForAndIfOps(Operation &op,
   std::reverse(ops->begin(), ops->end());
 }
 
+// Populates 'cst' with FlatAffineConstraints which represent original domain of
+// the loop bounds that define 'ivs'.
+LogicalResult
+ComputationSliceState::getSourceAsConstraints(FlatAffineConstraints &cst) {
+  assert(!ivs.empty() && "Cannot have a slice without its IVs");
+  cst.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs);
+  for (Value iv : ivs) {
+    AffineForOp loop = getForInductionVarOwner(iv);
+    assert(loop && "Expected affine for");
+    if (failed(cst.addAffineForOpDomain(loop)))
+      return failure();
+  }
+  return success();
+}
+
 // Populates 'cst' with FlatAffineConstraints which represent slice bounds.
 LogicalResult
 ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
@@ -75,9 +90,10 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
   values.append(lbOperands[0].begin(), lbOperands[0].end());
   cst->reset(numDims, numSymbols, 0, values);
 
-  // Add loop bound constraints for values which are loop IVs and equality
-  // constraints for symbols which are constants.
-  for (const auto &value : values) {
+  // Add loop bound constraints for values which are loop IVs of the destination
+  // of fusion and equality constraints for symbols which are constants.
+  for (unsigned i = numDims, end = values.size(); i < end; ++i) {
+    Value value = values[i];
     assert(cst->containsId(value) && "value expected to be present");
     if (isValidSymbol(value)) {
       // Check if the symbol is a constant.
@@ -196,6 +212,76 @@ Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
   return true;
 }
 
+/// Returns true if it is deterministically verified that the original iteration
+/// space of the slice is contained within the new iteration space that is
+/// created after fusing 'this' slice into its destination.
+Optional<bool> ComputationSliceState::isSliceValid() {
+  // Fast check to determine if the slice is valid. If the following conditions
+  // are verified to be true, slice is declared valid by the fast check:
+  // 1. Each slice loop is a single iteration loop bound in terms of a single
+  //    destination loop IV.
+  // 2. Loop bounds of the destination loop IV (from above) and those of the
+  //    source loop IV are exactly the same.
+  // If the fast check is inconclusive or false, we proceed with a more
+  // expensive analysis.
+  // TODO: Store the result of the fast check, as it might be used again in
+  // `canRemoveSrcNodeAfterFusion`.
+  Optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
+  if (isValidFastCheck.hasValue() && isValidFastCheck.getValue())
+    return true;
+
+  // Create constraints for the source loop nest using which slice is computed.
+  FlatAffineConstraints srcConstraints;
+  // TODO: Store the source's domain to avoid computation at each depth.
+  if (failed(getSourceAsConstraints(srcConstraints))) {
+    LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
+    return llvm::None;
+  }
+  // As the set 
diff erence utility currently cannot handle symbols in its
+  // operands, validity of the slice cannot be determined.
+  if (srcConstraints.getNumSymbolIds() > 0) {
+    LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
+    return llvm::None;
+  }
+  // TODO: Handle local ids in the source domains while using the 'projectOut'
+  // utility below. Currently, aligning is not done assuming that there will be
+  // no local ids in the source domain.
+  if (srcConstraints.getNumLocalIds() != 0) {
+    LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
+    return llvm::None;
+  }
+
+  // Create constraints for the slice loop nest that would be created if the
+  // fusion succeeds.
+  FlatAffineConstraints sliceConstraints;
+  if (failed(getAsConstraints(&sliceConstraints))) {
+    LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
+    return llvm::None;
+  }
+
+  // Projecting out every dimension other than the 'ivs' to express slice's
+  // domain completely in terms of source's IVs.
+  sliceConstraints.projectOut(ivs.size(),
+                              sliceConstraints.getNumIds() - ivs.size());
+
+  LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
+  LLVM_DEBUG(srcConstraints.dump());
+  LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
+                             "(expressed in terms of its source's IVs):\n");
+  LLVM_DEBUG(sliceConstraints.dump());
+
+  // TODO: Store 'srcSet' to avoid recalculating for each depth.
+  PresburgerSet srcSet(srcConstraints);
+  PresburgerSet sliceSet(sliceConstraints);
+  PresburgerSet 
diff Set = sliceSet.subtract(srcSet);
+
+  if (!
diff Set.isIntegerEmpty()) {
+    LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
+    return false;
+  }
+  return true;
+}
+
 /// Returns true if the computation slice encloses all the iterations of the
 /// sliced loop nest. Returns false if it does not. Returns llvm::None if it
 /// cannot determine if the slice is maximal or not.
@@ -715,14 +801,14 @@ unsigned mlir::getInnermostCommonLoopDepth(
 }
 
 /// Computes in 'sliceUnion' the union of all slice bounds computed at
-/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
-/// Returns 'Success' if union was computed, 'failure' otherwise.
-LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
-                                      ArrayRef<Operation *> opsB,
-                                      unsigned loopDepth,
-                                      unsigned numCommonLoops,
-                                      bool isBackwardSlice,
-                                      ComputationSliceState *sliceUnion) {
+/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
+/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
+/// union was computed correctly, an appropriate failure otherwise.
+SliceComputationResult
+mlir::computeSliceUnion(ArrayRef<Operation *> opsA, ArrayRef<Operation *> opsB,
+                        unsigned loopDepth, unsigned numCommonLoops,
+                        bool isBackwardSlice,
+                        ComputationSliceState *sliceUnion) {
   // Compute the union of slice bounds between all pairs in 'opsA' and
   // 'opsB' in 'sliceUnionCst'.
   FlatAffineConstraints sliceUnionCst;
@@ -738,7 +824,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
       if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) ||
           (isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) {
         LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
-        return failure();
+        return SliceComputationResult::GenericFailure;
       }
 
       bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
@@ -751,7 +837,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
           /*allowRAR=*/readReadAccesses);
       if (result.value == DependenceResult::Failure) {
         LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
-        return failure();
+        return SliceComputationResult::GenericFailure;
       }
       if (result.value == DependenceResult::NoDependence)
         continue;
@@ -768,7 +854,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
         if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
           LLVM_DEBUG(llvm::dbgs()
                      << "Unable to compute slice bound constraints\n");
-          return failure();
+          return SliceComputationResult::GenericFailure;
         }
         assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
         continue;
@@ -779,7 +865,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
       if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
         LLVM_DEBUG(llvm::dbgs()
                    << "Unable to compute slice bound constraints\n");
-        return failure();
+        return SliceComputationResult::GenericFailure;
       }
 
       // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
@@ -802,9 +888,9 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
         // to unionBoundingBox below expects constraints for each Loop IV, even
         // if they are the unsliced full loop bounds added here.
         if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
-          return failure();
+          return SliceComputationResult::GenericFailure;
         if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
-          return failure();
+          return SliceComputationResult::GenericFailure;
       }
       // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
       if (sliceUnionCst.getNumLocalIds() > 0 ||
@@ -812,14 +898,14 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
           failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
         LLVM_DEBUG(llvm::dbgs()
                    << "Unable to compute union bounding box of slice bounds\n");
-        return failure();
+        return SliceComputationResult::GenericFailure;
       }
     }
   }
 
   // Empty union.
   if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
-    return failure();
+    return SliceComputationResult::GenericFailure;
 
   // Gather loops surrounding ops from loop nest where slice will be inserted.
   SmallVector<Operation *, 4> ops;
@@ -831,7 +917,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
       getInnermostCommonLoopDepth(ops, &surroundingLoops);
   if (loopDepth > innermostCommonLoopDepth) {
     LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
-    return failure();
+    return SliceComputationResult::GenericFailure;
   }
 
   // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
@@ -868,7 +954,18 @@ LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
   // canonicalization.
   sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
   sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
-  return success();
+
+  // Check if the slice computed is valid. Return success only if it is verified
+  // that the slice is valid, otherwise return appropriate failure status.
+  Optional<bool> isSliceValid = sliceUnion->isSliceValid();
+  if (!isSliceValid.hasValue()) {
+    LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
+    return SliceComputationResult::GenericFailure;
+  }
+  if (!isSliceValid.getValue())
+    return SliceComputationResult::IncorrectSliceFailure;
+
+  return SliceComputationResult::Success;
 }
 
 const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";

diff  --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index 61dad71c20d8b..511f6a572f05e 100644
--- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -347,12 +347,18 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
 
   // Compute union of computation slices computed between all pairs of ops
   // from 'forOpA' and 'forOpB'.
-  if (failed(mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth,
-                                     numCommonLoops, isSrcForOpBeforeDstForOp,
-                                     srcSlice))) {
+  SliceComputationResult sliceComputationResult =
+      mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
+                              isSrcForOpBeforeDstForOp, srcSlice);
+  if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
     LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
     return FusionResult::FailPrecondition;
   }
+  if (sliceComputationResult.value ==
+      SliceComputationResult::IncorrectSliceFailure) {
+    LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
+    return FusionResult::FailIncorrectSlice;
+  }
 
   return FusionResult::Success;
 }
@@ -400,7 +406,7 @@ bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
     auto *parentForOp = forOp->getParentOp();
     if (!llvm::isa<FuncOp>(parentForOp)) {
       if (!isa<AffineForOp>(parentForOp)) {
-        LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
+        LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
         return WalkResult::interrupt();
       }
       // Add mapping to 'forOp' from its parent AffineForOp.
@@ -421,7 +427,7 @@ bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
     Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
     if (!maybeConstTripCount.hasValue()) {
       // Currently only constant trip count loop nests are supported.
-      LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
+      LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
       return WalkResult::interrupt();
     }
 
@@ -519,7 +525,11 @@ static bool buildSliceTripCountMap(
     auto *op = forOp.getOperation();
     AffineMap lbMap = slice.lbs[i];
     AffineMap ubMap = slice.ubs[i];
-    if (lbMap == AffineMap() || ubMap == AffineMap()) {
+    // If lower or upper bound maps are null or provide no results, it implies
+    // that source loop was not at all sliced, and the entire loop will be a
+    // part of the slice.
+    if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
+        ubMap.getNumResults() == 0) {
       // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
       if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
         (*tripCountMap)[op] =

diff  --git a/mlir/test/Transforms/loop-fusion-slice-computation.mlir b/mlir/test/Transforms/loop-fusion-slice-computation.mlir
index 7a29ebab932df..75b6787572cfd 100644
--- a/mlir/test/Transforms/loop-fusion-slice-computation.mlir
+++ b/mlir/test/Transforms/loop-fusion-slice-computation.mlir
@@ -7,7 +7,7 @@ func @slice_depth1_loop_nest() {
   %0 = memref.alloc() : memref<100xf32>
   %cst = constant 7.000000e+00 : f32
   affine.for %i0 = 0 to 16 {
-    // expected-remark at -1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}}
+    // expected-remark at -1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}}
     affine.store %cst, %0[%i0] : memref<100xf32>
   }
   affine.for %i1 = 0 to 5 {
@@ -19,6 +19,23 @@ func @slice_depth1_loop_nest() {
 
 // -----
 
+// CHECK-LABEL: func @forward_slice_slice_depth1_loop_nest() {
+func @forward_slice_slice_depth1_loop_nest() {
+  %0 = memref.alloc() : memref<100xf32>
+  %cst = constant 7.000000e+00 : f32
+  affine.for %i0 = 0 to 5 {
+    // expected-remark at -1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}}
+    affine.store %cst, %0[%i0] : memref<100xf32>
+  }
+  affine.for %i1 = 0 to 16 {
+    // expected-remark at -1 {{Incorrect slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}}
+    %1 = affine.load %0[%i1] : memref<100xf32>
+  }
+  return
+}
+
+// -----
+
 // Loop %i0 writes to locations [2, 17] and loop %i0 reads from locations [3, 6]
 // Slice loop bounds should be adjusted such that the load/store are for the
 // same location.
@@ -27,7 +44,7 @@ func @slice_depth1_loop_nest_with_offsets() {
   %0 = memref.alloc() : memref<100xf32>
   %cst = constant 7.000000e+00 : f32
   affine.for %i0 = 0 to 16 {
-    // expected-remark at -1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0 + 3), (d0) -> (d0 + 4)] )}}
+    // expected-remark at -1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0 + 3), (d0) -> (d0 + 4)] )}}
     %a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0)
     affine.store %cst, %0[%a0] : memref<100xf32>
   }
@@ -48,8 +65,8 @@ func @slice_depth2_loop_nest() {
   %0 = memref.alloc() : memref<100x100xf32>
   %cst = constant 7.000000e+00 : f32
   affine.for %i0 = 0 to 16 {
-    // expected-remark at -1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
-    // expected-remark at -2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}}
+    // expected-remark at -1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
+    // expected-remark at -2 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}}
     affine.for %i1 = 0 to 16 {
       affine.store %cst, %0[%i0, %i1] : memref<100x100xf32>
     }
@@ -75,8 +92,8 @@ func @slice_depth2_loop_nest_two_loads() {
   %c0 = constant 0 : index
   %cst = constant 7.000000e+00 : f32
   affine.for %i0 = 0 to 16 {
-    // expected-remark at -1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
-    // expected-remark at -2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (0), (d0, d1) -> (8)] )}}
+    // expected-remark at -1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
+    // expected-remark at -2 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (0), (d0, d1) -> (16)] )}}
     affine.for %i1 = 0 to 16 {
       affine.store %cst, %0[%i0, %i1] : memref<100x100xf32>
     }
@@ -103,7 +120,7 @@ func @slice_depth2_loop_nest_two_stores() {
   %c0 = constant 0 : index
   %cst = constant 7.000000e+00 : f32
   affine.for %i0 = 0 to 16 {
-    // expected-remark at -1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
+    // expected-remark at -1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}}
     affine.for %i1 = 0 to 16 {
       affine.store %cst, %0[%i0, %i1] : memref<100x100xf32>
     }
@@ -128,8 +145,8 @@ func @slice_loop_nest_with_smaller_outer_trip_count() {
   %c0 = constant 0 : index
   %cst = constant 7.000000e+00 : f32
   affine.for %i0 = 0 to 16 {
-    // expected-remark at -1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}}
-    // expected-remark at -2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}}
+    // expected-remark at -1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}}
+    // expected-remark at -2 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}}
     affine.for %i1 = 0 to 16 {
       affine.store %cst, %0[%i0, %i1] : memref<100x100xf32>
     }

diff  --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index 74f60e2e1e73f..f3fcae25b0fa2 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -3068,3 +3068,50 @@ func @call_op_does_not_prevent_fusion(%arg0: memref<16xf32>){
 // CHECK-LABEL: func @call_op_does_not_prevent_fusion
 // CHECK:         affine.for
 // CHECK-NOT:     affine.for
+
+// -----
+
+// Fusion is avoided when the slice computed is invalid. Comments below describe
+// incorrect backward slice computation. Similar logic applies for forward slice
+// as well.
+func @no_fusion_cannot_compute_valid_slice() {
+  %A = memref.alloc() : memref<5xf32>
+  %B = memref.alloc() : memref<6xf32>
+  %C = memref.alloc() : memref<5xf32>
+  %cst = constant 0. : f32
+
+  affine.for %arg0 = 0 to 5 {
+    %a = affine.load %A[%arg0] : memref<5xf32>
+    affine.store %a, %B[%arg0 + 1] : memref<6xf32>
+  }
+
+  affine.for %arg0 = 0 to 5 {
+    // Backward slice computed will be:
+    // slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0)
+    // loop bounds: [(d0) -> (d0 - 1), (d0) -> (d0)] )
+
+    // Resulting fusion would be as below. It is easy to note the out-of-bounds
+    // access by 'affine.load'.
+
+    // #map0 = affine_map<(d0) -> (d0 - 1)>
+    // #map1 = affine_map<(d0) -> (d0)>
+    // affine.for %arg1 = #map0(%arg0) to #map1(%arg0) {
+    //   %5 = affine.load %1[%arg1] : memref<5xf32>
+    //   ...
+    //   ...
+    // }
+
+    %a = affine.load %B[%arg0] : memref<6xf32>
+    %b = mulf %a, %cst : f32
+    affine.store %b, %C[%arg0] : memref<5xf32>
+  }
+  return
+}
+// CHECK-LABEL: func @no_fusion_cannot_compute_valid_slice
+// CHECK:         affine.for
+// CHECK-NEXT:      affine.load
+// CHECK-NEXT:      affine.store
+// CHECK:         affine.for
+// CHECK-NEXT:      affine.load
+// CHECK-NEXT:      mulf
+// CHECK-NEXT:      affine.store

diff  --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp
index b28e52851a608..ed439122e5a19 100644
--- a/mlir/test/lib/Transforms/TestLoopFusion.cpp
+++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp
@@ -99,10 +99,11 @@ static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
   return os.str();
 }
 
-// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
-// in range ['loopDepth' + 1, 'maxLoopDepth'].
-// Emits a string representation of the slice union as a remark on 'loops[j]'.
-// Returns false as IR is not transformed.
+/// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
+/// in range ['loopDepth' + 1, 'maxLoopDepth'].
+/// Emits a string representation of the slice union as a remark on 'loops[j]'
+/// and marks this as incorrect slice if the slice is invalid. Returns false as
+/// IR is not transformed.
 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
                                  unsigned i, unsigned j, unsigned loopDepth,
                                  unsigned maxLoopDepth) {
@@ -113,6 +114,10 @@ static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
       forOpB->emitRemark("slice (")
           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
           << " : " << getSliceStr(sliceUnion) << ")";
+    } else if (result.value == FusionResult::FailIncorrectSlice) {
+      forOpB->emitRemark("Incorrect slice (")
+          << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
+          << " : " << getSliceStr(sliceUnion) << ")";
     }
   }
   return false;


        


More information about the Mlir-commits mailing list