[Mlir-commits] [mlir] ec33916 - [MLIR][Linalg] Lower `linalg.tiled_loop` in a separate pass

Frederik Gossen llvmlistbot at llvm.org
Mon May 3 12:02:22 PDT 2021


Author: Frederik Gossen
Date: 2021-05-03T21:02:02+02:00
New Revision: ec339163a7a5bad134eef869356b359cc44dfc97

URL: https://github.com/llvm/llvm-project/commit/ec339163a7a5bad134eef869356b359cc44dfc97
DIFF: https://github.com/llvm/llvm-project/commit/ec339163a7a5bad134eef869356b359cc44dfc97.diff

LOG: [MLIR][Linalg] Lower `linalg.tiled_loop` in a separate pass

Add dedicated pass `convert-linalg-tiled-loops-to-scf` to lower
`linalg.tiled_loop`s.

Differential Revision: https://reviews.llvm.org/D101768

Added: 
    mlir/test/Dialect/Linalg/tiled-loops.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 3c7b7c146ccab..b81ea52ba3357 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -36,6 +36,10 @@ std::unique_ptr<OperationPass<FuncOp>>
 createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca);
 std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass();
 
+/// Create a pass to convert Linalg tiled loops to `scf.for` and `scf.parallel`
+/// loops and memref.load/memref.store accesses.
+std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgTiledLoopsToSCFPass();
+
 /// Create a pass to convert Linalg operations to scf.for loops and
 /// memref.load/memref.store accesses.
 std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToLoopsPass();

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index fe5ac6354f48b..c529dfded3eab 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -58,6 +58,17 @@ def LinalgFoldReshapeOpsByLinearization :
   let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
 }
 
+def LinalgLowerTiledLoopsToSCF
+    : FunctionPass<"convert-linalg-tiled-loops-to-scf"> {
+  let summary = "Lower linalg tiled loops to SCF loops and parallel loops";
+  let constructor = "mlir::createConvertLinalgTiledLoopsToSCFPass()";
+  let dependentDialects = [
+    "linalg::LinalgDialect",
+    "scf::SCFDialect",
+    "AffineDialect"
+  ];
+}
+
 def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
   let summary = "Lower the operations from the linalg dialect into affine "
                 "loops";
@@ -76,16 +87,6 @@ def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
   ];
 }
 
-def LinalgBufferize : Pass<"linalg-bufferize", "FuncOp"> {
-  let summary = "Bufferize the linalg dialect";
-  let constructor = "mlir::createLinalgBufferizePass()";
-  let dependentDialects = [
-    "linalg::LinalgDialect",
-    "AffineDialect",
-    "memref::MemRefDialect"
-  ];
-}
-
 def LinalgLowerToParallelLoops
     : FunctionPass<"convert-linalg-to-parallel-loops"> {
   let summary = "Lower the operations from the linalg dialect into parallel "
@@ -99,6 +100,16 @@ def LinalgLowerToParallelLoops
   ];
 }
 
+def LinalgBufferize : Pass<"linalg-bufferize", "FuncOp"> {
+  let summary = "Bufferize the linalg dialect";
+  let constructor = "mlir::createLinalgBufferizePass()";
+  let dependentDialects = [
+    "linalg::LinalgDialect",
+    "AffineDialect",
+    "memref::MemRefDialect"
+  ];
+}
+
 def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
   let summary = "Promote subview ops to local buffers";
   let constructor = "mlir::createLinalgPromotionPass()";

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 2cbefdaca98bf..f03e590bbde34 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -555,7 +555,7 @@ class LinalgRewritePattern : public RewritePattern {
   }
 };
 
