[Mlir-commits] [mlir] 98835e3 - [mlir][Linalg] Enable TileAndFusePattern to work with tensors.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 28 14:13:30 PST 2021
Author: MaheshRavishankar
Date: 2021-01-28T14:13:01-08:00
New Revision: 98835e3d9849a17e9f8eb5dcd8aee3c9d32e1e07
URL: https://github.com/llvm/llvm-project/commit/98835e3d9849a17e9f8eb5dcd8aee3c9d32e1e07
DIFF: https://github.com/llvm/llvm-project/commit/98835e3d9849a17e9f8eb5dcd8aee3c9d32e1e07.diff
LOG: [mlir][Linalg] Enable TileAndFusePattern to work with tensors.
Differential Revision: https://reviews.llvm.org/D94531
Added:
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 4fe90897873b..7441a54cdc05 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -732,45 +732,6 @@ collectFusableLoops(ArrayRef<LinalgOp> ops,
return fusableLoops;
}
-// /// For `consumer` with tensor semantics, find the Linalg operation on
-// tensors
-// /// producer the operand at position `consumerIdx`. This is a simple use-def
-// /// chain using the SSA value, but returned as an element of the
-// /// `LinalgDependenceGraphElem` to use the same analysis for both tensors and
-// /// buffers.
-// static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-// findFusableProducerForTensorOp(OpOperand &consumerOpOperand) {
-// // For now only looking for cases where the operand is produced by another
-// // Linalg structured operation.
-// LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
-// if (!consumer || !consumer.hasTensorSemantics())
-// return llvm::None;
-// unsigned consumerIdx = consumerOpOperand.getOperandNumber();
-// Value value = consumerOpOperand.get();
-// if (auto linalgOp = value.getDefiningOp<LinalgOp>()) {
-// return LinalgDependenceGraph::LinalgDependenceGraphElem{
-// &(linalgOp
-// .getOutputOpOperands()[value.cast<OpResult>().getResultNumber()]),
-// &(consumer.getInputOpOperands()[consumerIdx]),
-// LinalgDependenceGraph::DependenceType::RAW};
-// }
-// return llvm::None;
-// }
-
-// static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-// findFusableProducer(OpOperand &consumerOpOperand,
-// const LinalgDependenceGraph &dependenceGraph) {
-// LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
-// if (!consumer)
-// return llvm::None;
-// if (consumer.hasBufferSemantics())
-// return findFusableProducerForBufferOp(consumerOpOperand,
-// dependenceGraph);
-// if (consumer.hasTensorSemantics())
-// return findFusableProducerForTensorOp(consumerOpOperand);
-// return llvm::None;
-// }
-
/// Find all dependences that are fusable.
FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7260bb4aca6b..fc647ea45478 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -283,6 +283,19 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
return success();
}
+static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
+ if (tiledOp.loops.empty())
+ return tiledOp.op.getOperation()->getResults();
+ return tiledOp.loops.front()->getResults();
+}
+
+static ValueRange
+getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
+ if (tiledAndFusedOp.fusedLoops.empty())
+ return tiledAndFusedOp.op.getOperation()->getResults();
+ return tiledAndFusedOp.fusedLoops.front()->getResults();
+}
+
mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
StringRef opName, MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
@@ -301,8 +314,6 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
- if (!linalgOp.hasBufferSemantics())
- return failure();
DenseSet<Operation *> producers;
producers.insert(linalgOp);
@@ -359,9 +370,11 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
if (!unfusedTiledOp)
return failure();
- rewriter.eraseOp(tiledAndFusedOps->op);
+ rewriter.replaceOp(tiledAndFusedOps->op,
+ getTiledOpResult(unfusedTiledOp.getValue()));
tiledAndFusedOps->op = unfusedTiledOp->op;
}
+ op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
marker.replaceLinalgTransformationFilter(rewriter,
tiledAndFusedOps->op.getOperation());
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
new file mode 100644
index 000000000000..10cd5b454a4a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -0,0 +1,142 @@
+// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
+
+module {
+ func @matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
+ %arg4: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN1> <N1xN2>
+ %1 = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
+ ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN2> <N2xN3>
+ return %1 : tensor<?x?xf32>
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (64, d0 - d1)>
+// CHECK: func @matmul_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C32:.+]] = constant 32 : index
+// CHECK-DAG: %[[C64:.+]] = constant 64 : index
+// CHECK-DAG: %[[C16:.+]] = constant 16 : index
+// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK-SAME: %[[C0]] to %[[M]] step %[[C32]]
+// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+// CHECK: %[[M_2:.+]] = dim %[[ARG6]], %[[C0]]
+// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[M_2]], %[[IV0]])
+// CHECK: %[[N3:.+]] = dim %[[ARG6]], %[[C1]]
+// CHECK: %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
+// CHECK: %[[N2:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK: %[[N1:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK: %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
+// CHECK: %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, 0]
+// CHECK-SAME: [%[[N1]], %[[N2]]]
+// CHECK: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
+// CHECK: %[[LHS:.+]] = linalg.matmul
+// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
+// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
+// CHECK: %[[N3_2:.+]] = dim %[[ARG3]], %[[C1]]
+// CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]]
+// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ST_ARG6]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[YIELD1:.+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] =
+// CHECK-SAME: %[[C0]] to %[[N2]] step %[[C16]]
+// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]]
+// CHECK: %[[ST_LHS:.+]] = subtensor %[[LHS]][0, %[[IV2]]]
+// CHECK-SAME: [%[[TILE_M]], %[[TILE_N2]]]
+// CHECK: %[[N2_3:.+]] = dim %[[ARG3]], %[[C0]]
+// CHECK: %[[TILE_N2_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2_3]]]
+// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]]
+// CHECK: %[[ST_ARG3:.+]] = subtensor %[[ARG3]][%[[IV2]], %[[IV1]]]
+// CHECK-SAME: [%[[TILE_N2_2]], %[[TILE_N3]]]
+// CHECK: %[[M_4:.+]] = dim %[[ARG10]], %[[C0]]
+// CHECK: %[[N3_3:.+]] = dim %[[ARG10]], %[[C1]]
+// CHECK: %[[TILE_N3_2:.+]] = affine.min #[[MAP4]](%[[N3_3]], %[[IV1]])
+// CHECK: %[[ST_ARG4:.+]] = subtensor %[[ARG10]][0, %[[IV1]]]
+// CHECK-SAME: [%[[M_4]], %[[TILE_N3_2]]]
+// CHECK: %[[ST_RESULT:.+]] = linalg.matmul
+// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion"
+// CHECK-SAME: ins(%[[ST_LHS]], %[[ST_ARG3]]
+// CHECK-SAME: : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ST_ARG4]] : tensor<?x?xf32>)
+// CHECK: %[[UPDATE1:.+]] = subtensor_insert %[[ST_RESULT]]
+// CHECK-SAME: into %[[ARG10]][0, %[[IV1]]] [%[[M_4]], %[[TILE_N3_2]]]
+// CHECK: scf.yield %[[UPDATE1]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD1]]
+// CHECK: }
+// CHECK: %[[UPDATE0:.+]] = subtensor_insert %[[YIELD0]] into
+// CHECK-SAME: %[[ARG6]][%[[IV0]], 0] [%[[TILE_M_2]], %[[N3]]]
+// CHECK: scf.yield %[[UPDATE0]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+module {
+ func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg2, %c0 : tensor<?x?xf32>
+ %1 = dim %arg2, %c1 : tensor<?x?xf32>
+ %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %3 = dim %2, %c0 : tensor<?x?xf32>
+ %4 = dim %2, %c1 : tensor<?x?xf32>
+ %5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
+ %6 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"],
+ __internal_linalg_transform__ = "transpose_fusion"}
+ ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%5 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
+ %7 = addf %arg3, %arg4 : f32
+ linalg.yield %7 : f32
+ } -> tensor<?x?xf32>
+ return %6 : tensor<?x?xf32>
+ }
+}
+// CHECK: func @matmul_plus_matmul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]])
+// CHECK: %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
+// CHECK: %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, %[[IV1]]]
+// CHECK: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[LHS:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]]
+// CHECK-SAME: : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
+// CHECK: %[[ST_RESULT:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[LHS]] : tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ST_ARG6]] : tensor<?x?xf32>)
+// CHECK: %[[UPDATE:.+]] = subtensor_insert %[[ST_RESULT]]
+// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]]
+// CHECK: scf.yield %[[UPDATE]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index f2c9067d5cc2..20e0166d7331 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -20,30 +20,14 @@
using namespace mlir;
using namespace mlir::linalg;
-namespace {
-struct TestLinalgFusionTransforms
- : public PassWrapper<TestLinalgFusionTransforms, FunctionPass> {
- TestLinalgFusionTransforms() = default;
- TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect,
- StandardOpsDialect>();
- }
-
- void runOnFunction() override;
-};
-} // namespace
-
+template <LinalgTilingLoopType LoopType>
static void fillFusionPatterns(MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
OwningRewritePatternList &patterns) {
patterns.insert<LinalgTileAndFusePattern<MatmulOp>,
LinalgTileAndFusePattern<ConvOp>>(
context, dependenceGraph,
- LinalgTilingOptions()
- .setTileSizes({32, 64, 16})
- .setLoopType(LinalgTilingLoopType::ParallelLoops),
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({2}),
LinalgTransformationFilter(
Identifier::get("basic_fusion", context),
@@ -57,9 +41,7 @@ static void fillFusionPatterns(MLIRContext *context,
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
- LinalgTilingOptions()
- .setTileSizes({32, 64, 16})
- .setLoopType(LinalgTilingLoopType::ParallelLoops),
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({0}),
LinalgTransformationFilter(Identifier::get("lhs_fusion", context),
Identifier::get("after_lhs_fusion", context)),
@@ -72,9 +54,7 @@ static void fillFusionPatterns(MLIRContext *context,
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
- LinalgTilingOptions()
- .setTileSizes({32, 64, 16})
- .setLoopType(LinalgTilingLoopType::ParallelLoops),
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({1}),
LinalgTransformationFilter(Identifier::get("rhs_fusion", context),
Identifier::get("after_rhs_fusion", context)),
@@ -87,9 +67,7 @@ static void fillFusionPatterns(MLIRContext *context,
patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
context, dependenceGraph,
- LinalgTilingOptions()
- .setTileSizes({32, 64, 16})
- .setLoopType(LinalgTilingLoopType::ParallelLoops),
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({0, 2}),
LinalgTransformationFilter(
Identifier::get("two_operand_fusion", context),
@@ -103,8 +81,7 @@ static void fillFusionPatterns(MLIRContext *context,
patterns.insert<LinalgTileAndFusePattern<GenericOp>>(
context, dependenceGraph,
- LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(
- LinalgTilingLoopType::ParallelLoops),
+ LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
LinalgFusionOptions().setIndicesToFuse({0, 1}),
LinalgTransformationFilter(
Identifier::get("transpose_fusion", context),
@@ -117,18 +94,30 @@ static void fillFusionPatterns(MLIRContext *context,
Identifier::get("after_transpose_fusion_original", context)));
}
-static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {
- OwningRewritePatternList fusionPatterns;
- Aliases alias;
- LinalgDependenceGraph dependenceGraph =
- LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
- fillFusionPatterns(context, dependenceGraph, fusionPatterns);
- applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
-}
+namespace {
+template <LinalgTilingLoopType LoopType = LinalgTilingLoopType::ParallelLoops>
+struct TestLinalgFusionTransforms
+ : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> {
+ TestLinalgFusionTransforms() = default;
+ TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
-void TestLinalgFusionTransforms::runOnFunction() {
- applyFusionPatterns(&getContext(), getFunction());
-}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect,
+ StandardOpsDialect>();
+ }
+
+ void runOnFunction() override {
+ MLIRContext *context = &this->getContext();
+ FuncOp funcOp = this->getFunction();
+ OwningRewritePatternList fusionPatterns;
+ Aliases alias;
+ LinalgDependenceGraph dependenceGraph =
+ LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
+ fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
+ applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
+ }
+};
+} // namespace
static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
OpBuilder b(f);
@@ -237,7 +226,7 @@ struct TestLinalgTileAndFuseSequencePass
LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
OpBuilder builder(funcOp.getContext());
linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
- if (llvm::all_of(linalgOps, [](LinalgOp linalgOp) {
+ if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
return linalgOp.hasTensorSemantics();
}))
loopType = LinalgTilingLoopType::Loops;
@@ -260,10 +249,17 @@ struct TestLinalgTileAndFuseSequencePass
namespace mlir {
namespace test {
void registerTestLinalgFusionTransforms() {
- PassRegistration<TestLinalgFusionTransforms> testFusionTransformsPass(
+ PassRegistration<TestLinalgFusionTransforms<>> testFusionTransformsPass(
"test-linalg-fusion-transform-patterns",
"Test Linalg fusion transformation patterns by applying them greedily.");
}
+void registerTestLinalgTensorFusionTransforms() {
+ PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::Loops>>
+ testTensorFusionTransformsPass(
+ "test-linalg-tensor-fusion-transform-patterns",
+ "Test Linalg on tensor fusion transformation "
+ "patterns by applying them greedily.");
+}
void registerTestLinalgGreedyFusion() {
PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
"test-linalg-greedy-fusion",
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index dc68f8f4d778..7336ae18e3c6 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -74,6 +74,7 @@ void registerTestGpuParallelLoopMappingPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgFusionTransforms();
+void registerTestLinalgTensorFusionTransforms();
void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
void registerTestLinalgTileAndFuseSequencePass();
@@ -145,6 +146,7 @@ void registerTestPasses() {
test::registerTestInterfaces();
test::registerTestLinalgCodegenStrategy();
test::registerTestLinalgFusionTransforms();
+ test::registerTestLinalgTensorFusionTransforms();
test::registerTestLinalgGreedyFusion();
test::registerTestLinalgHoisting();
test::registerTestLinalgTileAndFuseSequencePass();
More information about the Mlir-commits
mailing list