[Mlir-commits] [mlir] 16488dc - [mlir][linalg] Pass all operands to tile to the tile loop region builder (NFC).

Tobias Gysi llvmlistbot at llvm.org
Fri Sep 10 01:45:24 PDT 2021


Author: Tobias Gysi
Date: 2021-09-10T08:35:11Z
New Revision: 16488dc300d088ee5c01af15cff07d349f18cd6a

URL: https://github.com/llvm/llvm-project/commit/16488dc300d088ee5c01af15cff07d349f18cd6a
DIFF: https://github.com/llvm/llvm-project/commit/16488dc300d088ee5c01af15cff07d349f18cd6a.diff

LOG: [mlir][linalg] Pass all operands to tile to the tile loop region builder (NFC).

Extend the signature of the tile loop nest region builder to take all operand values to use and not just the scf::For iterArgs. This change allows us to pass in all block arguments of TiledLoop and use them directly instead of replacing them after the loop generation.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D109569

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index c451d0d677d5..937330c90bf2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -263,7 +263,7 @@ struct RegionMatcher {
 /// Utility class used to generate nested loops with ranges described by
 /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn`
 /// is used to generate the body of the innermost loop. It is passed a range
-/// of loop induction variables and a range of iterArgs.
+/// of loop induction variables and a range of operand values to use.
 template <typename LoopTy>
 struct GenerateLoopNest {
   static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 71a363cb4ccd..07d9eef56b8e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -431,8 +431,9 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
   GenerateLoopNest<LoopTy>::doit(
       rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
       [&](OpBuilder &b, Location loc, ValueRange ivs,
-          ValueRange iterArgs) -> scf::ValueVector {
-        assert(iterArgs.empty() && "unexpected iterArgs");
+          ValueRange operandValuesToUse) -> scf::ValueVector {
+        assert(operandValuesToUse == linalgOp->getOperands() &&
+               "expect operands are captured and not passed by loop argument");
         allIvs.append(ivs.begin(), ivs.end());
         llvm::TypeSwitch<Operation *>(linalgOp)
             .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index c685f36d712e..57e77c0bec05 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -227,9 +227,9 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
   // 2. Create the tiled loops.
   LinalgOp res = op;
   SmallVector<Value, 4> ivs, tensorResults;
-  auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc,
-                                  ValueRange localIvs,
-                                  ValueRange iterArgs) -> scf::ValueVector {
+  auto tiledLoopBodyBuilder =
+      [&](OpBuilder &b, Location loc, ValueRange localIvs,
+          ValueRange operandValuesToUse) -> scf::ValueVector {
     ivs.assign(localIvs.begin(), localIvs.end());
 
     // When an `interchangeVector` is present, it has been applied to the
@@ -241,20 +241,16 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
     else
       interchangedIvs.assign(ivs.begin(), ivs.end());
 
-    assert(op.getOutputTensorOperands().size() == iterArgs.size() &&
-           "num output tensors must match number of loop iter arguments");
-
-    SmallVector<Value> operands = op.getInputOperands();
-    SmallVector<Value> outputBuffers = op.getOutputBufferOperands();
-    // TODO: thanks to simplifying assumption we do not need to worry about
-    // order of output buffers and tensors: there is only ever one kind.
-    assert(outputBuffers.empty() || iterArgs.empty());
-    operands.append(outputBuffers.begin(), outputBuffers.end());
-    operands.append(iterArgs.begin(), iterArgs.end());
+    // Tile the `operandValuesToUse` that either match the `op` operands
+    // themselves or the tile loop arguments forwarding them.
+    assert(operandValuesToUse.size() ==
+               static_cast<size_t>(op.getNumInputsAndOutputs()) &&
+           "expect the number of operands and inputs and outputs to match");
+    SmallVector<Value> valuesToTile = operandValuesToUse;
     auto sizeBounds =
         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
-        b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds);
+        b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds);
 
     // TODO: use an interface/adaptor to avoid leaking position in
     // `tiledOperands`.

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 9237c003297b..3317a38ed026 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -225,7 +225,18 @@ void GenerateLoopNest<scf::ForOp>::doit(
   SmallVector<Value, 4> lbs, ubs, steps;
   unpackRanges(loopRanges, lbs, ubs, steps);
   LoopNest loopNest = mlir::scf::buildLoopNest(
-      b, loc, lbs, ubs, steps, iterArgInitValues, bodyBuilderFn);
+      b, loc, lbs, ubs, steps, iterArgInitValues,
+      [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
+        assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() &&
+               "expect the number of output tensors and iter args to match");
+        SmallVector<Value> operandValuesToUse =
+            linalgOp.getInputAndOutputOperands();
+        if (!iterArgs.empty()) {
+          operandValuesToUse = linalgOp.getInputOperands();
+          operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
+        }
+        return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
+      });
 
   if (!distributionOptions || loopNest.loops.empty())
     return;
@@ -268,7 +279,9 @@ void GenerateLoopNest<AffineForOp>::doit(
 
   mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
                             [&](OpBuilder &b, Location loc, ValueRange ivs) {
-                              bodyBuilderFn(b, loc, ivs, {});
+                              SmallVector<Value> operandValuesToUse =
+                                  linalgOp.getInputAndOutputOperands();
+                              bodyBuilderFn(b, loc, ivs, operandValuesToUse);
                             });
 }
 
@@ -289,9 +302,10 @@ void GenerateLoopNest<TiledLoopOp>::doit(
   auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
                               ValueRange ivs, ValueRange inputs,
                               ValueRange outputs) {
-    SmallVector<Value> outputTensors = linalgOp.getOutputTensorOperands();
+    SmallVector<Value> operandValuesToUse = inputs;
+    operandValuesToUse.append(outputs.begin(), outputs.end());
     scf::ValueVector results =
-        bodyBuilderFn(nestedBuilder, nestedLoc, ivs, outputTensors);
+        bodyBuilderFn(nestedBuilder, nestedLoc, ivs, operandValuesToUse);
     nestedBuilder.create<linalg::YieldOp>(nestedLoc, results);
   };
 
@@ -302,15 +316,6 @@ void GenerateLoopNest<TiledLoopOp>::doit(
                             b.getArrayAttr(iteratorTypes), wrappedBuilderFn);
   if (!distributionTypes.empty())
     tiledLoop.setDistributionTypes(b, distributionTypes);
-
-  // Replace inputs/outputs with the corresponding region args.
-  auto isInsideTiledLoop = [&](OpOperand &operand) {
-    return operand.getOwner()->getBlock() == tiledLoop.getBody();
-  };
-  for (auto it : llvm::zip(inputOperands, tiledLoop.getRegionInputArgs()))
-    std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
-  for (auto it : llvm::zip(outputOperands, tiledLoop.getRegionOutputArgs()))
-    std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
 }
 
 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
@@ -505,7 +510,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
   generateParallelLoopNest(
       b, loc, lbs, ubs, steps, iteratorTypes,
       [&](OpBuilder &b, Location loc, ValueRange ivs) {
-        bodyBuilderFn(b, loc, ivs, {});
+        SmallVector<Value> operandValuesToUse =
+            linalgOp.getInputAndOutputOperands();
+        bodyBuilderFn(b, loc, ivs, operandValuesToUse);
       },
       ivs, distributionMethod);
 


        


More information about the Mlir-commits mailing list