-struct TiledLoopPattern : public OpRewritePattern<TiledLoopOp> {
+struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> {
   using OpRewritePattern<TiledLoopOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(TiledLoopOp tiledLoop,
@@ -597,7 +597,7 @@ template <typename LoopType>
 static void lowerLinalgToLoopsImpl(FuncOp funcOp) {
   MLIRContext *context = funcOp.getContext();
   RewritePatternSet patterns(context);
-  patterns.add<LinalgRewritePattern<LoopType>, TiledLoopPattern>(context);
+  patterns.add<LinalgRewritePattern<LoopType>>(context);
   memref::DimOp::getCanonicalizationPatterns(patterns, context);
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);
@@ -668,8 +668,23 @@ struct LowerToParallelLoops
     lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction());
   }
 };
+
+struct LowerTiledLoopsToSCF
+    : public LinalgLowerTiledLoopsToSCFBase<LowerTiledLoopsToSCF> {
+  void runOnFunction() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    patterns.add<TiledLoopToSCFPattern>(context);
+    (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+  }
+};
 } // namespace
 
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createConvertLinalgTiledLoopsToSCFPass() {
+  return std::make_unique<LowerTiledLoopsToSCF>();
+}
+
 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() {
   return std::make_unique<LowerToLoops>();
 }

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 5a280384726c3..b0ffc7f0053e5 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1522,78 +1522,3 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
 //       CHECKPARALLEL:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
 //       CHECKPARALLEL:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKPARALLEL:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
-
-
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
-
-func @tiled_loop_to_parallel(%A: memref<192x192xf32>,
-                             %B: memref<192x192xf32>,
-                             %C: memref<192x192xf32>) {
-  %cst = constant 0.000000e+00 : f32
-  %c24 = constant 24 : index
-  %c16 = constant 16 : index
-  %c0 = constant 0 : index
-  %c192 = constant 192 : index
-
-  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
-      outs (%C_ = %C: memref<192x192xf32>) {
-    %0 = affine.min #map0(%i)
-    %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1]
-      : memref<192x192xf32> to memref<?x192xf32, #map1>
-    %2 = affine.min #map2(%j)
-    %3 = memref.subview %B_[0, %j] [192, %2] [1, 1]
-      : memref<192x192xf32> to memref<192x?xf32, #map1>
-    %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1]
-      : memref<192x192xf32> to memref<?x?xf32, #map1>
-    linalg.fill(%4, %cst) : memref<?x?xf32, #map1>, f32
-    linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
-                               memref<192x?xf32, #map1>)
-                  outs(%4 : memref<?x?xf32, #map1>)
-    linalg.yield
-  }
-  return
-}
-
-// CHECKLOOP-LABEL: @tiled_loop_to_parallel
-// CHECKLOOP-SAME:  %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>,
-// CHECKLOOP-SAME:  %[[C:.*]]: memref<192x192xf32>) {
-// CHECKLOOP:       %[[C24:.*]] = constant 24 : index
-// CHECKLOOP:       %[[C16:.*]] = constant 16 : index
-// CHECKLOOP:       %[[C192:.*]] = constant 192 : index
-// CHECKLOOP:       %[[C0:.*]] = constant 0 : index
-// CHECKLOOP:       scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] {
-// CHECKLOOP:         scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] {
-// CHECKLOOP:           %[[A_sub:.*]] = memref.subview %[[A]][%[[I]]
-// CHECKLOOP:           %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]]
-// CHECKLOOP:           %[[C_sub:.*]] = memref.subview %[[C]][%[[I]]
-
-
-func @tiled_loop_to_for(%A: memref<192x192xf32>,
-                        %B: memref<192x192xf32>,
-                        %C: memref<f32>) {
-   %c24 = constant 24 : index
-   %c16 = constant 16 : index
-   %c0 = constant 0 : index
-   %c192 = constant 192 : index
-   %cst = constant 0.000000e+00 : f32
-
-  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
-      outs (%C_ = %C: memref<f32>)
-      iterators["reduction", "reduction"] {
-    linalg.fill(%A_, %cst) : memref<192x192xf32>, f32
-    linalg.yield
-  }
-  return
-}
-
-// CHECKLOOP-LABEL: @tiled_loop_to_for
-// CHECKLOOP:       %[[C24:.*]] = constant 24 : index
-// CHECKLOOP:       %[[C16:.*]] = constant 16 : index
-// CHECKLOOP:       %[[C192:.*]] = constant 192 : index
-// CHECKLOOP:       %[[C0:.*]] = constant 0 : index
-// CHECKLOOP:       scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
-// CHECKLOOP:         scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]

