[Mlir-commits] [mlir] 8be7e6f - [mlir][Linalg] Combine canonicalizers that deal with removing dead/redundant args.

Mahesh Ravishankar llvmlistbot at llvm.org
Wed May 11 22:22:58 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-05-12T05:22:30Z
New Revision: 8be7e6f56ac0f553a88c759d206ec51c6510bf08

URL: https://github.com/llvm/llvm-project/commit/8be7e6f56ac0f553a88c759d206ec51c6510bf08
DIFF: https://github.com/llvm/llvm-project/commit/8be7e6f56ac0f553a88c759d206ec51c6510bf08.diff

LOG: [mlir][Linalg] Combine canonicalizers that deal with removing dead/redundant args.

`linalg.generic` ops have canonicalizers that either remove arguments
not used in the payload, or redundant arguments. Combine these and
enhance the canonicalization to also remove results that have no use.
This is effectively dead code elimination for Linalg ops.

Differential Revision: https://reviews.llvm.org/D123632

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index b0c705a87a35..83972935035d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -31,6 +31,14 @@ struct OpOperandVector : public SmallVector<OpOperand *> {
   operator SmallVector<Value>();
 };
 
+namespace detail {
+/// Implementation of the method that that check if given operands
+/// can be dropped, i.e. the remaining operands can compute the loop
+/// bounds of the op.
+bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp,
+                                ArrayRef<OpOperand *> droppedOperands);
+} // namespace detail
+
 /// Checks whether `linalgOp` conforms to ContractionOpInterface.
 // TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
 bool isaContractionOpInterface(LinalgOp linalgOp);

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 18e65c849e28..6e039f5c4a61 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -958,6 +958,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return inversePermutation(getLoopsToShapesMap());
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Checks if the given operands can be dropped, and the remaining
+        operands can still compute the bounds of the op.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"canOpOperandsBeDropped",
+      /*args=*/(ins "ArrayRef<OpOperand *>":$droppedOperands),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the range of position in the result of the affine map

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 25b8247be3a1..babde8c6082c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -165,6 +165,12 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> {
   let regions = (region AnyRegion:$region);
 
   let builders = [
+    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+      "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
+      "ArrayAttr":$iteratorTypes, "StringAttr":$doc,
+      "StringAttr":$libraryCall,
+      "function_ref<void(OpBuilder &, Location, ValueRange)>",
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
       "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
       "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c916d15c9d86..fc99e290e5cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -23,6 +23,20 @@ using namespace mlir::linalg;
 /// Include the definitions of the copy operation interface.
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// Interface utility functions
+//===----------------------------------------------------------------------===//
+bool linalg::detail::canOpOperandsBeDroppedImpl(
+    linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
+  SmallVector<AffineMap> indexingMaps;
+  for (auto opOperand : linalgOp.getInputAndOutputOperands()) {
+    if (llvm::is_contained(droppedOperands, opOperand))
+      continue;
+    indexingMaps.push_back(linalgOp.getTiedIndexingMap(opOperand));
+  }
+  return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOpInterface implementation
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d58c253bbfce..0584043d8093 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -266,33 +267,6 @@ static LogicalResult foldMemRefCast(Operation *op) {
   return success(folded);
 }
 
-/// Helper function to find if there is atleast one dimension in an AffineMap
-/// testMap that is contained in `testMapLocation` of  `maps` but not in any
-/// other locations
-static bool hasaUniqueDim(ArrayRef<AffineMap> maps, unsigned testMapLocation) {
-  AffineMap testMap = maps[testMapLocation];
-  llvm::SmallDenseSet<unsigned> dimsToCheck;
-  for (auto result : testMap.getResults()) {
-    auto expr = result.dyn_cast<AffineDimExpr>();
-    if (expr != nullptr)
-      dimsToCheck.insert(expr.getPosition());
-  }
-  for (const auto &it : llvm::enumerate(maps)) {
-    if (it.index() == testMapLocation)
-      continue;
-    auto map = it.value();
-    for (auto result : map.getResults()) {
-      auto expr = result.dyn_cast<AffineDimExpr>();
-      if (expr != nullptr) {
-        dimsToCheck.erase(expr.getPosition());
-      }
-      if (dimsToCheck.empty())
-        return false;
-    }
-  }
-  return true;
-}
-
 //===----------------------------------------------------------------------===//
 // Region builder helper.
 // TODO: Move this to a utility library.
@@ -670,16 +644,12 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
-    ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
-    ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
+    ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
+    ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
     ArrayRef<NamedAttribute> attributes) {
-  build(builder, result, resultTensorTypes, inputs, outputs,
-        builder.getAffineMapArrayAttr(indexingMaps),
-        builder.getStrArrayAttr(iteratorTypes),
-        doc.empty() ? StringAttr() : builder.getStringAttr(doc),
-        libraryCall.empty() ? StringAttr()
-                            : builder.getStringAttr(libraryCall));
+  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
+        iteratorTypes, doc, libraryCall);
   result.addAttributes(attributes);
   if (!bodyBuild)
     return;
@@ -700,6 +670,20 @@ void GenericOp::build(
   bodyBuild(builder, result.location, bodyBlock->getArguments());
 }
 
+void GenericOp::build(
+    OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+    ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
+    ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, resultTensorTypes, inputs, outputs,
+        builder.getAffineMapArrayAttr(indexingMaps),
+        builder.getStrArrayAttr(iteratorTypes),
+        doc.empty() ? StringAttr() : builder.getStringAttr(doc),
+        libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
+        bodyBuild, attributes);
+}
+
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
@@ -844,93 +828,165 @@ void GenericOp::getEffects(
 LogicalResult GenericOp::verify() { return success(); }
 
 namespace {
-// Deduplicate redundant args of a linalg generic op.
-// An arg is redundant if it has the same Value and indexing map as another.
-struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
+
+struct DeduplicateAndRemoveDeadOperandsAndResults
+    : public OpRewritePattern<GenericOp> {
   using OpRewritePattern<GenericOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    // Associate each input to an equivalent "canonical" input that has the same
-    // Value and indexing map.
-    //
-    // In the non-duplicate case, input `i` will have canonical input `i`. But
-    // in the case of duplicated inputs, the canonical input could be some other
-    // input `< i`. That is, a later input will have some earlier input as its
-    // canonical input.
-    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
-    // For later remapping tasks like deduplicating payload block arguments,
-    // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
-    // convenient.
-    SmallVector<unsigned> canonicalInputIndices;
-    for (OpOperand *opOperand : genericOp.getInputOperands()) {
-      AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
-      // STL-like maps have a convenient behavior for our use case here. In the
-      // case of duplicate keys, the insertion is rejected, and the returned
-      // iterator gives access to the value already in the map.
-      auto pair = canonicalInput.insert(
-          {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
-      canonicalInputIndices.push_back(pair.first->second);
+    // Create a map from argument position in the original op to the argument
+    // position in the new op. If the argument is dropped it wont have an entry.
+    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+    unsigned numNewArgs = 0;
+    SmallVector<OpOperand *> droppedOpOperands;
+    llvm::SmallDenseSet<unsigned> droppedOutputs;
+
+    // Information needed to build the new op.
+    SmallVector<Value> newInputOperands, newOutputOperands;
+    SmallVector<AffineMap> newIndexingMaps;
+    SmallVector<Type> newResultTypes;
+
+    // Input argument can be dropped if
+    // - it has no uses, or,
+    // - there is a duplicate operand which is accessed using the same
+    //   indexing map.
+    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
+    auto indexingMaps = genericOp.getIndexingMaps();
+    ArrayRef<AffineMap> unprocessedIndexingMaps(indexingMaps);
+    for (OpOperand *inputOpOperand : genericOp.getInputOperands()) {
+      BlockArgument arg = genericOp.getTiedBlockArgument(inputOpOperand);
+      unsigned argNum = arg.getArgNumber();
+      unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front();
+
+      // Check if operand is dead and if dropping the indexing map makes the
+      // loops to shape computation invalid.
+      if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
+        // Add the current operands to the list of potentially droppable
+        // operands. If it cannot be dropped, this needs to be popped back.
+        droppedOpOperands.push_back(inputOpOperand);
+        if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
+          continue;
+        droppedOpOperands.pop_back();
+      }
+
+      // Check if this operand is a duplicate.
+      AffineMap indexingMap = genericOp.getTiedIndexingMap(inputOpOperand);
+      auto it = dedupedInputs.find(
+          std::make_pair(inputOpOperand->get(), indexingMap));
+      if (it != dedupedInputs.end()) {
+        origToNewPos[argNum] = it->second;
+        droppedOpOperands.push_back(inputOpOperand);
+        continue;
+      }
+
+      // This is a preserved argument.
+      origToNewPos[argNum] = numNewArgs;
+      dedupedInputs[{inputOpOperand->get(), indexingMap}] = numNewArgs;
+      newInputOperands.push_back(inputOpOperand->get());
+      newIndexingMaps.push_back(indexingMap);
+      numNewArgs++;
     }
 
-    // If there are no duplicate args, then bail out.
-    if (canonicalInput.size() == genericOp.getNumInputs())
-      return failure();
+    // If the op doesnt have tensor semantics, keep all the outputs as
+    // preserved.
+    if (!genericOp.hasTensorSemantics()) {
+      for (OpOperand *outputOpOperand : genericOp.getOutputOperands()) {
+        unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front();
+        BlockArgument arg = genericOp.getTiedBlockArgument(outputOpOperand);
+        origToNewPos[arg.getArgNumber()] = numNewArgs++;
+        newOutputOperands.push_back(outputOpOperand->get());
+        newIndexingMaps.push_back(
+            genericOp.getTiedIndexingMap(outputOpOperand));
+      }
+    } else {
+      // Output argument can be dropped if the result has
+      // - no users, and
+      // - it is not used in the payload, and
+      // - the corresponding indexing maps are not needed for loop bound
+      //   computation.
+      for (auto outputOpOperand :
+           llvm::enumerate(genericOp.getOutputOperands())) {
+        unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front();
+        Value result = genericOp.getResult(outputOpOperand.index());
+        BlockArgument arg =
+            genericOp.getTiedBlockArgument(outputOpOperand.value());
+        if (result.use_empty() &&
+            !genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
+          // Check if the opoperand can be dropped without affecting loop bound
+          // computation. Add the operand to the list of dropped op operand for
+          // checking. If it cannot be dropped, need to pop the value back.
+          droppedOpOperands.push_back(outputOpOperand.value());
+          if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
+            droppedOutputs.insert(outputOpOperand.index());
+            continue;
+          }
+          droppedOpOperands.pop_back();
+        }
 
-    // The operands for the newly canonicalized op.
-    SmallVector<Value> newInputOperands;
-    for (OpOperand *opOperand : genericOp.getInputOperands())
-      if (canonicalInputIndices[opOperand->getOperandNumber()] ==
-          opOperand->getOperandNumber())
-        newInputOperands.push_back(opOperand->get());
+        origToNewPos[arg.getArgNumber()] = numNewArgs++;
+        newOutputOperands.push_back(outputOpOperand.value()->get());
+        newIndexingMaps.push_back(
+            genericOp.getTiedIndexingMap(outputOpOperand.value()));
+        newResultTypes.push_back(result.getType());
+      }
+    }
 
-    // Repair the indexing maps by filtering out the ones that have been
-    // eliminated.
-    SmallVector<AffineMap> newIndexingMaps;
-    for (OpOperand *opOperand : genericOp.getInputOperands())
-      if (canonicalInputIndices[opOperand->getOperandNumber()] ==
-          opOperand->getOperandNumber())
-        newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
-    for (OpOperand *opOperand : genericOp.getOutputOperands())
-      newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
-
-    // Clone the old op with new operands.
-    SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+    // Check if there is any change to operands.
+    if (newInputOperands.size() + newOutputOperands.size() ==
+        static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
+      return failure();
+
+    // Create the new op with the body being empty.
+    Location loc = genericOp.getLoc();
     auto newOp = rewriter.create<GenericOp>(
-        genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
-        outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
+        loc, newResultTypes, newInputOperands, newOutputOperands,
+        rewriter.getAffineMapArrayAttr(newIndexingMaps),
         genericOp.iterator_types(), genericOp.docAttr(),
-        genericOp.library_callAttr());
-
+        genericOp.library_callAttr(),
+        [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
+          return;
+        });
     // Copy over unknown attributes. They might be load bearing for some flow.
     ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
-    for (NamedAttribute kv : genericOp->getAttrs()) {
-      if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
+    for (NamedAttribute kv : genericOp->getAttrs())
+      if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
         newOp->setAttr(kv.getName(), kv.getValue());
-      }
-    }
 
-    rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
-                                newOp.region().begin());
-
-    // Repair the payload entry block by RAUW'ing redundant arguments and
-    // erasing them.
-    Block &payload = newOp.region().front();
-    SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
-    for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
-      // Iterate in reverse, so that we erase later args first, preventing the
-      // argument list from shifting unexpectedly and invalidating all our
-      // indices.
-      unsigned operandNumber = opOperand->getOperandNumber();
-      if (canonicalInputIndices[operandNumber] == operandNumber)
-        continue;
-      payload.getArgument(operandNumber)
-          .replaceAllUsesWith(
-              payload.getArgument(canonicalInputIndices[operandNumber]));
-      payload.eraseArgument(operandNumber);
+    // Merge the body of the original op with the new op.
+    Block *newOpBlock = &newOp.region().front();
+    Block *origOpBlock = &genericOp.region().front();
+    SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
+    for (auto argNum : llvm::seq<unsigned>(0, origOpBlock->getNumArguments())) {
+      auto it = origToNewPos.find(argNum);
+      if (it != origToNewPos.end())
+        replacements[argNum] = newOpBlock->getArgument(it->second);
+    }
+    rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
+
+    // Drop the unused yield args.
+    Block *block = &newOp.region().front();
+    if (!droppedOutputs.empty()) {
+      OpBuilder::InsertionGuard g(rewriter);
+      SmallVector<Value> newYieldVals;
+      YieldOp origYieldOp = cast<YieldOp>(block->getTerminator());
+      rewriter.setInsertionPoint(origYieldOp);
+      for (auto yieldOpOperands : llvm::enumerate(origYieldOp.values())) {
+        if (!droppedOutputs.count(yieldOpOperands.index())) {
+          newYieldVals.push_back(yieldOpOperands.value());
+          continue;
+        }
+      }
+      rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
     }
 
-    rewriter.replaceOp(genericOp, newOp->getResults());
+    // Replace all live uses of the op.
+    SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
+    unsigned newResultNum = 0;
+    for (auto result : llvm::enumerate(genericOp.getResults()))
+      if (!droppedOutputs.count(result.index()))
+        replacementsVals[result.index()] = newOp.getResult(newResultNum++);
+    rewriter.replaceOp(genericOp, replacementsVals);
     return success();
   }
 };
@@ -1007,72 +1063,13 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
-
-/// Drop dead args of a linalg generic op.
-/// An arg is dead if it has zero uses in the op region.
-struct DeadArgsGenericOpInputs : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<AffineMap> oldIndexingMaps = genericOp.getIndexingMaps();
-    // Maps must be projected permutations.
-    if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) {
-          return !map.isProjectedPermutation();
-        }))
-      return failure();
-    Block &payload = genericOp.region().front();
-    SmallVector<Value> newInputOperands;
-    SmallVector<AffineMap> newIndexingMaps;
-    bool deadArgFound = false;
-    int inputSize = genericOp.getInputOperands().size();
-    for (int i = inputSize - 1; i >= 0; i--) {
-      OpOperand *opOperand = genericOp.getInputOperand(i);
-      // Iterate in reverse, so that we erase later args first, preventing the
-      // argument list from shifting unexpectedly and invalidating all our
-      // indices.
-      if (payload.getArgument(i).use_empty() &&
-          !hasaUniqueDim(oldIndexingMaps, i)) {
-        payload.eraseArgument(i);
-        deadArgFound = true;
-        // remove this indexing map out of consideration for hasaUniqueDim check
-        oldIndexingMaps.erase(oldIndexingMaps.begin() + i);
-      } else {
-        newInputOperands.insert(newInputOperands.begin(), opOperand->get());
-        newIndexingMaps.insert(newIndexingMaps.begin(),
-                               genericOp.getTiedIndexingMap(opOperand));
-      }
-    }
-    // Bail out if there are no dead args.
-    if (!deadArgFound)
-      return failure();
-    for (OpOperand *opOperand : genericOp.getOutputOperands())
-      newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
-    SmallVector<Value> outputOperands = genericOp.getOutputOperands();
-
-    auto newOp = rewriter.create<GenericOp>(
-        genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
-        outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
-        genericOp.iterator_types(), genericOp.docAttr(),
-        genericOp.library_callAttr());
-    // Copy over unknown attributes. They might be load bearing for some flow.
-    ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
-    for (NamedAttribute kv : genericOp->getAttrs()) {
-      if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
-        newOp->setAttr(kv.getName(), kv.getValue());
-      }
-    }
-    rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
-                                newOp.region().begin());
-    rewriter.replaceOp(genericOp, newOp->getResults());
-    return success();
-  }
-};
 } // namespace
 
 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
