[Mlir-commits] [mlir] [mlir][Linalg] NFC: Expose a method to deduplicate operands/remove dead results of `linalg.generic` op. (PR #125141)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 31 17:10:56 PST 2025


================
@@ -52,255 +52,266 @@ static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
   return true;
 }
 
-namespace {
-
-struct DeduplicateAndRemoveDeadOperandsAndResults
-    : public OpRewritePattern<GenericOp> {
-  DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
-                                             bool removeOutputs)
-      : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
-
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    // Create a map from argument position in the original op to the argument
-    // position in the new op. If the argument is dropped it wont have an entry.
-    SmallVector<OpOperand *> droppedOpOperands;
-
-    // Information needed to build the new op.
-    SmallVector<Value> newInputOperands, newOutputOperands;
-    SmallVector<AffineMap> newIndexingMaps;
-
-    // Gather information about duplicate input operands.
-    llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
-        deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
-                                 newIndexingMaps);
-
-    // Gather information about the dropped outputs.
-    llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
-        deduplicateOutputOperands(genericOp, droppedOpOperands,
-                                  newOutputOperands, newIndexingMaps);
-
-    // Check if there is any change to operands.
-    if (newInputOperands.size() + newOutputOperands.size() ==
-        genericOp->getNumOperands())
-      return failure();
-
-    // Create the new op with the body being empty.
-    Location loc = genericOp.getLoc();
-    SmallVector<Type> newResultTypes;
-    for (Value v : newOutputOperands)
-      if (isa<TensorType>(v.getType()))
-        newResultTypes.push_back(v.getType());
-    auto newOp = rewriter.create<GenericOp>(
-        loc, newResultTypes, newInputOperands, newOutputOperands,
-        rewriter.getAffineMapArrayAttr(newIndexingMaps),
-        genericOp.getIteratorTypes(), genericOp.getDocAttr(),
-        genericOp.getLibraryCallAttr(),
-        [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
-          return;
-        });
-    // Copy over unknown attributes. They might be load bearing for some flow.
-    ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
-    for (NamedAttribute kv : genericOp->getAttrs())
-      if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
-        newOp->setAttr(kv.getName(), kv.getValue());
-
-    // Fix up the payload of the canonicalized operation.
-    populateOpPayload(genericOp, newOp, origInsToNewInsPos,
-                      origOutsToNewOutsPos, rewriter);
-
-    // Replace all live uses of the op.
-    SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
-    for (const auto &result : llvm::enumerate(genericOp.getResults())) {
-      auto it = origOutsToNewOutsPos.find(result.index());
-      if (it == origOutsToNewOutsPos.end())
+//===---------------------------------------------------------------------===//
+// Helper methods for operand deduplication and dead results elimination
+//===---------------------------------------------------------------------===//
+
+// Deduplicate input operands, and return the
+// - Mapping from operand position in the original op, to operand position in
+// the canonicalized op.
+// - The preserved input operands list (by reference).
+llvm::SmallDenseMap<unsigned, unsigned> static deduplicateInputOperands(
+    GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
+    SmallVector<Value> &newInputOperands,
+    SmallVector<AffineMap> &newIndexingMaps) {
+  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+  llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
+  for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    OpOperand *inputOpOperand = en.value();
+    // Check if operand is dead and if dropping the indexing map makes the
+    // loops to shape computation invalid.
+    if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
+      // Add the current operands to the list of potentially droppable
+      // operands. If it cannot be dropped, this needs to be popped back.
+      droppedOpOperands.push_back(inputOpOperand);
+      if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
         continue;
-      replacementsVals[result.index()] = newOp.getResult(it->second);
+      droppedOpOperands.pop_back();
     }
-    rewriter.replaceOp(genericOp, replacementsVals);
-    return success();
-  }
 
-private:
-  /// If unset, outputs are not modified by this pattern.
-  bool removeOutputs;
-
-  // Deduplicate input operands, and return the
-  // - Mapping from operand position in the original op, to operand position in
-  // the canonicalized op.
-  // - The preserved input operands list (by reference).
-  llvm::SmallDenseMap<unsigned, unsigned>
-  deduplicateInputOperands(GenericOp genericOp,
-                           SmallVector<OpOperand *> &droppedOpOperands,
-                           SmallVector<Value> &newInputOperands,
-                           SmallVector<AffineMap> &newIndexingMaps) const {
-    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
-    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
-    for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
-      OpOperand *inputOpOperand = en.value();
-      // Check if operand is dead and if dropping the indexing map makes the
-      // loops to shape computation invalid.
-      if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
-        // Add the current operands to the list of potentially droppable
-        // operands. If it cannot be dropped, this needs to be popped back.
-        droppedOpOperands.push_back(inputOpOperand);
-        if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
-          continue;
-        droppedOpOperands.pop_back();
-      }
+    // Check if this operand is a duplicate.
+    AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
+    auto it =
+        dedupedInputs.find(std::make_pair(inputOpOperand->get(), indexingMap));
+    if (it != dedupedInputs.end()) {
+      origToNewPos[en.index()] = it->second;
+      droppedOpOperands.push_back(inputOpOperand);
+      continue;
+    }
 
