[Mlir-commits] [mlir] 5ee5bbd - [mlir][linalg] Extend tiled_loop to SCF conversion to generate scf.parallel.

Alexander Belyaev llvmlistbot at llvm.org
Fri Sep 3 09:06:09 PDT 2021

Author: Alexander Belyaev
Date: 2021-09-03T18:05:54+02:00
New Revision: 5ee5bbd0ffe162e604e75718fed987545c366359

URL: https://github.com/llvm/llvm-project/commit/5ee5bbd0ffe162e604e75718fed987545c366359
DIFF: https://github.com/llvm/llvm-project/commit/5ee5bbd0ffe162e604e75718fed987545c366359.diff

LOG: [mlir][linalg] Extend tiled_loop to SCF conversion to generate scf.parallel.

Differential Revision: https://reviews.llvm.org/D109230




diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 332993d8b002..4c82eafc9c97 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -732,6 +732,22 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
       if (it == outputs().end()) return nullptr;
       return it.getBase();
+    /// Return whether the op has only MemRef input and outputs.
+    bool hasBufferSemantics() {
+      Operation* op = this->getOperation();
+      return op->getNumResults() == 0 &&
+             llvm::all_of(op->getOpOperands(), [&](OpOperand & operand) {
+               return !operand.get().getType().template isa<ShapedType>() ||
+                      operand.get().getType().template isa<MemRefType>();
+             });
+    }
+    /// Return whether the loop dimension is parallel or not.
+    bool isParallelDimension(unsigned dim) {
+      StringAttr attr = this->iterator_types()[dim].cast<StringAttr>();
+      return attr.getValue() == getParallelIteratorTypeName();
+    }
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 320564c6559c..487ad383756d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -480,36 +480,67 @@ class LinalgRewritePattern : public RewritePattern {
+/// Converts tiled_loop to SCF loop nests. All parallel dimensions are collected
+/// into an scf.parallel loop and all sequential dimensions will result in the
+/// nested scf.for loop nest. The pattern assumes that a tiled loop with
+/// iterator_types ["reduction", "parallel", "reduction"] can be reordered. It
+/// is true for the tiling that is currently suppported by Linalg.
 struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> {
   using OpRewritePattern<TiledLoopOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(TiledLoopOp tiledLoop,
                                 PatternRewriter &rewriter) const override {
-    Location loc = tiledLoop.getLoc();
     // Fail conversion if the `tiled_loop` has not been bufferized.
-    if (!llvm::all_of(tiledLoop.outputs(), [&](Value arg) {
-          return arg.getType().isa<MemRefType>();
-        }))
+    if (!tiledLoop.hasBufferSemantics())
       return failure();
-    // TODO: Build loop nest with `scf.for` and `scf.parallel` depending on the
-    // iterator type.
-    scf::buildLoopNest(rewriter, loc, tiledLoop.lowerBound(),
-                       tiledLoop.upperBound(), tiledLoop.step(),
-                       [&](OpBuilder &builder, Location loc, ValueRange ivs) {
-                         // Move body without its terminator.
-                         SmallVector<Value> newBlockArgs;
-                         newBlockArgs.append(ivs.begin(), ivs.end());
-                         newBlockArgs.append(tiledLoop.inputs().begin(),
-                                             tiledLoop.inputs().end());
-                         newBlockArgs.append(tiledLoop.outputs().begin(),
-                                             tiledLoop.outputs().end());
-                         Block *newBody = rewriter.getInsertionBlock();
-                         rewriter.mergeBlocks(tiledLoop.getBody(), newBody,
-                                              newBlockArgs);
-                         rewriter.eraseOp(newBody->getTerminator());
-                       });
+    // Collect loop control parameters for parallel and sequential dimensions.
+    SmallVector<Value, 3> seqLBs, seqUBs, seqSteps, seqIVs;
+    SmallVector<Value, 3> parLBs, parUBs, parSteps, parIVs;
+    for (auto en : llvm::enumerate(
+             llvm::zip(tiledLoop.lowerBound(), tiledLoop.upperBound(),
+                       tiledLoop.step(), tiledLoop.getInductionVars()))) {
+      Value lb, ub, step, iv;
+      std::tie(lb, ub, step, iv) = en.value();
+      if (tiledLoop.isParallelDimension(en.index())) {
+        parLBs.push_back(lb);
+        parUBs.push_back(ub);
+        parSteps.push_back(step);
+        parIVs.push_back(iv);
+      } else {
+        seqLBs.push_back(lb);
+        seqUBs.push_back(ub);
+        seqSteps.push_back(step);
+        seqIVs.push_back(iv);
+      }
+    }
+    Location loc = tiledLoop.getLoc();
+    auto generateForLoopNestAndCloneBody = [&](OpBuilder &builder, Location loc,
+                                               ValueRange ivs) {
+      BlockAndValueMapping bvm;
+      bvm.map(parIVs, ivs);
+      bvm.map(tiledLoop.getRegionInputArgs(), tiledLoop.inputs());
+      bvm.map(tiledLoop.getRegionOutputArgs(), tiledLoop.outputs());
+      // If not all dimensions of the tiled loop are parallel, an scf.for loop
+      // nest is generated.
+      if (!seqIVs.empty()) {
+        scf::LoopNest nest =
+            scf::buildLoopNest(builder, loc, seqLBs, seqUBs, seqSteps,
+                               [&](OpBuilder &builder, Location loc,
+                                   ValueRange ivs) { bvm.map(seqIVs, ivs); });
+        builder.setInsertionPointToStart(nest.loops.back().getBody());
+      }
+      for (auto &op : tiledLoop.getBody()->without_terminator())
+        builder.clone(op, bvm);
+    };
+    if (parIVs.empty())
+      generateForLoopNestAndCloneBody(rewriter, loc, llvm::None);
+    else
+      rewriter.create<scf::ParallelOp>(loc, parLBs, parUBs, parSteps,
+                                       generateForLoopNestAndCloneBody);
     return success();

diff  --git a/mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir b/mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir
new file mode 100644
index 000000000000..b9a847402003
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tiled-loop-to-scf.mlir
@@ -0,0 +1,184 @@
+// RUN: mlir-opt %s -convert-linalg-tiled-loops-to-scf --split-input-file | FileCheck %s
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+func @tiled_loop(%A: memref<192x192xf32>,
+                 %B: memref<192x192xf32>,
+                 %C: memref<192x192xf32>) {
+  %cst = constant 0.000000e+00 : f32
+  %c24 = constant 24 : index
+  %c16 = constant 16 : index
+  %c0 = constant 0 : index
+  %c192 = constant 192 : index
+  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
+      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
+      outs (%C_ = %C: memref<192x192xf32>) {
+    %0 = affine.min #map0(%i)
+    %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1]
+      : memref<192x192xf32> to memref<?x192xf32, #map1>
+    %2 = affine.min #map2(%j)
+    %3 = memref.subview %B_[0, %j] [192, %2] [1, 1]
+      : memref<192x192xf32> to memref<192x?xf32, #map1>
+    %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1]
+      : memref<192x192xf32> to memref<?x?xf32, #map1>
+    linalg.fill(%cst, %4) : f32, memref<?x?xf32, #map1>
+    linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
+                               memref<192x?xf32, #map1>)
+                  outs(%4 : memref<?x?xf32, #map1>)
+    linalg.yield
+  }
+  return
+// CHECK-LABEL: @tiled_loop
+// CHECK-SAME:  %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>,
+// CHECK-SAME:  %[[C:.*]]: memref<192x192xf32>) {
+// CHECK:       %[[C24:.*]] = constant 24 : index
+// CHECK:       %[[C16:.*]] = constant 16 : index
+// CHECK:       %[[C0:.*]] = constant 0 : index
+// CHECK:       %[[C192:.*]] = constant 192 : index
+// CHECK:       scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+// CHECK-SAME:      to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]]) {
+// CHECK:         %[[A_sub:.*]] = memref.subview %[[A]][%[[I]]
+// CHECK:         %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]]
+// CHECK:         %[[C_sub:.*]] = memref.subview %[[C]][%[[I]]
+// CHECK:         linalg.fill
+// CHECK:         linalg.matmul
+// -----
+func @tiled_loop_reduction(%A: memref<192x192xf32>,
+                           %B: memref<192x192xf32>,
+                           %C: memref<f32>) {
+   %c24 = constant 24 : index
+   %c16 = constant 16 : index
+   %c0 = constant 0 : index
+   %c192 = constant 192 : index
+   %cst = constant 0.000000e+00 : f32
+  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
+      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
+      outs (%C_ = %C: memref<f32>)
+      iterators["reduction", "reduction"] {
+    linalg.fill(%cst, %A_) : f32, memref<192x192xf32>
+    linalg.yield
+  }
+  return
+// CHECK-LABEL: @tiled_loop_reduction
+// CHECK:       %[[C24:.*]] = constant 24 : index
+// CHECK:       %[[C16:.*]] = constant 16 : index
+// CHECK:       %[[C0:.*]] = constant 0 : index
+// CHECK:       %[[C192:.*]] = constant 192 : index
+// CHECK:       scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
+// CHECK:         scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]
+// CHECK:           linalg.fill
+// -----
+#strided_1d = affine_map<(d0)[s0] -> (d0 + s0)>
+#strided_2d = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
+func @tiled_loop_row_reduction(%A: memref<10x8xf32>,
+                               %B: memref<8xf32>) {
+   %c0 = constant 0 : index
+   %c2 = constant 2 : index
+   %c4 = constant 4 : index
+   %c8 = constant 8 : index
+   %c10 = constant 10 : index
+   %cst = constant 0.000000e+00 : f32
+  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c10, %c8) step (%c2, %c4)
+      ins (%A_ = %A: memref<10x8xf32>)
+      outs (%B_ = %B: memref<8xf32>)
+      iterators["reduction", "parallel"] {
+    %A_sub = memref.subview %A_[%i, %j][2, 4][1, 1]
+      : memref<10x8xf32> to memref<2x4xf32, #strided_2d>
+    %B_sub = memref.subview %B_[%j][4][1]
+      : memref<8xf32> to memref<4xf32, #strided_1d>
+    linalg.generic {
+        indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                         affine_map<(i, j) -> (j)>],
+        iterator_types = ["reduction", "parallel"]}
+        ins(%A_sub : memref<2x4xf32, #strided_2d>)
+        outs(%B_sub : memref<4xf32, #strided_1d>) {
+      ^bb(%a: f32, %b: f32) :
+        %0 = addf %a, %b: f32
+        linalg.yield %0 : f32
+      }
+    linalg.yield
+  }
+  return
+// CHECK-LABEL: @tiled_loop_row_reduction
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = constant 2 : index
+// CHECK-DAG: %[[C4:.*]] = constant 4 : index
+// CHECK-DAG: %[[C8:.*]] = constant 8 : index
+// CHECK-DAG: %[[C10:.*]] = constant 10 : index
+// CHECK:     scf.parallel (%[[J:.*]]) = (%[[C0]]) to (%[[C8]]) step (%[[C4]])
+// CHECK-NEXT:  scf.for %[[I:.*]] = %[[C0]] to %[[C10]] step %[[C2]]
+// CHECK-NEXT:    memref.subview %arg{{[0-9]+}}[%[[I]], %[[J]]] [2, 4] [1, 1]
+// CHECK-SAME:      : memref<10x8xf32> to memref<2x4xf32, #map{{[0-9]+}}>
+// CHECK-NEXT:    memref.subview %arg{{[0-9]+}}[%[[J]]] [4] [1]
+// CHECK-SAME:      : memref<8xf32> to memref<4xf32, #map{{[0-9]+}}>
+// -----
+#strided_1d = affine_map<(d0)[s0] -> (d0 + s0)>
+#strided_2d = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)>
+func @tiled_loop_col_reduction(%A: memref<10x8xf32>,
+                               %B: memref<10xf32>) {
+   %c0 = constant 0 : index
+   %c2 = constant 2 : index
+   %c4 = constant 4 : index
+   %c8 = constant 8 : index
+   %c10 = constant 10 : index
+   %cst = constant 0.000000e+00 : f32
+  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c10, %c8) step (%c2, %c4)
+      ins (%A_ = %A: memref<10x8xf32>)
+      outs (%B_ = %B: memref<10xf32>)
+      iterators["parallel", "reduction"] {
+    %A_sub = memref.subview %A_[%i, %j][2, 4][1, 1]
+      : memref<10x8xf32> to memref<2x4xf32, #strided_2d>
+    %B_sub = memref.subview %B_[%i][2][1]
+      : memref<10xf32> to memref<2xf32, #strided_1d>
+    linalg.generic {
+        indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                         affine_map<(i, j) -> (i)>],
+        iterator_types = ["parallel", "reduction"]}
+        ins(%A_sub : memref<2x4xf32, #strided_2d>)
+        outs(%B_sub : memref<2xf32, #strided_1d>) {
+      ^bb(%a: f32, %b: f32) :
+        %0 = addf %a, %b: f32
+        linalg.yield %0 : f32
+      }
+    linalg.yield
+  }
+  return
+// CHECK-LABEL: @tiled_loop_col_reduction
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = constant 2 : index
+// CHECK-DAG: %[[C4:.*]] = constant 4 : index
+// CHECK-DAG: %[[C8:.*]] = constant 8 : index
+// CHECK-DAG: %[[C10:.*]] = constant 10 : index
+// CHECK:     scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[C10]]) step (%[[C2]])
+// CHECK-NEXT:  scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C4]]
+// CHECK-NEXT:    memref.subview %arg{{[0-9]+}}[%[[I]], %[[J]]] [2, 4] [1, 1]
+// CHECK-SAME:      : memref<10x8xf32> to memref<2x4xf32, #map{{[0-9]+}}>
+// CHECK-NEXT:    memref.subview %arg{{[0-9]+}}[%[[I]]] [2] [1]
+// CHECK-SAME:      : memref<10xf32> to memref<2xf32, #map{{[0-9]+}}>

diff  --git a/mlir/test/Dialect/Linalg/tiled-loops.mlir b/mlir/test/Dialect/Linalg/tiled-loops.mlir
deleted file mode 100644
index 5798883ba255..000000000000
--- a/mlir/test/Dialect/Linalg/tiled-loops.mlir
+++ /dev/null
@@ -1,79 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-tiled-loops-to-scf | FileCheck %s
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
-func @tiled_loop(%A: memref<192x192xf32>,
-                 %B: memref<192x192xf32>,
-                 %C: memref<192x192xf32>) {
-  %cst = constant 0.000000e+00 : f32
-  %c24 = constant 24 : index
-  %c16 = constant 16 : index
-  %c0 = constant 0 : index
-  %c192 = constant 192 : index
-  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
-      outs (%C_ = %C: memref<192x192xf32>) {
-    %0 = affine.min #map0(%i)
-    %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1]
-      : memref<192x192xf32> to memref<?x192xf32, #map1>
-    %2 = affine.min #map2(%j)
-    %3 = memref.subview %B_[0, %j] [192, %2] [1, 1]
-      : memref<192x192xf32> to memref<192x?xf32, #map1>
-    %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1]
-      : memref<192x192xf32> to memref<?x?xf32, #map1>
-    linalg.fill(%cst, %4) : f32, memref<?x?xf32, #map1>
-    linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
-                               memref<192x?xf32, #map1>)
-                  outs(%4 : memref<?x?xf32, #map1>)
-    linalg.yield
-  }
-  return
-// CHECK-LABEL: @tiled_loop
-// CHECK-SAME:  %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>,
-// CHECK-SAME:  %[[C:.*]]: memref<192x192xf32>) {
-// CHECK:       %[[C24:.*]] = constant 24 : index
-// CHECK:       %[[C16:.*]] = constant 16 : index
-// CHECK:       %[[C0:.*]] = constant 0 : index
-// CHECK:       %[[C192:.*]] = constant 192 : index
-// CHECK:       scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] {
-// CHECK:         scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] {
-// CHECK:           %[[A_sub:.*]] = memref.subview %[[A]][%[[I]]
-// CHECK:           %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]]
-// CHECK:           %[[C_sub:.*]] = memref.subview %[[C]][%[[I]]
-// CHECK:           linalg.fill
-// CHECK:           linalg.matmul
-func @tiled_loop_reduction(%A: memref<192x192xf32>,
-                           %B: memref<192x192xf32>,
-                           %C: memref<f32>) {
-   %c24 = constant 24 : index
-   %c16 = constant 16 : index
-   %c0 = constant 0 : index
-   %c192 = constant 192 : index
-   %cst = constant 0.000000e+00 : f32
-  linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B:  memref<192x192xf32>)
-      outs (%C_ = %C: memref<f32>)
-      iterators["reduction", "reduction"] {
-    linalg.fill(%cst, %A_) : f32, memref<192x192xf32>
-    linalg.yield
-  }
-  return
-// CHECK-LABEL: @tiled_loop_reduction
-// CHECK:       %[[C24:.*]] = constant 24 : index
-// CHECK:       %[[C16:.*]] = constant 16 : index
-// CHECK:       %[[C0:.*]] = constant 0 : index
-// CHECK:       %[[C192:.*]] = constant 192 : index
-// CHECK:       scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
-// CHECK:         scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]
-// CHECK:           linalg.fill


