[Mlir-commits] [mlir] 2ed7c3f - [MLIR][SCF] Enable better bufferization for `TileConsumerAndFuseProducersUsingSCFForOp`

lorenzo chelini llvmlistbot at llvm.org
Thu Jul 21 01:14:38 PDT 2022


Author: lorenzo chelini
Date: 2022-07-21T10:14:26+02:00
New Revision: 2ed7c3fd841db1ae2a1ae2c3df865b04a890bb0d

URL: https://github.com/llvm/llvm-project/commit/2ed7c3fd841db1ae2a1ae2c3df865b04a890bb0d
DIFF: https://github.com/llvm/llvm-project/commit/2ed7c3fd841db1ae2a1ae2c3df865b04a890bb0d.diff

LOG: [MLIR][SCF] Enable better bufferization for `TileConsumerAndFuseProducersUsingSCFForOp`

Replace iterators of the outermost loop with region arguments of the innermost
one. The changes avoid later `bufferization` passes to insert allocation within
the body of the innermost loop.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D130083

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 3bad54327e078..c62f27d8a22f4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -355,6 +355,23 @@ static Optional<OpResult> getFusableProducer(Value v) {
   return v.cast<OpResult>();
 }
 
+// Replace iter args of the outer most loop with region args of the inner most
+// one.
+static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,
+                            PatternRewriter &rewriter) {
+  assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() &&
+         "expect same number of iter args");
+  Block *block = &(*innerFor.getRegion().begin());
+  for (auto it :
+       llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) {
+    Value source = std::get<0>(it);
+    Value target = std::get<1>(it);
+    source.replaceUsesWithIf(target, [&](OpOperand &use) {
+      return use.getOwner()->getBlock() == block;
+    });
+  }
+}
+
 FailureOr<scf::SCFTileAndFuseResult>
 scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
     TilingInterface op, PatternRewriter &rewriter) const {
@@ -470,5 +487,7 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
       }
     }
   }
+  replaceIterArgs(tileAndFuseResult.loops.front(),
+                  tileAndFuseResult.loops.back(), rewriter);
   return tileAndFuseResult;
 }

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index d1ca2d2c4625f..61aa706b10ae4 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -23,7 +23,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
 // CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
 //  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
 //  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
-//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
 //      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
 // CHECK-SAME:           outs(%[[INIT_TILE]] :
 //      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
@@ -68,7 +68,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 // CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
 //  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
 //  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
-//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
 //      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
 // CHECK-SAME:           outs(%[[INIT_TILE]] :
 //      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
@@ -123,7 +123,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
 // CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
 // CHECK-SAME:         outs(%[[FILL0_TILE]] :
 //  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
-//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0]
 //      CHECK:     %[[FILL1_TILE:.+]] = linalg.fill
 // CHECK-SAME:         outs(%[[INIT1_TILE]] :
 //      CHECK:     %[[GEMM1_TILE:.+]] = linalg.matmul
@@ -218,7 +218,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
 // CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
 //  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
 //  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
-//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
 //      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
 // CHECK-SAME:           outs(%[[INIT_TILE]] :
 //      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul


        


More information about the Mlir-commits mailing list