diff  --git a/mlir/test/Dialect/Linalg/tiled-loops.mlir b/mlir/test/Dialect/Linalg/tiled-loops.mlir
new file mode 100644
index 0000000000000..7d22699088131
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tiled-loops.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-opt %s -convert-linalg-tiled-loops-to-scf | FileCheck %s
+
+
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+
+func @tiled_loop(%A: memref<192x192xf32>,
+                 %B: memref<192x192xf32>,
+                 %C: memref<192x192xf32>) {
+  %cst = constant 0.000000e+00 : f32
+  %c24 = constant 24 : index
+  %c16 = constant 16 : index
+  %c0 = constant 0 : index
+  %c192 = constant 192 : index
+
+  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
+      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
+      outs (%C_ = %C: memref<192x192xf32>) {
+    %0 = affine.min #map0(%i)
+    %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1]
+      : memref<192x192xf32> to memref<?x192xf32, #map1>
+    %2 = affine.min #map2(%j)
+    %3 = memref.subview %B_[0, %j] [192, %2] [1, 1]
+      : memref<192x192xf32> to memref<192x?xf32, #map1>
+    %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1]
+      : memref<192x192xf32> to memref<?x?xf32, #map1>
+    linalg.fill(%4, %cst) : memref<?x?xf32, #map1>, f32
+    linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
+                               memref<192x?xf32, #map1>)
+                  outs(%4 : memref<?x?xf32, #map1>)
+    linalg.yield
+  }
+  return
+}
+
+// CHECK-LABEL: @tiled_loop
+// CHECK-SAME:  %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>,
+// CHECK-SAME:  %[[C:.*]]: memref<192x192xf32>) {
+// CHECK:       %[[C24:.*]] = constant 24 : index
+// CHECK:       %[[C16:.*]] = constant 16 : index
+// CHECK:       %[[C0:.*]] = constant 0 : index
+// CHECK:       %[[C192:.*]] = constant 192 : index
+// CHECK:       scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] {
+// CHECK:         scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] {
+// CHECK:           %[[A_sub:.*]] = memref.subview %[[A]][%[[I]]
+// CHECK:           %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]]
+// CHECK:           %[[C_sub:.*]] = memref.subview %[[C]][%[[I]]
+// CHECK:           linalg.fill
+// CHECK:           linalg.matmul
+
+
+func @tiled_loop_reduction(%A: memref<192x192xf32>,
+                           %B: memref<192x192xf32>,
+                           %C: memref<f32>) {
+   %c24 = constant 24 : index
+   %c16 = constant 16 : index
+   %c0 = constant 0 : index
+   %c192 = constant 192 : index
+   %cst = constant 0.000000e+00 : f32
+
+  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
+      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
+      outs (%C_ = %C: memref<f32>)
+      iterators["reduction", "reduction"] {
+    linalg.fill(%A_, %cst) : memref<192x192xf32>, f32
+    linalg.yield
+  }
+  return
+}
+
+// CHECK-LABEL: @tiled_loop_reduction
+// CHECK:       %[[C24:.*]] = constant 24 : index
+// CHECK:       %[[C16:.*]] = constant 16 : index
+// CHECK:       %[[C0:.*]] = constant 0 : index
+// CHECK:       %[[C192:.*]] = constant 192 : index
+// CHECK:       scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
+// CHECK:         scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]
+// CHECK:           linalg.fill


        


More information about the Mlir-commits mailing list