-              DeadArgsGenericOpInputs>(context);
+  results
+      .add<DeduplicateAndRemoveDeadOperandsAndResults, EraseIdentityGenericOp>(
+          context);
 }
 
 LogicalResult GenericOp::fold(ArrayRef<Attribute>,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
index 827168c5b413..f726922c66b6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir
@@ -90,3 +90,175 @@ func.func @multiple_
diff erent_redundant_args(%arg0: tensor<?xf32>, %arg1: tensor
   } -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// Drop dead result.
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+func.func @drop_dead_results(%arg0 : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+  %0:4 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2, #map3, #map4],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%arg0 : tensor<?x?x?xf32>)
+      outs(%arg0, %arg0, %arg0, %arg0
+          : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) :
+      %1 = arith.addf %b0, %b0: f32
+      linalg.yield %1, %1, %1, %1 : f32, f32, f32, f32
+    } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return %0#0, %0#2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>     
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+//      CHECK: func @drop_dead_results(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>)
+//      CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:       outs(%[[ARG0]], %[[ARG0]] :
+//      CHECK:   return %[[GENERIC]]#0, %[[GENERIC]]#1
+
+// -----
+
+// Current argmax lowering to `linalg.generic`. Cannot drop the
+// first return even though it isnt used since it has an internal
+// use.
+#map0 = affine_map<(d0) -> (d0)>
+#map1 = affine_map<(d0) -> ()>
+func.func @argmax_lowering(%arg0 : tensor<?xf32>) -> tensor<i32> {
+  %init0 = linalg.init_tensor [] : tensor<f32>
+  %init1 = linalg.init_tensor [] : tensor<i32>
+  %0:2 = linalg.generic {
+    indexing_maps = [#map0, #map1, #map1],
+    iterator_types = ["reduction"]}
+    ins(%arg0 : tensor<?xf32>)
+    outs(%init0, %init1 : tensor<f32>, tensor<i32>) {
+  ^bb0(%b0: f32, %b1: f32, %b2: i32):
+    %8 = linalg.index 0 : index
+    %9 = arith.index_cast %8 : index to i32
+    %10 = arith.cmpf oge, %b0, %b1 : f32
+    %11 = arith.select %10, %b0, %b1 : f32
+    %12 = arith.cmpf oeq, %b0, %b1 : f32
+    %13 = arith.minsi %9, %b2 : i32
+    %14 = arith.select %10, %9, %b2 : i32
+    %15 = arith.select %12, %13, %14 : i32
+    linalg.yield %11, %15 : f32, i32
+  } -> (tensor<f32>, tensor<i32>)
+  return %0#1 : tensor<i32>
+}
+//      CHECK: func @argmax_lowering(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?xf32>
+//  CHECK-DAG:   %[[INIT0:.+]] = linalg.init_tensor [] : tensor<f32>
+//  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [] : tensor<i32>
+//      CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME:       outs(%[[INIT0]], %[[INIT1]] :
+//      CHECK:   return %[[GENERIC]]#1
+
+// -----
+
+// Do not remove operand needed for loop dim.
+func.func @loop_dim_operand(%arg0 : tensor<?xf32>) -> tensor<i32> {
+  %cst = arith.constant 0 : i32
+  %init = linalg.init_tensor [] : tensor<i32>
+  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<i32>) -> tensor<i32>
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>],
+      iterator_types = ["reduction"]}
+      ins(%arg0 : tensor<?xf32>) outs(%fill : tensor<i32>) {
+    ^bb0(%b0: f32, %b1: i32):
+      %1 = linalg.index 0 : index
+      %2 = arith.index_cast %1 : index to i32
+      %3 = arith.addi %b1, %2 : i32
+      linalg.yield %3 : i32
+    } -> tensor<i32>
+  return %0 : tensor<i32>
+}
+//      CHECK: func @loop_dim_operand(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?xf32>
+//      CHECK:   linalg.generic
+// CHECK-SAME:       ins(%[[ARG0]] :
+
+// -----
+
+// Do not remove outs operand needed for loop bound computation.
+func.func @loop_dim_outs_operand(%arg0 : index) -> tensor<i32> {
+  %cst = arith.constant 0 : i32
+  %init1 = linalg.init_tensor [%arg0] : tensor<?xi32>
+  %init = linalg.init_tensor [] : tensor<i32>
+  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<i32>) -> tensor<i32>
+  %0:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>],
+      iterator_types = ["parallel"]}
+      outs(%init1, %fill : tensor<?xi32>, tensor<i32>) {
+    ^bb0(%b0: i32, %b1: i32):
+      %1 = linalg.index 0 : index
+      %2 = arith.index_cast %1 : index to i32
+      %3 = arith.addi %b1, %2 : i32
+      linalg.yield %2, %3 : i32, i32
+    } -> (tensor<?xi32>, tensor<i32>)
+  return %0#1 : tensor<i32>
+}
+//      CHECK: func @loop_dim_outs_operand(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%[[ARG0]]]
+//      CHECK:   linalg.generic
+// CHECK-SAME:       outs(%[[INIT]]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+#map2 = affine_map<(d0, d1) -> (d0)>
+#map3 = affine_map<(d0, d1) -> (d1)>
+func.func @multiple_redundant_args(%arg0 : tensor<?x?xi32>, %arg1 : tensor<?xi32>,
+    %arg2 : tensor<?xi32>, %arg3 : tensor<?x?xi32>, %arg4 : tensor<?xi32>) -> tensor<?xi32> {
+  %0 = linalg.generic {
+      indexing_maps = [#map3, #map0, #map0, #map2, #map1, #map1, #map2],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%arg4, %arg0, %arg0, %arg1, %arg3, %arg3
+          : tensor<?xi32>, tensor<?x?xi32>, tensor<?x?xi32>, tensor<?xi32>, tensor<?x?xi32>, tensor<?x?xi32>)
+      outs(%arg2 : tensor<?xi32>) {
+    ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32, %b5 : i32, %b6 : i32):
+      %1 = arith.addi %b0, %b1 : i32
+      %2 = arith.addi %1, %b2 : i32
+      %3 = arith.addi %2, %b3 : i32
+      %4 = arith.addi %3, %b4 : i32
+      %5 = arith.addi %4, %b5 : i32
+      %6 = arith.addi %5, %b6 : i32
+      linalg.yield %6 : i32
+    } -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//      CHECK: func @multiple_redundant_args(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xi32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xi32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xi32>
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<?x?xi32>
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<?xi32>)
+//      CHECK:   %[[RETURN:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
+// CHECK-SAME:       iterator_types = ["parallel", "reduction"]
+// CHECK-SAME:       ins(%[[ARG4]], %[[ARG0]], %[[ARG1]], %[[ARG3]] :
+// CHECK-SAME:       outs(%[[ARG2]] :
+//      CHECK:   ^{{.+}}(%[[B0:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME:       %[[B1:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME:       %[[B2:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME:       %[[B3:[a-zA-Z0-9]+]]: i32
+// CHECK-SAME:       %[[B4:[a-zA-Z0-9]+]]: i32)
+//      CHECK:     %[[T0:.+]] = arith.addi %[[B0]], %[[B1]]
+//      CHECK:     %[[T1:.+]] = arith.addi %[[T0]], %[[B1]]
+//      CHECK:     %[[T2:.+]] = arith.addi %[[T1]], %[[B2]]
+//      CHECK:     %[[T3:.+]] = arith.addi %[[T2]], %[[B3]]
+//      CHECK:     %[[T4:.+]] = arith.addi %[[T3]], %[[B3]]
+//      CHECK:     %[[T5:.+]] = arith.addi %[[T4]], %[[B4]]
+//      CHECK:     linalg.yield %[[T5]]
+//      CHECK:  return %[[RETURN]]


        


More information about the Mlir-commits mailing list