[Mlir-commits] [mlir] de9caf2 - [mlir][Interfaces] Add `promoteIfSingleIteration` to `LoopLikeOpInterface`

Matthias Springer llvmlistbot at llvm.org
Mon Jul 3 09:03:58 PDT 2023


Author: Matthias Springer
Date: 2023-07-03T18:03:35+02:00
New Revision: de9caf2f3244849cf0fe29bbf7cc67bae9035329

URL: https://github.com/llvm/llvm-project/commit/de9caf2f3244849cf0fe29bbf7cc67bae9035329
DIFF: https://github.com/llvm/llvm-project/commit/de9caf2f3244849cf0fe29bbf7cc67bae9035329.diff

LOG: [mlir][Interfaces] Add `promoteIfSingleIteration` to `LoopLikeOpInterface`

There are existing implementations for `scf.for`, `scf.forall` and `affine.for`. This revision adds an interface method to the `LoopLikeOpInterface`.

* `scf.forall` now implements the `LoopLikeOpInterface`.
* The implementations of `scf.for` and `scf.forall` become interface method implementations. `affine.for` remains as is for the moment. (The implementation of `promoteIfSingleIteration` depepends on helper functions from `MLIRAffineAnalysis`, which cannot be used from `MLIRAffineDialect`, where the interface is currently implemented.)
* More efficient implementations of `promoteIfSingleIteration`. In particular, the `scf.forall` operation now inlines operations instead of cloning them. This also preserves handles when used from the transform dialect.

Differential Revision: https://reviews.llvm.org/D154343

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCF.h
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Dialect/SCF/Utils/Utils.h
    mlir/include/mlir/Interfaces/LoopLikeInterface.h
    mlir/include/mlir/Interfaces/LoopLikeInterface.td
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index cb399b78c406d9..915ab3016b688e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -62,13 +62,8 @@ ForallOp getForallOpThreadIndexOwner(Value val);
 // TODO: Consider moving this functionality to RegionBranchOpInterface.
 bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
 
-/// Promotes the loop body of a scf::ForallOp to its containing block if the
-/// loop was known to have a single iteration.
-LogicalResult promoteIfSingleIteration(PatternRewriter &rewriter,
-                                       scf::ForallOp forallOp);
-
 /// Promotes the loop body of a scf::ForallOp to its containing block.
-void promote(PatternRewriter &rewriter, scf::ForallOp forallOp);
+void promote(RewriterBase &rewriter, scf::ForallOp forallOp);
 
 /// An owning vector of values, handy to return from functions.
 using ValueVector = SmallVector<Value>;

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 1eed4cec1e4207..58b720c6e39637 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -121,7 +121,7 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-        "getSingleUpperBound"]>,
+        "getSingleUpperBound", "promoteIfSingleIteration"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -361,6 +361,8 @@ def ForOp : SCF_Op<"for",
 def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
+       DeclareOpInterfaceMethods<LoopLikeOpInterface,
+          ["promoteIfSingleIteration"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,

diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 0b251a6fec9c18..2e299fd357f282 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -141,10 +141,6 @@ LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops);
 void collapseParallelLoops(scf::ParallelOp loops,
                            ArrayRef<std::vector<unsigned>> combinedDimensions);
 
-/// Promotes the loop body of a scf::ForOp to its containing block if the loop
-/// was known to have a single iteration.
-LogicalResult promoteIfSingleIteration(scf::ForOp forOp);
-
 /// Unrolls this for operation by the specified unroll factor. Returns failure
 /// if the loop cannot be unrolled either due to restrictions or due to invalid
 /// unroll factors. Requires positive loop bounds and step. If specified,

diff  --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 48399ad0d53a84..9d81a61fac8856 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -15,6 +15,10 @@
 
 #include "mlir/IR/OpDefinition.h"
 
+namespace mlir {
+class RewriterBase;
+} // namespace mlir
+
 /// Include the generated interface declarations.
 #include "mlir/Interfaces/LoopLikeInterface.h.inc"
 

diff  --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index df7315c4ecf045..a88d7e45e5d67f 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -48,6 +48,19 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         op->moveBefore($_op);
       }]
     >,
