[Mlir-commits] [mlir] 29dbac0 - [mlir] Add folding for tensor inputs and memref.cast in linalg.tiled_loop.

Alexander Belyaev llvmlistbot at llvm.org
Wed Apr 28 05:36:48 PDT 2021


Author: Alexander Belyaev
Date: 2021-04-28T14:36:07+02:00
New Revision: 29dbac0ae29576176318525c9af65a15429c9466

URL: https://github.com/llvm/llvm-project/commit/29dbac0ae29576176318525c9af65a15429c9466
DIFF: https://github.com/llvm/llvm-project/commit/29dbac0ae29576176318525c9af65a15429c9466.diff

LOG: [mlir] Add folding for tensor inputs and memref.cast in linalg.tiled_loop.

Tensor inputs, if not used in the body of TiledLoopOp, can be removed.
memref::CastOp can be folded into TiledLoopOp as well.

Differential Revision: https://reviews.llvm.org/D101445

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5a6d498a65b49..8a80036688cd9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -136,7 +136,7 @@ static void dispatchIndexOpFoldResult(OpFoldResult ofr,
 
 /// This is a common class used for patterns of the form
 /// ```
-///    someop(memrefcast) -> someop
+///    someop(memrefcast(%src)) -> someop(%src)
 /// ```
 /// It folds the source of the memref.cast into the root operation directly.
 static LogicalResult foldMemRefCast(Operation *op) {
@@ -151,6 +151,44 @@ static LogicalResult foldMemRefCast(Operation *op) {
   return success(folded);
 }
 
+/// This is a specialization of `foldMemRefCast` used for patterns of the form
+/// ```
+///    tiled_loop(memrefcast(%src)) -> tiled_loop(%src)
+/// ```
+/// It folds the source of the memref.cast into the root operation directly.
+static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
+  bool folded = false;
+  Location loc = op->getLoc();
+
+  Block *body = op.getBody();
+  OpBuilder b = OpBuilder::atBlockBegin(body);
+
+  // Update `input` and `output` operands and block arguments if necessary.
+  // Operands list: [lbs, ubs, steps, inputs, outputs].
+  // Block args list: [ivs, inputs, outputs].
+  for (size_t operandIndex = op.getNumControlOperands(),
+              bbArgIndex = op.getNumLoops(), e = op.getNumOperands();
+       operandIndex < e; ++operandIndex, ++bbArgIndex) {
+    OpOperand &operand = op->getOpOperand(operandIndex);
+
+    auto castOp = operand.get().getDefiningOp<memref::CastOp>();
+    if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
+      operand.set(castOp.getOperand());
+      auto newBbArg =
+          body->insertArgument(bbArgIndex, castOp.getOperand().getType());
+      auto oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1);
+
+      // Insert memref.cast back to the original type.
+      oldBbArg.replaceAllUsesWith(
+          b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg));
+      body->eraseArgument(oldBbArg.getArgNumber());
+
+      folded = true;
+    }
+  }
+  return success(folded);
+}
+
 //===----------------------------------------------------------------------===//
 // Region builder helper.
 // TODO: Move this to a utility library.
