[Mlir-commits] [mlir] 4b13b75 - [mlir] Add a pass to tile Linalg ops using `linalg.tiled_loop`.

Alexander Belyaev llvmlistbot at llvm.org
Tue Apr 27 03:33:41 PDT 2021


Author: Alexander Belyaev
Date: 2021-04-27T12:33:28+02:00
New Revision: 4b13b7581db59adbc0ee4bbf269f3eda96fc9bd7

URL: https://github.com/llvm/llvm-project/commit/4b13b7581db59adbc0ee4bbf269f3eda96fc9bd7
DIFF: https://github.com/llvm/llvm-project/commit/4b13b7581db59adbc0ee4bbf269f3eda96fc9bd7.diff

LOG: [mlir] Add a pass to tile Linalg ops using `linalg.tiled_loop`.

Differential Revision: https://reviews.llvm.org/D101084

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/tile-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 17fcf08dba96b..3c7b7c146ccab 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -29,6 +29,9 @@ createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
 std::unique_ptr<OperationPass<FuncOp>>
 createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes = {});
 
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgTilingToTiledLoopPass(ArrayRef<int64_t> tileSizes = {});
+
 std::unique_ptr<OperationPass<FuncOp>>
 createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca);
 std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass();

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 8d411d5964c5d..fe5ac6354f48b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -121,8 +121,7 @@ def LinalgTiling : FunctionPass<"linalg-tile"> {
     "scf::SCFDialect"
   ];
   let options = [
-    ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
-               "Test generation of dynamic promoted buffers",
+    ListOption<"tileSizes", "linalg-tile-sizes", "int64_t", "Tile sizes",
                "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
   ];
 }
