[Mlir-commits] [mlir] [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (PR #67083)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 18 16:32:05 PDT 2023


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/67083

>From eb8332775d1ac7f1bb28358889c5e09b8068a30e Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Thu, 21 Sep 2023 16:24:11 -0700
Subject: [PATCH 1/2] [mlir][TilingInterface] Add `scf::tileUsingSCFForallOp`
 method to tile using the interface to generate `scf::forall`.

Similar to `scf::tileUsingSCFForOp` that is a method that tiles
operations that implement the `TilingInterface`, using `scf.for`
operations, this method introduces tiling of operations using
`scf.forall`. Most of this implementation is derived from
`linalg::tileToForallOp` method. Eventually that method will either be
deprecated or moved to use the method introduced here.
---
 .../SCF/Transforms/TileUsingInterface.h       |  17 +++
 .../SCF/Transforms/TileUsingInterface.cpp     | 133 ++++++++++++++++++
 .../TilingInterface/tile-using-scfforall.mlir |  37 +++++
 .../TilingInterface/TestTilingInterface.cpp   |  69 +++++++++
 4 files changed, 256 insertions(+)
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9f49d97e141e0c8..06cce19894e9f5a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -51,6 +51,17 @@ struct SCFTilingOptions {
     interchangeVector = llvm::to_vector(interchange);
     return *this;
   }
+
+  /// Specify mapping of loops to devices. This is only respected when the loop
+  /// constructs support such a mapping (like `scf.forall`). Will be ignored
+  /// when using loop constructs that dont support such a mapping (like
+  /// `scf.for`)
+  SmallVector<Attribute> mappingVector = {};
+  SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
+    mappingVector = llvm::to_vector(
+        llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
+    return *this;
+  }
 };
 
 /// Transformation information returned after tiling.
@@ -82,6 +93,12 @@ struct SCFTileAndFuseOptions {
   }
 };
 
+/// Method to tile and op that implements the `TilingInterface` using
+/// `scf.forall`.
+FailureOr<SCFTilingResult>
+tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+                     const SCFTilingOptions &options);
+
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
 /// required slice of the producer in-place.  Note that the method
 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 96d6169111b3856..a58cd7a7541a515 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -122,6 +122,24 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
 }
 
+/// Clones the operation and updates the destination if the operation
+/// implements the `DestinationStyleOpInterface`.
+static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
+                                                  Operation *op,
+                                                  ValueRange newDestArgs) {
+  Operation *clonedOp = rewriter.clone(*op);
+  if (auto destinationStyleOp =
+          dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
+    // Note that this is assuming that
+    auto [start, end] = destinationStyleOp.getDpsInitsPositionRange();
+    assert((end - start == newDestArgs.size()) &&
+           "expected as many new destination args as number of inits of the "
+           "operation");
+    clonedOp->setOperands(start, end - start, newDestArgs);
+  }
+  return clonedOp;
+}
+
 /// Generate an empty loop nest that represents the tiled loop nest shell.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