@@ -2054,6 +2092,63 @@ static LogicalResult verify(TiledLoopOp op) {
 
 namespace {
 
+static constexpr int64_t kNoMatch = -1;
+
+// Folds away TiledLoopOp input tensors if they have no uses within the body.
+//
+// Example:
+//
+// %0 = linalg.tiled_loop ...  ins (%in_ = %in: tensor<...>,
+//                                  %in_buf_ = %in_buf: memref<...>) {...}
+// Becomes
+//
+// linalg.tiled_loop ...  ins (%in_buf_ = %in_buf: memref<...>) {...}
+struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
+  using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
+                                PatternRewriter &rewriter) const final {
+    SmallVector<Value, 2> newInputs, regionInputTensorArgs;
+    // Store ids of the corresponding old and new input operands.
+    SmallVector<int64_t, 2> oldInputIdToNew(tiledLoop.inputs().size(),
+                                            kNoMatch);
+    for (auto en : llvm::enumerate(
+             llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) {
+      Value in, bbArg;
+      size_t index = en.index();
+      std::tie(in, bbArg) = en.value();
+      if (!in.getType().isa<RankedTensorType>() || !bbArg.use_empty()) {
+        oldInputIdToNew[index] = newInputs.size();
+        newInputs.push_back(in);
+        continue;
+      }
+    }
+    if (newInputs.size() == tiledLoop.inputs().size())
+      return failure();
+    Location loc = tiledLoop.getLoc();
+    auto newTiledLoop = rewriter.create<TiledLoopOp>(
+        loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
+        newInputs, tiledLoop.outputs(), tiledLoop.iterator_types());
+
+    // Clone the region.
+    BlockAndValueMapping bvm;
+    bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+    bvm.map(tiledLoop.getRegionOutputArgs(),
+            newTiledLoop.getRegionOutputArgs());
+    for (const auto &en : llvm::enumerate(oldInputIdToNew))
+      if (en.value() != kNoMatch)
+        bvm.map(tiledLoop.getRegionInputArgs()[en.index()],
+                newTiledLoop.getRegionInputArgs()[en.value()]);
+    OpBuilder innerBuilder =
+        OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
+    for (auto &op : *tiledLoop.getBody())
+      innerBuilder.clone(op, bvm);
+    rewriter.eraseOp(tiledLoop);
+
+    return success();
+  }
+};
+
 // Folds away TiledLoopOp output tensors when the following conditions are met:
 // * result of `linalg.tiled_loop` has no uses
 // * output tensor is the argument of `linalg.yield`
@@ -2085,27 +2180,26 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
 
     // Match the pattern and collect output buffers that will replace the output
     // tensors and also the ops that will be ignored when cloning the body.
-    SmallVector<Value, 2> newOutputOperands, newYieldArgs,
-        regionOutputTensorArgs;
+    SmallVector<Value, 2> newOutputOperands, newYieldArgs;
     int resultId = 0;
     // Store ids of the corresponding old and new output operands.
-    SmallVector<std::pair<size_t, size_t>, 2> old_out_id_to_new;
-    for (auto item : llvm::enumerate(
+    SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
+                                             kNoMatch);
+    for (auto en : llvm::enumerate(
              llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
-      size_t index = item.index();
-      Value out = std::get<0>(item.value());
-      Value outRegionArg = std::get<1>(item.value());
+      size_t index = en.index();
+      Value out = std::get<0>(en.value());
+      Value outRegionArg = std::get<1>(en.value());
 
       if (!out.getType().isa<RankedTensorType>()) {
-        old_out_id_to_new.push_back({index, newOutputOperands.size()});
+        oldOutputIdToNew[index] = newOutputOperands.size();
         newOutputOperands.push_back(out);
-        regionOutputTensorArgs.push_back(outRegionArg);
         continue;
       }
       Value result = tiledLoop.getResult(resultId);
       Value yieldArg = yieldOp.getOperand(resultId);
       if (yieldArg != outRegionArg || !result.use_empty()) {
-        old_out_id_to_new.push_back({index, newOutputOperands.size()});
+        oldOutputIdToNew[index] = newOutputOperands.size();
         newOutputOperands.push_back(out);
         newYieldArgs.push_back(yieldArg);
       }
@@ -2119,14 +2213,18 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
         loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
         tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types());
 
-    // Clone the region ignoring the def-chain for linalg.yield args:
-    // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops.
+    // Clone the region.
     BlockAndValueMapping bvm;
     bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
     bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs());
-    for (const auto &item : old_out_id_to_new)
-      bvm.map(tiledLoop.getRegionOutputArgs()[item.first],
-              newTiledLoop.getRegionOutputArgs()[item.second]);
+    for (const auto &en : llvm::enumerate(oldOutputIdToNew)) {
+      if (en.value() != kNoMatch)
+        bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
+                newTiledLoop.getRegionOutputArgs()[en.value()]);
+      else
+        bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
+                tiledLoop.outputs()[en.index()]);
+    }
     OpBuilder innerBuilder =
         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
     for (auto &op : tiledLoop.getBody()->without_terminator())
@@ -2141,12 +2239,12 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
 
 void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.insert<TiledLoopResultsFolder>(context);
+  results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder>(context);
 }
 
 LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
                                 SmallVectorImpl<OpFoldResult> &) {
-  return foldMemRefCast(*this);
+  return foldMemRefCastInTiledLoopOp(*this);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index afdfe6fb98a81..e66ee388c65eb 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -18,6 +18,31 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
 
 // -----
 
+#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK-LABEL: func @memref_cast_into_tiled_loop(
+func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>)  {
+  %0 = memref.cast %arg0
+    : memref<192xf32> to memref<192xf32, #map>
+  %cst = constant 0.000000e+00 : f32
+  %c24 = constant 24 : index
+  %c0 = constant 0 : index
+  %c192 = constant 192 : index
+  // CHECK: linalg.tiled_loop
+  // CHECK-SAME: outs (%{{.*}} = %{{.*}}: memref<192xf32>)
+  linalg.tiled_loop (%arg3) = (%c0) to (%c192) step (%c24)
+    outs (%out = %0: memref<192xf32, #map>) {
+    %14 = affine.min affine_map<(d0) -> (-d0 + 192, 24)>(%arg3)
+    %16 = memref.subview %out[%arg3] [%14] [1]
+      : memref<192xf32, #map> to memref<?xf32, #map>
+    linalg.fill(%16, %cst) : memref<?xf32, #map>, f32
+    linalg.yield
+  }
+  return
+}
+
+// -----
+
 func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
 {
   %0 = linalg.tensor_reshape %arg0
@@ -889,6 +914,30 @@ func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
 
 // -----
 
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+
+func private @foo(%A: memref<192xf32>) -> ()
+
+func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>) {
+  %c0 = constant 0 : index
+  %c24 = constant 24 : index
+  %c192 = constant 192 : index
+  linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24)
+      ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) {
+        call @foo(%A_) : (memref<192xf32>)-> ()
+    linalg.yield
+  }
+  return
+}
+
+// CHECK-LABEL: func @fold_tiled_loop_inputs
+// CHECK: linalg.tiled_loop
+// CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>)
+
+// -----
+
 func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
     %arg3: f32) -> (index, index, index)
 {


        


More information about the Mlir-commits mailing list