[Mlir-commits] [mlir] 8fb6c31 - [mlir][linalg] Cleanup LinalgOp usage in op declarations.

Tobias Gysi llvmlistbot at llvm.org
Thu Jun 3 07:06:08 PDT 2021


Author: Tobias Gysi
Date: 2021-06-03T14:04:44Z
New Revision: 8fb6c31cbba51b494f232273cdc54dc0788fcd59

URL: https://github.com/llvm/llvm-project/commit/8fb6c31cbba51b494f232273cdc54dc0788fcd59
DIFF: https://github.com/llvm/llvm-project/commit/8fb6c31cbba51b494f232273cdc54dc0788fcd59.diff

LOG: [mlir][linalg] Cleanup LinalgOp usage in op declarations.

Replace the uses of deprecated Structured Op Interface methods in LinalgOps.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103506

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 04c6e4b9d4029..89387b08c11c1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -375,11 +375,12 @@ ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
 void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
 
 static LogicalResult verify(CopyOp op) {
-  auto outputViewType = op.getOutputShapedType(0);
-  auto inputViewType = op.getInputShapedType(0);
-  if (inputViewType.getElementType() != outputViewType.getElementType())
+  OpOperand *output = op.getOutputOperand(0);
+  OpOperand *input = op.getInputOperand(0);
+  if (getElementTypeOrSelf(input->get().getType()) !=
+      getElementTypeOrSelf(output->get().getType()))
     return op.emitOpError("expects views of the same type");
-  if (inputViewType.getRank() != outputViewType.getRank())
+  if (op.getRank(input) != op.getRank(output))
     return op.emitOpError("expects views of the same rank");
   auto rank = op.getNumParallelLoops();
   auto inputPermutationMap = op.inputPermutation();
@@ -449,11 +450,11 @@ ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
 void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
 
 static LogicalResult verify(FillOp op) {
-  auto viewType = op.getOutputShapedType(0);
-  auto fillType = op.value().getType();
-  if (viewType.getElementType() != fillType)
+  OpOperand *output = op.getOutputOperand(0);
+  Type fillType = op.value().getType();
+  if (getElementTypeOrSelf(output->get().getType()) != fillType)
     return op.emitOpError("expects fill type to match view elemental type");
-  if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
+  if (!op.getNumResults() && !output->get().getType().isa<MemRefType>()) {
     return op.emitOpError(
         "expected fill op with no result value to use memref type");
   }
@@ -739,11 +740,13 @@ struct ConvertIndexedToGenericOp : OpRewritePattern<IndexedGenericOp> {
 
     // Create a generic replacement operation and clone the body.
     rewriter.setInsertionPointAfter(indexedOp);
+    SmallVector<Value> inputOperands = indexedOp.getInputOperands();
+    SmallVector<Value> outputOperands = indexedOp.getOutputOperands();
     SmallVector<StringRef> iterators = llvm::to_vector<4>(
         indexedOp.iterator_types().getAsValueRange<StringAttr>());
     GenericOp genericOp = rewriter.create<GenericOp>(
-        indexedOp.getLoc(), indexedOp->getResultTypes(), indexedOp.getInputs(),
-        indexedOp.getOutputs(), indexedOp.getIndexingMaps(), iterators);
+        indexedOp.getLoc(), indexedOp->getResultTypes(), inputOperands,
+        outputOperands, indexedOp.getIndexingMaps(), iterators);
     Region &genericRegion = genericOp.region();
     Region &indexedRegion = indexedOp.region();
     rewriter.cloneRegionBefore(indexedRegion, genericRegion,
@@ -2107,21 +2110,21 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
 
 // Check the operand number and types must match the element types of the
 // LinalgOp interface's shaped operands.
-static LogicalResult verifyYield(linalg::YieldOp op,
-                                 LinalgOp linalgOpInterface) {
-  auto nOutputs = linalgOpInterface.getNumOutputs();
-  if (op.getNumOperands() != nOutputs)
+static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
+  if (op.getNumOperands() != linalgOp.getNumOutputs())
     return op.emitOpError("expected number of yield values (")
-           << nOutputs << ") to match the number of operands of the enclosing "
+           << linalgOp.getNumOutputs()
+           << ") to match the number of operands of the enclosing "
            << "LinalgOp (" << op.getNumOperands() << ")";
 
-  for (unsigned i = 0; i != nOutputs; ++i) {
-    auto elementType =
-        linalgOpInterface.getOutputShapedType(i).getElementType();
-    if (op.getOperand(i).getType() != elementType)
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    OpOperand *outputOperand =
+        linalgOp.getOutputOperand(opOperand.getOperandNumber());
+    Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
+    if (opOperand.get().getType() != elementType)
       return op.emitOpError("type of yield operand ")
-             << (i + 1) << " (" << op.getOperand(i).getType()
-             << ") doesn't match "
+             << (opOperand.getOperandNumber() + 1) << " ("
+             << opOperand.get().getType() << ") doesn't match "
              << "the element type of the enclosing linalg.generic op ("
              << elementType << ")";
   }
@@ -3096,14 +3099,14 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
 
   LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
-    for (Value v : op.getShapedOperands()) {
+    for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
       // Linalg "inputs" may be either tensor or memref type.
       // tensor<0xelt_type> is a convention that may not always mean
       // "0 iterations". Only erase in cases we see memref<...x0x...>.
-      auto mt = v.getType().dyn_cast<MemRefType>();
+      auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
       if (!mt)
         continue;
-      if (llvm::is_contained(mt.getShape(), 0)) {
+      if (llvm::is_contained(op.getShape(opOperand), 0)) {
         rewriter.eraseOp(op);
         return success();
       }
@@ -3119,10 +3122,10 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
                                 PatternRewriter &rewriter) const override {
     // If no operand comes from a tensor::CastOp and can be folded then fail.
     bool hasTensorCastOperand =
-        llvm::any_of(op.getShapedOperands(), [&](Value v) {
-          if (v.isa<BlockArgument>())
+        llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
+          if (opOperand->get().isa<BlockArgument>())
             return false;
-          auto castOp = v.getDefiningOp<tensor::CastOp>();
+          auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
           return castOp && canFoldIntoConsumerOp(castOp);
         });
     if (!hasTensorCastOperand)
@@ -3133,16 +3136,18 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
     SmallVector<Value, 4> newOperands;
     newOperands.reserve(op->getNumOperands());
     // Inputs may fold.
-    for (Value v : op.getInputs()) {
-      auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
-      newOperands.push_back(
-          canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
+    for (OpOperand *opOperand : op.getInputOperands()) {
+      auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+      newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
+                                ? tensorCastOp.source()
+                                : opOperand->get());
     }
     // Init tensors may fold, in which case the resultType must also change.
-    for (Value v : op.getOutputs()) {
-      auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
+    for (OpOperand *opOperand : op.getOutputOperands()) {
+      auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
       bool fold = canFoldIntoConsumerOp(tensorCastOp);
-      newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
+      newOperands.push_back(fold ? tensorCastOp.getOperand()
+                                 : opOperand->get());
       newResultTypes.push_back(newOperands.back().getType());
     }
     auto extraOperands = op.getAssumedNonShapedOperands();
@@ -3189,18 +3194,18 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
     // in the case of duplicated inputs, the canonical input could be some other
     // input `< i`. That is, a later input will have some earlier input as its
     // canonical input.
-    llvm::SmallDenseMap<std::pair<Value, AffineMap>, int> canonicalInput;
+    llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
     // For later remapping tasks like deduplicating payload block arguments,
     // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
     // convenient.
-    SmallVector<int, 6> canonicalInputIndices;
-    for (int i = 0, e = op.getNumInputs(); i != e; i++) {
-      Value input = op.getInput(i);
-      AffineMap indexingMap = op.getInputIndexingMap(i);
+    SmallVector<unsigned> canonicalInputIndices;
+    for (OpOperand *opOperand : op.getInputOperands()) {
+      AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
       // STL-like maps have a convenient behavior for our use case here. In the
       // case of duplicate keys, the insertion is rejected, and the returned
       // iterator gives access to the value already in the map.
-      auto pair = canonicalInput.insert({{input, indexingMap}, i});
+      auto pair = canonicalInput.insert(
+          {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
       canonicalInputIndices.push_back(pair.first->second);
     }
 
@@ -3209,26 +3214,29 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
       return failure();
 
     // The operands for the newly canonicalized op.
-    SmallVector<Value, 6> newOperands;
-    for (auto v : llvm::enumerate(op.getInputs()))
-      if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
-        newOperands.push_back(v.value());
-    llvm::append_range(newOperands, op.getOutputs());
+    SmallVector<Value> newOperands;
+    for (OpOperand *opOperand : op.getInputOperands())
+      if (canonicalInputIndices[opOperand->getOperandNumber()] ==
+          opOperand->getOperandNumber())
+        newOperands.push_back(opOperand->get());
+    SmallVector<Value> outputOperands = op.getOutputOperands();
+    llvm::append_range(newOperands, outputOperands);
     llvm::append_range(newOperands, op.getAssumedNonShapedOperands());
 
+    // Repair the indexing maps by filtering out the ones that have been
+    // eliminated.
+    SmallVector<AffineMap> newIndexingMaps;
+    for (OpOperand *opOperand : op.getInputOperands())
+      if (canonicalInputIndices[opOperand->getOperandNumber()] ==
+          opOperand->getOperandNumber())
+        newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
+    for (OpOperand *opOperand : op.getOutputOperands())
+      newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
+
     // Clone the old op with new operands.
     Operation *newOp =
         op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
     auto newLinalgOp = cast<LinalgOp>(newOp);
-
-    // Repair the indexing maps by filtering out the ones that have been
-    // eliminated.
-    SmallVector<AffineMap, 6> newIndexingMaps;
-    for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++)
-      if (canonicalInputIndices[i] == i)
-        newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i));
-    for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++)
-      newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i));
     newOp->setAttr("indexing_maps",
                    rewriter.getAffineMapArrayAttr(newIndexingMaps));
 
@@ -3243,18 +3251,18 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
     // Repair the payload entry block by RAUW'ing redundant arguments and
     // erasing them.
     Block &payload = newOp->getRegion(0).front();
-    for (int i = 0, e = op.getNumInputs(); i < e; i++) {
+    SmallVector<OpOperand *> inputOperands = op.getInputOperands();
+    for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
       // Iterate in reverse, so that we erase later args first, preventing the
       // argument list from shifting unexpectedly and invalidating all our
       // indices.
-      int reversed = e - i - 1;
-      int canonicalIndex = canonicalInputIndices[reversed];
-      if (canonicalInputIndices[reversed] == reversed)
+      unsigned operandNumber = opOperand->getOperandNumber();
+      if (canonicalInputIndices[operandNumber] == operandNumber)
         continue;
-      payload.getArgument(bbArgBaseOffset + reversed)
-          .replaceAllUsesWith(
-              payload.getArgument(bbArgBaseOffset + canonicalIndex));
-      payload.eraseArgument(bbArgBaseOffset + reversed);
+      payload.getArgument(bbArgBaseOffset + operandNumber)
+          .replaceAllUsesWith(payload.getArgument(
+              bbArgBaseOffset + canonicalInputIndices[operandNumber]));
+      payload.eraseArgument(bbArgBaseOffset + operandNumber);
     }
 
     rewriter.replaceOp(op, newOp->getResults());


        


More information about the Mlir-commits mailing list