@@ -728,6 +746,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
                                    getAsOperations(forLoops), replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+                                const scf::SCFTilingOptions &options) {
+  Location loc = op->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+
+  // 1. Get the range of loops that are represented by the operation.
+  SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+  if (loopRanges.empty())
+    return op->emitOpError("expected non-empty loop ranges");
+  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+  if (llvm::any_of(loopRanges, hasStrideOne))
+    return op->emitOpError("only stride-1 supported atm");
+
+  // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+  // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+  SmallVector<OpFoldResult> tileSizeVector =
+      options.tileSizeComputationFunction(rewriter, op);
+  tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+  // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  for (auto [index, tileSize, loopRange] :
+       llvm::enumerate(tileSizeVector, loopRanges)) {
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    lbs.push_back(loopRange.offset);
+    ubs.push_back(loopRange.size);
+    steps.push_back(tileSize);
+  }
+
+  // 4. Gather destination tensors.
+  SmallVector<Value> dest;
+  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+    return op->emitOpError("failed to get destination tensors");
+
+  // 5. Build the device mapping attribute;
+  std::optional<ArrayAttr> mappingAttr;
+  if (!options.mappingVector.empty()) {
+    mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
+  }
+
+  // 6. Create the ForallOp. We don't use the lambda body-builder
+  // version because we require the use of RewriterBase in the body, so we
+  // manually move the insertion point to the body below.
+  auto forallOp =
+      rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
+
+  // 7. Get the tile offset and sizes.
+  rewriter.setInsertionPoint(forallOp.getTerminator());
+  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+  tiledOffsets.reserve(loopRanges.size());
+  tiledSizes.reserve(loopRanges.size());
+  ValueRange ivs = forallOp.getInductionVars();
+  {
+    int materializedLoopNum = 0;
+    for (auto [index, tileSize, loopRange] :
+         llvm::enumerate(tileSizeVector, loopRanges)) {
+      if (isConstantIntValue(tileSize, 0)) {
+        tiledOffsets.push_back(loopRange.offset);
+        tiledSizes.push_back(loopRange.size);
+        continue;
+      }
+      Value iv = ivs[materializedLoopNum++];
+      tiledOffsets.push_back(iv);
+      tiledSizes.push_back(
+          getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+    }
+  }
+
+  // 8. Tile the operation. Clone the operation to allow fix up of destination
+  // operands
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  Operation *clonedOp =
+      cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
+  FailureOr<TilingResult> tilingResult =
+      cast<TilingInterface>(clonedOp).getTiledImplementation(
+          rewriter, tiledOffsets, tiledSizes);
+  if (failed(tilingResult))
+    return clonedOp->emitError("Failed to tile op: ");
+  rewriter.eraseOp(clonedOp);
+
+  // 9. Parallel insert back into the result tensor.
+  for (auto [index, tiledValue, destBBArg] :
+       llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
+    // 9.a. Partial subset information is inserted just before the terminator.
+    rewriter.setInsertionPoint(forallOp.getTerminator());
+
+    SmallVector<OpFoldResult> resultOffsets, resultSizes;
+    if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
+                                        tiledSizes, resultOffsets,
+                                        resultSizes)))
+      return op->emitOpError("output offsets couldn't be calculated");
+    SmallVector<OpFoldResult> strides(resultSizes.size(),
+                                      rewriter.getIndexAttr(1));
+
+    // 5.b. Parallel insertions are inserted at the end of the combining
+    // terminator.
+    rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
+  }
+
+  // 10. Return the tiling result;
+  return scf::SCFTilingResult{
+      tilingResult->tiledOps,
+      {forallOp.getOperation()},
+      llvm::to_vector(llvm::map_range(forallOp.getResults(),
+                                      [](auto val) -> Value { return val; }))};
+}
+
 //===----------------------------------------------------------------------===//
 // lowerToLoopsUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
new file mode 100644
index 000000000000000..bfc352c764ad11a
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul {__internal_transform__ = "simple_gemm"}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+//      CHECK: func.func @simple_matmul(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK:   %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+// CHECK-SAME:       (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]])
+//      CHECK:     %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//      CHECK:     %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]]]
+//      CHECK:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:         [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
+//      CHECK:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME:         [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
+//      CHECK:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]]
+// CHECK-SAME:         [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:         outs(%[[INIT_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//      CHECK:       tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:   return %[[RESULT]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 2573e11979dbc47..2bec859b50f26ba 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -186,6 +186,51 @@ struct TestTileUsingSCFForOp
   TransformationFilter filter;
 };
 
+/// Pattern for testing `tileUsingSCFForallOp` (that tiles operations using
+/// the `TilingInterface` with `scf.forall` ops for iterating over the tiles)
+/// while using a `filter` to avoid recursive application.
+struct TestTileUsingSCFForallOp
+    : public OpInterfaceRewritePattern<TilingInterface> {
+  TestTileUsingSCFForallOp(MLIRContext *context, scf::SCFTilingOptions options,
+                           TransformationFilter filter = TransformationFilter(),
+                           PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        options(std::move(options)), filter(std::move(filter)) {}
+
+  /// Construct a generic pattern applied to `opName`.
+  TestTileUsingSCFForallOp(StringRef opName, MLIRContext *context,
+                           scf::SCFTilingOptions options,
+                           TransformationFilter filter = TransformationFilter(),
+                           PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        options(std::move(options)), filter(std::move(filter)) {}
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(filter.checkAndNotify(rewriter, op)))
+      return failure();
+
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        scf::tileUsingSCFForallOp(rewriter, op, options);
+    if (failed(tilingResult))
+      return rewriter.notifyMatchFailure(op, "failed to tile operation");
+
+    if (op->getNumResults()) {
+      rewriter.replaceOp(op, tilingResult->replacements);
+    } else {
+      rewriter.eraseOp(op);
+    }
+
+    for (auto *tiledOp : tilingResult->tiledOps)
+      filter.replaceTransformationFilter(rewriter, tiledOp);
+    return success();
+  }
+
+private:
+  scf::SCFTilingOptions options;
+  TransformationFilter filter;
+};
+
 /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
 /// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
 /// ops for iterating over the tiles) while using a `filter` to avoid recursive
