[Mlir-commits] [mlir] 9b17bf2 - [mlir][Linalg] Make Linalg fusion a test pass
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Oct 29 08:19:53 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-29T15:18:51Z
New Revision: 9b17bf2e54c71b36bf28fbab05698fb73ea8dda9
URL: https://github.com/llvm/llvm-project/commit/9b17bf2e54c71b36bf28fbab05698fb73ea8dda9
DIFF: https://github.com/llvm/llvm-project/commit/9b17bf2e54c71b36bf28fbab05698fb73ea8dda9.diff
LOG: [mlir][Linalg] Make Linalg fusion a test pass
Linalg "tile-and-fuse" is currently exposed as a Linalg pass "-linalg-fusion" but only the mechanics of the transformation are currently relevant.
Instead turn it into a "-test-linalg-greedy-fusion" pass which performs canonicalizations to enable more fusions to compose.
This allows dropping the OperationFolder which is not meant to be used with the pattern rewrite infrastructure.
Differential Revision: https://reviews.llvm.org/D90394
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/fusion-2-level.mlir
mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
mlir/test/Dialect/Linalg/fusion.mlir
mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index a0235cf87fdb..24570d3c4ec6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -18,7 +18,6 @@
namespace mlir {
std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
-std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass();
std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 2df2051255c2..7446ca8f6636 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -23,12 +23,6 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
-def LinalgFusion : FunctionPass<"linalg-fusion"> {
- let summary = "Fuse operations in the linalg dialect";
- let constructor = "mlir::createLinalgFusionPass()";
- let dependentDialects = ["linalg::LinalgDialect"];
-}
-
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
let summary = "Fuse operations on RankedTensorType in linalg dialect";
let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 9b343b3b04ab..d8c595fb91fd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -88,27 +88,22 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
/// to be a `subview` op (generally obtained by applying the tiling
/// transformation).
-/// When non-null, the optional pointer `folder` is used to call into the
-/// `createAndFold` builder method. If `folder` is null, the regular `create`
-/// method is called.
Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
unsigned consumerIdx,
- const LinalgDependenceGraph &graph,
- OperationFolder *folder = nullptr);
+ const LinalgDependenceGraph &graph);
/// Tensor counterpart of `fuseProducerOfBuffer`.
/// This implements the fusion part of the "tileAndFuse on tensors"
/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
/// to be the result of a `subtensor` op (generally obtained by applying the
/// tiling transformation).
Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
- unsigned consumerIdx,
- OperationFolder *folder);
+ unsigned consumerIdx);
/// Fuse linalg operation on tensors, with the producer of the operand at
/// position `consumerIdx` of the consumer.
-Optional<SmallVector<Value, 1>>
-fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
- unsigned consumerIdx, OperationFolder *folder = nullptr);
+Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
+ Operation *consumer,
+ unsigned consumerIdx);
/// Returns the linearized list of all shape dimensions in a `linalgOp`.
/// Applying the inverse, concatenated loopToOperandRangeMaps to this list
@@ -122,17 +117,12 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, ConcreteOpTy linalgOp) {
/// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
/// concatenated indexing maps to the result of `getShape`. Returns None if
/// the bounds computation fails.
-Optional<SmallVector<Value, 4>>
-getLoopRanges(OpBuilder &builder, LinalgOp linalgOp,
- OperationFolder *folder = nullptr);
+Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
+ LinalgOp linalgOp);
/// Returns the values obtained by applying `map` to the list of values.
-/// When non-null, the optional pointer `folder` is used to call into the
-/// `createAndFold` builder method. If `folder` is null, the regular `create`
-/// method is called.
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
- AffineMap map, ValueRange values,
- OperationFolder *folder = nullptr);
+ AffineMap map, ValueRange values);
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 00eb9a2fe834..ac35d87a8413 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -13,7 +13,6 @@
#include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
-#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
@@ -24,7 +23,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
@@ -37,8 +35,6 @@ using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
-using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
-
using llvm::dbgs;
/// Implements a simple high-level fusion pass on linalg structured operations.
@@ -201,8 +197,7 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
/// 2. Tensor case: `producerIdx` is the index of the tensor in
/// `producer.getResults()`.
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
- LinalgOp consumer, unsigned consumerIdx,
- OperationFolder *folder = nullptr) {
+ LinalgOp consumer, unsigned consumerIdx) {
Operation *shapeProducingOp =
consumer.getShapedOperand(consumerIdx).getDefiningOp();
assert((isa<SubViewOp>(shapeProducingOp) ||
@@ -244,9 +239,9 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
<< "existing LoopRange: " << loopRanges[i] << "\n");
else {
auto shapeDim = getShapeDefiningLoopRange(producer, i);
- loopRanges[i] = Range{folded_std_constant_index(folder, 0),
+ loopRanges[i] = Range{std_constant_index(0),
std_dim(shapeDim.shape, shapeDim.dimension),
- folded_std_constant_index(folder, 1)};
+ std_constant_index(1)};
LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
}
}
@@ -396,15 +391,21 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
return {};
}
-Optional<FusionInfo> mlir::linalg::fuseProducerOfBuffer(
- OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
- const LinalgDependenceGraph &graph, OperationFolder *folder) {
+Optional<FusionInfo>
+mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
+ unsigned consumerIdx,
+ const LinalgDependenceGraph &graph) {
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
findFusableProducer(consumer, consumerIdx, graph);
if (!fusableDependence)
return {};
LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+ // If producer is already in the same block as consumer, we are done.
+ if (consumer.getOperation()->getBlock() ==
+ producerOp.getOperation()->getBlock())
+ return {};
+
Value producerView = fusableDependence->dependentOpView.view;
Value consumerView = fusableDependence->indexingView;
@@ -427,8 +428,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOfBuffer(
assert(producerIdxOpt.hasValue() && "incorrect operand index");
unsigned producerIdx = producerIdxOpt.getValue();
- auto fusedProducer =
- fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
+ auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
return FusionInfo{producerOp, fusedProducer};
}
@@ -459,10 +459,9 @@ static void getProducerOfTensor(Value tensor, LinalgOp &producer,
}
}
-Optional<FusionInfo>
-mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
- unsigned consumerIdx,
- OperationFolder *folder) {
+Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
+ LinalgOp consumer,
+ unsigned consumerIdx) {
Value inputTensor = consumer.getInput(consumerIdx);
LinalgOp producerOp;
unsigned producerIdx;
@@ -475,13 +474,18 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
return {};
}
+ // If producer is already in the same block as consumer, we are done.
+ if (consumer.getOperation()->getBlock() ==
+ producerOp.getOperation()->getBlock())
+ return {};
+
// Insert fused `producer` just before `consumer`.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(consumer.getOperation());
ScopedContext scope(b, consumer.getLoc());
LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
LinalgOp fusedProducer =
- fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
+ fuse(b, producerOp, producerIdx, consumer, consumerIdx);
// Replace use.
// Canonicalizations are not guaranteed to have happened before constructing
@@ -796,72 +800,3 @@ mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
}
return llvm::None;
}
-
-static void fuseLinalgOpsGreedily(FuncOp f) {
- LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
-
- OpBuilder b(f);
- OperationFolder folder(f.getContext());
- DenseSet<Operation *> eraseSet;
-
- // Save original Linalg ops, we only want to make a pass over those.
- SmallVector<Operation *, 8> linalgOps;
- f.walk([&](LinalgOp op) {
- // TODO: support multi-results.
- if (op.getOperation()->getNumResults() <= 1)
- linalgOps.push_back(op);
- });
-
- // Tile and Fuse for tensors inputs (TODO: all tensor operands).
- for (auto *op : llvm::reverse(linalgOps)) {
- LinalgOp linalgOp = cast<LinalgOp>(op);
- for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
- if (en.value().getType().isa<MemRefType>()) {
- // TODO: LinalgDependenceGraph should be able to update itself.
- // The current naive and expensive reconstruction of the graph should be
- // removed.
- linalg::Aliases aliases;
- linalg::LinalgDependenceGraph graph(aliases, linalgOps);
- if (auto info =
- fuseProducerOfBuffer(b, op, en.index(), graph, &folder)) {
- auto *originalOp = info->originalProducer.getOperation();
- eraseSet.insert(originalOp);
- auto *originalOpInLinalgOpsVector =
- std::find(linalgOps.begin(), linalgOps.end(), originalOp);
- *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
- }
- } else {
- assert(en.value().getType().isa<RankedTensorType>());
- // Tile and Fuse tensor input (TODO: init_tensors too).
- if (en.index() >= linalgOp.getNumInputs())
- continue;
- if (auto info = fuseProducerOfTensor(b, op, en.index(), &folder)) {
- auto *originalOp = info->originalProducer.getOperation();
- auto *originalOpInLinalgOpsVector =
- std::find(linalgOps.begin(), linalgOps.end(), originalOp);
- *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
- // Don't mark for erasure in the tensor case, let DCE handle this.
- }
- }
- }
- }
- // The `fuseProducerOfBuffer` function performs structural checks and in
- // particular that no covering read or write exist between the consumer and
- // the producer. As a consequence, the only fusions that may occur preserve
- // subsequent dependences and are guaranteed by construction to produce the
- // whole view. We may thus erase the producer once it is fused.
- for (auto *e : eraseSet)
- e->erase();
-
- LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
-}
-
-namespace {
-struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
- void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
-};
-} // namespace
-
-std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
- return std::make_unique<LinalgFusionPass>();
-}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d30713d76e36..3e3392d84975 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -177,8 +177,7 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
static Optional<SmallVector<Value, 1>>
fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
- PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
+ PatternRewriter &rewriter) {
if (!areTensorOpsFusable(producer, consumer, consumerIdx))
return llvm::None;
@@ -440,8 +439,8 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
/// conditions have been satisfied.
static Optional<SmallVector<Value, 1>>
fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
- unsigned fusedTensorIndex, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
+ unsigned fusedTensorIndex,
+ PatternRewriter &rewriter) {
assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) &&
"preconditions for fuse operation failed");
// Check if reshape is expanding or collapsing.
@@ -929,7 +928,7 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
Optional<SmallVector<Value, 1>>
mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
- unsigned consumerIdx, OperationFolder *folder) {
+ unsigned consumerIdx) {
if (consumerIdx >= consumer->getNumOperands())
return llvm::None;
Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
@@ -942,7 +941,7 @@ mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
return llvm::None;
return fuseTensorOpsImpl(cast<LinalgOp>(producer), cast<LinalgOp>(consumer),
- consumerIdx, rewriter, folder);
+ consumerIdx, rewriter);
}
namespace {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 3f29949ffe63..210d17516718 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -24,7 +24,6 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/FoldUtils.h"
using namespace mlir;
using namespace mlir::linalg;
@@ -57,30 +56,27 @@ RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
return llvm::None;
}
-static Value emitOrFoldComposedAffineApply(OpBuilder &b, Location loc,
- AffineMap map,
- ValueRange operandsRef,
- OperationFolder *folder) {
+static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
+ AffineMap map,
+ ValueRange operandsRef) {
SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
fullyComposeAffineMapAndOperands(&map, &operands);
canonicalizeMapAndOperands(&map, &operands);
- return folder ? folder->create<AffineApplyOp>(b, loc, map, operands)
- : b.create<AffineApplyOp>(loc, map, operands);
+ return b.createOrFold<AffineApplyOp>(loc, map, operands);
}
SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
AffineMap map,
- ValueRange values,
- OperationFolder *folder) {
+ ValueRange values) {
SmallVector<Value, 4> res;
res.reserve(map.getNumResults());
unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
// For each `expr` in `map`, applies the `expr` to the values extracted from
// ranges. If the resulting application can be folded into a Value, the
- // folding occurs eagerly. Otherwise, an affine.apply operation is emitted.
+ // folding occurs eagerly.
for (auto expr : map.getResults()) {
AffineMap map = AffineMap::get(numDims, numSym, expr);
- res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, folder));
+ res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
}
return res;
}
@@ -159,15 +155,14 @@ SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp) {
return res;
}
-Optional<SmallVector<Value, 4>>
-getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
+Optional<SmallVector<Value, 4>> getLoopRanges(OpBuilder &builder,
+ LinalgOp linalgOp) {
SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
AffineMap invertedMap =
inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
if (!invertedMap)
return {};
- return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes,
- folder);
+ return applyMapToValues(builder, linalgOp.getLoc(), invertedMap, viewSizes);
}
/// Specialization to build an scf "for" nest.
diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
index 0c9b9ca0dca7..27154a5e277c 100644
--- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-fusion | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-greedy-fusion | FileCheck %s
func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, offset: ?, strides: [?, 1]>, %C: memref<?x?xf32, offset: ?, strides: [?, 1]>, %D: memref<?x?xf32, offset: ?, strides: [?, 1]>, %E: memref<?x?xf32, offset: ?, strides: [?, 1]>) -> memref<?x?xf32, offset: ?, strides: [?, 1]> {
%c1 = constant 1 : index
diff --git a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
index ee256170bad0..3b4948a3b4f1 100644
--- a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
@@ -82,8 +82,11 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
+ %j_int = index_cast %j: index to i32
+ %j_float = sitofp %j_int : i32 to f32
%ab = addf %a, %b : f32
- %out = addf %ab, %i_float : f32
+ %tmp = addf %ab, %i_float : f32
+ %out = addf %tmp, %j_float : f32
linalg.yield %out : f32
}
%C_X = dim %C, %c0 : memref<?x?xf32>
@@ -115,6 +118,7 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
// CHECK: [[i_new:%.*]] = addi [[i]], [[I]] : index
// CHECK: [[j_new:%.*]] = addi [[j]], [[J]] : index
// CHECK: {{.*}} = index_cast [[i_new]] : index to i32
+// CHECK: {{.*}} = index_cast [[j_new]] : index to i32
// CHECK: linalg.generic
// CHECK: addf
@@ -137,10 +141,13 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,
ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C : memref<?x?xf32>) {
^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
+ %i_int = index_cast %i: index to i32
+ %i_float = sitofp %i_int : i32 to f32
%j_int = index_cast %j: index to i32
%j_float = sitofp %j_int : i32 to f32
%ab = addf %a, %b : f32
- %out = addf %ab, %j_float : f32
+ %tmp = addf %ab, %i_float : f32
+ %out = addf %tmp, %j_float : f32
linalg.yield %out : f32
}
%C_X = dim %C, %c0 : memref<?x?xf32>
@@ -176,8 +183,8 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,
// CHECK-NOT: scf.parallel
// CHECK: linalg.indexed_generic
// CHECK: ^bb0([[i:%.*]]: index, [[j:%.*]]: index
-// CHECK: [[i_new:%.*]] = addi [[i]], [[C0]] : index
// CHECK: [[j_new:%.*]] = addi [[j]], [[J]] : index
+// CHECK: {{.*}} = index_cast [[i]] : index to i32
// CHECK: {{.*}} = index_cast [[j_new]] : index to i32
// CHECK: linalg.generic
// CHECK: addf
diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 788cb89b40f7..9a4a1c5f3f6f 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>,
%B: memref<?x?xf32, offset: 0, strides: [?, 1]>,
@@ -98,6 +98,8 @@ func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
// -----
+// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
+
func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%B: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%C: memref<?x?xf32, offset: 0, strides: [?, ?]>,
@@ -137,9 +139,11 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
}
// CHECK-LABEL: func @f3
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK: %[[D_0:.*]] = dim %[[D]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[D_0:.*]] = dim %[[D]], %[[C0]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[D_1:.*]] = dim %[[D]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[C_1:.*]] = dim %[[C]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
@@ -148,6 +152,8 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
// -----
+// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
+
func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%B: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%C: memref<?x?xf32, offset: 0, strides: [?, ?]>,
@@ -190,9 +196,11 @@ func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
}
// CHECK-LABEL: func @f4
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[C_0:.*]] = dim %[[C]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[C_1:.*]] = dim %[[C]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
@@ -246,26 +254,24 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
}
// CHECK-LABEL: func @f5
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK-DAG: %[[B_1:.*]] = dim %[[B]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG: %[[D_0:.*]] = dim %[[D]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
-// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
-// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
-// CHECK-DAG: %[[D_IK:.*]] = subview %[[D]][%[[I]], %[[K]]]
-// CHECK-DAG: %[[B_KJ:.*]] = subview %[[B]][%[[K]], %[[J]]]
-// CHECK-DAG: %[[E_IJ:.*]] = subview %[[E]][%[[I]], %[[J]]]
-// CHECK: dim
-// CHECK-DAG: %[[C_I0:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK-DAG: %[[B_0K:.*]] = subview %[[B]][%{{.*}}, %[[K]]]
-// CHECK-DAG: %[[D_IK_:.*]] = subview %[[D]][%[[I]], %[[K]]]
-// CHECK: dim
-// CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}]
-// CHECK-DAG: %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}]
-// CHECK-DAG: %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_]]
-// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_]]
-// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[B_1:.*]] = dim %[[B]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[D_0:.*]] = dim %[[D]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[B_00:.*]] = subview %[[B]][0, 0]{{.*}}
+// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
+// CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], 0]
+// CHECK-DAG: %[[C_I0:.*]] = subview %[[C]][%[[I]], 0]
+// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
+// CHECK: %[[E_IJ:.*]] = subview %[[E]][%[[I]], %[[J]]]
+// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
+// CHECK-DAG: %[[D_IK:.*]] = subview %[[D]][%[[I]], %[[K]]]
+// CHECK-DAG: %[[B_0K:.*]] = subview %[[B]][0, %[[K]]]
+// CHECK-DAG: %[[B_KJ:.*]] = subview %[[B]][%[[K]], %[[J]]]
+// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0]]
+// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK]]
+// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
// -----
@@ -390,11 +396,13 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
}
// CHECK-LABEL: func @f7
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
-// CHECK: %[[A_0:.*]] = dim %[[A]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[A_1:.*]] = dim %[[A]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
-// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[A_0:.*]] = dim %[[A]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[A_1:.*]] = dim %[[A]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[C_1:.*]] = dim %[[C]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[C_0:.*]] = dim %[[C]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK: %[[D_1:.*]] = dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]]
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index e43f261632e9..41adff7d46c3 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -linalg-fusion -canonicalize -cse -split-input-file | FileCheck %s --check-prefix=CANONICALIZED
+// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
@@ -41,44 +40,19 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
// CHECK-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
-// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor<?x?xf32>
// CHECK: scf.for %[[I:[0-9a-z]*]]
+// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<2x?xf32>
// CHECK-NEXT: scf.for %[[J:[0-9a-z]*]]
-// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]]
-//
-// subtensor of the original program, first one refers to the unfused matmul and becomes a dead SSA value.
-// CHECK: subtensor %{{.*}}[%[[I]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x4xf32>
-// CHECK: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] {{.*}} : tensor<?x?xf32> to tensor<4x?xf32>
-// CHECK: %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-//
-// subtensors of the producing matmul.
-// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], %[[C0]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK-NEXT: %[[stB2:.*]] = subtensor %[[B]][%[[C0]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK-NEXT: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK-NEXT: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) init(%[[stC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-NEXT: %[[stD2:.*]] = tensor_cast %[[stD]] : tensor<?x?xf32> to tensor<?x4xf32>
-// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD2]], %[[stB1]] : tensor<?x4xf32>, tensor<4x?xf32>) init(%[[stF]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK-NEXT: subtensor_insert %[[stG]]
-
-
-// CANONICALIZED-LABEL: func @matmul_tensors(
-// CANONICALIZED-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
-// CANONICALIZED-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
-// CANONICALIZED-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
-// CANONICALIZED: %[[C0:.*]] = constant 0 : index
-// CANONICALIZED: %[[C1:.*]] = constant 1 : index
-// CANONICALIZED: scf.for %[[I:[0-9a-z]*]]
-// CANONICALIZED-NEXT: scf.for %[[J:[0-9a-z]*]]
-// CANONICALIZED-NEXT: scf.for %[[K:[0-9a-z]*]]
-//
-// CANONICALIZED: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor<?x?xf32> to tensor<4x3xf32>
-// CANONICALIZED: %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : tensor<?x?xf32> to tensor<2x3xf32>
+// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
+// CHECK-DAG: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor<?x?xf32> to tensor<4x3xf32>
+// CHECK-DAG: %[[stF:.*]] = subtensor %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor<?x?xf32> to tensor<2x3xf32>
//
// subtensors of the producing matmul.
-// CANONICALIZED: %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor<?x?xf32>
-// CANONICALIZED: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<2x?xf32>
-// CANONICALIZED-NEXT: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
-// CANONICALIZED-NEXT: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor<?x?xf32> to tensor<2x4xf32>
-// CANONICALIZED-NEXT: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) init(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
-// CANONICALIZED-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
-// CANONICALIZED-NEXT: subtensor_insert %[[stG]]
+// CHECK-DAG: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
+// CHECK-DAG: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor<?x?xf32> to tensor<2x4xf32>
+// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) init(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
+// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
+// CHECK-NEXT: subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]]
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 4dfb653ac858..33f9429bdb27 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -13,7 +13,9 @@
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::linalg;
@@ -104,10 +106,96 @@ void TestLinalgFusionTransforms::runOnFunction() {
applyFusionPatterns(&getContext(), getFunction());
}
+static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
+ OpBuilder b(f);
+ DenseSet<Operation *> eraseSet;
+
+ // Save original Linalg ops, we only want to make a pass over those.
+ SmallVector<Operation *, 8> linalgOps;
+ f.walk([&](LinalgOp op) {
+ // TODO: support multi-results.
+ if (op.getOperation()->getNumResults() <= 1)
+ linalgOps.push_back(op);
+ });
+
+ // Tile and Fuse for tensors inputs (TODO: all tensor operands).
+ bool changed = false;
+ for (auto *op : llvm::reverse(linalgOps)) {
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
+ if (en.value().getType().isa<MemRefType>()) {
+ // TODO: LinalgDependenceGraph should be able to update itself.
+ // The current naive and expensive reconstruction of the graph should be
+ // removed.
+ linalg::Aliases aliases;
+ linalg::LinalgDependenceGraph graph(aliases, linalgOps);
+ if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) {
+ auto *originalOp = info->originalProducer.getOperation();
+ eraseSet.insert(originalOp);
+ auto *originalOpInLinalgOpsVector =
+ std::find(linalgOps.begin(), linalgOps.end(), originalOp);
+ *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+ changed = true;
+ }
+ } else {
+ assert(en.value().getType().isa<RankedTensorType>());
+ // Tile and Fuse tensor input (TODO: init_tensors too).
+ if (en.index() >= linalgOp.getNumInputs())
+ continue;
+ if (auto info = fuseProducerOfTensor(b, op, en.index())) {
+ auto *originalOp = info->originalProducer.getOperation();
+ auto *originalOpInLinalgOpsVector =
+ std::find(linalgOps.begin(), linalgOps.end(), originalOp);
+ *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+ // Don't mark for erasure in the tensor case, let DCE handle this.
+ changed = true;
+ }
+ }
+ }
+ }
+ // The `fuseProducerOfBuffer` function performs structural checks and in
+ // particular that no covering read or write exist between the consumer and
+ // the producer. As a consequence, the only fusions that may occur preserve
+ // subsequent dependences and are guaranteed by construction to produce the
+ // whole view. We may thus erase the producer once it is fused.
+ for (auto *e : eraseSet)
+ e->erase();
+
+ return changed ? success() : failure();
+}
+
+namespace {
+struct TestLinalgGreedyFusion
+ : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
+ void runOnFunction() override {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList patterns =
+ linalg::getLinalgTilingCanonicalizationPatterns(context);
+ patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
+ FrozenRewritePatternList frozenPatterns(std::move(patterns));
+ while (succeeded(fuseLinalgOpsGreedily(getFunction()))) {
+ applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
+ PassManager pm(context);
+ pm.addPass(createLoopInvariantCodeMotionPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+ LogicalResult res = pm.run(getFunction().getParentOfType<ModuleOp>());
+ if (failed(res))
+ this->signalPassFailure();
+ }
+ }
+};
+} // namespace
+
namespace mlir {
void registerTestLinalgFusionTransforms() {
PassRegistration<TestLinalgFusionTransforms> testFusionTransformsPass(
"test-linalg-fusion-transform-patterns",
"Test Linalg fusion transformation patterns by applying them greedily.");
}
+void registerTestLinalgGreedyFusion() {
+ PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
+ "test-linalg-greedy-fusion",
+ "Test Linalg fusion by applying a greedy test transformation.");
+}
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index ac29305e3c6f..196bda69dbaf 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -61,6 +61,7 @@ void registerTestGpuParallelLoopMappingPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgFusionTransforms();
+void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
void registerTestLinalgTransforms();
void registerTestLivenessPass();
@@ -121,6 +122,7 @@ void registerTestPasses() {
registerTestInterfaces();
registerTestLinalgCodegenStrategy();
registerTestLinalgFusionTransforms();
+ registerTestLinalgGreedyFusion();
registerTestLinalgHoisting();
registerTestLinalgTransforms();
registerTestLivenessPass();
More information about the Mlir-commits
mailing list