+    InterfaceMethod<[{
+        Promotes the loop body to its containing block if the loop is known to
+        have a single iteration. Returns "success" if the promotion was
+        successful.
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"promoteIfSingleIteration",
+      /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
     InterfaceMethod<[{
         If there is a single induction variable return it, otherwise return
         std::nullopt.

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index db69195a7c7046..4f805d692637ea 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -385,6 +385,35 @@ std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
   return OpFoldResult(getUpperBound());
 }
 
+/// Promotes the loop body of a forOp to its containing block if the forOp
+/// it can be determined that the loop has a single iteration.
+LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
+  std::optional<int64_t> tripCount =
+      constantTripCount(getLowerBound(), getUpperBound(), getStep());
+  if (!tripCount.has_value() || tripCount != 1)
+    return failure();
+
+  // Replace all results with the yielded values.
+  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
+  rewriter.replaceAllUsesWith(getResults(), yieldOp.getOperands());
+
+  // Replace block arguments with lower bound (replacement for IV) and
+  // iter_args.
+  SmallVector<Value> bbArgReplacements;
+  bbArgReplacements.push_back(getLowerBound());
+  bbArgReplacements.append(getIterOperands().begin(), getIterOperands().end());
+
+  // Move the loop body operations to the loop's containing block.
+  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
+                             getOperation()->getIterator(), bbArgReplacements);
+
+  // Erase the old terminator and the loop.
+  rewriter.eraseOp(yieldOp);
+  rewriter.eraseOp(*this);
+
+  return success();
+}
+
 /// Prints the initialization list in the form of
 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
 /// where 'inner' values are assumed to be region arguments and 'outer' values
@@ -536,59 +565,64 @@ void ForOp::getSuccessorRegions(std::optional<unsigned> index,
   regions.push_back(RegionSuccessor(getResults()));
 }
 
+Region &ForallOp::getLoopBody() { return getRegion(); }
+
 /// Promotes the loop body of a forallOp to its containing block if it can be
 /// determined that the loop has a single iteration.
-LogicalResult mlir::scf::promoteIfSingleIteration(PatternRewriter &rewriter,
-                                                  scf::ForallOp forallOp) {
+LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
   for (auto [lb, ub, step] :
-       llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
-                 forallOp.getMixedStep())) {
+       llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
     auto tripCount = constantTripCount(lb, ub, step);
     if (!tripCount.has_value() || *tripCount != 1)
       return failure();
   }
 
-  promote(rewriter, forallOp);
+  promote(rewriter, *this);
   return success();
 }
 
 /// Promotes the loop body of a scf::ForallOp to its containing block.
-void mlir::scf::promote(PatternRewriter &rewriter, scf::ForallOp forallOp) {
-  IRMapping mapping;
-  mapping.map(forallOp.getInductionVars(), forallOp.getLowerBound(rewriter));
-  mapping.map(forallOp.getOutputBlockArguments(), forallOp.getOutputs());
-  for (auto &bodyOp : forallOp.getBody()->without_terminator())
-    rewriter.clone(bodyOp, mapping);
+void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
+  OpBuilder::InsertionGuard g(rewriter);
+  scf::InParallelOp terminator = forallOp.getTerminator();
+
+  // Replace block arguments with lower bounds (replacements for IVs) and
+  // outputs.
+  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
+  bbArgReplacements.append(forallOp.getOutputs().begin(),
+                           forallOp.getOutputs().end());
 
+  // Move the loop body operations to the loop's containing block.
+  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
+                             forallOp->getIterator(), bbArgReplacements);
+
+  // Replace the terminator with tensor.insert_slice ops.
+  rewriter.setInsertionPointAfter(forallOp);
   SmallVector<Value> results;
   results.reserve(forallOp.getResults().size());
-  scf::InParallelOp terminator = forallOp.getTerminator();
   for (auto &yieldingOp : terminator.getYieldingOps()) {
     auto parallelInsertSliceOp =
         cast<tensor::ParallelInsertSliceOp>(yieldingOp);
 
     Value dst = parallelInsertSliceOp.getDest();
     Value src = parallelInsertSliceOp.getSource();
-
-    auto getMappedValues = [&](ValueRange values) {
-      return llvm::to_vector(llvm::map_range(
-          values, [&](Value value) { return mapping.lookupOrDefault(value); }));
-    };
-
-    Value srcVal = mapping.lookupOrDefault(src);
-    if (llvm::isa<TensorType>(srcVal.getType())) {
+    if (llvm::isa<TensorType>(src.getType())) {
       results.push_back(rewriter.create<tensor::InsertSliceOp>(
-          forallOp.getLoc(), dst.getType(), srcVal,
-          mapping.lookupOrDefault(dst),
-          getMappedValues(parallelInsertSliceOp.getOffsets()),
-          getMappedValues(parallelInsertSliceOp.getSizes()),
-          getMappedValues(parallelInsertSliceOp.getStrides()),
+          forallOp.getLoc(), dst.getType(), src, dst,
+          parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
+          parallelInsertSliceOp.getStrides(),
           parallelInsertSliceOp.getStaticOffsets(),
           parallelInsertSliceOp.getStaticSizes(),
           parallelInsertSliceOp.getStaticStrides()));
+    } else {
+      llvm_unreachable("unsupported terminator");
     }
   }
-  rewriter.replaceOp(forallOp, results);
+  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
+
+  // Erase the old terminator and the loop.
+  rewriter.eraseOp(terminator);
+  rewriter.eraseOp(forallOp);
 }
 
 LoopNest mlir::scf::buildLoopNest(

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 34225fd133b2e3..96e4f890a9c12d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -362,44 +362,6 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return builder.create<arith::DivUIOp>(loc, sum, divisor);
 }
 
