[llvm-branch-commits] [mlir] 80f0785 - [mlir][Linalg] NFC - Refactor fusion APIs
Nicolas Vasilache via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 12 06:36:12 PST 2021
Author: Nicolas Vasilache
Date: 2021-01-12T14:27:15Z
New Revision: 80f078548868d0dd3d74ab8a1deb8aa46870cdf3
URL: https://github.com/llvm/llvm-project/commit/80f078548868d0dd3d74ab8a1deb8aa46870cdf3
DIFF: https://github.com/llvm/llvm-project/commit/80f078548868d0dd3d74ab8a1deb8aa46870cdf3.diff
LOG: [mlir][Linalg] NFC - Refactor fusion APIs
This revision uniformizes fusion APIs to allow passing OpOperand, OpResult and adds a finer level of control fusion.
Differential Revision: https://reviews.llvm.org/D94493
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 3fc3fa4a5556..f3b7181d71a5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -726,6 +726,18 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
getNumShapedOperands());
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the OpOperands for all the shaped operands.
+ }],
+ /*retTy=*/" OpOperand&",
+ /*methodName=*/"getShapedOpOperand",
+ /*args=*/(ins "unsigned":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return *(this->getShapedOpOperands().begin() + i);
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the range over input and output operands.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d816414ef8b4..de1658f96a87 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -35,6 +35,7 @@ struct TiledLinalgOp {
LinalgOp op;
SmallVector<Operation *, 8> loops;
SmallVector<Value, 4> tensorResults;
+ TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
};
/// Populates patterns for vectorization of all ConvN-D ops.
@@ -412,9 +413,8 @@ struct LinalgBaseTilingPattern : public RewritePattern {
LinalgTilingOptions options,
LinalgMarker marker = LinalgMarker(),
PatternBenefit benefit = 1);
- LogicalResult
- matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
- SmallVectorImpl<Value> &tensorResults) const;
+ LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
+ TiledLinalgOp &result) const;
private:
/// LinalgTransformMarker handles special attribute manipulations.
@@ -432,14 +432,14 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
marker, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- SmallVector<Value, 4> tensorResults;
+ TiledLinalgOp tiledLinalgOp;
if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
- tensorResults)))
+ tiledLinalgOp)))
return failure();
- if (tensorResults.empty())
+ if (tiledLinalgOp.tensorResults.empty())
rewriter.eraseOp(op);
else
- rewriter.replaceOp(op, tensorResults);
+ rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
return success();
}
};
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 2ef32cfe378b..f194209f1910 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -92,26 +92,31 @@ findAllFusableDependences(ArrayRef<LinalgOp> ops,
/// Fuses producer into consumer if the producer is structurally feasible and
/// the fusion would not violate dependencies.
-/// Implements the fusion part of the "tileAndFuse on buffers"
-/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
-/// to be a `subview` op (generally obtained by applying the tiling
-/// transformation).
-Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
- unsigned consumerIdx,
+/// Implements the fusion part of the "tileAndFuse on buffers" transformation
+/// and thus requires the `consumerOpOperand` to be a `subview` op (generally
+/// obtained by applying the tiling transformation).
+Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b,
+ OpOperand &consumerOpOperand,
const LinalgDependenceGraph &graph);
/// Tensor counterpart of `fuseProducerOfBuffer`.
/// This implements the fusion part of the "tileAndFuse on tensors"
-/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
-/// to be the result of a `subtensor` op (generally obtained by applying the
-/// tiling transformation).
-Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
- unsigned consumerIdx);
+/// transformation and thus requires the `consumerOpOperand` to be a `subtensor`
+/// op (generally obtained by applying the tiling transformation).
+Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
+ OpOperand &consumerOpOperand);
+/// Tensor counterpart of `fuseProducerOfBuffer`.
+/// This implements the fusion part of the "tileAndFuse on tensors"
+/// transformation and thus requires the `consumerOpOperand` to be a `subtensor`
+/// op (generally obtained by applying the tiling transformation).
+/// Assumes `producerOfTensor` is a Linalg op that produces `consumerOpOperand`.
+Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
+ OpResult producerOpResult,
+ OpOperand &consumerOpOperand);
/// Fuse linalg operation on tensors, with the producer of the operand at
/// position `consumerIdx` of the consumer.
Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
- Operation *consumer,
- unsigned consumerIdx);
+ OpOperand &consumerOpOperand);
/// Like `getShape`, but only returns statically-known information, without
/// generating any new IR. For each shape dimension, returns >=0 if that
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index f9908af29313..8f02f3d83cf1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -258,20 +258,19 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
/// `producer.getOutputBuffers()`.
/// 2. Tensor case: `producerIdx` is the index of the tensor in
/// `producer.getResults()`.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
- LinalgOp consumer, unsigned consumerIdx) {
- AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
- LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
+static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp,
+ unsigned producerOutNumber, OpOperand &consumerOpOperand) {
+ AffineMap producerMap = producerOp.getOutputIndexingMap(producerOutNumber);
+ LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerOutNumber
<< ", producer map: " << producerMap << "\n");
DenseMap<unsigned, Range> fusedLoopsAndRanges;
- Location loc = consumer.getLoc();
- Value shapedOperand = consumer.getShapedOperand(consumerIdx);
+ Value shapedOperand = consumerOpOperand.get();
for (auto en : llvm::enumerate(producerMap.getResults())) {
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
- fusedLoopsAndRanges[posInProducerLoop] =
- getRangeFromOperandShape(b, loc, shapedOperand, en.index());
+ fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
+ b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
}
- return fuse(b, producer, fusedLoopsAndRanges);
+ return fuse(b, producerOp, fusedLoopsAndRanges);
}
// Encode structural fusion safety preconditions.
@@ -378,9 +377,10 @@ static bool isSameSubView(Value a, Value b) {
}
static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
-findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
+findFusableProducer(OpOperand &consumerOpOperand,
const LinalgDependenceGraph &dependenceGraph) {
- assert(consumer.hasBufferSemantics() && "revisit usage of shaped operand");
+ LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
+ assert(consumerOp.hasBufferSemantics() && "revisit usage of shaped operand");
// Only consider RAW and WAW atm.
for (auto depType : {
@@ -388,21 +388,16 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
LinalgDependenceGraph::DependenceType::WAW,
}) {
for (auto dependence : llvm::make_filter_range(
- dependenceGraph.getDependencesInto(consumer, depType),
- [consumerIdx](
- LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
- return elem.indexingOpView->getOperandNumber() == consumerIdx;
+ dependenceGraph.getDependencesInto(consumerOp, depType),
+ [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
+ return elem.indexingOpView->get() == consumerOpOperand.get() &&
+ elem.indexingOpView->getOperandNumber() ==
+ consumerOpOperand.getOperandNumber();
})) {
- // Check that the dependence is indeed on the input `consumerIdx` view.
- Value consumedView = dependence.indexingOpView->get();
- if (!isSameSubView(consumer.getShapedOperand(consumerIdx), consumedView))
- continue;
-
// Consumer consumes this view, `isStructurallyFusableProducer` also
// checks whether it is a strict subview of the producer view.
auto producer = cast<LinalgOp>(dependence.dependentOpView->getOwner());
- Value producedView = dependence.dependentOpView->get();
LLVM_DEBUG(llvm::dbgs()
<< "\n"
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
@@ -412,10 +407,10 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
<< dependence.dependentOpView->getOperandNumber() -
producer.getNumInputs()
<< "\n");
- (void)producedView;
// Simple fusability checks.
- if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
+ if (!isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
+ producer))
continue;
return dependence;
@@ -425,29 +420,28 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
}
Optional<FusionInfo>
-mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
- unsigned consumerIdx,
+mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
const LinalgDependenceGraph &graph) {
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
- findFusableProducer(consumer, consumerIdx, graph);
+ findFusableProducer(consumerOpOperand, graph);
if (!fusableDependence)
return {};
LinalgOp producerOp =
cast<LinalgOp>(fusableDependence->dependentOpView->getOwner());
// If producer is already in the same block as consumer, we are done.
- if (consumer->getBlock() == producerOp->getBlock())
+ if (consumerOpOperand.get().getParentBlock() ==
+ fusableDependence->dependentOpView->get().getParentBlock())
return {};
unsigned producerIdx =
fusableDependence->dependentOpView->getOperandNumber() -
producerOp.getNumInputs();
- Value consumerView = consumer.getShapedOperand(consumerIdx);
// Must be a subview or a slice to guarantee there are loops we can fuse
// into.
- auto subView = consumerView.getDefiningOp<SubViewOp>();
- auto slice = consumerView.getDefiningOp<SliceOp>();
+ auto subView = consumerOpOperand.get().getDefiningOp<SubViewOp>();
+ auto slice = consumerOpOperand.get().getDefiningOp<SliceOp>();
if (!subView && !slice) {
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
return {};
@@ -455,25 +449,25 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
// Fuse `producer` just before `consumer`.
OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(consumer.getOperation());
- ScopedContext scope(b, consumer.getLoc());
- LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
+ b.setInsertionPoint(consumerOpOperand.getOwner());
+ ScopedContext scope(b, consumerOpOperand.getOwner()->getLoc());
+ LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
+ << *consumerOpOperand.getOwner() << "\n");
- auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
+ auto fusedProducer = fuse(b, producerOp, producerIdx, consumerOpOperand);
return FusionInfo{producerOp, fusedProducer};
}
/// Walk back use-def chain through scf::For yields.
/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
-static void getProducerOfTensor(Value tensor, LinalgOp &producer,
- unsigned &outputIndex) {
+static void getProducerOfTensor(Value tensor, OpResult &opResult) {
if (!tensor.getType().isa<RankedTensorType>())
return;
while (true) {
+ LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
- producer = linalgOp;
- outputIndex = tensor.cast<OpResult>().getResultNumber();
+ opResult = tensor.cast<OpResult>();
return;
}
if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
@@ -482,7 +476,7 @@ static void getProducerOfTensor(Value tensor, LinalgOp &producer,
}
if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
- tensor = forOp.getResult(blockArg.getArgNumber());
+ tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
continue;
}
}
@@ -490,45 +484,58 @@ static void getProducerOfTensor(Value tensor, LinalgOp &producer,
}
}
-Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
- LinalgOp consumer,
- unsigned consumerIdx) {
- Value inputTensor = consumer.getInput(consumerIdx);
- LinalgOp producerOp;
- unsigned producerIdx;
- getProducerOfTensor(inputTensor, producerOp, producerIdx);
+Optional<FusionInfo>
+mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
+ Value inputTensor = consumerOpOperand.get();
+ OpResult producerOpResult;
+ getProducerOfTensor(inputTensor, producerOpResult);
+ if (!producerOpResult) {
+ LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
+ return {};
+ }
+ return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
+}
+
+Optional<FusionInfo>
+mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
+ OpOperand &consumerOpOperand) {
+ auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
+ assert(producerOp && "expected Linalg producer");
+ LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
+ Value inputTensor = consumerOpOperand.get();
// Must be a subtensor to guarantee there are loops we can fuse into.
auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
- if (!subTensor || !producerOp) {
- LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
+ if (!subTensor) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "\nNot fusable, not a subtensor: " << inputTensor);
return {};
}
// If producer is already in the same block as consumer, we are done.
- if (consumer->getBlock() == producerOp->getBlock())
+ if (consumerOpOperand.get().getParentBlock() ==
+ producerOpResult.getParentBlock())
return {};
// Insert fused `producer` just before `consumer`.
OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(consumer.getOperation());
- ScopedContext scope(b, consumer.getLoc());
- LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
- LinalgOp fusedProducer =
- fuse(b, producerOp, producerIdx, consumer, consumerIdx);
+ b.setInsertionPoint(consumerOp);
+ ScopedContext scope(b, consumerOp->getLoc());
+ LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
+ LinalgOp fusedProducer = fuse(
+ b, producerOp, producerOpResult.getResultNumber(), consumerOpOperand);
// Replace use.
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
// mismatches. Insert a `tensor.cast` op to propagate the transformation
// invariant that types are compatible.
- Value def = fusedProducer->getResult(producerIdx);
- OpOperand &use = consumer->getOpOperand(consumerIdx);
- Type consumerType = use.get().getType();
+ Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
+ Type consumerType = consumerOpOperand.get().getType();
if (consumerType != def.getType())
def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
- use.set(def);
- return FusionInfo{producerOp, fusedProducer};
+ consumerOpOperand.set(def);
+ return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
}
/// Prune all dimensions that are of reduction iterator type from `map`.
@@ -734,11 +741,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
// in the meanwhile disallow such a fusion.
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
for (LinalgOp op : reverse(ops)) {
- for (auto operandIndex :
- llvm::seq<unsigned>(0, op.getNumShapedOperands())) {
+ for (OpOperand &opOperand : op.getShapedOpOperands()) {
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
- fusableDependence =
- findFusableProducer(op, operandIndex, dependenceGraph);
+ fusableDependence = findFusableProducer(opOperand, dependenceGraph);
if (!fusableDependence)
continue;
LinalgOp producerOp =
@@ -759,7 +764,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
op.emitRemark(
"unhandled non permutation indexing map for fused view in "
"producer for operand at index ")
- << operandIndex;
+ << opOperand.getOperandNumber();
return FusableOpDependencesTy{};
}
@@ -770,7 +775,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
op.emitRemark(
"unhandled case where indexing map for fused view in the consumer "
"is not a projected permutation while fusing at index ")
- << operandIndex;
+ << opOperand.getOperandNumber();
return FusableOpDependencesTy{};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 833662d282b6..670d456ad2f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -178,8 +178,10 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
}
static Optional<SmallVector<Value, 1>>
-fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
+fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
PatternRewriter &rewriter) {
+ LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
+ unsigned consumerIdx = consumerOpOperand.getOperandNumber();
if (!areTensorOpsFusable(producer, consumer, consumerIdx))
return llvm::None;
@@ -1027,21 +1029,19 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
} // namespace
Optional<SmallVector<Value, 1>>
-mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
- unsigned consumerIdx) {
- if (consumerIdx >= consumer->getNumOperands())
- return llvm::None;
- Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
+mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
+ OpOperand &consumerOpOperand) {
+ Operation *producer = consumerOpOperand.get().getDefiningOp();
if (!producer || producer->getNumResults() != 1)
return llvm::None;
// Fuse when consumer is GenericOp or IndexedGenericOp.
- if (!isa<GenericOp, IndexedGenericOp>(consumer) ||
+ if (!isa<GenericOp, IndexedGenericOp>(consumerOpOperand.getOwner()) ||
!isa<GenericOp, IndexedGenericOp>(producer))
return llvm::None;
- return fuseTensorOpsImpl(cast<LinalgOp>(producer), cast<LinalgOp>(consumer),
- consumerIdx, rewriter);
+ return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
+ rewriter);
}
namespace {
@@ -1053,12 +1053,12 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
LogicalResult matchAndRewrite(LinalgOpTy op,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
- for (auto operandNum : llvm::seq<unsigned>(0, op->getNumOperands())) {
- Operation *producer = op->getOperand(operandNum).getDefiningOp();
+ for (OpOperand &opOperand : op.getShapedOpOperands()) {
+ Operation *producer = opOperand.get().getDefiningOp();
if (!producer)
continue;
Optional<SmallVector<Value, 1>> fusedOpResults =
- fuseTensorOps(rewriter, op, operandNum);
+ fuseTensorOps(rewriter, opOperand);
if (fusedOpResults) {
rewriter.replaceOp(op, *fusedOpResults);
if (producer->use_empty())
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c5d811c41edb..5b6302a7e5a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -117,8 +117,7 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
options(options) {}
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
- Operation *op, PatternRewriter &rewriter,
- SmallVectorImpl<Value> &tensorResults) const {
+ Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
@@ -131,7 +130,7 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
return failure();
// Return relevant information to derived pattern.
- tensorResults = res->tensorResults;
+ result = *res;
// New marker if specified.
marker.replaceLinalgMarker(rewriter, res->op.getOperation());
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 046fad43c3bf..5d55f0375f37 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -135,14 +135,14 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// Tile and Fuse for tensors inputs (TODO: all tensor operands).
bool changed = false;
for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
- for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
- if (en.value().getType().isa<MemRefType>()) {
+ for (OpOperand &opOperand : linalgOp.getShapedOpOperands()) {
+ if (opOperand.get().getType().isa<MemRefType>()) {
// TODO: LinalgDependenceGraph should be able to update itself.
// The current naive and expensive reconstruction of the graph should be
// removed.
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
- if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
+ if (auto info = fuseProducerOfBuffer(b, opOperand, graph)) {
auto *originalOp = info->originalProducer.getOperation();
eraseSet.insert(originalOp);
auto *originalOpInLinalgOpsVector =
@@ -151,11 +151,11 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
changed = true;
}
} else {
- assert(en.value().getType().isa<RankedTensorType>());
- // Tile and Fuse tensor input (TODO: init_tensors too).
- if (en.index() >= linalgOp.getNumInputs())
+ assert(opOperand.get().getType().isa<RankedTensorType>());
+ // Tile and Fuse tensor input.
+ if (opOperand.getOperandNumber() >= linalgOp.getNumInputs())
continue;
- if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
+ if (auto info = fuseProducerOfTensor(b, opOperand)) {
auto *originalOp = info->originalProducer.getOperation();
auto *originalOpInLinalgOpsVector =
std::find(linalgOps.begin(), linalgOps.end(), originalOp);
More information about the llvm-branch-commits
mailing list