[Mlir-commits] [mlir] [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (PR #67083)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 23:01:41 PDT 2023
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/67083
>From fd48251127299ed9fc0b55d97e21fea94954b110 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Thu, 21 Sep 2023 16:24:11 -0700
Subject: [PATCH 1/3] [mlir][TilingInterface] Add `scf::tileUsingSCFForallOp`
method to tile using the interface to generate `scf::forall`.
Similar to `scf::tileUsingSCFForOp` that is a method that tiles
operations that implement the `TilingInterface`, using `scf.for`
operations, this method introduces tiling of operations using
`scf.forall`. Most of this implementation is derived from
`linalg::tileToForallOp` method. Eventually that method will either be
deprecated or moved to use the method introduced here.
---
.../SCF/Transforms/TileUsingInterface.h | 17 +++
.../SCF/Transforms/TileUsingInterface.cpp | 133 ++++++++++++++++++
.../TilingInterface/tile-using-scfforall.mlir | 37 +++++
.../TilingInterface/TestTilingInterface.cpp | 69 +++++++++
4 files changed, 256 insertions(+)
create mode 100644 mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9f49d97e141e0c8..06cce19894e9f5a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -51,6 +51,17 @@ struct SCFTilingOptions {
interchangeVector = llvm::to_vector(interchange);
return *this;
}
+
+ /// Specify mapping of loops to devices. This is only respected when the loop
+ /// constructs support such a mapping (like `scf.forall`). Will be ignored
+ /// when using loop constructs that dont support such a mapping (like
+ /// `scf.for`)
+ SmallVector<Attribute> mappingVector = {};
+ SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
+ mappingVector = llvm::to_vector(
+ llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
+ return *this;
+ }
};
/// Transformation information returned after tiling.
@@ -82,6 +93,12 @@ struct SCFTileAndFuseOptions {
}
};
+/// Method to tile and op that implements the `TilingInterface` using
+/// `scf.forall`.
+FailureOr<SCFTilingResult>
+tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+ const SCFTilingOptions &options);
+
/// Fuse the producer of the source of `candidateSliceOp` by computing the
/// required slice of the producer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 96d6169111b3856..a58cd7a7541a515 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -122,6 +122,24 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
}
+/// Clones the operation and updates the destination if the operation
+/// implements the `DestinationStyleOpInterface`.
+static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
+ Operation *op,
+ ValueRange newDestArgs) {
+ Operation *clonedOp = rewriter.clone(*op);
+ if (auto destinationStyleOp =
+ dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
+ // Note that this is assuming that
+ auto [start, end] = destinationStyleOp.getDpsInitsPositionRange();
+ assert((end - start == newDestArgs.size()) &&
+ "expected as many new destination args as number of inits of the "
+ "operation");
+ clonedOp->setOperands(start, end - start, newDestArgs);
+ }
+ return clonedOp;
+}
+
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
@@ -728,6 +746,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
getAsOperations(forLoops), replacements};
}
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+ const scf::SCFTilingOptions &options) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+
+ // 1. Get the range of loops that are represented by the operation.
+ SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+ if (loopRanges.empty())
+ return op->emitOpError("expected non-empty loop ranges");
+ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+ if (llvm::any_of(loopRanges, hasStrideOne))
+ return op->emitOpError("only stride-1 supported atm");
+
+ // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+ // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+ SmallVector<OpFoldResult> tileSizeVector =
+ options.tileSizeComputationFunction(rewriter, op);
+ tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+ // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ for (auto [index, tileSize, loopRange] :
+ llvm::enumerate(tileSizeVector, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0))
+ continue;
+ lbs.push_back(loopRange.offset);
+ ubs.push_back(loopRange.size);
+ steps.push_back(tileSize);
+ }
+
+ // 4. Gather destination tensors.
+ SmallVector<Value> dest;
+ if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+ return op->emitOpError("failed to get destination tensors");
+
+ // 5. Build the device mapping attribute;
+ std::optional<ArrayAttr> mappingAttr;
+ if (!options.mappingVector.empty()) {
+ mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
+ }
+
+ // 6. Create the ForallOp. We don't use the lambda body-builder
+ // version because we require the use of RewriterBase in the body, so we
+ // manually move the insertion point to the body below.
+ auto forallOp =
+ rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
+
+ // 7. Get the tile offset and sizes.
+ rewriter.setInsertionPoint(forallOp.getTerminator());
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ tiledOffsets.reserve(loopRanges.size());
+ tiledSizes.reserve(loopRanges.size());
+ ValueRange ivs = forallOp.getInductionVars();
+ {
+ int materializedLoopNum = 0;
+ for (auto [index, tileSize, loopRange] :
+ llvm::enumerate(tileSizeVector, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0)) {
+ tiledOffsets.push_back(loopRange.offset);
+ tiledSizes.push_back(loopRange.size);
+ continue;
+ }
+ Value iv = ivs[materializedLoopNum++];
+ tiledOffsets.push_back(iv);
+ tiledSizes.push_back(
+ getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+ }
+ }
+
+ // 8. Tile the operation. Clone the operation to allow fix up of destination
+ // operands
+ ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+ Operation *clonedOp =
+ cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
+ FailureOr<TilingResult> tilingResult =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(
+ rewriter, tiledOffsets, tiledSizes);
+ if (failed(tilingResult))
+ return clonedOp->emitError("Failed to tile op: ");
+ rewriter.eraseOp(clonedOp);
+
+ // 9. Parallel insert back into the result tensor.
+ for (auto [index, tiledValue, destBBArg] :
+ llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
+ // 9.a. Partial subset information is inserted just before the terminator.
+ rewriter.setInsertionPoint(forallOp.getTerminator());
+
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
+ tiledSizes, resultOffsets,
+ resultSizes)))
+ return op->emitOpError("output offsets couldn't be calculated");
+ SmallVector<OpFoldResult> strides(resultSizes.size(),
+ rewriter.getIndexAttr(1));
+
+ // 5.b. Parallel insertions are inserted at the end of the combining
+ // terminator.
+ rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
+ }
+
+ // 10. Return the tiling result;
+ return scf::SCFTilingResult{
+ tilingResult->tiledOps,
+ {forallOp.getOperation()},
+ llvm::to_vector(llvm::map_range(forallOp.getResults(),
+ [](auto val) -> Value { return val; }))};
+}
+
//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
new file mode 100644
index 000000000000000..bfc352c764ad11a
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul {__internal_transform__ = "simple_gemm"}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+// CHECK: func.func @simple_matmul(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+// CHECK-SAME: (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]])
+// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]]]
+// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
+// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
+// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME: outs(%[[INIT_TILE]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+// CHECK: return %[[RESULT]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 2573e11979dbc47..2bec859b50f26ba 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -186,6 +186,51 @@ struct TestTileUsingSCFForOp
TransformationFilter filter;
};
+/// Pattern for testing `tileUsingSCFForallOp` (that tiles operations using
+/// the `TilingInterface` with `scf.forall` ops for iterating over the tiles)
+/// while using a `filter` to avoid recursive application.
+struct TestTileUsingSCFForallOp
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ TestTileUsingSCFForallOp(MLIRContext *context, scf::SCFTilingOptions options,
+ TransformationFilter filter = TransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
+
+ /// Construct a generic pattern applied to `opName`.
+ TestTileUsingSCFForallOp(StringRef opName, MLIRContext *context,
+ scf::SCFTilingOptions options,
+ TransformationFilter filter = TransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
+
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override {
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ scf::tileUsingSCFForallOp(rewriter, op, options);
+ if (failed(tilingResult))
+ return rewriter.notifyMatchFailure(op, "failed to tile operation");
+
+ if (op->getNumResults()) {
+ rewriter.replaceOp(op, tilingResult->replacements);
+ } else {
+ rewriter.eraseOp(op);
+ }
+
+ for (auto *tiledOp : tilingResult->tiledOps)
+ filter.replaceTransformationFilter(rewriter, tiledOp);
+ return success();
+ }
+
+private:
+ scf::SCFTilingOptions options;
+ TransformationFilter filter;
+};
+
/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
/// ops for iterating over the tiles) while using a `filter` to avoid recursive
@@ -415,6 +460,12 @@ struct TestTilingInterfacePass
"Test tiling using TilingInterface with scf.for operations"),
llvm::cl::init(false)};
+ Option<bool> testTilingForAll{
+ *this, "tile-using-scf-forall",
+ llvm::cl::desc(
+ "Test tiling using TilingInterface with scf.forall operations"),
+ llvm::cl::init(false)};
+
Option<bool> testTileConsumerFuseAndYieldProducer{
*this, "tile-consumer-fuse-and-yield-producer-using-scf-for",
llvm::cl::desc(
@@ -455,6 +506,20 @@ static void addPatternForTiling(MLIRContext *context,
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
}
+static void addPatternForTilingUsingForall(MLIRContext *context,
+ RewritePatternSet &patterns,
+ StringRef filterName,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> interchange = {}) {
+ scf::SCFTilingOptions tilingOptions;
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(context, tileSizes);
+ tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
+ TransformationFilter filter(StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
+ patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter);
+}
+
static void addPatternForTileFuseAndYield(MLIRContext *context,
RewritePatternSet &patterns,
StringRef filterName,
@@ -514,6 +579,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20});
return;
}
+ if (testTilingForAll) {
+ addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
+ return;
+ }
if (testTileConsumerAndFuseProducer) {
// 1. Tile and fuse of gemm with fill producer and bias-add consumer.
addPatternForTileAndFuse(context, patterns, "fusion", {10, 20});
>From d1f1103792c206c65b756bb3b935c528a5485681 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Wed, 18 Oct 2023 16:20:39 -0700
Subject: [PATCH 2/3] Add lit tests.
---
.../SCF/Transforms/TileUsingInterface.cpp | 11 +-
.../TilingInterface/tile-using-scfforall.mlir | 133 +++++++++++++++++-
.../TilingInterface/TestTilingInterface.cpp | 10 ++
3 files changed, 144 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a58cd7a7541a515..a45918eb062ee06 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -101,10 +101,10 @@ static bool tileDividesIterationDomain(Range loopRange) {
/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
Range loopRange, Value iv,
- Value tileSize) {
+ OpFoldResult tileSize) {
std::optional<int64_t> ts = getConstantIntValue(tileSize);
if (ts && ts.value() == 1)
- return getAsOpFoldResult(tileSize);
+ return tileSize;
if (tileDividesIterationDomain(
Range{loopRange.offset, loopRange.size, tileSize}))
@@ -130,12 +130,7 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
Operation *clonedOp = rewriter.clone(*op);
if (auto destinationStyleOp =
dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
- // Note that this is assuming that
- auto [start, end] = destinationStyleOp.getDpsInitsPositionRange();
- assert((end - start == newDestArgs.size()) &&
- "expected as many new destination args as number of inits of the "
- "operation");
- clonedOp->setOperands(start, end - start, newDestArgs);
+ destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
}
return clonedOp;
}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
index bfc352c764ad11a..709ecb6a97e3c40 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul {__internal_transform__ = "simple_gemm"}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@@ -35,3 +35,132 @@ func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
// CHECK: return %[[RESULT]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+ %init0 = tensor.empty() : tensor<128x300x200xf32>
+ %init1 = tensor.empty() : tensor<300x128x200xf32>
+ %0:2 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ {__internal_transform__ = "parallel_generic_transpose"}
+ ins(%arg0 : tensor<128x200x300xf32>)
+ outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+ linalg.yield %b0, %b0 : f32, f32
+ } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
+ return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
+}
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
+// CHECK-LABEL: func.func @multi_result(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty()
+// CHECK-DAG: %[[INIT1:.+]] = tensor.empty()
+// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = (0, 0) to (128, 300) step (10, 20)
+// CHECK-SAME: shared_outs(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+// CHECK: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
+// CHECK: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
+// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
+// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG2]]
+// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
+// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[ARG_TILE]] :
+// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK-DAG: tensor.parallel_insert_slice %[[RESULT_TILE]]#0 into %[[ARG1]][%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
+// CHECK-DAG: tensor.parallel_insert_slice %[[RESULT_TILE]]#1 into %[[ARG2]][%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
+// CHECK: }
+// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1
+
+// -----
+
+func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+ %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwc_hwcf {
+ strides = dense<[2, 3]> : tensor<2xi64>,
+ dilation = dense<[4, 5]> : tensor<2xi64>,
+ __internal_transform__ = "simple_conv"}
+ ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
+// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
+// CHECK-LABEL: func.func @conv2D(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[FILTER:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
+// CHECK-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
+// CHECK-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
+// CHECK-DAG: %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]]
+// CHECK-DAG: %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
+// CHECK-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
+// CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
+// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]], %[[IV2:[a-zA-Z0-9]+]]) =
+// CHECK-SAME: (0, 0, 0) to (%[[P]], %[[Q]], %[[C]]) step (10, 20, 30) shared_outs(%[[INIT0:.+]] = %[[INIT]])
+// CHECK-DAG: %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
+// CHECK-DAG: %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]]
+// CHECK-DAG: %[[TS_C:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[C]]]
+// CHECK-DAG: %[[TS_H:.+]] = affine.apply #[[$MAP3]](%[[TS_P]])[%[[R]]]
+// CHECK-DAG: %[[TS_W:.+]] = affine.apply #[[$MAP4]](%[[TS_Q]])[%[[S]]]
+// CHECK-DAG: %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]]
+// CHECK-SAME: [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]]
+// CHECK-DAG: %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]]
+// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT0]]
+// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
+// CHECK: %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME: dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>
+// CHECK-SAME: ins(%[[INPUT_TILE]], %[[FILTER_TILE]] :
+// CHECK-SAME: outs(%[[INIT_TILE]] :
+// CHECK: scf.forall.in_parallel
+// CHECK: tensor.parallel_insert_slice %[[CONV_TILE]] into %[[INIT0]]
+// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] [1, 1, 1, 1]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+
+func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // Check that we correctly amend "linalg.index" results.
+
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ {__internal_transform__ = "indexed_semantics"}
+ ins(%arg0: tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32):
+ %1 = linalg.index 0 : index
+ %2 = linalg.index 1 : index
+ %3 = arith.addi %1, %2 : index
+ %4 = arith.index_cast %3 : index to i64
+ %5 = arith.uitofp %4 : i64 to f32
+ %6 = arith.addf %5, %arg2 : f32
+ linalg.yield %6 : f32
+ } -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @indexed_semantics
+// CHECK: scf.forall (%[[I0:.+]], %[[I1:.+]]) =
+// CHECK: %[[INDEX0:.+]] = linalg.index 0
+// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
+// CHECK: %[[INDEX1:.+]] = linalg.index 1
+// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
+// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 2bec859b50f26ba..04632567ee2a777 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -580,7 +580,17 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
return;
}
if (testTilingForAll) {
+ // 1. Tiling M and N dims of `linalg.matmul` on tensors.
addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
+ // 2. Tiling 3D parallel generic op which implements a transpose.
+ addPatternForTilingUsingForall(context, patterns,
+ "parallel_generic_transpose", {10, 0, 20});
+ // 3. Tiling 2D conv op.
+ addPatternForTilingUsingForall(context, patterns, "simple_conv",
+ {0, 0, 0, 0, 10, 20, 30});
+ // 4. Tiling a simple op with `linalg.index` inside.
+ addPatternForTilingUsingForall(context, patterns, "indexed_semantics",
+ {10, 20});
return;
}
if (testTileConsumerAndFuseProducer) {
>From 55f9518c728af5be1cef142e592ad612438e0edc Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Thu, 19 Oct 2023 23:00:35 -0700
Subject: [PATCH 3/3] Address comments.
---
.../SCF/Transforms/TileUsingInterface.h | 6 ++---
.../SCF/Transforms/TileUsingInterface.cpp | 27 +++++++++----------
.../TilingInterface/tile-using-scfforall.mlir | 1 +
.../TilingInterface/TestTilingInterface.cpp | 23 +++++++++-------
4 files changed, 30 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 06cce19894e9f5a..81325b62791c44b 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -58,8 +58,8 @@ struct SCFTilingOptions {
/// `scf.for`)
SmallVector<Attribute> mappingVector = {};
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
- mappingVector = llvm::to_vector(
- llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
+ mappingVector = llvm::map_to_vector(
+ mapping, [](auto attr) -> Attribute { return attr; });
return *this;
}
};
@@ -93,7 +93,7 @@ struct SCFTileAndFuseOptions {
}
};
-/// Method to tile and op that implements the `TilingInterface` using
+/// Method to tile an op that implements the `TilingInterface` using
/// `scf.forall`.
FailureOr<SCFTilingResult>
tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a45918eb062ee06..2c6e66de6dc60f4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -767,8 +767,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
// 3. Build the offsets, sizes and steps for the tile and distributed loops.
SmallVector<OpFoldResult> lbs, ubs, steps;
- for (auto [index, tileSize, loopRange] :
- llvm::enumerate(tileSizeVector, loopRanges)) {
+ for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
if (isConstantIntValue(tileSize, 0))
continue;
lbs.push_back(loopRange.offset);
@@ -781,7 +780,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
return op->emitOpError("failed to get destination tensors");
- // 5. Build the device mapping attribute;
+ // 5. Build the device mapping attribute.
std::optional<ArrayAttr> mappingAttr;
if (!options.mappingVector.empty()) {
mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
@@ -796,13 +795,10 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
// 7. Get the tile offset and sizes.
rewriter.setInsertionPoint(forallOp.getTerminator());
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
- tiledOffsets.reserve(loopRanges.size());
- tiledSizes.reserve(loopRanges.size());
ValueRange ivs = forallOp.getInductionVars();
{
int materializedLoopNum = 0;
- for (auto [index, tileSize, loopRange] :
- llvm::enumerate(tileSizeVector, loopRanges)) {
+ for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
if (isConstantIntValue(tileSize, 0)) {
tiledOffsets.push_back(loopRange.offset);
tiledSizes.push_back(loopRange.size);
@@ -816,7 +812,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
}
// 8. Tile the operation. Clone the operation to allow fix up of destination
- // operands
+ // operands.
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
Operation *clonedOp =
cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
@@ -824,7 +820,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
cast<TilingInterface>(clonedOp).getTiledImplementation(
rewriter, tiledOffsets, tiledSizes);
if (failed(tilingResult))
- return clonedOp->emitError("Failed to tile op: ");
+ return clonedOp->emitError("failed to tile op: ");
rewriter.eraseOp(clonedOp);
// 9. Parallel insert back into the result tensor.
@@ -836,24 +832,25 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
tiledSizes, resultOffsets,
- resultSizes)))
+ resultSizes))) {
return op->emitOpError("output offsets couldn't be calculated");
+ }
+
SmallVector<OpFoldResult> strides(resultSizes.size(),
rewriter.getIndexAttr(1));
-
- // 5.b. Parallel insertions are inserted at the end of the combining
+ // 9.b. Parallel insertions are inserted at the end of the combining
// terminator.
rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
rewriter.create<tensor::ParallelInsertSliceOp>(
loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
}
- // 10. Return the tiling result;
+ // 10. Return the tiling result.
return scf::SCFTilingResult{
tilingResult->tiledOps,
{forallOp.getOperation()},
- llvm::to_vector(llvm::map_range(forallOp.getResults(),
- [](auto val) -> Value { return val; }))};
+ llvm::map_to_vector(forallOp.getResults(),
+ [](auto val) -> Value { return val; })};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
index 709ecb6a97e3c40..314efde45720a2f 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -34,6 +34,7 @@ func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+// CHECK: mapping = [#gpu.block<y>, #gpu.block<x>]
// CHECK: return %[[RESULT]]
// -----
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 04632567ee2a777..e5d7dc54409e447 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -443,9 +444,9 @@ struct TestTilingInterfacePass
TestTilingInterfacePass(const TestTilingInterfacePass &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<affine::AffineDialect, linalg::LinalgDialect,
- memref::MemRefDialect, scf::SCFDialect,
- tensor::TensorDialect>();
+ registry.insert<affine::AffineDialect, gpu::GPUDialect,
+ linalg::LinalgDialect, memref::MemRefDialect,
+ scf::SCFDialect, tensor::TensorDialect>();
linalg::registerTilingInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
}
@@ -506,15 +507,16 @@ static void addPatternForTiling(MLIRContext *context,
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
}
-static void addPatternForTilingUsingForall(MLIRContext *context,
- RewritePatternSet &patterns,
- StringRef filterName,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> interchange = {}) {
+static void addPatternForTilingUsingForall(
+ MLIRContext *context, RewritePatternSet &patterns, StringRef filterName,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<DeviceMappingAttrInterface> mapping = {},
+ ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(context, tileSizes);
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
+ tilingOptions.setMapping(mapping);
TransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter);
@@ -581,7 +583,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
}
if (testTilingForAll) {
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
- addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
+ addPatternForTilingUsingForall(
+ context, patterns, "simple_gemm", {10, 20},
+ {gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimY),
+ gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimX)});
// 2. Tiling 3D parallel generic op which implements a transpose.
addPatternForTilingUsingForall(context, patterns,
"parallel_generic_transpose", {10, 0, 20});
More information about the Mlir-commits
mailing list