[Mlir-commits] [mlir] b8a1f00 - [mlir][TilingInterface] Add support for interchange to tiling patterns that use the `TilingInterface`.
Mahesh Ravishankar
llvmlistbot at llvm.org
Tue Jul 19 22:25:29 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-07-20T05:24:17Z
New Revision: b8a1f00d414ece5597423449d433fbb26a217626
URL: https://github.com/llvm/llvm-project/commit/b8a1f00d414ece5597423449d433fbb26a217626
DIFF: https://github.com/llvm/llvm-project/commit/b8a1f00d414ece5597423449d433fbb26a217626.diff
LOG: [mlir][TilingInterface] Add support for interchange to tiling patterns that use the `TilingInterface`.
Differential Revision: https://reviews.llvm.org/D129956
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f3ee8a5b27f6..52cc52325eb9a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -26,7 +26,7 @@ namespace mlir {
namespace scf {
using SCFTileSizeComputationFunction =
- std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
+ std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
/// Options to use to control tiling.
struct SCFTilingOptions {
@@ -51,6 +51,13 @@ struct SCFTilingOptions {
/// function that computes tile sizes at the point they are needed. Allows
/// proper interaction with folding.
SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+
+ /// The interchange vector to reorder the tiled loops.
+ SmallVector<unsigned> interchangeVector = {};
+ SCFTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
+ interchangeVector = llvm::to_vector(interchange);
+ return *this;
+ }
};
struct SCFTilingResult {
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 5e0ea65dd20df..3bad54327e078 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -29,7 +29,7 @@ using namespace mlir;
scf::SCFTilingOptions &
scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
assert(!tileSizeComputationFunction && "tile sizes already set");
- SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
+ SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(
@@ -42,6 +42,49 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
return *this;
}
+/// Helper method to adjust the interchange vector to match the iteration
+/// domain.
+static SmallVector<unsigned>
+fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
+ size_t iterationDomainSize) {
+ SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector);
+ if (filledVector.size() < iterationDomainSize) {
+ auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
+ filledVector.append(range.begin(), range.end());
+ }
+ if (filledVector.size() > iterationDomainSize)
+ filledVector.resize(iterationDomainSize);
+ return filledVector;
+}
+
+/// Helper method to apply permutation to a vector
+template <typename T>
+static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
+ ArrayRef<unsigned> interchange) {
+ assert(interchange.size() == vector.size());
+ return llvm::to_vector(
+ llvm::map_range(interchange, [&](unsigned val) { return vector[val]; }));
+}
+/// Helper method to apply to invert a permutation.
+static SmallVector<unsigned>
+invertPermutationVector(ArrayRef<unsigned> interchange) {
+ SmallVector<unsigned> inversion(interchange.size());
+ for (auto pos : llvm::enumerate(interchange)) {
+ inversion[pos.value()] = pos.index();
+ }
+ return inversion;
+}
+/// Method to check if an interchange vector is a permutation.
+static bool isPermutation(ArrayRef<unsigned> interchange) {
+ llvm::SmallDenseSet<unsigned, 4> seenVals;
+ for (auto val : interchange) {
+ if (seenVals.count(val))
+ return false;
+ seenVals.insert(val);
+ }
+ return seenVals.size() == interchange.size();
+}
+
//===----------------------------------------------------------------------===//
// TileUsingSCFForOp pattern implementation.
//===----------------------------------------------------------------------===//
@@ -137,7 +180,7 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
- SmallVector<Value, 4> tileSizeVector =
+ SmallVector<Value> tileSizeVector =
options.tileSizeComputationFunction(rewriter, op);
if (tileSizeVector.size() < iterationDomain.size()) {
auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
@@ -147,12 +190,38 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
scf::SCFTilingResult tilingResult;
SmallVector<OpFoldResult> offsets, sizes;
{
+ // If there is an interchange specified, permute the iteration domain and
+ // the tile sizes.
+ SmallVector<unsigned> interchangeVector;
+ if (!options.interchangeVector.empty()) {
+ interchangeVector = fillInterchangeVector(options.interchangeVector,
+ iterationDomain.size());
+ }
+ if (!interchangeVector.empty()) {
+ if (!isPermutation(interchangeVector)) {
+ return rewriter.notifyMatchFailure(
+ op, "invalid intechange vector, not a permutation of the entire "
+ "iteration space");
+ }
+
+ iterationDomain =
+ applyPermutationToVector(iterationDomain, interchangeVector);
+ tileSizeVector =
+ applyPermutationToVector(tileSizeVector, interchangeVector);
+ }
+
// 3. Materialize an empty loop nest that iterates over the tiles. These
// loops for now do not return any values even if the original operation has
// results.
tilingResult.loops = generateTileLoopNest(
rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
+ if (!interchangeVector.empty()) {
+ auto inversePermutation = invertPermutationVector(interchangeVector);
+ offsets = applyPermutationToVector(offsets, inversePermutation);
+ sizes = applyPermutationToVector(sizes, inversePermutation);
+ }
+
LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
llvm::errs() << "LoopNest shell :\n";
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index dd77211d8ccc6..81e2bfbe2d9ff 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -183,3 +183,50 @@ func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32
// CHECK-SAME: outs(%[[OUTS_TILE]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
// CHECK scf.yield %[[INSERT]]
+
+// -----
+
+func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+ %cst = arith.constant 0.0 : f32
+ %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %2 = linalg.matmul
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = linalg.generic {
+ __internal_linalg_transform__ = "gemm_interchange_fusion",
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %4 = arith.addf %b0, %b0 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x?xf32>
+ return %3 : tensor<?x?xf32>
+}
+// CHECK: func.func @interchange_matmul_fusion(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+// CHECK: %[[INIT:.+]] = linalg.init_tensor
+// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
+// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
+// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
+// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
+// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]]
+// CHECK: %[[FILL_TILE:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[INIT_TILE]] :
+// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME: outs(%[[FILL_TILE]] :
+// CHECK: %[[INIT_TILE_2:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
+// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[GEMM_TILE]] :
+// CHECK-SAME: outs(%[[INIT_TILE_2]] :
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
+// CHECK scf.yield %[[INSERT]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index d8ec2c56409e2..ad5307c43e265 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -226,3 +226,52 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
} -> (tensor<?x?xf32>)
return %0 : tensor<?x?xf32>
}
+
+// -----
+
+func.func @interchange_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul {__internal_linalg_transform__ = "gemm_interchange"}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func.func @interchange_matmul(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]])
+// CHECK: %[[TS_N:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C20]], %[[N]]]
+// CHECK: %[[INNER1:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
+// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]])
+// CHECK: %[[TS_K:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C30]], %[[K]]]
+// CHECK: %[[INNER2:[a-zA-Z0-9]+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT1]])
+// CHECK-DAG: %[[TS_M:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C10]], %[[M]]]
+// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[IV2]], %[[IV1]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
+// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME: [%[[IV1]], %[[IV0]]] [%[[TS_K]], %[[TS_N]]] [1, 1]
+// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT2]]
+// CHECK-SAME: [%[[IV2]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME: outs(%[[INIT_TILE]] :
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT2]]
+// CHECK-SAME: [%[[IV2]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+// CHECK: scf.yield %[[UPDATE]]
+// CHECK: scf.yield %[[INNER2]]
+// CHECK: scf.yield %[[INNER1]]
+// CHECK: return %[[OUTER]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index cebe7b1086ee1..214a4053e232c 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -147,10 +147,11 @@ struct TestTilingInterfacePass
template <class Pattern>
static void
-addPatternForTiling(MLIRContext *context, ArrayRef<int64_t> tileSizes,
- StringRef filterName, RewritePatternSet &patterns) {
+addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns,
+ StringRef filterName, ArrayRef<int64_t> tileSizes,
+ ArrayRef<unsigned> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
- tilingOptions.setTileSizes(tileSizes);
+ tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
linalg::LinalgTransformationFilter filter(
StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
patterns.add<Pattern>(context, tilingOptions, filter);
@@ -161,29 +162,35 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
if (testTiling) {
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, {10, 20}, "simple_gemm", patterns);
+ context, patterns, "simple_gemm", {10, 20});
// 2. Tiling M, N and K of `linalg.matmul` on buffers.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, {10, 20, 30}, "simple_gemm_memref", patterns);
+ context, patterns, "simple_gemm_memref", {10, 20, 30});
// 3. Tiling 3D parallel generic op which implements a transpose
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, {10, 0, 20}, "parallel_generic_transpose", patterns);
+ context, patterns, "parallel_generic_transpose", {10, 0, 20});
// 4. Tiling 2D conv op.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns);
+ context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30});
// 5. Tiling a simple op with `linalg.index` inside.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, {10, 20}, "indexed_semantics", patterns);
+ context, patterns, "indexed_semantics", {10, 20});
+ // 6. Tiling + interchange of an operation
+ addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+ context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0});
return;
}
if (testTileConsumerAndFuseProducer) {
// 1. Tile and fuse of gemm with bias-add operation.
addPatternForTiling<
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
- context, {10, 20}, "fusion", patterns);
+ context, patterns, "fusion", {10, 20});
+ addPatternForTiling<
+ TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
+ context, patterns, "gemm_fusion", {10});
addPatternForTiling<
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
- context, {10}, "gemm_fusion", patterns);
+ context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0});
return;
}
}
More information about the Mlir-commits
mailing list