[Mlir-commits] [mlir] 09635dc - [mlir][linalg] Specialize LinalgOp canonicalization patterns (NFC).
Tobias Gysi
llvmlistbot at llvm.org
Wed Jul 28 04:40:15 PDT 2021
Author: Tobias Gysi
Date: 2021-07-28T11:39:34Z
New Revision: 09635dc7bfa42bed2809e3ee4edc96d0decdb9db
URL: https://github.com/llvm/llvm-project/commit/09635dc7bfa42bed2809e3ee4edc96d0decdb9db
DIFF: https://github.com/llvm/llvm-project/commit/09635dc7bfa42bed2809e3ee4edc96d0decdb9db.diff
LOG: [mlir][linalg] Specialize LinalgOp canonicalization patterns (NFC).
Specialize the DeduplicateInputs and RemoveIdentityLinalgOps patterns for GenericOp instead of implementing them for the LinalgOp interface.
This revsion is based on https://reviews.llvm.org/D105622 that moves the logic to erase identity CopyOps in a separate pattern.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D105291
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 3ddc7b788d2c..33f4992e41f9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -658,6 +658,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
let verifier = [{ return ::verify(*this); }];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7931f9b91066..5ce4b3d27842 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -671,6 +671,138 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
+namespace {
+// Deduplicate redundant args of a linalg generic op.
+// An arg is redundant if it has the same Value and indexing map as another.
+struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ // Associate each input to an equivalent "canonical" input that has the same
+ // Value and indexing map.
+ //
+ // In the non-duplicate case, input `i` will have canonical input `i`. But
+ // in the case of duplicated inputs, the canonical input could be some other
+ // input `< i`. That is, a later input will have some earlier input as its
+ // canonical input.
+ llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
+ // For later remapping tasks like deduplicating payload block arguments,
+ // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
+ // convenient.
+ SmallVector<unsigned> canonicalInputIndices;
+ for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ // STL-like maps have a convenient behavior for our use case here. In the
+ // case of duplicate keys, the insertion is rejected, and the returned
+ // iterator gives access to the value already in the map.
+ auto pair = canonicalInput.insert(
+ {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
+ canonicalInputIndices.push_back(pair.first->second);
+ }
+
+ // If there are no duplicate args, then bail out.
+ if (canonicalInput.size() == genericOp.getNumInputs())
+ return failure();
+
+ // The operands for the newly canonicalized op.
+ SmallVector<Value> newInputOperands;
+ for (OpOperand *opOperand : genericOp.getInputOperands())
+ if (canonicalInputIndices[opOperand->getOperandNumber()] ==
+ opOperand->getOperandNumber())
+ newInputOperands.push_back(opOperand->get());
+
+ // Repair the indexing maps by filtering out the ones that have been
+ // eliminated.
+ SmallVector<AffineMap> newIndexingMaps;
+ for (OpOperand *opOperand : genericOp.getInputOperands())
+ if (canonicalInputIndices[opOperand->getOperandNumber()] ==
+ opOperand->getOperandNumber())
+ newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
+ for (OpOperand *opOperand : genericOp.getOutputOperands())
+ newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
+
+ // Clone the old op with new operands.
+ SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+ auto newOp = rewriter.create<GenericOp>(
+ genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
+ outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
+ genericOp.iterator_types(), genericOp.docAttr(),
+ genericOp.library_callAttr());
+ rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
+ newOp.region().begin());
+
+ // Repair the payload entry block by RAUW'ing redundant arguments and
+ // erasing them.
+ Block &payload = newOp.region().front();
+ SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
+ for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
+ // Iterate in reverse, so that we erase later args first, preventing the
+ // argument list from shifting unexpectedly and invalidating all our
+ // indices.
+ unsigned operandNumber = opOperand->getOperandNumber();
+ if (canonicalInputIndices[operandNumber] == operandNumber)
+ continue;
+ payload.getArgument(operandNumber)
+ .replaceAllUsesWith(
+ payload.getArgument(canonicalInputIndices[operandNumber]));
+ payload.eraseArgument(operandNumber);
+ }
+
+ rewriter.replaceOp(genericOp, newOp->getResults());
+ return success();
+ }
+};
+
+/// Remove generic operations (on tensors) that are just copying
+/// the values from inputs to the results. Requirements are
+/// 1) All iterator types are parallel
+/// 2) The body contains just a yield operation with the yielded values being
+/// the arguments corresponding to the operands.
+struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+ // Check all indexing maps are identity.
+ if (llvm::any_of(genericOp.getIndexingMaps(),
+ [](AffineMap map) { return !map.isIdentity(); }))
+ return failure();
+
+ // Check that the body of the linalg operation is just a linalg.yield
+ // operation.
+ Block &body = genericOp.region().front();
+ if (!llvm::hasSingleElement(body))
+ return failure();
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
+ if (!yieldOp)
+ return failure();
+
+ // Get the argument number of the returned values. That is the operand
+ // number to use for replacing uses of this operation.
+ SmallVector<Value> returnedArgs;
+ for (Value yieldVal : yieldOp.values()) {
+ auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+ if (!yieldArg || yieldArg.getOwner() != &body)
+ return failure();
+ unsigned argumentNumber = yieldArg.getArgNumber();
+ returnedArgs.push_back(genericOp->getOperand(argumentNumber));
+ }
+ if (returnedArgs.size() != genericOp->getNumResults())
+ return failure();
+ rewriter.replaceOp(genericOp, returnedArgs);
+ return success();
+ }
+};
+} // namespace
+
+void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
+}
+
//===----------------------------------------------------------------------===//
// InitTensorOp
//===----------------------------------------------------------------------===//
@@ -2539,143 +2671,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
};
} // namespace
-namespace {
-// Deduplicate redundant args of a linalg op.
-// An arg is redundant if it has the same Value and indexing map as another.
-struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
- using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
- LogicalResult matchAndRewrite(LinalgOp op,
- PatternRewriter &rewriter) const override {
- // This pattern reduces the number of arguments of an op, which breaks
- // the invariants of semantically charged named ops.
- if (!isa<GenericOp>(op))
- return failure();
-
- // Associate each input to an equivalent "canonical" input that has the same
- // Value and indexing map.
- //
- // In the non-duplicate case, input `i` will have canonical input `i`. But
- // in the case of duplicated inputs, the canonical input could be some other
- // input `< i`. That is, a later input will have some earlier input as its
- // canonical input.
- llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
- // For later remapping tasks like deduplicating payload block arguments,
- // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
- // convenient.
- SmallVector<unsigned> canonicalInputIndices;
- for (OpOperand *opOperand : op.getInputOperands()) {
- AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
- // STL-like maps have a convenient behavior for our use case here. In the
- // case of duplicate keys, the insertion is rejected, and the returned
- // iterator gives access to the value already in the map.
- auto pair = canonicalInput.insert(
- {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
- canonicalInputIndices.push_back(pair.first->second);
- }
-
- // If there are no duplicate args, then bail out.
- if (canonicalInput.size() == op.getNumInputs())
- return failure();
-
- // The operands for the newly canonicalized op.
- SmallVector<Value> newOperands;
- for (OpOperand *opOperand : op.getInputOperands())
- if (canonicalInputIndices[opOperand->getOperandNumber()] ==
- opOperand->getOperandNumber())
- newOperands.push_back(opOperand->get());
- SmallVector<Value> outputOperands = op.getOutputOperands();
- llvm::append_range(newOperands, outputOperands);
-
- // Repair the indexing maps by filtering out the ones that have been
- // eliminated.
- SmallVector<AffineMap> newIndexingMaps;
- for (OpOperand *opOperand : op.getInputOperands())
- if (canonicalInputIndices[opOperand->getOperandNumber()] ==
- opOperand->getOperandNumber())
- newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
- for (OpOperand *opOperand : op.getOutputOperands())
- newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
-
- // Clone the old op with new operands.
- Operation *newOp =
- op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
- auto newLinalgOp = cast<LinalgOp>(newOp);
- newOp->setAttr("indexing_maps",
- rewriter.getAffineMapArrayAttr(newIndexingMaps));
-
- // Set the number of inputs to the new value. The `clone` call above kept
- // the value from the original op.
- newLinalgOp.setNumInputs(canonicalInput.size());
-
- // Repair the payload entry block by RAUW'ing redundant arguments and
- // erasing them.
- Block &payload = newOp->getRegion(0).front();
- SmallVector<OpOperand *> inputOperands = op.getInputOperands();
- for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
- // Iterate in reverse, so that we erase later args first, preventing the
- // argument list from shifting unexpectedly and invalidating all our
- // indices.
- unsigned operandNumber = opOperand->getOperandNumber();
- if (canonicalInputIndices[operandNumber] == operandNumber)
- continue;
- payload.getArgument(operandNumber)
- .replaceAllUsesWith(
- payload.getArgument(canonicalInputIndices[operandNumber]));
- payload.eraseArgument(operandNumber);
- }
-
- rewriter.replaceOp(op, newOp->getResults());
- return success();
- }
-};
-
-/// Remove generic operations (on tensors) that are just copying
-/// the values from inputs to the results. Requirements are
-/// 1) All iterator types are parallel
-/// 2) The body contains just a yield operation with the yielded values being
-/// the arguments corresponding to the operands.
-struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
- using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
- LogicalResult matchAndRewrite(LinalgOp op,
- PatternRewriter &rewriter) const override {
- if (!isa<GenericOp>(op))
- return failure();
- if (!op.hasTensorSemantics())
- return failure();
- // Check all indexing maps are identity.
- if (llvm::any_of(op.getIndexingMaps(),
- [](AffineMap map) { return !map.isIdentity(); }))
- return failure();
-
- // Check that the body of the linalg operation is just a linalg.yield
- // operation.
- Block &body = op->getRegion(0).front();
- if (!llvm::hasSingleElement(body))
- return failure();
- auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
- if (!yieldOp)
- return failure();
-
- // Get the argument number of the returned values. That is the operand
- // number to use for replacing uses of this operation.
- SmallVector<Value, 4> returnedArgs;
- for (Value yieldVal : yieldOp.values()) {
- auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
- if (!yieldArg || yieldArg.getOwner() != &body)
- return failure();
- unsigned argumentNumber = yieldArg.getArgNumber();
- returnedArgs.push_back(op->getOperand(argumentNumber));
- }
- if (returnedArgs.size() != op.getOperation()->getNumResults())
- return failure();
- rewriter.replaceOp(op, returnedArgs);
- return success();
- }
-};
-} // namespace
-
#define LINALGOP_FOLDERS(XXX) \
LogicalResult XXX::fold(ArrayRef<Attribute>, \
SmallVectorImpl<OpFoldResult> &) { \
@@ -2699,6 +2694,5 @@ LINALGOP_FOLDERS(GenericOp)
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
- results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
- RemoveIdentityLinalgOps>(getContext());
+ results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
}
More information about the Mlir-commits
mailing list