[Mlir-commits] [mlir] 09635dc - [mlir][linalg] Specialize LinalgOp canonicalization patterns (NFC).

Tobias Gysi llvmlistbot at llvm.org
Wed Jul 28 04:40:15 PDT 2021


Author: Tobias Gysi
Date: 2021-07-28T11:39:34Z
New Revision: 09635dc7bfa42bed2809e3ee4edc96d0decdb9db

URL: https://github.com/llvm/llvm-project/commit/09635dc7bfa42bed2809e3ee4edc96d0decdb9db
DIFF: https://github.com/llvm/llvm-project/commit/09635dc7bfa42bed2809e3ee4edc96d0decdb9db.diff

LOG: [mlir][linalg] Specialize LinalgOp canonicalization patterns (NFC).

Specialize the DeduplicateInputs and RemoveIdentityLinalgOps patterns for GenericOp instead of implementing them for the LinalgOp interface.

This revsion is based on https://reviews.llvm.org/D105622 that moves the logic to erase identity CopyOps in a separate pattern.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D105291

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 3ddc7b788d2c..33f4992e41f9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -658,6 +658,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
 
   let verifier = [{ return ::verify(*this); }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7931f9b91066..5ce4b3d27842 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -671,6 +671,138 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
 
 static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
 
+namespace {
+// Deduplicate redundant args of a linalg generic op.
+// An arg is redundant if it has the same Value and indexing map as another.
+struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    // Associate each input to an equivalent "canonical" input that has the same
+    // Value and indexing map.
+    //
+    // In the non-duplicate case, input `i` will have canonical input `i`. But
+    // in the case of duplicated inputs, the canonical input could be some other
+    // input `< i`. That is, a later input will have some earlier input as its
+    // canonical input.
+    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
+    // For later remapping tasks like deduplicating payload block arguments,
+    // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
+    // convenient.
+    SmallVector<unsigned> canonicalInputIndices;
+    for (OpOperand *opOperand : genericOp.getInputOperands()) {
+      AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+      // STL-like maps have a convenient behavior for our use case here. In the
+      // case of duplicate keys, the insertion is rejected, and the returned
+      // iterator gives access to the value already in the map.
+      auto pair = canonicalInput.insert(
+          {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
+      canonicalInputIndices.push_back(pair.first->second);
+    }
+
+    // If there are no duplicate args, then bail out.
+    if (canonicalInput.size() == genericOp.getNumInputs())
+      return failure();
+
+    // The operands for the newly canonicalized op.
+    SmallVector<Value> newInputOperands;
+    for (OpOperand *opOperand : genericOp.getInputOperands())
+      if (canonicalInputIndices[opOperand->getOperandNumber()] ==
+          opOperand->getOperandNumber())
+        newInputOperands.push_back(opOperand->get());
+
+    // Repair the indexing maps by filtering out the ones that have been
+    // eliminated.
+    SmallVector<AffineMap> newIndexingMaps;
+    for (OpOperand *opOperand : genericOp.getInputOperands())
+      if (canonicalInputIndices[opOperand->getOperandNumber()] ==
+          opOperand->getOperandNumber())
+        newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
+    for (OpOperand *opOperand : genericOp.getOutputOperands())
+      newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
+
+    // Clone the old op with new operands.
+    SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+    auto newOp = rewriter.create<GenericOp>(
+        genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
+        outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
+        genericOp.iterator_types(), genericOp.docAttr(),
+        genericOp.library_callAttr());
+    rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
+                                newOp.region().begin());
+
+    // Repair the payload entry block by RAUW'ing redundant arguments and
+    // erasing them.
+    Block &payload = newOp.region().front();
+    SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
+    for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
+      // Iterate in reverse, so that we erase later args first, preventing the
+      // argument list from shifting unexpectedly and invalidating all our
+      // indices.
+      unsigned operandNumber = opOperand->getOperandNumber();
+      if (canonicalInputIndices[operandNumber] == operandNumber)
+        continue;
+      payload.getArgument(operandNumber)
+          .replaceAllUsesWith(
+              payload.getArgument(canonicalInputIndices[operandNumber]));
+      payload.eraseArgument(operandNumber);
+    }
+
+    rewriter.replaceOp(genericOp, newOp->getResults());
+    return success();
+  }
+};
+
+/// Remove generic operations (on tensors) that are just copying
+/// the values from inputs to the results. Requirements are
+/// 1) All iterator types are parallel
+/// 2) The body contains just a yield operation with the yielded values being
+///    the arguments corresponding to the operands.
+struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+    // Check all indexing maps are identity.
+    if (llvm::any_of(genericOp.getIndexingMaps(),
+                     [](AffineMap map) { return !map.isIdentity(); }))
+      return failure();
+
+    // Check that the body of the linalg operation is just a linalg.yield
+    // operation.
+    Block &body = genericOp.region().front();
+    if (!llvm::hasSingleElement(body))
+      return failure();
+    auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
+    if (!yieldOp)
+      return failure();
+
+    // Get the argument number of the returned values. That is the operand
+    // number to use for replacing uses of this operation.
+    SmallVector<Value> returnedArgs;
+    for (Value yieldVal : yieldOp.values()) {
+      auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+      if (!yieldArg || yieldArg.getOwner() != &body)
+        return failure();
+      unsigned argumentNumber = yieldArg.getArgNumber();
+      returnedArgs.push_back(genericOp->getOperand(argumentNumber));
+    }
+    if (returnedArgs.size() != genericOp->getNumResults())
+      return failure();
+    rewriter.replaceOp(genericOp, returnedArgs);
+    return success();
+  }
+};
+} // namespace
+
+void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                            MLIRContext *context) {
+  results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // InitTensorOp
 //===----------------------------------------------------------------------===//
@@ -2539,143 +2671,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
 };
 } // namespace
 