@@ -132,8 +131,23 @@ def LinalgTilingToParallelLoops
   let summary = "Tile operations in the linalg dialect to parallel loops";
   let constructor = "mlir::createLinalgTilingToParallelLoopsPass()";
   let options = [
-    ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
-               "Test generation of dynamic promoted buffers",
+    ListOption<"tileSizes", "linalg-tile-sizes", "int64_t", "Tile sizes",
+               "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
+  ];
+  let dependentDialects = [
+    "AffineDialect",
+    "linalg::LinalgDialect",
+    "memref::MemRefDialect",
+    "scf::SCFDialect"
+  ];
+}
+
+def LinalgTilingToTiledLoops
+    : FunctionPass<"linalg-tile-to-tiled-loop"> {
+  let summary = "Tile operations in the linalg dialect to linalg.tiled_loop";
+  let constructor = "mlir::createLinalgTilingToTiledLoopPass()";
+  let options = [
+    ListOption<"tileSizes", "linalg-tile-sizes", "int64_t", "Tile sizes",
                "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
   ];
   let dependentDialects = [

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8c5e618629874..d07d4f2ec8773 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -424,6 +424,7 @@ enum class LinalgTilingLoopType {
   Loops = 0,
   AffineLoops = 1,
   ParallelLoops = 2,
+  TiledLoops = 3,
 };
 
 using TileSizeComputationFunction =

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index ba41e86fb79b5..b0dd38e754862 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -253,7 +253,7 @@ struct GenerateLoopNest {
                                 edsc::intrinsics::MemRefIndexedValue>::type;
 
   static void
-  doit(ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
+  doit(ArrayRef<Range> loopRanges, LinalgOp linalgOp,
        ArrayRef<Attribute> iteratorTypes,
        function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
        Optional<LinalgLoopDistributionOptions> = None);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index f19493c3cca9c..920e77b4cdd52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -473,7 +473,7 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
 
   SmallVector<Value, 4> allIvs;
   GenerateLoopNest<LoopTy>::doit(
-      loopRanges, /*iterInitArgs=*/{}, iteratorTypes,
+      loopRanges, linalgOp, iteratorTypes,
       [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
         assert(iterArgs.empty() && "unexpected iterArgs");
         allIvs.append(ivs.begin(), ivs.end());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index d8e1acd14e409..674ef93e4e9cd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -312,9 +312,8 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
   // 2. Create the tiled loops.
   LinalgOp res = op;
   SmallVector<Value, 4> ivs, tensorResults;
-  auto outputTensors = op.getOutputTensors();
   GenerateLoopNest<LoopTy>::doit(
-      loopRanges, /*iterArgInitValues*/ outputTensors, iteratorTypes,
+      loopRanges, op, iteratorTypes,
       [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector {
         auto &b = ScopedContext::getBuilderRef();
         auto loc = ScopedContext::getLocation();
@@ -439,6 +438,8 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
   case LinalgTilingLoopType::ParallelLoops:
     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
+  case LinalgTilingLoopType::TiledLoops:
+    return tileLinalgOpImpl<linalg::TiledLoopOp>(b, op, options);
   default:;
   }
   return llvm::None;
@@ -567,6 +568,17 @@ struct LinalgTilingToParallelLoopsPass
   }
 };
 
+struct LinalgTilingToTiledLoopsPass
+    : public LinalgTilingToTiledLoopsBase<LinalgTilingToTiledLoopsPass> {
+  LinalgTilingToTiledLoopsPass() = default;
+  LinalgTilingToTiledLoopsPass(ArrayRef<int64_t> sizes) { tileSizes = sizes; }
+
+  void runOnFunction() override {
+    applyTilingToLoopPatterns(LinalgTilingLoopType::TiledLoops, getFunction(),
+                              tileSizes);
+  }
+};
+
 } // namespace
 
 std::unique_ptr<OperationPass<FuncOp>>
@@ -578,3 +590,8 @@ std::unique_ptr<OperationPass<FuncOp>>
 mlir::createLinalgTilingToParallelLoopsPass(ArrayRef<int64_t> tileSizes) {
   return std::make_unique<LinalgTilingToParallelLoopsPass>(tileSizes);
 }
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgTilingToTiledLoopPass(ArrayRef<int64_t> tileSizes) {
+  return std::make_unique<LinalgTilingToTiledLoopsPass>(tileSizes);
+}

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 1aa47ea3ac19c..2714856a24126 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -142,6 +142,7 @@ bool mlir::linalg::isWindowIteratorType(Attribute attr) {
 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
+template struct mlir::linalg::GenerateLoopNest<TiledLoopOp>;
 
 /// Given a list of subview ranges, extract individual values for lower, upper
 /// bounds and steps and put them into the corresponding vectors.
@@ -186,10 +187,11 @@ IntegerAttr getSmallestBoundingIndex(Value size) {
 /// Specialization to build an scf "for" nest.
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(
-    ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
+    ArrayRef<Range> loopRanges, LinalgOp linalgOp,
     ArrayRef<Attribute> iteratorTypes,
     function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
     Optional<LinalgLoopDistributionOptions> distributionOptions) {
+  auto iterArgInitValues = linalgOp.getOutputTensors();
   // Create procInfo so it dominates loops, if appropriate.
   OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
   Location loc = edsc::ScopedContext::getLocation();
@@ -216,10 +218,11 @@ void GenerateLoopNest<scf::ForOp>::doit(
 /// Specialization to build affine "for" nest.
 template <>
 void GenerateLoopNest<AffineForOp>::doit(
-    ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
+    ArrayRef<Range> loopRanges, LinalgOp linalgOp,
     ArrayRef<Attribute> iteratorTypes,
     function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
     Optional<LinalgLoopDistributionOptions>) {
+  auto iterArgInitValues = linalgOp.getOutputTensors();
   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
   SmallVector<Value, 4> lbs, ubs, steps;
   unpackRanges(loopRanges, lbs, ubs, steps);
@@ -240,6 +243,44 @@ void GenerateLoopNest<AffineForOp>::doit(
                               bodyBuilderWithoutIterArgsFn);
 }
 
+/// Specialization to build an linalg.tiled_loop
+template <>
+void GenerateLoopNest<TiledLoopOp>::doit(
+    ArrayRef<Range> loopRanges, LinalgOp linalgOp,
+    ArrayRef<Attribute> iteratorTypes,
+    function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
+    Optional<LinalgLoopDistributionOptions>) {
+  OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
+  Location loc = edsc::ScopedContext::getLocation();
+  SmallVector<ProcInfo, 2> procInfo;
+
+  SmallVector<Value, 4> lbs, ubs, steps;
+  unpackRanges(loopRanges, lbs, ubs, steps);
+
+  auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                              ValueRange ivs, ValueRange inputs,
+                              ValueRange outputs) {
+    ScopedContext context(nestedBuilder, nestedLoc);
+    scf::ValueVector results = bodyBuilderFn(ivs, linalgOp.getOutputTensors());
+    nestedBuilder.create<linalg::YieldOp>(nestedLoc, results);
+  };
+
+  auto tiledLoop = builder.create<TiledLoopOp>(
+      loc, lbs, ubs, steps, linalgOp.getInputs(), linalgOp.getOutputs(),
+      builder.getArrayAttr(iteratorTypes), wrappedBuilderFn);
+
+  // Replace inputs/outputs with the corresponding region args.
+  auto isInsideTiledLoop = [&](OpOperand &operand) {
+    return operand.getOwner()->getBlock() == tiledLoop.getBody();
+  };
+  for (auto it :
+       llvm::zip(linalgOp.getInputs(), tiledLoop.getRegionInputArgs()))
+    std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
+  for (auto it :
+       llvm::zip(linalgOp.getOutputs(), tiledLoop.getRegionOutputArgs()))
+    std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
+}
+
 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
 void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc,
                                        Value procId, Value nprocs, Value &lb,
@@ -373,10 +414,11 @@ generateParallelLoopNest(ValueRange lbs, ValueRange ubs, ValueRange steps,
 /// Specialization for generating a mix of parallel and sequential scf loops.
 template <>
 void GenerateLoopNest<scf::ParallelOp>::doit(
-    ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
+    ArrayRef<Range> loopRanges, LinalgOp linalgOp,
     ArrayRef<Attribute> iteratorTypes,
     function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
     Optional<LinalgLoopDistributionOptions> distributionOptions) {
+  auto iterArgInitValues = linalgOp.getOutputTensors();
   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
   // This function may be passed more iterator types than ranges.
   assert(iteratorTypes.size() >= loopRanges.size() &&

diff  --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 8a4478ee28121..744cc59d79fd3 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-tile-to-tiled-loop="linalg-tile-sizes=2,3,4" -split-input-file | FileCheck %s -check-prefix=TLOOP
 
 // CHECK-LABEL: func @matmul_tensors(
 // CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
@@ -27,6 +28,38 @@ func @matmul_tensors(
   return %0 : tensor<?x?xf32>
 }
 
+// TLOOP-LABEL: func @matmul_tensors
+// TLOOP-SAME: (%[[ARG_0:.*]]: [[TY:.*]], %[[ARG_1:.*]]: [[TY]],
+// TLOOP-SAME: %[[ARG_2:.*]]: [[TY]]) -> [[TY]] {
+
+// TLOOP-DAG: %[[C0:.*]] = constant 0 : index
+// TLOOP-DAG: %[[C1:.*]] = constant 1 : index
+// TLOOP-DAG: %[[C2:.*]] = constant 2 : index
+// TLOOP-DAG: %[[C3:.*]] = constant 3 : index
+// TLOOP-DAG: %[[C4:.*]] = constant 4 : index
+
+// TLOOP: %[[ARG_0_X:.*]] = memref.dim %[[ARG_0]], %[[C0]] : [[TY]]
+// TLOOP: %[[ARG_0_Y:.*]] = memref.dim %[[ARG_0]], %[[C1]] : [[TY]]
+// TLOOP: %[[ARG_1_Y:.*]] = memref.dim %[[ARG_1]], %[[C1]] : [[TY]]
+
+// TLOOP: %{{.*}} = linalg.tiled_loop (%[[I:.*]], %[[J:.*]], %[[K:.*]]) =
+// TLOOP-SAME: (%[[C0]], %[[C0]], %[[C0]])
+// TLOOP-SAME: to (%[[ARG_0_X]], %[[ARG_1_Y]], %[[ARG_0_Y]])
+// TLOOP-SAME: step (%[[C2]], %[[C3]], %[[C4]])
+// TLOOP-SAME: ins (%[[A0:.*]] = %[[ARG_0]]: [[TY]], %[[A1:.*]] = %[[ARG_1]]: [[TY]])
+// TLOOP-SAME: outs (%[[A2:.*]] = %[[ARG_2]]: [[TY]])
+// TLOOP-SAME: iterators["parallel", "parallel", "reduction"] {
+
+// TLOOP: %[[SUB_ARG_0:.*]] = subtensor %[[A0]][%[[I]], %[[K]]]
+// TLOOP: %[[SUB_ARG_1:.*]] = subtensor %[[A1]][%[[K]], %[[J]]]
+// TLOOP: %[[SUB_ARG_2:.*]] = subtensor %[[A2]][%[[I]], %[[J]]]
+
+// TLOOP: %[[PROD:.*]] = linalg.matmul ins(%[[SUB_ARG_0]], %[[SUB_ARG_1]]
+// TLOOP-SE: outs(%[[SUB_ARG_2]] : [[TY]]) -> [[TY]]
+
+// TLOOP: %[[O:.*]] = subtensor_insert %[[PROD]] into %[[A2]][%[[I]], %[[J]]]
+// TLOOP: linalg.yield %[[O]] : [[TY]]
+
 // -----
 
 func @generic_op_tensors(
@@ -74,6 +107,28 @@ func @generic_op_tensors(
 //       CHECK: }
 //       CHECK: return %[[TD0]]
 
+// TLOOP-LABEL: func @generic_op_tensors(
+// TLOOP-SAME:    %[[ARG_0:.*]]: [[TY:.*]],
+// TLOOP-SAME:    %[[ARG_1:.*]]: [[TY]]) -> [[TY]] {
+
+// TLOOP-DAG: %[[C0:.*]] = constant 0 : index
+// TLOOP-DAG: %[[C1:.*]] = constant 1 : index
+// TLOOP-DAG: %[[C2:.*]] = constant 2 : index
+// TLOOP-DAG: %[[C3:.*]] = constant 3 : index
+// TLOOP-DAG: %[[C4:.*]] = constant 4 : index
+
+// TLOOP:     %[[INIT:.*]] = linalg.init_tensor
+// TLOOP:     %[[ARG_0_X:.*]] = memref.dim %[[ARG_0]], %[[C0]] : [[TY]]
+// TLOOP:     %[[ARG_0_Y:.*]] = memref.dim %[[ARG_0]], %[[C1]] : [[TY]]
+// TLOOP:     %[[ARG_0_Z:.*]] = memref.dim %[[ARG_0]], %[[C2]] : [[TY]]
+
+// TLOOP:     %{{.*}} = linalg.tiled_loop (%{{.*}}, %{{.*}}, %{{.*}}) =
+// TLOOP-SAME: (%[[C0]], %[[C0]], %[[C0]])
+// TLOOP-SAME: to (%[[ARG_0_X]], %[[ARG_0_Y]], %[[ARG_0_Z]])
+// TLOOP-SAME: step (%[[C2]], %[[C3]], %[[C4]])
+// TLOOP-SAME: ins (%{{.*}} = %[[ARG_0]]: [[TY]], %{{.*}} = %[[ARG_1]]: [[TY]])
+// TLOOP-SAME: outs (%{{.*}} = %[[INIT]]: [[TY]])
+
 // -----
 
 func @indexed_generic_op_tensors(


        


More information about the Mlir-commits mailing list