[Mlir-commits] [mlir] 8fb6c31 - [mlir][linalg] Cleanup LinalgOp usage in op declarations.
Tobias Gysi
llvmlistbot at llvm.org
Thu Jun 3 07:06:08 PDT 2021
Author: Tobias Gysi
Date: 2021-06-03T14:04:44Z
New Revision: 8fb6c31cbba51b494f232273cdc54dc0788fcd59
URL: https://github.com/llvm/llvm-project/commit/8fb6c31cbba51b494f232273cdc54dc0788fcd59
DIFF: https://github.com/llvm/llvm-project/commit/8fb6c31cbba51b494f232273cdc54dc0788fcd59.diff
LOG: [mlir][linalg] Cleanup LinalgOp usage in op declarations.
Replace the uses of deprecated Structured Op Interface methods in LinalgOps.cpp. This patch is based on https://reviews.llvm.org/D103394.
Differential Revision: https://reviews.llvm.org/D103506
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 04c6e4b9d4029..89387b08c11c1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -375,11 +375,12 @@ ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
static LogicalResult verify(CopyOp op) {
- auto outputViewType = op.getOutputShapedType(0);
- auto inputViewType = op.getInputShapedType(0);
- if (inputViewType.getElementType() != outputViewType.getElementType())
+ OpOperand *output = op.getOutputOperand(0);
+ OpOperand *input = op.getInputOperand(0);
+ if (getElementTypeOrSelf(input->get().getType()) !=
+ getElementTypeOrSelf(output->get().getType()))
return op.emitOpError("expects views of the same type");
- if (inputViewType.getRank() != outputViewType.getRank())
+ if (op.getRank(input) != op.getRank(output))
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
@@ -449,11 +450,11 @@ ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
static LogicalResult verify(FillOp op) {
- auto viewType = op.getOutputShapedType(0);
- auto fillType = op.value().getType();
- if (viewType.getElementType() != fillType)
+ OpOperand *output = op.getOutputOperand(0);
+ Type fillType = op.value().getType();
+ if (getElementTypeOrSelf(output->get().getType()) != fillType)
return op.emitOpError("expects fill type to match view elemental type");
- if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
+ if (!op.getNumResults() && !output->get().getType().isa<MemRefType>()) {
return op.emitOpError(
"expected fill op with no result value to use memref type");
}
@@ -739,11 +740,13 @@ struct ConvertIndexedToGenericOp : OpRewritePattern<IndexedGenericOp> {
// Create a generic replacement operation and clone the body.
rewriter.setInsertionPointAfter(indexedOp);
+ SmallVector<Value> inputOperands = indexedOp.getInputOperands();
+ SmallVector<Value> outputOperands = indexedOp.getOutputOperands();
SmallVector<StringRef> iterators = llvm::to_vector<4>(
indexedOp.iterator_types().getAsValueRange<StringAttr>());
GenericOp genericOp = rewriter.create<GenericOp>(
- indexedOp.getLoc(), indexedOp->getResultTypes(), indexedOp.getInputs(),
- indexedOp.getOutputs(), indexedOp.getIndexingMaps(), iterators);
+ indexedOp.getLoc(), indexedOp->getResultTypes(), inputOperands,
+ outputOperands, indexedOp.getIndexingMaps(), iterators);
Region &genericRegion = genericOp.region();
Region &indexedRegion = indexedOp.region();
rewriter.cloneRegionBefore(indexedRegion, genericRegion,
@@ -2107,21 +2110,21 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
// Check the operand number and types must match the element types of the
// LinalgOp interface's shaped operands.
-static LogicalResult verifyYield(linalg::YieldOp op,
- LinalgOp linalgOpInterface) {
- auto nOutputs = linalgOpInterface.getNumOutputs();
- if (op.getNumOperands() != nOutputs)
+static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
+ if (op.getNumOperands() != linalgOp.getNumOutputs())
return op.emitOpError("expected number of yield values (")
- << nOutputs << ") to match the number of operands of the enclosing "
+ << linalgOp.getNumOutputs()
+ << ") to match the number of operands of the enclosing "
<< "LinalgOp (" << op.getNumOperands() << ")";
- for (unsigned i = 0; i != nOutputs; ++i) {
- auto elementType =
- linalgOpInterface.getOutputShapedType(i).getElementType();
- if (op.getOperand(i).getType() != elementType)
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ OpOperand *outputOperand =
+ linalgOp.getOutputOperand(opOperand.getOperandNumber());
+ Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
+ if (opOperand.get().getType() != elementType)
return op.emitOpError("type of yield operand ")
- << (i + 1) << " (" << op.getOperand(i).getType()
- << ") doesn't match "
+ << (opOperand.getOperandNumber() + 1) << " ("
+ << opOperand.get().getType() << ") doesn't match "
<< "the element type of the enclosing linalg.generic op ("
<< elementType << ")";
}
@@ -3096,14 +3099,14 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
- for (Value v : op.getShapedOperands()) {
+ for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
// Linalg "inputs" may be either tensor or memref type.
// tensor<0xelt_type> is a convention that may not always mean
// "0 iterations". Only erase in cases we see memref<...x0x...>.
- auto mt = v.getType().dyn_cast<MemRefType>();
+ auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
if (!mt)
continue;
- if (llvm::is_contained(mt.getShape(), 0)) {
+ if (llvm::is_contained(op.getShape(opOperand), 0)) {
rewriter.eraseOp(op);
return success();
}
@@ -3119,10 +3122,10 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
PatternRewriter &rewriter) const override {
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
- llvm::any_of(op.getShapedOperands(), [&](Value v) {
- if (v.isa<BlockArgument>())
+ llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
+ if (opOperand->get().isa<BlockArgument>())
return false;
- auto castOp = v.getDefiningOp<tensor::CastOp>();
+ auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
if (!hasTensorCastOperand)
@@ -3133,16 +3136,18 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
// Inputs may fold.
- for (Value v : op.getInputs()) {
- auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
- newOperands.push_back(
- canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
+ for (OpOperand *opOperand : op.getInputOperands()) {
+ auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
+ ? tensorCastOp.source()
+ : opOperand->get());
}
// Init tensors may fold, in which case the resultType must also change.
- for (Value v : op.getOutputs()) {
- auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
+ for (OpOperand *opOperand : op.getOutputOperands()) {
+ auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
- newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
+ newOperands.push_back(fold ? tensorCastOp.getOperand()
+ : opOperand->get());
newResultTypes.push_back(newOperands.back().getType());
}
auto extraOperands = op.getAssumedNonShapedOperands();
@@ -3189,18 +3194,18 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
// in the case of duplicated inputs, the canonical input could be some other
// input `< i`. That is, a later input will have some earlier input as its
// canonical input.
- llvm::SmallDenseMap<std::pair<Value, AffineMap>, int> canonicalInput;
+ llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
// For later remapping tasks like deduplicating payload block arguments,
// having a simple "inputIndex -> canonicalInputIndex" integer mapping is
// convenient.
- SmallVector<int, 6> canonicalInputIndices;
- for (int i = 0, e = op.getNumInputs(); i != e; i++) {
- Value input = op.getInput(i);
- AffineMap indexingMap = op.getInputIndexingMap(i);
+ SmallVector<unsigned> canonicalInputIndices;
+ for (OpOperand *opOperand : op.getInputOperands()) {
+ AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
// STL-like maps have a convenient behavior for our use case here. In the
// case of duplicate keys, the insertion is rejected, and the returned
// iterator gives access to the value already in the map.
- auto pair = canonicalInput.insert({{input, indexingMap}, i});
+ auto pair = canonicalInput.insert(
+ {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
canonicalInputIndices.push_back(pair.first->second);
}
@@ -3209,26 +3214,29 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
return failure();
// The operands for the newly canonicalized op.
- SmallVector<Value, 6> newOperands;
- for (auto v : llvm::enumerate(op.getInputs()))
- if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
- newOperands.push_back(v.value());
- llvm::append_range(newOperands, op.getOutputs());
+ SmallVector<Value> newOperands;
+ for (OpOperand *opOperand : op.getInputOperands())
+ if (canonicalInputIndices[opOperand->getOperandNumber()] ==
+ opOperand->getOperandNumber())
+ newOperands.push_back(opOperand->get());
+ SmallVector<Value> outputOperands = op.getOutputOperands();
+ llvm::append_range(newOperands, outputOperands);
llvm::append_range(newOperands, op.getAssumedNonShapedOperands());
+ // Repair the indexing maps by filtering out the ones that have been
+ // eliminated.
+ SmallVector<AffineMap> newIndexingMaps;
+ for (OpOperand *opOperand : op.getInputOperands())
+ if (canonicalInputIndices[opOperand->getOperandNumber()] ==
+ opOperand->getOperandNumber())
+ newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
+ for (OpOperand *opOperand : op.getOutputOperands())
+ newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
+
// Clone the old op with new operands.
Operation *newOp =
op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
auto newLinalgOp = cast<LinalgOp>(newOp);
-
- // Repair the indexing maps by filtering out the ones that have been
- // eliminated.
- SmallVector<AffineMap, 6> newIndexingMaps;
- for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++)
- if (canonicalInputIndices[i] == i)
- newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i));
- for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++)
- newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i));
newOp->setAttr("indexing_maps",
rewriter.getAffineMapArrayAttr(newIndexingMaps));
@@ -3243,18 +3251,18 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
// Repair the payload entry block by RAUW'ing redundant arguments and
// erasing them.
Block &payload = newOp->getRegion(0).front();
- for (int i = 0, e = op.getNumInputs(); i < e; i++) {
+ SmallVector<OpOperand *> inputOperands = op.getInputOperands();
+ for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
// Iterate in reverse, so that we erase later args first, preventing the
// argument list from shifting unexpectedly and invalidating all our
// indices.
- int reversed = e - i - 1;
- int canonicalIndex = canonicalInputIndices[reversed];
- if (canonicalInputIndices[reversed] == reversed)
+ unsigned operandNumber = opOperand->getOperandNumber();
+ if (canonicalInputIndices[operandNumber] == operandNumber)
continue;
- payload.getArgument(bbArgBaseOffset + reversed)
- .replaceAllUsesWith(
- payload.getArgument(bbArgBaseOffset + canonicalIndex));
- payload.eraseArgument(bbArgBaseOffset + reversed);
+ payload.getArgument(bbArgBaseOffset + operandNumber)
+ .replaceAllUsesWith(payload.getArgument(
+ bbArgBaseOffset + canonicalInputIndices[operandNumber]));
+ payload.eraseArgument(bbArgBaseOffset + operandNumber);
}
rewriter.replaceOp(op, newOp->getResults());
More information about the Mlir-commits
mailing list