-      // Check if this operand is a duplicate.
-      AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
-      auto it = dedupedInputs.find(
-          std::make_pair(inputOpOperand->get(), indexingMap));
-      if (it != dedupedInputs.end()) {
-        origToNewPos[en.index()] = it->second;
-        droppedOpOperands.push_back(inputOpOperand);
-        continue;
-      }
+    // This is a preserved argument.
+    origToNewPos[en.index()] = newInputOperands.size();
+    dedupedInputs[{inputOpOperand->get(), indexingMap}] =
+        newInputOperands.size();
+    newInputOperands.push_back(inputOpOperand->get());
+    newIndexingMaps.push_back(indexingMap);
+  }
+  return origToNewPos;
+}
 
-      // This is a preserved argument.
-      origToNewPos[en.index()] = newInputOperands.size();
-      dedupedInputs[{inputOpOperand->get(), indexingMap}] =
-          newInputOperands.size();
-      newInputOperands.push_back(inputOpOperand->get());
-      newIndexingMaps.push_back(indexingMap);
+// Deduplicate output operands, and return the
+// - Mapping from operand position in the original op, to operand position in
+// the canonicalized op.
+// - The preserved output operands list (by reference).
+llvm::SmallDenseMap<unsigned, unsigned> static deduplicateOutputOperands(
+    GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
+    SmallVector<Value> &newOutputOperands,
+    SmallVector<AffineMap> &newIndexingMaps, bool removeOutputs) {
+  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+  llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
+      dedupedOutpts;
+  // If the op doesn't have tensor semantics or outputs should not be removed,
+  // keep all the outputs as preserved.
+  if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
+    for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
+      origToNewPos[en.index()] = newOutputOperands.size();
+      newOutputOperands.push_back(en.value().get());
+      newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&en.value()));
     }
     return origToNewPos;
   }
-
-  // Deduplicate output operands, and return the
-  // - Mapping from operand position in the original op, to operand position in
-  // the canonicalized op.
-  // - The preserved output operands list (by reference).
-  llvm::SmallDenseMap<unsigned, unsigned>
-  deduplicateOutputOperands(GenericOp genericOp,
-                            SmallVector<OpOperand *> &droppedOpOperands,
-                            SmallVector<Value> &newOutputOperands,
-                            SmallVector<AffineMap> &newIndexingMaps) const {
-    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
-    llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
-        dedupedOutpts;
-    // If the op doesn't have tensor semantics or outputs should not be removed,
-    // keep all the outputs as preserved.
-    if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
-      for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
-        origToNewPos[en.index()] = newOutputOperands.size();
-        newOutputOperands.push_back(en.value().get());
-        newIndexingMaps.push_back(
-            genericOp.getMatchingIndexingMap(&en.value()));
+  // Output argument can be dropped if the result has
+  // - no users, and
+  // - it is not used in the payload, and
+  // - the corresponding indexing maps are not needed for loop bound
+  //   computation.
+  auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
+  for (const auto &outputOpOperand :
+       llvm::enumerate(genericOp.getDpsInitsMutable())) {
+    OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
+    AffineMap indexingMap =
+        genericOp.getMatchingIndexingMap(&outputOpOperand.value());
+    auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
+                               yieldOp->getOperand(outputOpOperand.index()));
+    if (isResultValueDead(genericOp, result)) {
+      // Check if the opoperand can be dropped without affecting loop
+      // bound computation. Add the operand to the list of dropped op
+      // operand for checking. If it cannot be dropped, need to pop the
+      // value back.
+      droppedOpOperands.push_back(&outputOpOperand.value());
+      if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
+        continue;
       }
-      return origToNewPos;
+      droppedOpOperands.pop_back();
     }
-    // Output argument can be dropped if the result has
-    // - no users, and
-    // - it is not used in the payload, and
-    // - the corresponding indexing maps are not needed for loop bound
-    //   computation.
-    auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
-    for (const auto &outputOpOperand :
-         llvm::enumerate(genericOp.getDpsInitsMutable())) {
-      OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
-      AffineMap indexingMap =
-          genericOp.getMatchingIndexingMap(&outputOpOperand.value());
-      auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
-                                 yieldOp->getOperand(outputOpOperand.index()));
-      if (isResultValueDead(genericOp, result)) {
-        // Check if the opoperand can be dropped without affecting loop
-        // bound computation. Add the operand to the list of dropped op
-        // operand for checking. If it cannot be dropped, need to pop the
-        // value back.
+
+    if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
+      // The out operand can also be dropped if it is computed redundantly
+      // by another result, the conditions for that are
+      // - The same operand is used as the out operand
+      // - The same indexing map is used
+      // - The same yield value is used.
+      auto it = dedupedOutpts.find(key);
+      if (it != dedupedOutpts.end()) {
+        origToNewPos[outputOpOperand.index()] = it->second;
         droppedOpOperands.push_back(&outputOpOperand.value());
-        if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
-          continue;
-        }
-        droppedOpOperands.pop_back();
+        continue;
       }
+    }
 
