[Mlir-commits] [mlir] 7f2236c - [mlir][linalg] Add output tensor args folding for linalg.tiled_loop.

Alexander Belyaev llvmlistbot at llvm.org
Thu Mar 25 10:12:07 PDT 2021


Author: Alexander Belyaev
Date: 2021-03-25T18:11:05+01:00
New Revision: 7f2236cf581e6d666e4c3eb512a76f1608fe0bf7

URL: https://github.com/llvm/llvm-project/commit/7f2236cf581e6d666e4c3eb512a76f1608fe0bf7
DIFF: https://github.com/llvm/llvm-project/commit/7f2236cf581e6d666e4c3eb512a76f1608fe0bf7.diff

LOG: [mlir][linalg] Add output tensor args folding for linalg.tiled_loop.

Folds away TiledLoopOp output tensors when the following conditions are met:
* result of `linalg.tiled_loop` has no uses
* output tensor is the argument of `linalg.yield`

Example:

```
%0 = linalg.tiled_loop ...  outs (%out, %out_buf:tensor<...>, memref<...>) {
  ...
  linalg.yield %out : tensor ...
}
```

Becomes

```
linalg.tiled_loop ...  outs (%out_buf:memref<...>) {
  ...
  linalg.yield
}
```

Differential Revision: https://reviews.llvm.org/D99333

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index d54efbe37a57..fe6720761abe 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -584,6 +584,9 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
     }
     unsigned getNumLoops() { return step().size(); }
   }];
+
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fdb2e4f4603e..744f0276daa8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1943,6 +1943,87 @@ bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) {
 
 static LogicalResult verify(TiledLoopOp op) { return success(); }
 
+namespace {
+
+// Folds away TiledLoopOp output tensors when the following conditions are met:
+// * result of `linalg.tiled_loop` has no uses
+// * output tensor is the argument of `linalg.yield`
+//
+// Example:
+//
+// %0 = linalg.tiled_loop ...  outs (%out, %out_buf:tensor<...>, memref<...>) {
+//   ...
+//   linalg.yield %out : tensor ...
+// }
+//
+// Becomes
+//
+// linalg.tiled_loop ...  outs (%out_buf:memref<...>) {
+//   ...
+//   linalg.yield
+// }
+struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
+  using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
+                                PatternRewriter &rewriter) const final {
+    if (tiledLoop.getNumResults() == 0)
+      return failure();
+
+    Block *block = tiledLoop.getBody();
+    auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
+
+    // Match the pattern and collect output buffers that will replace the output
+    // tensors and also the ops that will be ignored when cloning the body.
+    SmallVector<Value, 2> newOutputOperands, newYieldArgs;
+    int resultId = 0;
+    for (Value out : tiledLoop.outputs()) {
+      if (!out.getType().isa<RankedTensorType>()) {
+        newOutputOperands.push_back(out);
+        continue;
+      }
+      Value result = tiledLoop.getResult(resultId);
+      Value yieldArg = yieldOp.getOperand(resultId);
+      if (yieldArg != out || !result.use_empty()) {
+        newOutputOperands.push_back(out);
+        newYieldArgs.push_back(yieldArg);
+      }
+      ++resultId;
+    }
+    if (newOutputOperands.size() == tiledLoop.outputs().size())
+      return failure();
+
+    Location loc = tiledLoop.getLoc();
+    auto newTiledLoop = rewriter.create<TiledLoopOp>(
+        loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
+        tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types());
+
+    // Clone the region ignoring the def-chain for linalg.yield args:
+    // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops.
+    BlockAndValueMapping bvm;
+    bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+    OpBuilder innerBuilder =
+        OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
+    for (auto &op : tiledLoop.getBody()->without_terminator())
+      innerBuilder.clone(op, bvm);
+    innerBuilder.create<linalg::YieldOp>(loc, newYieldArgs);
+    rewriter.eraseOp(tiledLoop);
+
+    return success();
+  }
+};
+} // namespace
+
+void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  results.insert<TiledLoopResultsFolder>(context);
+}
+
+LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
+                                SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+
 /////// Operations corresponding to library calls defined with Tablegen ////////
 
 template <typename LinalgPoolingOp>

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5ec93dda59d0..44f9dbd49cd5 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -818,3 +818,46 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
   // CHECK: return %[[FILL]] : tensor<6x4xf32>
   return %reshape : tensor<6x4xf32>
 }
+
+// -----
+
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+
+func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
+                  %C: memref<192x192xf32>) -> ()
+
+func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
+                              %C: memref<192x192xf32>,
+                              %C_tensor: tensor<192x192xf32>) {
+  %cst = constant 0.000000e+00 : f32
+  %c24 = constant 24 : index
+  %c16 = constant 16 : index
+  %c0 = constant 0 : index
+  %c192 = constant 192 : index
+  %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
+      step (%c24, %c16)
+      ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
+      outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
+        call @foo(%A, %B, %C) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
+    linalg.yield %C_tensor : tensor<192x192xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @fold_tiled_loop_results(
+// CHECK-SAME:    %[[A:.*]]: [[TY:.*]], %[[B:.*]]: [[TY]], %[[C:.*]]: [[TY]],
+// CHECK-SAME:    %[[C_TENSOR:.*]]: tensor<{{.*}}>) {
+// CHECK:  %[[C24:.*]] = constant 24 : index
+// CHECK:  %[[C16:.*]] = constant 16 : index
+// CHECK:  %[[C0:.*]] = constant 0 : index
+// CHECK:  %[[C192:.*]] = constant 192 : index
+
+// CHECK-NOT: %{{.*}} = linalg.tiled_loop
+// CHECK:  linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]])
+// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]])
+// CHECK-SAME: ins (%[[A]], %[[B]]: memref<192x192xf32>, memref<192x192xf32>)
+// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) {
+// CHECK-NEXT:   call @foo(%[[A]], %[[B]], %[[C]])
+// CHECK-NEXT:   linalg.yield


        


More information about the Mlir-commits mailing list