[Mlir-commits] [mlir] e3d386e - [mlir][linalg] Add a tile and fuse on tensors pattern.
Tobias Gysi
llvmlistbot at llvm.org
Mon Nov 22 03:14:14 PST 2021
Author: Tobias Gysi
Date: 2021-11-22T11:13:21Z
New Revision: e3d386ea27336edc04ae4fd324ab4337b9f3cf16
URL: https://github.com/llvm/llvm-project/commit/e3d386ea27336edc04ae4fd324ab4337b9f3cf16
DIFF: https://github.com/llvm/llvm-project/commit/e3d386ea27336edc04ae4fd324ab4337b9f3cf16.diff
LOG: [mlir][linalg] Add a tile and fuse on tensors pattern.
Add a pattern to apply the new tile and fuse on tensors method. Integrate the pattern into the CodegenStrategy and use the CodegenStrategy to implement the tests.
Depends On D114012
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D114067
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index c0173ec2f443a..b25ba69a42a38 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -81,6 +81,12 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFuseTensorOpsPass();
//===----------------------------------------------------------------------===//
/// Linalg strategy passes.
//===----------------------------------------------------------------------===//
+/// Create a LinalgStrategyTileAndFusePass.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyTileAndFusePass(
+ StringRef opName = "", linalg::LinalgTilingAndFusionOptions opt = {},
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter());
+
/// Create a LinalgStrategyTilePass.
std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyTilePass(
StringRef opName = "",
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index c9bcfebecb022..4c55b4e349d74 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -235,6 +235,18 @@ def LinalgTileAndFuseTensorOps
let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"];
}
+def LinalgStrategyTileAndFusePass
+ : FunctionPass<"linalg-strategy-tile-and-fuse-pass"> {
+ let summary = "Configurable pass to apply pattern-based tiling and fusion.";
+ let constructor = "mlir::createLinalgStrategyTileAndFusePass()";
+ let options = [
+ Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
+ "Which func op is the anchor to latch on.">,
+ Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
+ "Which linalg op within the func is the anchor to latch on.">,
+ ];
+}
+
def LinalgStrategyTilePass
: FunctionPass<"linalg-strategy-tile-pass"> {
let summary = "Configurable pass to apply pattern-based linalg tiling.";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index f4579a6150489..ea2afa01889b0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -30,6 +30,22 @@ struct Transformation {
LinalgTransformationFilter::FilterFunction filter = nullptr;
};
+/// Represent one application of LinalgStrategyTileAndFusePass.
+struct TileAndFuse : public Transformation {
+ TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
+ LinalgTransformationFilter::FilterFunction f = nullptr)
+ : Transformation(f), opName(name), options(options) {}
+
+ void addToPassPipeline(OpPassManager &pm,
+ LinalgTransformationFilter m) const override {
+ pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m));
+ }
+
+private:
+ std::string opName;
+ linalg::LinalgTilingAndFusionOptions options;
+};
+
/// Represent one application of LinalgStrategyTilePass.
struct Tile : public Transformation {
Tile(StringRef name, linalg::LinalgTilingOptions options,
@@ -147,6 +163,22 @@ struct VectorLowering : public Transformation {
/// Codegen strategy controls how a Linalg op is progressively lowered.
struct CodegenStrategy {
+ /// Append a pattern to tile the Op `opName` and fuse its producers with
+ /// tiling and fusion `options`.
+ CodegenStrategy &
+ tileAndFuse(StringRef opName, LinalgTilingAndFusionOptions options,
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
+ transformationSequence.emplace_back(
+ std::make_unique<TileAndFuse>(opName, options, f));
+ return *this;
+ }
+ /// Conditionally append a pattern to tile the Op `opName` and fuse its
+ /// producers with tiling and fusion `options`.
+ CodegenStrategy &
+ tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options,
+ LinalgTransformationFilter::FilterFunction f = nullptr) {
+ return b ? tileAndFuse(opName, options, f) : *this;
+ }
/// Append a pattern to add a level of tiling for Op `opName` with tiling
/// `options`.
CodegenStrategy &
@@ -161,7 +193,7 @@ struct CodegenStrategy {
CodegenStrategy &
tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options,
LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? tile(opName, options) : *this;
+ return b ? tile(opName, options, f) : *this;
}
/// Append a pattern to pad and hoist the operands of Op `opName` with padding
/// `options`.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6c044bc26c934..cfdd560f0cacc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -517,6 +517,13 @@ struct LinalgPaddingOptions {
}
};
+struct LinalgTilingAndFusionOptions {
+ /// Tile sizes used to tile the root operation.
+ SmallVector<int64_t> tileSizes;
+ /// Tile interchange used to permute the tile loops.
+ SmallVector<int64_t> tileInterchange;
+};
+
struct LinalgTilingOptions {
/// Computation function that returns the tile sizes for each operation.
/// Delayed construction of constant tile sizes should occur to interoperate
@@ -767,6 +774,34 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
fusionOptions, filter, fusedOpMarker, originalOpMarker, benefit) {}
};
+///
+/// Linalg tile and fuse tensor ops pattern.
+///
+/// Apply tiling and fusion as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `tileConsumerAndFuseProducers` for more details.
+struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
+ // Entry point to match any LinalgOp.
+ LinalgTileAndFuseTensorOpsPattern(
+ MLIRContext *context, LinalgTilingAndFusionOptions options,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+ // Entry point to match a specific LinalgOp.
+ LinalgTileAndFuseTensorOpsPattern(
+ StringRef opName, MLIRContext *context,
+ LinalgTilingAndFusionOptions options,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ LinalgTransformationFilter filter;
+ /// Tile sizes and interchange used to tile the root operation.
+ LinalgTilingAndFusionOptions options;
+};
+
///
/// Linalg generic interchage pattern.
///
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 25aee5f23a513..c1cdd3eda2cb3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -235,6 +235,9 @@ class TileLoopNest {
/// Returns the tiled root operation.
LinalgOp getRootOp() { return rootOp; }
+ /// Returns the tiled root operation and the fused producers.
+ SmallVector<LinalgOp> getAllTiledAndFusedOps();
+
/// Returns the loop ops generated from tiling.
ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 00904a48712b7..68d02db728df7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -390,6 +390,17 @@ ValueRange TileLoopNest::getRootOpReplacementResults() {
return tileLoopOps.front()->getOpResults();
}
+SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
+ SmallVector<LinalgOp> result;
+ for (const auto &kvp : tiledRootAndFusedOpsLoops) {
+ auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst());
+ assert(linalgOp &&
+ "expect all tiled and fused operations are linalg operations");
+ result.push_back(linalgOp);
+ }
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// Tile and fuse entry-points.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 24cec12cec62d..acbc30a93c6c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -36,6 +36,43 @@ using namespace linalg;
namespace {
+/// Configurable pass to apply pattern-based tiling and fusion.
+struct LinalgStrategyTileAndFusePass
+ : public LinalgStrategyTileAndFusePassBase<LinalgStrategyTileAndFusePass> {
+
+ LinalgStrategyTileAndFusePass() = default;
+
+ LinalgStrategyTileAndFusePass(StringRef opName,
+ LinalgTilingAndFusionOptions opt,
+ LinalgTransformationFilter filt)
+ : options(opt), filter(filt) {
+ this->anchorOpName.setValue(opName.str());
+ }
+
+ void runOnFunction() override {
+ auto funcOp = getFunction();
+ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
+ return;
+
+ RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
+ if (!anchorOpName.empty()) {
+ tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
+ anchorOpName, funcOp.getContext(), options, filter);
+ } else {
+ tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
+ funcOp.getContext(), options, filter);
+ }
+ // Search the root operation using bottom up traversal.
+ GreedyRewriteConfig grc;
+ grc.useTopDownTraversal = false;
+ (void)applyPatternsAndFoldGreedily(funcOp,
+ std::move(tilingAndFusionPattern), grc);
+ }
+
+ LinalgTilingAndFusionOptions options;
+ LinalgTransformationFilter filter;
+};
+
/// Configurable pass to apply pattern-based linalg tiling.
struct LinalgStrategyTilePass
: public LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
@@ -380,6 +417,15 @@ struct LinalgStrategyRemoveMarkersPass
};
} // namespace
+/// Create a LinalgStrategyTileAndFusePass.
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgStrategyTileAndFusePass(StringRef opName,
+ LinalgTilingAndFusionOptions options,
+ LinalgTransformationFilter filter) {
+ return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
+ filter);
+}
+
/// Create a LinalgStrategyTilePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 36bb0171823f7..20073359a68ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -520,6 +520,75 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
return success();
}
+/// Linalg tile and fuse tensor ops pattern.
+mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
+ LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
+ LinalgTilingAndFusionOptions options,
+ LinalgTransformationFilter filter,
+ PatternBenefit benefit)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
+ options(options) {}
+
+mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
+ LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
+ LinalgTilingAndFusionOptions options,
+ LinalgTransformationFilter filter,
+ PatternBenefit benefit)
+ : RewritePattern(opName, benefit, context), filter(filter),
+ options(options) {}
+
+LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ LinalgOp rootOp = dyn_cast<LinalgOp>(op);
+ if (!rootOp)
+ return failure();
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
+ if (options.tileSizes.size() < rootOp.getNumLoops())
+ return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
+
+ // Check `tileInterchange` contains no entries or as many as `tileSizes`.
+ if (!options.tileInterchange.empty() &&
+ options.tileInterchange.size() != options.tileSizes.size())
+ return rewriter.notifyMatchFailure(
+ op, "expect the number of tile sizes and interchange dims to match");
+
+ // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
+ SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
+ options.tileSizes.begin() +
+ rootOp.getNumLoops());
+ SmallVector<int64_t> rootInterchange =
+ options.tileInterchange.empty()
+ ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
+ : SmallVector<int64_t>(options.tileInterchange.begin(),
+ options.tileInterchange.begin() +
+ rootOp.getNumLoops());
+
+ // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
+ // It has to be a permutation since the tiling cannot tile the same loop
+ // dimension multiple times.
+ if (!isPermutation(rootInterchange))
+ return rewriter.notifyMatchFailure(
+ op, "expect the tile interchange permutes the root loops");
+
+ // Tile `rootOp` and fuse its producers.
+ FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
+ rewriter, rootOp, rootTileSizes, rootInterchange);
+ if (failed(tileLoopNest))
+ return rewriter.notifyMatchFailure(
+ op, "tileConsumerAndFuseProducers failed unexpectedly");
+
+ // Replace all uses of the tiled loop operation.
+ rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
+
+ // Apply the filter if specified.
+ for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
+ filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
+ return failure();
+}
+
/// Linalg generic interchange pattern.
mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
index 90b9ad60b97f3..e46218995ffb8 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
@@ -1,12 +1,13 @@
-// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=5,4,7 tile-interchange=1,0,2" -cse -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=5,4,7 tile-interchange=1,0,2 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=MATMUL %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic fuse tile-sizes=5,4,7 tile-interchange=1,0,2 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=GENERIC %s
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
+// MATMUL-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
+// MATMUL-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
+// MATMUL-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)>
+// MATMUL-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
-// CHECK: fuse_input
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// MATMUL: fuse_input
+// MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
builtin.func @fuse_input(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>,
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
@@ -18,31 +19,31 @@ builtin.func @fuse_input(%arg0: tensor<24x12xf32>,
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32>
- // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
- // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
- // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
- // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] =
- // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
+ // MATMUL: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ // MATMUL: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ // MATMUL: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
+ // MATMUL: scf.for %[[IV2:[0-9a-zA-Z]*]] =
+ // MATMUL: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
// Tile both input operand dimensions.
- // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]])
- // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
- // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
- // CHECK-SAME: %[[IV1]], %[[IV2]]
- // CHECK-SAME: %[[UB1]], %[[UB2]]
- // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
- // CHECK: %{{.*}} = linalg.matmul ins(%[[T1]]
+ // MATMUL: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]])
+ // MATMUL: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
+ // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+ // MATMUL-SAME: %[[IV1]], %[[IV2]]
+ // MATMUL-SAME: %[[UB1]], %[[UB2]]
+ // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+ // MATMUL: %{{.*}} = linalg.matmul ins(%[[T1]]
%1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
return %1 : tensor<24x25xf32>
}
// -----
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
+// MATMUL-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)>
+// MATMUL-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
-// CHECK: fuse_output
-// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+// MATMUL: fuse_output
+// MATMUL-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>,
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
@@ -55,34 +56,34 @@ builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
%0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
// Update the iteration argument of the outermost tile loop.
- // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
- // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
- // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
- // CHECK: %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]])
+ // MATMUL: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+ // MATMUL: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
+ // MATMUL: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]])
+ // MATMUL: %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]])
// Tile the both output operand dimensions.
- // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
- // CHECK-SAME: %[[IV1]], %[[IV0]]
- // CHECK-SAME: %[[TS1]], %[[TS0]]
- // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
- // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
- // CHECK: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]]
+ // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
+ // MATMUL-SAME: %[[IV1]], %[[IV0]]
+ // MATMUL-SAME: %[[TS1]], %[[TS0]]
+ // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+ // MATMUL: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
+ // MATMUL: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]]
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32>
return %1 : tensor<24x25xf32>
}
// -----
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
+// MATMUL-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)>
+// MATMUL-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)>
+// MATMUL-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)>
+// MATMUL-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)>
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK: fuse_reduction
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
-// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32>
+// MATMUL: fuse_reduction
+// MATMUL-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
+// MATMUL-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32>
builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>,
%arg2: tensor<24x25xf32>,
@@ -98,23 +99,23 @@ builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>,
linalg.yield %2 : f32
} -> tensor<12x25xf32>
- // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
- // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
- // CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])
- // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] =
- // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
- // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
- // CHECK: %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]])
+ // MATMUL: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ // MATMUL: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ // MATMUL: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])
+ // MATMUL: scf.for %[[IV2:[0-9a-zA-Z]*]] =
+ // MATMUL: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]])
+ // MATMUL: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]])
+ // MATMUL: %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]])
// Tile only the parallel dimensions but not the reduction dimension.
- // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
- // CHECK-SAME: %[[IV2]], 0, %[[IV0]]
- // CHECK-SAME: %[[UB2]], 7, %[[UB0]]
- // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
- // CHECK-SAME: %[[IV2]], %[[IV0]]
- // CHECK-SAME: %[[UB2]], %[[UB0]]
- // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
- // CHECK: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]]
+ // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
+ // MATMUL-SAME: %[[IV2]], 0, %[[IV0]]
+ // MATMUL-SAME: %[[UB2]], 7, %[[UB0]]
+ // MATMUL: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+ // MATMUL-SAME: %[[IV2]], %[[IV0]]
+ // MATMUL-SAME: %[[UB2]], %[[UB0]]
+ // MATMUL: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
+ // MATMUL: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]]
%1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
return %1 : tensor<24x25xf32>
}
@@ -124,9 +125,9 @@ builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>,
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: fuse_transposed
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
-// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32>
+// MATMUL: fuse_transposed
+// MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// MATMUL-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32>
builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>,
%arg2: tensor<24x25xf32>,
@@ -142,26 +143,26 @@ builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>,
linalg.yield %2 : f32
} -> tensor<24x12xf32>
- // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
- // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
- // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] =
+ // MATMUL: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ // MATMUL: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ // MATMUL: scf.for %[[IV2:[0-9a-zA-Z]*]] =
// Swap the input operand slice offsets due to the transposed indexing map.
- // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
- // CHECK-SAME: %[[IV2]], %[[IV1]]
- // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
- // CHECK-SAME: %[[IV1]], %[[IV2]]
- // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
- // CHECK: %{{.*}} = linalg.matmul ins(%[[T2]]
+ // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG3]]
+ // MATMUL-SAME: %[[IV2]], %[[IV1]]
+ // MATMUL: %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+ // MATMUL-SAME: %[[IV1]], %[[IV2]]
+ // MATMUL: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]]
+ // MATMUL: %{{.*}} = linalg.matmul ins(%[[T2]]
%1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
return %1 : tensor<24x25xf32>
}
// -----
-// CHECK: fuse_input_and_output
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
-// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+// MATMUL: fuse_input_and_output
+// MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// MATMUL-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>,
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
@@ -175,27 +176,27 @@ builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>,
%1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32>
// Fuse both producers to the appropriate tile loops.
- // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
- // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
- // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
- // CHECK-SAME: %[[IV1]], %[[IV0]]
- // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
- // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
- // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG0]]
- // CHECK-SAME: %[[IV1]], %[[IV2]]
- // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
- // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]]
+ // MATMUL: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+ // MATMUL: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]
+ // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG4]]
+ // MATMUL-SAME: %[[IV1]], %[[IV0]]
+ // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+ // MATMUL: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
+ // MATMUL: %[[T2:.*]] = tensor.extract_slice %[[ARG0]]
+ // MATMUL-SAME: %[[IV1]], %[[IV2]]
+ // MATMUL: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+ // MATMUL: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]]
%2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32>
return %2 : tensor<24x25xf32>
}
// -----
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// MATMUL-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
#map0 = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK: fuse_indexed
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xi32>
+// MATMUL: fuse_indexed
+// MATMUL-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xi32>
builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
%arg1: tensor<12x25xi32>,
%arg2: tensor<24x25xi32>) -> tensor<24x25xi32> {
@@ -213,19 +214,19 @@ builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
linalg.yield %9 : i32
} -> tensor<12x25xi32>
- // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
- // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
- // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] =
+ // MATMUL: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ // MATMUL: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ // MATMUL: scf.for %[[IV2:[0-9a-zA-Z]*]] =
// Shift the indexes by the slice offsets and swap the offsets due to the transposed indexing map.
- // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
- // CHECK-SAME: %[[IV2]], %[[IV0]]
- // CHECK: linalg.generic {{.*}} outs(%[[T1]]
- // CHECK: %[[IDX0:.*]] = linalg.index 0
- // CHECK: %[[IDX0_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX0]], %[[IV0]])
- // CHECK: %[[IDX1:.*]] = linalg.index 1
- // CHECK: %[[IDX1_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX1]], %[[IV2]])
- // CHECK: %{{.*}} = arith.addi %[[IDX0_SHIFTED]], %[[IDX1_SHIFTED]]
+ // MATMUL: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+ // MATMUL-SAME: %[[IV2]], %[[IV0]]
+ // MATMUL: linalg.generic {{.*}} outs(%[[T1]]
+ // MATMUL: %[[IDX0:.*]] = linalg.index 0
+ // MATMUL: %[[IDX0_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX0]], %[[IV0]])
+ // MATMUL: %[[IDX1:.*]] = linalg.index 1
+ // MATMUL: %[[IDX1_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX1]], %[[IV2]])
+ // MATMUL: %{{.*}} = arith.addi %[[IDX0_SHIFTED]], %[[IDX1_SHIFTED]]
%1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xi32>, tensor<12x25xi32>) outs(%arg2 : tensor<24x25xi32>) -> tensor<24x25xi32>
return %1 : tensor<24x25xi32>
}
@@ -235,28 +236,28 @@ builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
-// CHECK: fuse_outermost_reduction
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32>
+// GENERIC: fuse_outermost_reduction
+// GENERIC-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
+// GENERIC-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32>
func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>,
%arg1: tensor<10xf32>) -> tensor<10xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32>
// Cannot fuse the output fill since the reduction loop is the outermost loop.
- // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]])
+ // GENERIC: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]])
%1 = linalg.fill(%cst, %arg1) : f32, tensor<10xf32> -> tensor<10xf32>
- // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]]
- // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+ // GENERIC: scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]]
+ // GENERIC: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
- // Check the input fill has been fused.
- // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
- // CHECK-SAME: %[[IV1]], %[[IV0]]
- // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
- // CHECK: %[[T3:.*]] = tensor.extract_slice %[[ARG3]]
- // CHECK-SAME: %[[IV1]]
- // CHECK: linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]]
+ // MATMUL the input fill has been fused.
+ // GENERIC: %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+ // GENERIC-SAME: %[[IV1]], %[[IV0]]
+ // GENERIC: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+ // GENERIC: %[[T3:.*]] = tensor.extract_slice %[[ARG3]]
+ // GENERIC-SAME: %[[IV1]]
+ // GENERIC: linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]]
%2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<10x17xf32>) outs(%1 : tensor<10xf32>) {
^bb0(%arg2: f32, %arg3: f32): // no predecessors
%3 = arith.addf %arg2, %arg3 : f32
@@ -267,39 +268,39 @@ func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>,
// -----
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 17)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 17)>
+// GENERIC-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// GENERIC-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 17)>
+// GENERIC-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 17)>
#map0 = affine_map<(d0, d1) -> (d0, d0 + d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: fuse_non_rectangular
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
+// GENERIC: fuse_non_rectangular
+// GENERIC-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
func @fuse_non_rectangular(%arg0: tensor<10x17xf32>,
%arg1: tensor<10x8xf32>) -> tensor<10x8xf32> {
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
- // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
- // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
- // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+ // GENERIC-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // GENERIC-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // GENERIC-DAG: %[[C5:.*]] = arith.constant 5 : index
+ // GENERIC-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // GENERIC-DAG: %[[C10:.*]] = arith.constant 10 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32>
- // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]]
- // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]]
+ // GENERIC: scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]]
+ // GENERIC: scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]]
// Compute producer on a hyper rectangular bounding box. Along the second dimenson,
// the offset is set to the sum of the induction variables, and the upper bound
// to either 8 (tile size) or 17 (sum of max indices (9+7) then + 1) minus the
// induction variables.
- // CHECK: %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]]
- // CHECK: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]]
- // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]]
- // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
- // CHECK-SAME: %[[IV1]], %[[SUM]]
- // CHECK-SAME: , %[[UB1]]
- // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
+ // GENERIC-DAG: %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]]
+ // GENERIC-DAG: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]]
+ // GENERIC-DAG: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]]
+ // GENERIC: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+ // GENERIC-SAME: %[[IV1]], %[[SUM]]
+ // GENERIC-SAME: , %[[UB1]]
+ // GENERIC: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
%1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x17xf32>) outs(%arg1 : tensor<10x8xf32>) {
^bb0(%arg2: f32, %arg3: f32): // no predecessors
%2 = arith.addf %arg2, %arg3 : f32
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
index 1578d230017ba..ef28641f4062c 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir
@@ -1,11 +1,12 @@
-// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=4,4,0,0 tile-interchange=0,1,2,3" -cse --canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.conv_2d fuse tile-sizes=4,4,0,0 tile-interchange=0,1,2,3 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=CONV %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=4,4,0 tile-interchange=0,1,2 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=MATMUL %s
-// CHECK: fuse_conv_chain
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32>
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32>
-// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32>
-// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32>
-// CHECK-SAME: %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32>
+// CONV: fuse_conv_chain
+// CONV-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32>
+// CONV-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32>
+// CONV-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32>
+// CONV-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32>
+// CONV-SAME: %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32>
builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
%arg1: tensor<11x11xf32>,
%arg2: tensor<10x10xf32>,
@@ -14,34 +15,34 @@ builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
%cst = arith.constant 1.0 : f32
// Do not tile the filter fill since the filter dimensions are not tiled.
- // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
+ // CONV: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
%0 = linalg.fill(%cst, %arg0) : f32, tensor<2x2xf32> -> tensor<2x2xf32>
// Fuse all other operations.
- // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]]
- // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]]
+ // CONV: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]]
+ // CONV: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]]
- // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
- // CHECK-SAME: %[[IV0]], %[[IV1]]
- // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG2]]
- // CHECK-SAME: %[[IV0]], %[[IV1]]
- // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
- // CHECK: %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]]
+ // CONV: %[[T1:.*]] = tensor.extract_slice %[[ARG1]]
+ // CONV-SAME: %[[IV0]], %[[IV1]]
+ // CONV: %[[T2:.*]] = tensor.extract_slice %[[ARG2]]
+ // CONV-SAME: %[[IV0]], %[[IV1]]
+ // CONV: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
+ // CONV: %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]]
%1 = linalg.fill(%cst, %arg2) : f32, tensor<10x10xf32> -> tensor<10x10xf32>
%2 = linalg.conv_2d ins(%arg1, %0 : tensor<11x11xf32>, tensor<2x2xf32>) outs(%1 : tensor<10x10xf32>) -> tensor<10x10xf32>
- // CHECK: %[[T5:.*]] = tensor.extract_slice %[[ARG3]]
- // CHECK-SAME: %[[IV0]], %[[IV1]]
- // CHECK: %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]])
- // CHECK: %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]]
+ // CONV: %[[T5:.*]] = tensor.extract_slice %[[ARG3]]
+ // CONV-SAME: %[[IV0]], %[[IV1]]
+ // CONV: %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]])
+ // CONV: %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]]
%3 = linalg.fill(%cst, %arg3) : f32, tensor<9x9xf32> -> tensor<9x9xf32>
%4 = linalg.conv_2d ins(%2, %0 : tensor<10x10xf32>, tensor<2x2xf32>) outs(%3 : tensor<9x9xf32>) -> tensor<9x9xf32>
// Use the argument passed in by iteration argument.
- // CHECK: %[[T8:.*]] = tensor.extract_slice %[[ARG6]]
- // CHECK-SAME: %[[IV0]], %[[IV1]]
- // CHECK: %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]])
- // CHECK: %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]]
+ // CONV: %[[T8:.*]] = tensor.extract_slice %[[ARG6]]
+ // CONV-SAME: %[[IV0]], %[[IV1]]
+ // CONV: %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]])
+ // CONV: %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]]
%5 = linalg.fill(%cst, %arg4) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
%6 = linalg.conv_2d ins(%4, %0 : tensor<9x9xf32>, tensor<2x2xf32>) outs(%5 : tensor<8x8xf32>) -> tensor<8x8xf32>
return %6 : tensor<8x8xf32>
@@ -49,8 +50,8 @@ builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>,
// -----
-// CHECK: fuse_matmul_chain
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32>
+// MATMUL: fuse_matmul_chain
+// MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32>
builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
@@ -60,24 +61,24 @@ builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
%cst = arith.constant 0.000000e+00 : f32
// Do not tile rhs fill of the producer matmul since none of its loop dimension is tiled.
- // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
+ // MATMUL: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]])
%0 = linalg.fill(%cst, %arg0) : f32, tensor<8x8xf32> -> tensor<8x8xf32>
- // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]]
- // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]]
+ // MATMUL: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]]
+ // MATMUL: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]]
// Only the outermost loop of the producer matmul is tiled.
- // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
- // CHECK-SAME: %[[IV0]], 0
- // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
- // CHECK: %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}}
+ // MATMUL: %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+ // MATMUL-SAME: %[[IV0]], 0
+ // MATMUL: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+ // MATMUL: %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}}
%1 = linalg.matmul ins(%0, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
// Use the argument passed in by iteration argument.
- // CHECK: %[[T4:.*]] = tensor.extract_slice %[[ARG2]]
- // CHECK-SAME: %[[IV0]], %[[IV1]]
- // CHECK: %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]])
- // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]]
+ // MATMUL: %[[T4:.*]] = tensor.extract_slice %[[ARG2]]
+ // MATMUL-SAME: %[[IV0]], %[[IV1]]
+ // MATMUL: %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]])
+ // MATMUL: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]]
%2 = linalg.matmul ins(%1, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32>
return %2 : tensor<8x8xf32>
}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
index 49b2bb9b4bcd6..c2c563abe4cfc 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp
@@ -52,16 +52,21 @@ struct TestLinalgCodegenStrategy
void runOnFunction() override;
- void runStrategy(LinalgTilingOptions tilingOptions,
+ void runStrategy(LinalgTilingAndFusionOptions tilingAndFusionOptions,
+ LinalgTilingOptions tilingOptions,
LinalgTilingOptions registerTilingOptions,
LinalgPaddingOptions paddingOptions,
vector::VectorContractLowering vectorContractLowering,
vector::VectorTransferSplit vectorTransferSplit);
+ Option<bool> fuse{
+ *this, "fuse",
+ llvm::cl::desc("Fuse the producers after tiling the root op."),
+ llvm::cl::init(false)};
ListOption<int64_t> tileSizes{*this, "tile-sizes",
llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Specifies the tile sizes.")};
- ListOption<unsigned> tileInterchange{
+ ListOption<int64_t> tileInterchange{
*this, "tile-interchange", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Specifies the tile interchange.")};
@@ -148,6 +153,7 @@ struct TestLinalgCodegenStrategy
};
void TestLinalgCodegenStrategy::runStrategy(
+ LinalgTilingAndFusionOptions tilingAndFusionOptions,
LinalgTilingOptions tilingOptions,
LinalgTilingOptions registerTilingOptions,
LinalgPaddingOptions paddingOptions,
@@ -156,7 +162,10 @@ void TestLinalgCodegenStrategy::runStrategy(
assert(!anchorOpName.empty());
CodegenStrategy strategy;
StringRef genericOpName = GenericOp::getOperationName();
- strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions)
+ strategy
+ .tileAndFuseIf(fuse && !tileSizes.empty(), anchorOpName,
+ tilingAndFusionOptions)
+ .tileIf(!fuse && !tileSizes.empty(), anchorOpName, tilingOptions)
.promoteIf(promote, anchorOpName,
LinalgPromotionOptions()
.setAlignment(16)
@@ -204,11 +213,17 @@ void TestLinalgCodegenStrategy::runOnFunction() {
if (!anchorFuncOpName.empty() && anchorFuncOpName != getFunction().getName())
return;
+ LinalgTilingAndFusionOptions tilingAndFusionOptions;
+ tilingAndFusionOptions.tileSizes = {tileSizes.begin(), tileSizes.end()};
+ tilingAndFusionOptions.tileInterchange = {tileInterchange.begin(),
+ tileInterchange.end()};
+
LinalgTilingOptions tilingOptions;
if (!tileSizes.empty())
tilingOptions = tilingOptions.setTileSizes(tileSizes);
if (!tileInterchange.empty())
- tilingOptions = tilingOptions.setInterchange(tileInterchange);
+ tilingOptions = tilingOptions.setInterchange(
+ SmallVector<unsigned>(tileInterchange.begin(), tileInterchange.end()));
LinalgTilingOptions registerTilingOptions;
if (!registerTileSizes.empty())
@@ -245,8 +260,8 @@ void TestLinalgCodegenStrategy::runOnFunction() {
.Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
.Default(vector::VectorTransferSplit::None);
- runStrategy(tilingOptions, registerTilingOptions, paddingOptions,
- vectorContractLowering, vectorTransferSplit);
+ runStrategy(tilingAndFusionOptions, tilingOptions, registerTilingOptions,
+ paddingOptions, vectorContractLowering, vectorTransferSplit);
}
namespace mlir {
More information about the Mlir-commits
mailing list