@@ -415,6 +460,12 @@ struct TestTilingInterfacePass
           "Test tiling using TilingInterface with scf.for operations"),
       llvm::cl::init(false)};
 
+  Option<bool> testTilingForAll{
+      *this, "tile-using-scf-forall",
+      llvm::cl::desc(
+          "Test tiling using TilingInterface with scf.forall operations"),
+      llvm::cl::init(false)};
+
   Option<bool> testTileConsumerFuseAndYieldProducer{
       *this, "tile-consumer-fuse-and-yield-producer-using-scf-for",
       llvm::cl::desc(
@@ -455,6 +506,20 @@ static void addPatternForTiling(MLIRContext *context,
   patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
 }
 
+static void addPatternForTilingUsingForall(MLIRContext *context,
+                                           RewritePatternSet &patterns,
+                                           StringRef filterName,
+                                           ArrayRef<int64_t> tileSizes,
+                                           ArrayRef<int64_t> interchange = {}) {
+  scf::SCFTilingOptions tilingOptions;
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(context, tileSizes);
+  tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
+  TransformationFilter filter(StringAttr::get(context, filterName),
+                              StringAttr::get(context, "tiled"));
+  patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter);
+}
+
 static void addPatternForTileFuseAndYield(MLIRContext *context,
                                           RewritePatternSet &patterns,
                                           StringRef filterName,
@@ -514,6 +579,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
     addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20});
     return;
   }
