[Mlir-commits] [mlir] 7915027 - [mlir][Linalg] Retire LinalgStrategyTileAndFusePass and filter-based pattern.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Oct 10 07:04:11 PDT 2022
Author: Nicolas Vasilache
Date: 2022-10-10T07:04:01-07:00
New Revision: 79150279268ae7b4da95750585a71f3df405fa6e
URL: https://github.com/llvm/llvm-project/commit/79150279268ae7b4da95750585a71f3df405fa6e
DIFF: https://github.com/llvm/llvm-project/commit/79150279268ae7b4da95750585a71f3df405fa6e.diff
LOG: [mlir][Linalg] Retire LinalgStrategyTileAndFusePass and filter-based pattern.
Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785
In the process, also retire `tileConsumerAndFuseProducers` that is now replaced by `tileConsumerAndFuseProducerGreedilyUsingSCFForOp`.
Context: https://discourse.llvm.org/t/psa-retire-tileandfuselinalgops-method/63850
When performing this replacement, a change of behavior appeared: the older `tileConsumerAndFuseProducers` would split the parallel
and non-parallel dimensions automatically and perform a first level of tile-and-fuse on parallel dimensions only and then introduce a
second level of tiling-only on the reduction dimensions. The newer `tileConsumerAndFuseProducerGreedilyUsingSCFForOp` on the other hand
does not perform this breakdown. As a consequence, the transform specification is evolved to produce the same output.
Additionally, replace some uses of `unsigned` by `int64_t` where possible without pulling in larger interface changes (left for a future PR).
Context: https://www.youtube.com/watch?v=Puio5dly9N8
Lastly, tests that were performing tile and fuse and distribute on tensors are retired: the generated IR mixing scf.for, tensors and
distributed processor ids was racy at best ..
Differential Revision: https://reviews.llvm.org/D135559
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Dialect/Linalg/transform-op-fuse.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Removed:
mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 719a79290620d..6e41f05cc36d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -76,13 +76,6 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
//===----------------------------------------------------------------------===//
/// Linalg strategy passes.
//===----------------------------------------------------------------------===//
-/// Create a LinalgStrategyTileAndFusePass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createLinalgStrategyTileAndFusePass(
- StringRef opName = "", const linalg::LinalgTilingAndFusionOptions &opt = {},
- const linalg::LinalgTransformationFilter &filter =
- linalg::LinalgTransformationFilter());
-
/// Create a LinalgStrategyTilePass.
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyTilePass(
StringRef opName = "",
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 1889d1e0cab90..40a2f112f0a15 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -162,18 +162,6 @@ def LinalgDetensorize : Pass<"linalg-detensorize", ""> {
];
}
-def LinalgStrategyTileAndFusePass
- : Pass<"linalg-strategy-tile-and-fuse-pass", "func::FuncOp"> {
- let summary = "Configurable pass to apply pattern-based tiling and fusion.";
- let constructor = "mlir::createLinalgStrategyTileAndFusePass()";
- let options = [
- Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
- "Which func op is the anchor to latch on.">,
- Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
- "Which linalg op within the func is the anchor to latch on.">,
- ];
-}
-
def LinalgStrategyTilePass
: Pass<"linalg-strategy-tile-pass", "func::FuncOp"> {
let summary = "Configurable pass to apply pattern-based linalg tiling.";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index 6f56702dd2b2e..d7c0d22031692 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -30,23 +30,6 @@ struct Transformation {
LinalgTransformationFilter::FilterFunction filter = nullptr;
};
-/// Represent one application of LinalgStrategyTileAndFusePass.
-struct TileAndFuse : public Transformation {
- TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter::FilterFunction f = nullptr)
- : Transformation(std::move(f)), opName(name),
- options(std::move(options)) {}
-
- void addToPassPipeline(OpPassManager &pm,
- LinalgTransformationFilter m) const override {
- pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m));
- }
-
-private:
- std::string opName;
- linalg::LinalgTilingAndFusionOptions options;
-};
-
/// Represent one application of LinalgStrategyTilePass.
struct Tile : public Transformation {
Tile(StringRef name, linalg::LinalgTilingOptions options,
@@ -66,22 +49,6 @@ struct Tile : public Transformation {
/// Codegen strategy controls how a Linalg op is progressively lowered.
struct CodegenStrategy {
- /// Append a pattern to tile the Op `opName` and fuse its producers with
- /// tiling and fusion `options`.
- CodegenStrategy &
- tileAndFuse(StringRef opName, const LinalgTilingAndFusionOptions &options,
- const LinalgTransformationFilter::FilterFunction &f = nullptr) {
- transformationSequence.emplace_back(
- std::make_unique<TileAndFuse>(opName, options, f));
- return *this;
- }
- /// Conditionally append a pattern to tile the Op `opName` and fuse its
- /// producers with tiling and fusion `options`.
- CodegenStrategy &
- tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter::FilterFunction f = nullptr) {
- return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this;
- }
/// Append a pattern to add a level of tiling for Op `opName` with tiling
/// `options`.
CodegenStrategy &
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b7f99ab6add02..62dcc8e786877 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -787,42 +787,6 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
}
};
-///
-/// Linalg tile and fuse tensor ops pattern.
-///
-/// Apply tiling and fusion as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `tileConsumerAndFuseProducers` for more details.
-struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
- // Entry point to match any LinalgOp.
- LinalgTileAndFuseTensorOpsPattern(
- MLIRContext *context, LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
- // Entry point to match a specific LinalgOp.
- LinalgTileAndFuseTensorOpsPattern(
- StringRef opName, MLIRContext *context,
- LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
-
- /// `matchAndRewrite` implementation that returns the significant transformed
- /// pieces of IR.
- FailureOr<TileLoopNest>
- returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(op, rewriter);
- }
-
-private:
- /// LinalgTransformMarker handles special attribute manipulations.
- LinalgTransformationFilter filter;
- /// Tile sizes and interchange used to tile the root operation.
- LinalgTilingAndFusionOptions options;
-};
-
///
/// Linalg generalization pattern.
///
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 3ec6fc7522d23..305b859ac13d1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -445,14 +445,6 @@ class TileLoopNest {
DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
};
-/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
-/// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control
-/// the tiling.
-FailureOr<TileLoopNest> tileConsumerAndFuseProducers(
- OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> tileInterchange,
- const Optional<LinalgLoopDistributionOptions> &tileDistribution);
-
//===----------------------------------------------------------------------===//
// Generic op region utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1c374d62425d1..174b39ce89684 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -53,8 +53,8 @@ struct SCFTilingOptions {
SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
/// The interchange vector to reorder the tiled loops.
- SmallVector<unsigned> interchangeVector = {};
- SCFTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
+ SmallVector<int64_t> interchangeVector = {};
+ SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
interchangeVector = llvm::to_vector(interchange);
return *this;
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 99f93ed8a5ae9..5b825201ebcb2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
@@ -99,45 +100,63 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}
-
//===----------------------------------------------------------------------===//
// FuseOp
//===----------------------------------------------------------------------===//
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
-static LogicalResult
-applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
- unsigned numLoops,
- transform::TransformResults &transformResults,
- function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
+static LogicalResult applyTilingToAll(
+ Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops,
+ transform::TransformResults &transformResults,
+ function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
+ applyFn) {
SmallVector<Operation *> tiledLinalgOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].reserve(payloadOps.size());
for (Operation *target : payloadOps) {
- auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
- if (!linalgOp)
- return transformOp->emitError("only LinalgOps are supported");
-
- FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
- if (failed(tiled))
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+ if (!tilingInterfaceOp)
+ return transformOp->emitError("only TilingInterface ops are supported");
+
+ SimpleRewriter rewriter(target->getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<scf::SCFTileAndFuseResult> tiledResults =
+ applyFn(tilingInterfaceOp);
+ if (failed(tiledResults))
return failure();
- tiledLinalgOps.push_back(tiled->op);
- if (tiled->loops.size() != numLoops)
- // Not enough loops were generated. This usually means that the input size
- // was smaller than the tiling size.
- // TODO: LinalgTilingPattern should return failure().
- return failure();
+ // Perform the replacement of tiled and fused values.
+ SmallVector<Operation *> opsToReplace{target};
+ llvm::append_range(opsToReplace, tiledResults->fusedProducers);
+ for (Operation *toReplace : opsToReplace) {
+ SmallVector<Value> replacements;
+ replacements.reserve(toReplace->getNumResults());
+ for (OpResult res : toReplace->getResults()) {
+ auto it = tiledResults->replacements.find(res);
+ if (it == tiledResults->replacements.end())
+ replacements.push_back(res);
+ else
+ replacements.push_back(it->getSecond());
+ }
+ rewriter.replaceOp(toReplace, replacements);
+ }
+
+ // Report back the relevant handles to the transform op.
+ tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
+ assert(tiledResults->loops.size() == numLoops &&
+ "Mismatched number of loops, tile and fuse transform should have "
+ "failed");
for (unsigned int i = 0; i < numLoops; ++i)
- loopOps[i].push_back(tiled->loops[i]);
+ loopOps[i].push_back(tiledResults->loops[i]);
}
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
for (unsigned int i = 0; i < numLoops; ++i)
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
+
return success();
}
@@ -172,27 +191,23 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
DiagnosedSilenceableFailure
transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
- LinalgTilingAndFusionOptions fusionOptions;
- fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
- fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
+ SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
+ SmallVector<int64_t> tileInterchange =
+ extractFromI64ArrayAttr(getTileInterchange());
+ scf::SCFTilingOptions tilingOptions;
+ tilingOptions.interchangeVector = tileInterchange;
+ tilingOptions = tilingOptions.setTileSizes(tileSizes);
+ scf::SCFTileAndFuseOptions tileAndFuseOptions;
+ tileAndFuseOptions.tilingOptions = tilingOptions;
LogicalResult result = applyTilingToAll(
getOperation(), state.getPayloadOps(getTarget()),
- fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
- transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
- LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
+ tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
+ [&](TilingInterface tilingInterfaceOp)
+ -> FailureOr<scf::SCFTileAndFuseResult> {
SimpleRewriter rewriter(getContext());
- rewriter.setInsertionPoint(linalgOp);
- FailureOr<TileLoopNest> tileLoopNest =
- pattern.returningMatchAndRewrite(linalgOp, rewriter);
- if (failed(tileLoopNest))
- return failure();
-
- TiledLinalgOp tiledLinalgOp;
- tiledLinalgOp.op = tileLoopNest->getRootOp();
- tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
- tileLoopNest->getLoopOps().end()};
- return tiledLinalgOp;
+ return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
+ rewriter, tilingInterfaceOp, tileAndFuseOptions);
});
return DiagnosedSilenceableFailure(result);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d29b767df9d71..2451c79a35052 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -414,68 +414,3 @@ SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
}
return result;
}
-
-//===----------------------------------------------------------------------===//
-// Tile and fuse entry-points.
-//===----------------------------------------------------------------------===//
-
-FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
- OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> tileInterchange,
- const Optional<LinalgLoopDistributionOptions> &tileDistribution) {
- assert(tileSizes.size() == tileInterchange.size() &&
- "expect the number of tile sizes and interchange dims to match");
- assert(isPermutation(tileInterchange) &&
- "expect tile interchange is a permutation");
-
- // Create an empty tile loop nest.
- TileLoopNest tileLoopNest(consumerOp);
-
- // Search the number of outer parallel loops to separate them from possible
- // inner reduction dimensions.
- SmallVector<StringRef> iterTypes = consumerOp.getIteratorTypesArray();
- applyPermutationToVector(iterTypes, tileInterchange);
- auto *it = find_if_not(iterTypes, isParallelIterator);
- int64_t split = std::distance(iterTypes.begin(), it);
-
- // Helper to fuse the producers greedily using a queue of fusion candidates.
- auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
- SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
- while (!candidates.empty()) {
- FailureOr<LinalgOp> fusedProducer =
- tileLoopNest.fuseProducer(b, candidates.pop_back_val());
- if (failed(fusedProducer))
- continue;
- candidates.append(fusedProducer->getInputAndOutputOperands());
- }
- };
-
- // Perform tiling and fusion in two steps. We need to respect the loop
- // interchange here; filter parellel dimensions based on their order *after*
- // permutation but pass in the original configuration *before* permuation,
- // given the tiling and interchange happen together.
- SmallVector<int64_t> outerTileSizes(tileSizes.size(), 0);
- SmallVector<int64_t> innerTileSizes(tileSizes.size(), 0);
- for (int64_t i : tileInterchange.take_front(split))
- outerTileSizes[i] = tileSizes[i];
- for (int64_t i : tileInterchange.drop_front(split))
- innerTileSizes[i] = tileSizes[i];
-
- // Tile the outer parallel loops and fuse the output operands.
- if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
- tileDistribution)))
- return failure();
- fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
-
- // Tile the remaining loops and fuse the input operands.
- if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
- tileDistribution)))
- return failure();
- fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
-
- // Exit if the tile loop nest is empty since all tile sizes are zero.
- if (tileLoopNest.isEmpty())
- return failure();
-
- return tileLoopNest;
-}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 3faf45e8caa5d..162e74f1ba31d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -51,44 +51,6 @@ using namespace linalg;
namespace {
-/// Configurable pass to apply pattern-based tiling and fusion.
-struct LinalgStrategyTileAndFusePass
- : public impl::LinalgStrategyTileAndFusePassBase<
- LinalgStrategyTileAndFusePass> {
-
- LinalgStrategyTileAndFusePass() = default;
-
- LinalgStrategyTileAndFusePass(StringRef opName,
- LinalgTilingAndFusionOptions opt,
- LinalgTransformationFilter filt)
- : options(std::move(opt)), filter(std::move(filt)) {
- this->anchorOpName.setValue(opName.str());
- }
-
- void runOnOperation() override {
- auto funcOp = getOperation();
- if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
- return;
-
- RewritePatternSet tilingAndFusionPattern(funcOp.getContext());
- if (!anchorOpName.empty()) {
- tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
- anchorOpName, funcOp.getContext(), options, filter);
- } else {
- tilingAndFusionPattern.add<LinalgTileAndFuseTensorOpsPattern>(
- funcOp.getContext(), options, filter);
- }
- // Search the root operation using bottom up traversal.
- GreedyRewriteConfig config;
- config.useTopDownTraversal = false;
- (void)applyPatternsAndFoldGreedily(
- funcOp, std::move(tilingAndFusionPattern), config);
- }
-
- LinalgTilingAndFusionOptions options;
- LinalgTransformationFilter filter;
-};
-
/// Configurable pass to apply pattern-based linalg tiling.
struct LinalgStrategyTilePass
: public impl::LinalgStrategyTilePassBase<LinalgStrategyTilePass> {
@@ -139,15 +101,6 @@ struct LinalgStrategyRemoveMarkersPass
};
} // namespace
-/// Create a LinalgStrategyTileAndFusePass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgStrategyTileAndFusePass(
- StringRef opName, const LinalgTilingAndFusionOptions &options,
- const LinalgTransformationFilter &filter) {
- return std::make_unique<LinalgStrategyTileAndFusePass>(opName, options,
- filter);
-}
-
/// Create a LinalgStrategyTilePass.
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createLinalgStrategyTilePass(StringRef opName,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index b3062f53b5e0d..938b9e736bf72 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -447,82 +447,6 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
return paddedOp;
}
-/// Linalg tile and fuse tensor ops pattern.
-mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
- LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
- LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter f,
- PatternBenefit benefit)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
- filter(std::move(f)), options(std::move(options)) {}
-
-mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
- LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
- LinalgTilingAndFusionOptions options,
- LinalgTransformationFilter f,
- PatternBenefit benefit)
- : RewritePattern(opName, benefit, context), filter(std::move(f)),
- options(std::move(options)) {}
-
-FailureOr<mlir::linalg::TileLoopNest>
-mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
- LinalgOp rootOp = dyn_cast<LinalgOp>(op);
- if (!rootOp)
- return failure();
- if (failed(filter.checkAndNotify(rewriter, op)))
- return failure();
-
- // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
- if (options.tileSizes.size() < rootOp.getNumLoops())
- return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
-
- // Check `tileInterchange` contains no entries or as many as `tileSizes`.
- if (!options.tileInterchange.empty() &&
- options.tileInterchange.size() != options.tileSizes.size())
- return rewriter.notifyMatchFailure(
- op, "expect the number of tile sizes and interchange dims to match");
-
- // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
- SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
- options.tileSizes.begin() +
- rootOp.getNumLoops());
- SmallVector<int64_t> rootInterchange =
- options.tileInterchange.empty()
- ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
- : SmallVector<int64_t>(options.tileInterchange.begin(),
- options.tileInterchange.begin() +
- rootOp.getNumLoops());
-
- // Check `rootTileSizes` contains non-zero tile sizes.
- if (llvm::count(rootTileSizes, 0) == static_cast<long>(rootTileSizes.size()))
- return rewriter.notifyMatchFailure(
- op, "expect at least one non-zero tile size");
-
- // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
- // It has to be a permutation since the tiling cannot tile the same loop
- // dimension multiple times.
- if (!isPermutation(rootInterchange))
- return rewriter.notifyMatchFailure(
- op, "expect the tile interchange permutes the root loops");
-
- // Tile `rootOp` and fuse its producers.
- FailureOr<TileLoopNest> tileLoopNest =
- tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
- rootInterchange, options.tileDistribution);
- if (failed(tileLoopNest))
- return rewriter.notifyMatchFailure(
- op, "tileConsumerAndFuseProducers failed unexpectedly");
-
- // Replace all uses of the tiled loop operation.
- rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
-
- // Apply the filter if specified.
- for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
- filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
- return tileLoopNest;
-}
-
/// Linalg generalization pattern.
mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 65bc941e0e0fa..2630da381b382 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -45,12 +45,12 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
/// Helper method to adjust the interchange vector to match the iteration
/// domain.
-static SmallVector<unsigned>
-fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
+static SmallVector<int64_t>
+fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
size_t iterationDomainSize) {
- SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector);
+ SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
if (filledVector.size() < iterationDomainSize) {
- auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
+ auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
filledVector.append(range.begin(), range.end());
}
if (filledVector.size() > iterationDomainSize)
@@ -61,23 +61,23 @@ fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
/// Helper method to apply permutation to a vector
template <typename T>
static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
- ArrayRef<unsigned> interchange) {
+ ArrayRef<int64_t> interchange) {
assert(interchange.size() == vector.size());
return llvm::to_vector(
- llvm::map_range(interchange, [&](unsigned val) { return vector[val]; }));
+ llvm::map_range(interchange, [&](int64_t val) { return vector[val]; }));
}
/// Helper method to apply to invert a permutation.
-static SmallVector<unsigned>
-invertPermutationVector(ArrayRef<unsigned> interchange) {
- SmallVector<unsigned> inversion(interchange.size());
+static SmallVector<int64_t>
+invertPermutationVector(ArrayRef<int64_t> interchange) {
+ SmallVector<int64_t> inversion(interchange.size());
for (const auto &pos : llvm::enumerate(interchange)) {
inversion[pos.value()] = pos.index();
}
return inversion;
}
/// Method to check if an interchange vector is a permutation.
-static bool isPermutation(ArrayRef<unsigned> interchange) {
- llvm::SmallDenseSet<unsigned, 4> seenVals;
+static bool isPermutation(ArrayRef<int64_t> interchange) {
+ llvm::SmallDenseSet<int64_t, 4> seenVals;
for (auto val : interchange) {
if (seenVals.count(val))
return false;
@@ -298,7 +298,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
{
// If there is an interchange specified, permute the iteration domain and
// the tile sizes.
- SmallVector<unsigned> interchangeVector;
+ SmallVector<int64_t> interchangeVector;
if (!options.interchangeVector.empty()) {
interchangeVector = fillInterchangeVector(options.interchangeVector,
iterationDomain.size());
@@ -365,7 +365,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// 5. Yield all the results of the tiled operation. The surrounding loop
// nest is modified to insert a destructive update pattern to yield
// from the loop nest values to replace the untiled op with.
- unsigned numResults = op->getNumResults();
+ int64_t numResults = op->getNumResults();
SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
resultSizesList(numResults);
for (auto result : llvm::enumerate(op->getResults())) {
@@ -443,7 +443,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
// 1. First tile the consumer.
scf::SCFTileAndFuseResult tileAndFuseResult;
- llvm::SmallDenseMap<Value, unsigned> yieldedValueToResultNumber;
+ llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
{
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
@@ -566,7 +566,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
*destinationIterArg.value());
}
if (iterArgNumber) {
- unsigned resultNumber = fusableProducer.getResultNumber();
+ int64_t resultNumber = fusableProducer.getResultNumber();
if (auto producerOp =
dyn_cast<TilingInterface>(fusableProducer.getOwner())) {
SmallVector<Value> destination =
diff --git a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir
deleted file mode 100644
index 01f219528109d..0000000000000
--- a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir
+++ /dev/null
@@ -1,55 +0,0 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-fuse-and-distribute-options -split-input-file | FileCheck %s
-
-// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
-// CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
-// CHECK: func @fill_matmul_tensors(
-// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-func.func @fill_matmul_tensors(
- %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
- -> tensor<?x?xf32> {
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
-// CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y
-// CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
-// CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x
-// CHECK-DAG: %[[INIT:.+]] = tensor.empty
-// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]]
-// CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
-// CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]]
-// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[INIT]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]]
-// CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
-// CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]]
-// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[OUTSLICEA:.+]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[OUTSLICEB:.+]] = tensor.extract_slice %{{.*}}[0, %{{.*}}] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TC1]]
-// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[SLICE]]
-// CHECK: %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[OUTSLICEA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[OUTSLICEB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
-// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
-// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
-// CHECK: %[[TD2:.*]] = tensor.insert_slice %[[sTD2]] into %[[TC1]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
-// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
-// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %cst = arith.constant 0.0 : f32
- %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
- %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
- %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"}
- ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%3: tensor<?x?xf32>)
- -> tensor<?x?xf32>
-
-// CHECK: return %[[TD0]] : tensor<?x?xf32>
- return %4 : tensor<?x?xf32>
-}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index f26462b5f228b..e9801d8742cdd 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -3,10 +3,11 @@
// CHECK-LABEL: func.func @fuse_unary
func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
- // CHECK: scf.for
- // CHECK: scf.for
+ // CHECK: %[[RES:.*]] = scf.for
+ // CHECK: scf.for
// CHECK: linalg.elemwise_unary
// CHECK: linalg.elemwise_binary
+ // CHECK: return %[[RES]]
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
@@ -28,14 +29,15 @@ transform.with_pdl_patterns {
// CHECK-LABEL: func.func @fuse_unary
func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
- // CHECK: scf.for
+ // CHECK: %[[PARTIAL_RES:.*]] = scf.for
// CHECK: scf.for
// CHECK: linalg.elemwise_unary
// CHECK: linalg.elemwise_binary
- // CHECK: scf.for
+ // CHECK: %[[RES:.*]] = scf.for {{.*}}%[[PARTIAL_RES]]
// CHECK: scf.for
// CHECK: linalg.elemwise_unary
// CHECK: linalg.elemwise_binary
+ // CHECK: return %[[RES]]
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
@@ -61,19 +63,23 @@ func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf3
%five = arith.constant 5.0 : f32
%init = tensor.empty() : tensor<12x25xf32>
-// CHECK: %[[INIT:.+]] = tensor.empty()
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
-// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]])
+// CHECK: %[[RES:.*]] = scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]])
// CHECK: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C7]] iter_args(%[[FOR_ARG1:.+]] = %[[FOR_ARG0]])
// CHECK: %[[OUT_SLICE0:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], 0, %[[IV1]]]
// CHECK: %[[OUT_SLICE1:.+]] = tensor.extract_slice %[[FOR_ARG1]][%[[IV0]], %[[IV1]]]
// CHECK: %[[FILL:.+]] = linalg.fill {{.+}} outs(%[[OUT_SLICE1]] : tensor<?x?xf32>)
+//
+// Extra 4 constant is introduced, discard it.
+// CHECK: arith.constant 4 : index
// CHECK: %[[C4:.+]] = arith.constant 4 : index
// CHECK: scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %[[C4]] iter_args(%[[FOR_ARG2:.+]] = %[[FILL]])
// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[OUT_SLICE0]]
// CHECK: %[[OUT_SLICE2:.+]] = tensor.extract_slice %[[FOR_ARG2]][0, 0]
// CHECK: linalg.generic {{.+}} ins(%[[IN_SLICE]] : tensor<?x?x?xf32>) outs(%[[OUT_SLICE2]] : tensor<?x?xf32>)
+// CHECK: return %[[RES]]
%fill = linalg.fill ins(%five : f32) outs(%init : tensor<12x25xf32>) -> tensor<12x25xf32>
%0 = linalg.generic {
@@ -92,6 +98,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [5, 4, 7], tile_interchange = [0, 2, 1]}
+ %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]}
+ %2, %loops_2 = transform.structured.tile %1 [0, 4]
}
}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 01ea7e1e3a148..781936f56f61c 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -65,10 +65,6 @@ struct TestLinalgTransforms
*this, "test-tile-and-distribute-options",
llvm::cl::desc("Test tile and distribute options"),
llvm::cl::init(false)};
- Option<bool> testTileFuseAndDistributionOptions{
- *this, "test-tile-fuse-and-distribute-options",
- llvm::cl::desc("Test tile, fuse and distribute options"),
- llvm::cl::init(false)};
Option<bool> testVectorTransferForwardingPatterns{
*this, "test-vector-transfer-forwarding-patterns",
llvm::cl::desc(
@@ -415,27 +411,6 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
}
}
-static void fillTileFuseAndDistributePatterns(MLIRContext *context,
- RewritePatternSet &patterns) {
- LinalgLoopDistributionOptions cyclicNprocsEqNiters;
- SmallVector<linalg::DistributionMethod> distributionMethod = {
- DistributionMethod::Cyclic, DistributionMethod::Cyclic};
- cyclicNprocsEqNiters.procInfo =
- [distributionMethod](OpBuilder &b, Location loc,
- ArrayRef<Range> parallelLoopRanges) {
- return getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>(
- b, loc, parallelLoopRanges, distributionMethod);
- };
- patterns.add<LinalgTileAndFuseTensorOpsPattern>(
- MatmulOp::getOperationName(), context,
- LinalgTilingAndFusionOptions()
- .setTileSizes({8, 8, 4})
- .setDistributionOptions(cyclicNprocsEqNiters),
- LinalgTransformationFilter(
- StringAttr::get(context, "tensors_fuse_distribute1"),
- StringAttr::get(context, "tensors_after_fuse_distribute1")));
-}
-
static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
RewritePatternSet forwardPattern(funcOp.getContext());
forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
@@ -552,12 +527,6 @@ void TestLinalgTransforms::runOnOperation() {
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
- if (testTileFuseAndDistributionOptions) {
- RewritePatternSet patterns(&getContext());
- fillTileFuseAndDistributePatterns(&getContext(), patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- return;
- }
if (testPatterns)
return applyPatterns(getOperation());
if (testVectorTransferForwardingPatterns)
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 977a0541ec083..8e3b9765f0882 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -199,7 +199,7 @@ static void addPatternForTiling(MLIRContext *context,
RewritePatternSet &patterns,
StringRef filterName,
ArrayRef<int64_t> tileSizes,
- ArrayRef<unsigned> interchange = {}) {
+ ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
linalg::LinalgTransformationFilter filter(
@@ -211,7 +211,7 @@ static void addPatternForTileAndFuse(MLIRContext *context,
RewritePatternSet &patterns,
StringRef filterName,
ArrayRef<int64_t> tileSizes,
- ArrayRef<unsigned> interchange = {}) {
+ ArrayRef<int64_t> interchange = {}) {
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
interchange);
More information about the Mlir-commits
mailing list