[Mlir-commits] [mlir] b62060a - [mlir][Linalg] NFC: Refactor canonicalization for deduping generic op operands.

Mahesh Ravishankar llvmlistbot at llvm.org
Fri Jul 15 12:47:59 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-07-15T19:47:45Z
New Revision: b62060a8e330e9cc7b537b50d797344666ae51fc

URL: https://github.com/llvm/llvm-project/commit/b62060a8e330e9cc7b537b50d797344666ae51fc
DIFF: https://github.com/llvm/llvm-project/commit/b62060a8e330e9cc7b537b50d797344666ae51fc.diff

LOG: [mlir][Linalg] NFC: Refactor canonicalization for deduping generic op operands.

This is a NFC change to make it easier to update this canonicalization
for more use cases. The refactoring makes things easier to
understand/adapt.

Differential Revision: https://reviews.llvm.org/D129829

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1ce7d4dab1f13..928a8039a43e8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -860,67 +860,128 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
                                 PatternRewriter &rewriter) const override {
     // Create a map from argument position in the original op to the argument
     // position in the new op. If the argument is dropped it wont have an entry.
-    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
-    unsigned numNewArgs = 0;
     SmallVector<OpOperand *> droppedOpOperands;
-    llvm::SmallDenseSet<unsigned> droppedOutputs;
 
     // Information needed to build the new op.
     SmallVector<Value> newInputOperands, newOutputOperands;
     SmallVector<AffineMap> newIndexingMaps;
+
+    // Gather information about duplicate input operands.
+    llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
+        deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
+                                 newIndexingMaps);
+
+    // Gather information about the dropped outputs.
+    llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
+        deduplicateOutputOperands(genericOp, droppedOpOperands,
+                                  newOutputOperands, newIndexingMaps);
+
+    // Check if there is any change to operands.
+    if (newInputOperands.size() + newOutputOperands.size() ==
+        static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
+      return failure();
+
+    // Create the new op with the body being empty.
+    Location loc = genericOp.getLoc();
     SmallVector<Type> newResultTypes;
+    if (genericOp.hasTensorSemantics()) {
+      newResultTypes = llvm::to_vector(llvm::map_range(
+          newOutputOperands, [](Value v) { return v.getType(); }));
+    }
+    auto newOp = rewriter.create<GenericOp>(
+        loc, newResultTypes, newInputOperands, newOutputOperands,
+        rewriter.getAffineMapArrayAttr(newIndexingMaps),
+        genericOp.iterator_types(), genericOp.docAttr(),
+        genericOp.library_callAttr(),
+        [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
+          return;
+        });
+    // Copy over unknown attributes. They might be load bearing for some flow.
+    ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
+    for (NamedAttribute kv : genericOp->getAttrs())
+      if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
+        newOp->setAttr(kv.getName(), kv.getValue());
 
-    // Input argument can be dropped if
-    // - it has no uses, or,
-    // - there is a duplicate operand which is accessed using the same
-    //   indexing map.
-    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
-    auto indexingMaps = genericOp.getIndexingMaps();
-    ArrayRef<AffineMap> unprocessedIndexingMaps(indexingMaps);
-    for (OpOperand *inputOpOperand : genericOp.getInputOperands()) {
-      BlockArgument arg = genericOp.getTiedBlockArgument(inputOpOperand);
-      unsigned argNum = arg.getArgNumber();
-      unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front();
+    // Fix up the payload of the canonicalized operation.
+    populateOpPayload(genericOp, newOp, origInsToNewInsPos,
+                      origOutsToNewOutsPos, rewriter);
 
+    // Replace all live uses of the op.
+    SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
+    for (auto result : llvm::enumerate(genericOp.getResults())) {
+      auto it = origOutsToNewOutsPos.find(result.index());
+      if (it == origOutsToNewOutsPos.end())
+        continue;
+      replacementsVals[result.index()] = newOp.getResult(it->second);
+    }
+    rewriter.replaceOp(genericOp, replacementsVals);
+    return success();
+  }
+
+private:
+  // Deduplicate input operands, and return the
+  // - Mapping from operand position in the original op, to operand position in
+  // the canonicalized op.
+  // - The preserved input operands list (by reference).
+  llvm::SmallDenseMap<unsigned, unsigned>
+  deduplicateInputOperands(GenericOp genericOp,
+                           SmallVector<OpOperand *> &droppedOpOperands,
+                           SmallVector<Value> &newInputOperands,
+                           SmallVector<AffineMap> &newIndexingMaps) const {
+    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
+    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
+    for (auto inputOpOperand : llvm::enumerate(genericOp.getInputOperands())) {
       // Check if operand is dead and if dropping the indexing map makes the
       // loops to shape computation invalid.
-      if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
+      if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) {
         // Add the current operands to the list of potentially droppable
         // operands. If it cannot be dropped, this needs to be popped back.
-        droppedOpOperands.push_back(inputOpOperand);
+        droppedOpOperands.push_back(inputOpOperand.value());
         if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
           continue;
         droppedOpOperands.pop_back();
       }
 
       // Check if this operand is a duplicate.
-      AffineMap indexingMap = genericOp.getTiedIndexingMap(inputOpOperand);
+      AffineMap indexingMap =
+          genericOp.getTiedIndexingMap(inputOpOperand.value());
       auto it = dedupedInputs.find(
-          std::make_pair(inputOpOperand->get(), indexingMap));
+          std::make_pair(inputOpOperand.value()->get(), indexingMap));
       if (it != dedupedInputs.end()) {
-        origToNewPos[argNum] = it->second;
-        droppedOpOperands.push_back(inputOpOperand);
+        origToNewPos[inputOpOperand.index()] = it->second;
+        droppedOpOperands.push_back(inputOpOperand.value());
         continue;
       }
 
       // This is a preserved argument.
-      origToNewPos[argNum] = numNewArgs;
-      dedupedInputs[{inputOpOperand->get(), indexingMap}] = numNewArgs;
-      newInputOperands.push_back(inputOpOperand->get());
+      origToNewPos[inputOpOperand.index()] = newInputOperands.size();
+      dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] =
+          newInputOperands.size();
+      newInputOperands.push_back(inputOpOperand.value()->get());
       newIndexingMaps.push_back(indexingMap);