+  if (testTilingForAll) {
+    addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
+    return;
+  }
   if (testTileConsumerAndFuseProducer) {
     // 1. Tile and fuse of gemm with fill producer and bias-add consumer.
     addPatternForTileAndFuse(context, patterns, "fusion", {10, 20});

>From b26e64370d04ba858f7a711d3f56d3af4a00c539 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Wed, 18 Oct 2023 16:20:39 -0700
Subject: [PATCH 2/2] Add lit tests.

---
 .../SCF/Transforms/TileUsingInterface.cpp     |  11 +-
 .../TilingInterface/tile-using-scfforall.mlir | 133 +++++++++++++++++-
 .../TilingInterface/TestTilingInterface.cpp   |  10 ++
 3 files changed, 144 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a58cd7a7541a515..a45918eb062ee06 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -101,10 +101,10 @@ static bool tileDividesIterationDomain(Range loopRange) {
 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
                                        Range loopRange, Value iv,
-                                       Value tileSize) {
+                                       OpFoldResult tileSize) {
   std::optional<int64_t> ts = getConstantIntValue(tileSize);
   if (ts && ts.value() == 1)
-    return getAsOpFoldResult(tileSize);
+    return tileSize;
 
   if (tileDividesIterationDomain(
           Range{loopRange.offset, loopRange.size, tileSize}))
@@ -130,12 +130,7 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
   Operation *clonedOp = rewriter.clone(*op);
   if (auto destinationStyleOp =
           dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
-    // Note that this is assuming that
-    auto [start, end] = destinationStyleOp.getDpsInitsPositionRange();
-    assert((end - start == newDestArgs.size()) &&
-           "expected as many new destination args as number of inits of the "
-           "operation");
-    clonedOp->setOperands(start, end - start, newDestArgs);
+    destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
   }
   return clonedOp;
 }
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
index bfc352c764ad11a..709ecb6a97e3c40 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
+// RUN: mlir-opt  -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
 
 func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.matmul {__internal_transform__ = "simple_gemm"}
-      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
@@ -35,3 +35,132 @@ func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 //      CHECK:       tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
 // CHECK-SAME:           [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
 //      CHECK:   return %[[RESULT]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+  %init0 = tensor.empty() : tensor<128x300x200xf32>
+  %init1 = tensor.empty() : tensor<300x128x200xf32>
+  %0:2 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      {__internal_transform__ = "parallel_generic_transpose"}
+      ins(%arg0 : tensor<128x200x300xf32>)
+      outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      linalg.yield %b0, %b0 : f32, f32
+    } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
+  return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
+}
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
+//      CHECK-LABEL: func.func @multi_result(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
+//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty()
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = (0, 0) to (128, 300) step (10, 20)
+// CHECK-SAME:       shared_outs(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+//      CHECK:     %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
+//      CHECK:     %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:         [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
+//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME:         [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
+//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG2]]
+// CHECK-SAME:         [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
+//      CHECK:     %[[RESULT_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME:         ins(%[[ARG_TILE]] :
+// CHECK-SAME:         outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//  CHECK-DAG:       tensor.parallel_insert_slice %[[RESULT_TILE]]#0 into %[[ARG1]][%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
+//  CHECK-DAG:       tensor.parallel_insert_slice %[[RESULT_TILE]]#1 into %[[ARG2]][%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
+//      CHECK:     }
+//      CHECK:   return %[[OUTER]]#0, %[[OUTER]]#1
+
+// -----
+
+func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+    %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf {
+      strides = dense<[2, 3]> : tensor<2xi64>,
+      dilation = dense<[4, 5]> : tensor<2xi64>,
+      __internal_transform__ = "simple_conv"}
+      ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+//  CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)>
+//  CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
+//  CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
+//      CHECK-LABEL: func.func @conv2D(
+// CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:     %[[FILTER:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
+//  CHECK-DAG:   %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
+//  CHECK-DAG:   %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
+//  CHECK-DAG:   %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]]
+//  CHECK-DAG:   %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
+//  CHECK-DAG:   %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
+//  CHECK-DAG:   %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
+//      CHECK:   %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]], %[[IV2:[a-zA-Z0-9]+]]) =
+// CHECK-SAME:       (0, 0, 0) to (%[[P]], %[[Q]], %[[C]]) step (10, 20, 30) shared_outs(%[[INIT0:.+]] = %[[INIT]])
+//  CHECK-DAG:     %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
+//  CHECK-DAG:     %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]]
+//  CHECK-DAG:     %[[TS_C:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[C]]]
+//  CHECK-DAG:     %[[TS_H:.+]] = affine.apply #[[$MAP3]](%[[TS_P]])[%[[R]]]
+//  CHECK-DAG:     %[[TS_W:.+]] = affine.apply #[[$MAP4]](%[[TS_Q]])[%[[S]]]
+//  CHECK-DAG:     %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]]
+// CHECK-SAME:         [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]]
+//  CHECK-DAG:     %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]]
+// CHECK-SAME:         [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]]
+//  CHECK-DAG:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT0]]
+// CHECK-SAME:         [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
+//      CHECK:     %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME:         dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME:         ins(%[[INPUT_TILE]], %[[FILTER_TILE]] :
+// CHECK-SAME:         outs(%[[INIT_TILE]] :
+//      CHECK:     scf.forall.in_parallel
+//      CHECK:       tensor.parallel_insert_slice %[[CONV_TILE]] into %[[INIT0]]
+// CHECK-SAME:           [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] [1, 1, 1, 1]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+
+func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // Check that we correctly amend "linalg.index" results.
+
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    {__internal_transform__ = "indexed_semantics"}
+    ins(%arg0: tensor<?x?xf32>)
+    outs(%arg1: tensor<?x?xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):
+    %1 = linalg.index 0 : index
+    %2 = linalg.index 1 : index
+    %3 = arith.addi %1, %2 : index
+    %4 = arith.index_cast %3 : index to i64
+    %5 = arith.uitofp %4 : i64 to f32
+    %6 = arith.addf %5, %arg2 : f32
+    linalg.yield %6 : f32
+  } -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @indexed_semantics
+//       CHECK: scf.forall (%[[I0:.+]], %[[I1:.+]]) =
+//       CHECK:   %[[INDEX0:.+]] = linalg.index 0
+//       CHECK:   %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
+//       CHECK:   %[[INDEX1:.+]] = linalg.index 1
+//       CHECK:   %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
+//       CHECK:   arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 2bec859b50f26ba..04632567ee2a777 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -580,7 +580,17 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
     return;
   }
   if (testTilingForAll) {
+    // 1. Tiling M and N dims of `linalg.matmul` on tensors.
     addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
+    // 2. Tiling 3D parallel generic op which implements a transpose.
+    addPatternForTilingUsingForall(context, patterns,
+                                   "parallel_generic_transpose", {10, 0, 20});
+    // 3. Tiling 2D conv op.
+    addPatternForTilingUsingForall(context, patterns, "simple_conv",
+                                   {0, 0, 0, 0, 10, 20, 30});
+    // 4. Tiling a simple op with `linalg.index` inside.
+    addPatternForTilingUsingForall(context, patterns, "indexed_semantics",
+                                   {10, 20});
     return;
   }
   if (testTileConsumerAndFuseProducer) {



More information about the Mlir-commits mailing list