-namespace {
-// Deduplicate redundant args of a linalg op.
-// An arg is redundant if it has the same Value and indexing map as another.
-struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
-  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
-  LogicalResult matchAndRewrite(LinalgOp op,
-                                PatternRewriter &rewriter) const override {
-    // This pattern reduces the number of arguments of an op, which breaks
-    // the invariants of semantically charged named ops.
-    if (!isa<GenericOp>(op))
-      return failure();
-
-    // Associate each input to an equivalent "canonical" input that has the same
-    // Value and indexing map.
-    //
-    // In the non-duplicate case, input `i` will have canonical input `i`. But
-    // in the case of duplicated inputs, the canonical input could be some other
-    // input `< i`. That is, a later input will have some earlier input as its
-    // canonical input.
-    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
-    // For later remapping tasks like deduplicating payload block arguments,
-    // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
-    // convenient.
-    SmallVector<unsigned> canonicalInputIndices;
-    for (OpOperand *opOperand : op.getInputOperands()) {
-      AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
-      // STL-like maps have a convenient behavior for our use case here. In the
-      // case of duplicate keys, the insertion is rejected, and the returned
-      // iterator gives access to the value already in the map.
-      auto pair = canonicalInput.insert(
-          {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
-      canonicalInputIndices.push_back(pair.first->second);
-    }
-
-    // If there are no duplicate args, then bail out.
-    if (canonicalInput.size() == op.getNumInputs())
-      return failure();
-
-    // The operands for the newly canonicalized op.
-    SmallVector<Value> newOperands;
-    for (OpOperand *opOperand : op.getInputOperands())
-      if (canonicalInputIndices[opOperand->getOperandNumber()] ==
-          opOperand->getOperandNumber())
-        newOperands.push_back(opOperand->get());
-    SmallVector<Value> outputOperands = op.getOutputOperands();
-    llvm::append_range(newOperands, outputOperands);
-
-    // Repair the indexing maps by filtering out the ones that have been
-    // eliminated.
-    SmallVector<AffineMap> newIndexingMaps;
-    for (OpOperand *opOperand : op.getInputOperands())
-      if (canonicalInputIndices[opOperand->getOperandNumber()] ==
-          opOperand->getOperandNumber())
-        newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
-    for (OpOperand *opOperand : op.getOutputOperands())
-      newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
-
-    // Clone the old op with new operands.
-    Operation *newOp =
-        op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
-    auto newLinalgOp = cast<LinalgOp>(newOp);
-    newOp->setAttr("indexing_maps",
-                   rewriter.getAffineMapArrayAttr(newIndexingMaps));
-
-    // Set the number of inputs to the new value. The `clone` call above kept
-    // the value from the original op.
-    newLinalgOp.setNumInputs(canonicalInput.size());
-
-    // Repair the payload entry block by RAUW'ing redundant arguments and
-    // erasing them.
-    Block &payload = newOp->getRegion(0).front();
-    SmallVector<OpOperand *> inputOperands = op.getInputOperands();
-    for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
-      // Iterate in reverse, so that we erase later args first, preventing the
-      // argument list from shifting unexpectedly and invalidating all our
-      // indices.
-      unsigned operandNumber = opOperand->getOperandNumber();
-      if (canonicalInputIndices[operandNumber] == operandNumber)
-        continue;
-      payload.getArgument(operandNumber)
-          .replaceAllUsesWith(
-              payload.getArgument(canonicalInputIndices[operandNumber]));
-      payload.eraseArgument(operandNumber);
-    }
-
-    rewriter.replaceOp(op, newOp->getResults());
-    return success();
-  }
-};
-
-/// Remove generic operations (on tensors) that are just copying
-/// the values from inputs to the results. Requirements are
-/// 1) All iterator types are parallel
-/// 2) The body contains just a yield operation with the yielded values being
-///    the arguments corresponding to the operands.
-struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
-  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
-  LogicalResult matchAndRewrite(LinalgOp op,
-                                PatternRewriter &rewriter) const override {
-    if (!isa<GenericOp>(op))
-      return failure();
-    if (!op.hasTensorSemantics())
-      return failure();
-    // Check all indexing maps are identity.
-    if (llvm::any_of(op.getIndexingMaps(),
-                     [](AffineMap map) { return !map.isIdentity(); }))
-      return failure();
-
-    // Check that the body of the linalg operation is just a linalg.yield
-    // operation.
-    Block &body = op->getRegion(0).front();
-    if (!llvm::hasSingleElement(body))
-      return failure();
-    auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
-    if (!yieldOp)
-      return failure();
-
-    // Get the argument number of the returned values. That is the operand
-    // number to use for replacing uses of this operation.
-    SmallVector<Value, 4> returnedArgs;
-    for (Value yieldVal : yieldOp.values()) {
-      auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
-      if (!yieldArg || yieldArg.getOwner() != &body)
-        return failure();
-      unsigned argumentNumber = yieldArg.getArgNumber();
-      returnedArgs.push_back(op->getOperand(argumentNumber));
-    }
-    if (returnedArgs.size() != op.getOperation()->getNumResults())
-      return failure();
-    rewriter.replaceOp(op, returnedArgs);
-    return success();
-  }
-};
-} // namespace
-
 #define LINALGOP_FOLDERS(XXX)                                                  \
   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
                           SmallVectorImpl<OpFoldResult> &) {                   \
@@ -2699,6 +2694,5 @@ LINALGOP_FOLDERS(GenericOp)
 
 void LinalgDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
-  results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
-              RemoveIdentityLinalgOps>(getContext());
+  results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
 }


        


More information about the Mlir-commits mailing list