-      numNewArgs++;
     }
+    return origToNewPos;
+  }
 
+  // Deduplicate output operands, and return the
+  // - Mapping from operand position in the original op, to operand position in
+  // the canonicalized op.
+  // - The preserved output operands list (by reference).
+  llvm::SmallDenseMap<unsigned, unsigned>
+  deduplicateOutputOperands(GenericOp genericOp,
+                            SmallVector<OpOperand *> &droppedOpOperands,
+                            SmallVector<Value> &newOutputOperands,
+                            SmallVector<AffineMap> &newIndexingMaps) const {
+    llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
     // If the op doesnt have tensor semantics, keep all the outputs as
     // preserved.
     if (!genericOp.hasTensorSemantics()) {
-      for (OpOperand *outputOpOperand : genericOp.getOutputOperands()) {
-        unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front();
-        BlockArgument arg = genericOp.getTiedBlockArgument(outputOpOperand);
-        origToNewPos[arg.getArgNumber()] = numNewArgs++;
-        newOutputOperands.push_back(outputOpOperand->get());
+      for (auto outputOpOperand :
+           llvm::enumerate(genericOp.getOutputOperands())) {
+        origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
+        newOutputOperands.push_back(outputOpOperand.value()->get());
         newIndexingMaps.push_back(
-            genericOp.getTiedIndexingMap(outputOpOperand));
+            genericOp.getTiedIndexingMap(outputOpOperand.value()));
       }
     } else {
       // Output argument can be dropped if the result has
@@ -928,12 +989,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
       // - it is not used in the payload, and
       // - the corresponding indexing maps are not needed for loop bound
       //   computation.
-      for (const auto &outputOpOperand :
+      for (auto outputOpOperand :
            llvm::enumerate(genericOp.getOutputOperands())) {
-        unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front();
         Value result = genericOp.getResult(outputOpOperand.index());
-        BlockArgument arg =
-            genericOp.getTiedBlockArgument(outputOpOperand.value());
         if (result.use_empty() &&
             !genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
           // Check if the opoperand can be dropped without affecting loop bound
@@ -941,77 +999,75 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
           // checking. If it cannot be dropped, need to pop the value back.
           droppedOpOperands.push_back(outputOpOperand.value());
           if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
-            droppedOutputs.insert(outputOpOperand.index());
             continue;
           }
           droppedOpOperands.pop_back();
         }
 
-        origToNewPos[arg.getArgNumber()] = numNewArgs++;
+        origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
         newOutputOperands.push_back(outputOpOperand.value()->get());
         newIndexingMaps.push_back(
             genericOp.getTiedIndexingMap(outputOpOperand.value()));
-        newResultTypes.push_back(result.getType());
       }
     }
 
-    // Check if there is any change to operands.
-    if (newInputOperands.size() + newOutputOperands.size() ==
-        static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
-      return failure();
-
-    // Create the new op with the body being empty.
-    Location loc = genericOp.getLoc();
-    auto newOp = rewriter.create<GenericOp>(
-        loc, newResultTypes, newInputOperands, newOutputOperands,
-        rewriter.getAffineMapArrayAttr(newIndexingMaps),
-        genericOp.iterator_types(), genericOp.docAttr(),
-        genericOp.library_callAttr(),
-        [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
-          return;
-        });
-    // Copy over unknown attributes. They might be load bearing for some flow.
-    ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
-    for (NamedAttribute kv : genericOp->getAttrs())
-      if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
-        newOp->setAttr(kv.getName(), kv.getValue());
+    return origToNewPos;
+  }
 
+  // Populate the body of the canonicalized operation.
+  void populateOpPayload(
+      GenericOp genericOp, GenericOp newOp,
+      const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
+      const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
+      PatternRewriter &rewriter) const {
     // Merge the body of the original op with the new op.
     Block *newOpBlock = &newOp.region().front();
+    assert(newOpBlock->empty() && "expected new op to have an empty payload");
     Block *origOpBlock = &genericOp.region().front();
     SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
-    for (auto argNum : llvm::seq<unsigned>(0, origOpBlock->getNumArguments())) {
-      auto it = origToNewPos.find(argNum);
-      if (it != origToNewPos.end())
-        replacements[argNum] = newOpBlock->getArgument(it->second);
-    }
+
+    // Replace all arguments in the original op, with arguments from the
+    // canonicalized op.
+    auto updateReplacements =
+        [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
+            const llvm::SmallDenseMap<unsigned, unsigned> &map) {
+          for (auto origOperand : llvm::enumerate(origOperands)) {
+            auto it = map.find(origOperand.index());
+            if (it == map.end())
+              continue;
+            OpOperand *newOperand = newOperands[it->second];
+            replacements[origOperand.value()->getOperandNumber()] =
+                newOpBlock->getArgument(newOperand->getOperandNumber());
+          }
+        };
+
+    OpOperandVector origInputOperands = genericOp.getInputOperands();
+    OpOperandVector newInputOperands = newOp.getInputOperands();
+    updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
+
+    OpOperandVector origOutputOperands = genericOp.getOutputOperands();
+    OpOperandVector newOutputOperands = newOp.getOutputOperands();
+    updateReplacements(origOutputOperands, newOutputOperands,
+                       origOutsToNewOutsPos);
+
     rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
 
     // Drop the unused yield args.
-    Block *block = &newOp.region().front();
-    if (!droppedOutputs.empty()) {
+    if (newOp.getNumOutputs() != genericOp.getNumOutputs()) {
       OpBuilder::InsertionGuard g(rewriter);
-      SmallVector<Value> newYieldVals;
-      YieldOp origYieldOp = cast<YieldOp>(block->getTerminator());
+      YieldOp origYieldOp = cast<YieldOp>(newOpBlock->getTerminator());
       rewriter.setInsertionPoint(origYieldOp);
+
+      SmallVector<Value> newYieldVals(newOp.getNumOutputs(), nullptr);
       for (const auto &yieldOpOperands :
            llvm::enumerate(origYieldOp.values())) {
-        if (!droppedOutputs.count(yieldOpOperands.index())) {
-          newYieldVals.push_back(yieldOpOperands.value());
+        auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
+        if (it == origOutsToNewOutsPos.end())
           continue;
-        }
+        newYieldVals[it->second] = yieldOpOperands.value();
       }
       rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
     }
-
-    // Replace all live uses of the op.
-    SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
-    unsigned newResultNum = 0;
-    for (const auto &result : llvm::enumerate(genericOp.getResults()))
-      if (!droppedOutputs.count(result.index()))
-        replacementsVals[result.index()] = newOp.getResult(newResultNum++);
-    rewriter.replaceOp(genericOp, replacementsVals);
-    return success();
   }
 };
 


        


More information about the Mlir-commits mailing list