-      if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
-        // The out operand can also be dropped if it is computed redundantly
-        // by another result, the conditions for that are
-        // - The same operand is used as the out operand
-        // - The same indexing map is used
-        // - The same yield value is used.
-        auto it = dedupedOutpts.find(key);
-        if (it != dedupedOutpts.end()) {
-          origToNewPos[outputOpOperand.index()] = it->second;
-          droppedOpOperands.push_back(&outputOpOperand.value());
-          continue;
-        }
-      }
+    origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
+    dedupedOutpts[key] = newOutputOperands.size();
+    newOutputOperands.push_back(outputOpOperand.value().get());
+    newIndexingMaps.push_back(
+        genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
+  }
+  return origToNewPos;
+}
 
-      origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
-      dedupedOutpts[key] = newOutputOperands.size();
-      newOutputOperands.push_back(outputOpOperand.value().get());
-      newIndexingMaps.push_back(
-          genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
+// Populate the body of the canonicalized operation.
+static void populateOpPayload(
+    GenericOp genericOp, GenericOp newOp,
+    const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
+    const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
+    RewriterBase &rewriter) {
+  // Merge the body of the original op with the new op.
+  Block *newOpBlock = &newOp.getRegion().front();
+  assert(newOpBlock->empty() && "expected new op to have an empty payload");
+  Block *origOpBlock = &genericOp.getRegion().front();
+  SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
+
+  // Replace all arguments in the original op, with arguments from the
+  // canonicalized op.
+  auto updateReplacements =
+      [&](SmallVector<OpOperand *> &origOperands,
+          SmallVector<OpOperand *> &newOperands,
+          const llvm::SmallDenseMap<unsigned, unsigned> &map) {
+        for (const auto &origOperand : llvm::enumerate(origOperands)) {
+          auto it = map.find(origOperand.index());
+          if (it == map.end())
+            continue;
+          OpOperand *newOperand = newOperands[it->second];
+          replacements[origOperand.value()->getOperandNumber()] =
+              newOpBlock->getArgument(newOperand->getOperandNumber());
+        }
+      };
+
+  SmallVector<OpOperand *> origInputOperands = genericOp.getDpsInputOperands();
+  SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
+  updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
+
+  SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range(
+      genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
+  SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range(
+      newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
+  updateReplacements(origOutputOperands, newOutputOperands,
+                     origOutsToNewOutsPos);
+
+  // Drop the unused yield args.
+  if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
+    OpBuilder::InsertionGuard g(rewriter);
+    YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
+    rewriter.setInsertionPoint(origYieldOp);
+
+    SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
+    for (const auto &yieldOpOperands :
+         llvm::enumerate(origYieldOp.getValues())) {
+      auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
+      if (it == origOutsToNewOutsPos.end())
+        continue;
+      newYieldVals[it->second] = yieldOpOperands.value();
     }
-    return origToNewPos;
+    rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
   }
 
-  // Populate the body of the canonicalized operation.
-  void populateOpPayload(
-      GenericOp genericOp, GenericOp newOp,
-      const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
-      const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
-      PatternRewriter &rewriter) const {
-    // Merge the body of the original op with the new op.
-    Block *newOpBlock = &newOp.getRegion().front();
-    assert(newOpBlock->empty() && "expected new op to have an empty payload");
-    Block *origOpBlock = &genericOp.getRegion().front();
-    SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
-
-    // Replace all arguments in the original op, with arguments from the
-    // canonicalized op.
-    auto updateReplacements =
-        [&](SmallVector<OpOperand *> &origOperands,
-            SmallVector<OpOperand *> &newOperands,
-            const llvm::SmallDenseMap<unsigned, unsigned> &map) {
-          for (const auto &origOperand : llvm::enumerate(origOperands)) {
-            auto it = map.find(origOperand.index());
-            if (it == map.end())
-              continue;
-            OpOperand *newOperand = newOperands[it->second];
-            replacements[origOperand.value()->getOperandNumber()] =
-                newOpBlock->getArgument(newOperand->getOperandNumber());
-          }
-        };
-
-    SmallVector<OpOperand *> origInputOperands =
-        genericOp.getDpsInputOperands();
-    SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
-    updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
-
-    SmallVector<OpOperand *> origOutputOperands =
-        llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
-                                        [](OpOperand &o) { return &o; }));
-    SmallVector<OpOperand *> newOutputOperands =
-        llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
-                                        [](OpOperand &o) { return &o; }));
-    updateReplacements(origOutputOperands, newOutputOperands,
-                       origOutsToNewOutsPos);
-
-    // Drop the unused yield args.
-    if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
-      OpBuilder::InsertionGuard g(rewriter);
-      YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
-      rewriter.setInsertionPoint(origYieldOp);
-
-      SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
-      for (const auto &yieldOpOperands :
-           llvm::enumerate(origYieldOp.getValues())) {
-        auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
-        if (it == origOutsToNewOutsPos.end())
-          continue;
-        newYieldVals[it->second] = yieldOpOperands.value();
-      }
-      rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
-    }
+  rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
+}
----------------
MaheshRavishankar wrote:

Is there a need for that. Its already static....

https://github.com/llvm/llvm-project/pull/125141


More information about the Mlir-commits mailing list