[Mlir-commits] [mlir] d2a291a - [MLIR][Linalg] Lower `linalg.tiled_loop` to `scf` loops

Frederik Gossen llvmlistbot at llvm.org
Mon May 3 09:47:34 PDT 2021


Author: Frederik Gossen
Date: 2021-05-03T18:47:12+02:00
New Revision: d2a291a5f81af7c99351ecf619c7115ac652b115

URL: https://github.com/llvm/llvm-project/commit/d2a291a5f81af7c99351ecf619c7115ac652b115
DIFF: https://github.com/llvm/llvm-project/commit/d2a291a5f81af7c99351ecf619c7115ac652b115.diff

LOG: [MLIR][Linalg] Lower `linalg.tiled_loop` to `scf` loops

Differential Revision: https://reviews.llvm.org/D101747

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 920e77b4cdd52..2cbefdaca98bf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -555,6 +555,41 @@ class LinalgRewritePattern : public RewritePattern {
   }
 };
 
+struct TiledLoopPattern : public OpRewritePattern<TiledLoopOp> {
+  using OpRewritePattern<TiledLoopOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TiledLoopOp tiledLoop,
+                                PatternRewriter &rewriter) const override {
+    Location loc = tiledLoop.getLoc();
+
+    // Fail conversion if the `tiled_loop` has not been bufferized.
+    if (!llvm::all_of(tiledLoop.outputs(), [&](Value arg) {
+          return arg.getType().isa<MemRefType>();
+        }))
+      return failure();
+
+    // TODO: Build loop nest with `scf.for` and `scf.parallel` depending on the
+    // iterator type.
+    scf::buildLoopNest(rewriter, loc, tiledLoop.lowerBound(),
+                       tiledLoop.upperBound(), tiledLoop.step(),
+                       [&](OpBuilder &builder, Location loc, ValueRange ivs) {
+                         // Move body without its terminator.
+                         SmallVector<Value, 16> newBlockArgs;
+                         newBlockArgs.append(ivs.begin(), ivs.end());
+                         newBlockArgs.append(tiledLoop.inputs().begin(),
+                                             tiledLoop.inputs().end());
+                         newBlockArgs.append(tiledLoop.outputs().begin(),
+                                             tiledLoop.outputs().end());
+                         Block *newBody = rewriter.getInsertionBlock();
+                         rewriter.mergeBlocks(tiledLoop.getBody(), newBody,
+                                              newBlockArgs);
+                         rewriter.eraseOp(newBody->getTerminator());
+                       });
+    rewriter.eraseOp(tiledLoop);
+    return success();
+  }
+};
+
 struct FoldAffineOp;
 } // namespace
 
@@ -562,7 +597,7 @@ template <typename LoopType>
 static void lowerLinalgToLoopsImpl(FuncOp funcOp) {
   MLIRContext *context = funcOp.getContext();
   RewritePatternSet patterns(context);
-  patterns.add<LinalgRewritePattern<LoopType>>(context);
+  patterns.add<LinalgRewritePattern<LoopType>, TiledLoopPattern>(context);
   memref::DimOp::getCanonicalizationPatterns(patterns, context);
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldAffineOp>(context);

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index b0ffc7f0053e5..5a280384726c3 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1522,3 +1522,78 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
 //       CHECKPARALLEL:         %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
 //       CHECKPARALLEL:         %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECKPARALLEL:         store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+
+
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+
+func @tiled_loop_to_parallel(%A: memref<192x192xf32>,
+                             %B: memref<192x192xf32>,
+                             %C: memref<192x192xf32>) {
+  %cst = constant 0.000000e+00 : f32
+  %c24 = constant 24 : index
+  %c16 = constant 16 : index
+  %c0 = constant 0 : index
+  %c192 = constant 192 : index
+
+  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
+      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
+      outs (%C_ = %C: memref<192x192xf32>) {
+    %0 = affine.min #map0(%i)
+    %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1]
+      : memref<192x192xf32> to memref<?x192xf32, #map1>
+    %2 = affine.min #map2(%j)
+    %3 = memref.subview %B_[0, %j] [192, %2] [1, 1]
+      : memref<192x192xf32> to memref<192x?xf32, #map1>
+    %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1]
+      : memref<192x192xf32> to memref<?x?xf32, #map1>
+    linalg.fill(%4, %cst) : memref<?x?xf32, #map1>, f32
+    linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
+                               memref<192x?xf32, #map1>)
+                  outs(%4 : memref<?x?xf32, #map1>)
+    linalg.yield
+  }
+  return
+}
+
+// CHECKLOOP-LABEL: @tiled_loop_to_parallel
+// CHECKLOOP-SAME:  %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>,
+// CHECKLOOP-SAME:  %[[C:.*]]: memref<192x192xf32>) {
+// CHECKLOOP:       %[[C24:.*]] = constant 24 : index
+// CHECKLOOP:       %[[C16:.*]] = constant 16 : index
+// CHECKLOOP:       %[[C192:.*]] = constant 192 : index
+// CHECKLOOP:       %[[C0:.*]] = constant 0 : index
+// CHECKLOOP:       scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] {
+// CHECKLOOP:         scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] {
+// CHECKLOOP:           %[[A_sub:.*]] = memref.subview %[[A]][%[[I]]
+// CHECKLOOP:           %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]]
+// CHECKLOOP:           %[[C_sub:.*]] = memref.subview %[[C]][%[[I]]
+
+
+func @tiled_loop_to_for(%A: memref<192x192xf32>,
+                        %B: memref<192x192xf32>,
+                        %C: memref<f32>) {
+   %c24 = constant 24 : index
+   %c16 = constant 16 : index
+   %c0 = constant 0 : index
+   %c192 = constant 192 : index
+   %cst = constant 0.000000e+00 : f32
+
+  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
+      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
+      outs (%C_ = %C: memref<f32>)
+      iterators["reduction", "reduction"] {
+    linalg.fill(%A_, %cst) : memref<192x192xf32>, f32
+    linalg.yield
+  }
+  return
+}
+
+// CHECKLOOP-LABEL: @tiled_loop_to_for
+// CHECKLOOP:       %[[C24:.*]] = constant 24 : index
+// CHECKLOOP:       %[[C16:.*]] = constant 16 : index
+// CHECKLOOP:       %[[C192:.*]] = constant 192 : index
+// CHECKLOOP:       %[[C0:.*]] = constant 0 : index
+// CHECKLOOP:       scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
+// CHECKLOOP:         scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]


        


More information about the Mlir-commits mailing list