-/// Helper to replace uses of loop carried values (iter_args) and loop
-/// yield values while promoting single iteration scf.for ops.
-static void replaceIterArgsAndYieldResults(scf::ForOp forOp) {
-  // Replace uses of iter arguments with iter operands (initial values).
-  auto iterOperands = forOp.getIterOperands();
-  auto iterArgs = forOp.getRegionIterArgs();
-  for (auto e : llvm::zip(iterOperands, iterArgs))
-    std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
-
-  // Replace uses of loop results with the values yielded by the loop.
-  auto outerResults = forOp.getResults();
-  auto innerResults = forOp.getBody()->getTerminator()->getOperands();
-  for (auto e : llvm::zip(outerResults, innerResults))
-    std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
-}
-
-/// Promotes the loop body of a forOp to its containing block if the forOp
-/// it can be determined that the loop has a single iteration.
-LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
-  std::optional<int64_t> tripCount = constantTripCount(
-      forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
-  if (!tripCount.has_value() || tripCount != 1)
-    return failure();
-  auto iv = forOp.getInductionVar();
-  iv.replaceAllUsesWith(forOp.getLowerBound());
-
-  replaceIterArgsAndYieldResults(forOp);
-
-  // Move the loop body operations, except for its terminator, to the loop's
-  // containing block.
-  auto *parentBlock = forOp->getBlock();
-  forOp.getBody()->getTerminator()->erase();
-  parentBlock->getOperations().splice(Block::iterator(forOp),
-                                      forOp.getBody()->getOperations());
-  forOp.erase();
-  return success();
-}
-
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -469,6 +431,7 @@ LogicalResult mlir::loopUnrollByFactor(
   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
   OpBuilder boundsBuilder(forOp);
+  IRRewriter rewriter(forOp.getContext());
   auto loc = forOp.getLoc();
   Value step = forOp.getStep();
   Value upperBoundUnrolled;
@@ -488,7 +451,7 @@ LogicalResult mlir::loopUnrollByFactor(
     int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
 
     if (unrollFactor == 1) {
-      if (tripCount == 1 && failed(promoteIfSingleIteration(forOp)))
+      if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
         return failure();
       return success();
     }
@@ -553,7 +516,7 @@ LogicalResult mlir::loopUnrollByFactor(
     }
     epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
                                epilogueForOp.getNumIterOperands(), results);
-    (void)promoteIfSingleIteration(epilogueForOp);
+    (void)epilogueForOp.promoteIfSingleIteration(rewriter);
   }
 
   // Create unrolled loop.
@@ -573,7 +536,7 @@ LogicalResult mlir::loopUnrollByFactor(
       },
       annotateFn, iterArgs, yieldedValues);
   // Promote the loop body up if this has turned into a single iteration loop.
-  (void)promoteIfSingleIteration(forOp);
+  (void)forOp.promoteIfSingleIteration(rewriter);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index e23e86184c3118..e2b45e85e4f8c7 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -545,9 +545,9 @@ func.func @parallel_insert_slice_no_conflict(
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
 
+  // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
   // CHECK: scf.forall (%[[tidx:.*]]) in (%[[idx2]])
   %2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
-      // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
       %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
       // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref<?xf32
       %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
@@ -591,9 +591,9 @@ func.func @parallel_insert_slice_with_conflict(
   // CHECK: %[[alloc1:.*]] = memref.alloc
   // CHECK: memref.copy %[[arg2]], %[[alloc1]]
 
+  // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
   // CHECK: scf.forall (%[[tidx:.*]]) in (%[[idx2]])
   %2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
-      // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
       %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
 
       // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview1]] : memref<?xf32


        


More information about the Mlir-commits mailing list