[Mlir-commits] [mlir] f84b908 - [mlir][linalg] Cleanup LinalgOp usage in fusion on tensors (NFC).

Tobias Gysi llvmlistbot at llvm.org
Wed Jun 2 05:21:31 PDT 2021


Author: Tobias Gysi
Date: 2021-06-02T12:20:45Z
New Revision: f84b908f89af76002112acbf915ab0677b99c01c

URL: https://github.com/llvm/llvm-project/commit/f84b908f89af76002112acbf915ab0677b99c01c
DIFF: https://github.com/llvm/llvm-project/commit/f84b908f89af76002112acbf915ab0677b99c01c.diff

LOG: [mlir][linalg] Cleanup LinalgOp usage in fusion on tensors (NFC).

Replace the uses of deprecated Structured Op Interface methods in FusionOnTensors.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103471

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 6ee4d765d5f8d..9b2292f46c3a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -28,7 +28,7 @@ using namespace mlir::linalg;
 
 /// Conditions for elementwise fusion of generic operations.
 static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
-                                     unsigned consumerIdx) {
+                                     OpOperand *consumerOpOperand) {
   // Producer and consumer must have tensor semantics.
   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
     return false;
@@ -40,12 +40,12 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
 
   // Only allow fusing the producer of an input operand for now.
   // TODO: allow fusing the producer of an output operand.
-  if (consumerIdx >= consumer.getNumInputs())
+  if (!consumer.isInputTensor(consumerOpOperand))
     return false;
 
   // Get the consumer index map. The number of results of the consumer index
   // map must match the number of loops of the producer.
-  AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
+  AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
     return false;
 
@@ -55,7 +55,8 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
 
   // Finally the index_map for the result must be invertible. For now just
   // verify it is a permutation.
-  AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+  AffineMap producerResultIndexMap =
+      producer.getTiedIndexingMap(producer.getOutputOperand(0));
   return producerResultIndexMap.isPermutation();
 }
 
@@ -63,7 +64,7 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
 /// the `producer` to use in the fused operation given the indexing map of the
 /// result of the producer in the consumer.
 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
-    OpOperand &producerOpOperand, AffineMap producerResultIndexMap,
+    OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
     AffineMap fusedConsumerArgIndexMap) {
   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
   // from consumer loop -> consumer arg tensor index/producer result tensor
@@ -78,10 +79,9 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
   assert(invProducerResultIndexMap &&
          "expected producer result indexig map to be invertible");
 
-  LinalgOp producer = cast<LinalgOp>(producerOpOperand.getOwner());
+  LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
   // argMap is a map from producer loop -> producer arg tensor index.
-  AffineMap argMap =
-      producer.getIndexingMap(producerOpOperand.getOperandNumber());
+  AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
 
   // Compose argMap with invProducerResultIndexMap to get a map from
   // producer result tensor index -> producer arg tensor index.
@@ -96,9 +96,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
 /// op must be empty.
 static void
 generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
-                                 GenericOp producer, GenericOp consumer,
                                  AffineMap consumerToProducerLoopsMap,
-                                 unsigned consumerIdx, unsigned nloops) {
+                                 OpOperand *consumerOpOperand,
+                                 unsigned nloops) {
+  auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
+  auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
   // Build the region of the fused op.
   Block &producerBlock = producer->getRegion(0).front();
   Block &consumerBlock = consumer->getRegion(0).front();
@@ -129,11 +131,11 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
     }
   }
   // TODO: allow fusing the producer of an output operand.
-  assert(consumerIdx < consumer.getNumInputs() &&
+  assert(consumer.isInputTensor(consumerOpOperand) &&
          "expected producer of input operand");
   // 3. Consumer input operands up to consumerIdx (exclusive).
   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
-           consumerIdx)) // input assumption.
+           consumerOpOperand->getOperandNumber())) // input assumption.
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
 
   // Replacing consumerIdx requires getting the cloned, yielded, value from
@@ -147,7 +149,7 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
   // 4.b. Producer output operand/map that is fused needs to be mapped to the
   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
   assert(producer->getNumResults() == 1 && "expected single result producer");
-  if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
+  if (producer.isInitTensor(producer.getOutputOperand(0))) {
     BlockArgument bbArg = producerBlock.getArguments()
                               .drop_front(producer.getNumInputs())
                               // TODO: bbArg index of
@@ -155,9 +157,10 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
   }
   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
-  for (BlockArgument bbArg : consumerBlock.getArguments()
-                                 .take_front(consumer.getNumInputs())
-                                 .drop_front(consumerIdx + 1))
+  for (BlockArgument bbArg :
+       consumerBlock.getArguments()
+           .take_front(consumer.getNumInputs())
+           .drop_front(consumerOpOperand->getOperandNumber() + 1))
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
   // 6. All of consumer's output operands.
   for (BlockArgument bbArg :
@@ -191,7 +194,8 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
              "yielded value must have been mapped");
   }
-  mapper.map(consumerBlock.getArgument(consumerIdx), replacement);
+  mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
+             replacement);
   // 10. Clone operations from the consumer to the fused op.
   for (auto &op : consumerBlock.getOperations())
     rewriter.clone(op, mapper);
@@ -202,17 +206,16 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
 }
 
 static Optional<SmallVector<Value>>
-fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
+fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
                        const ControlElementwiseOpsFusionFn &controlFn,
                        PatternRewriter &rewriter) {
-  auto consumer = cast<GenericOp>(consumerOpOperand.getOwner());
-  unsigned consumerIdx = consumerOpOperand.getOperandNumber();
-  if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) ||
-      !controlFn(producer->getResult(0), consumerOpOperand))
+  auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
+  if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
+      !controlFn(producer->getResult(0), *consumerOpOperand))
     return llvm::None;
 
   // TODO: allow fusing the producer of an output operand.
-  assert(consumerIdx < consumer.getNumInputs() &&
+  assert(consumer.isInputTensor(consumerOpOperand) &&
          "expected producer of input operand");
 
   // Compute the fused operands list and indexing maps.
@@ -224,62 +227,66 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
                          consumer->getNumOperands());
   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
-  llvm::append_range(fusedOperands,
-                     consumer.getInputs().take_front(consumerIdx));
-  llvm::append_range(
-      fusedIndexMaps,
-      ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.take_front(
-          consumerIdx));
+  SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
+  SmallVector<OpOperand *>::iterator it =
+      llvm::find(consumerInputs, consumerOpOperand);
+  assert(it != consumerInputs.end() && "expected to find the consumer operand");
+  for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
+    fusedOperands.push_back(opOperand->get());
+    fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
+  }
   // 4. Splice in producer's input operands/maps.
-  llvm::append_range(fusedOperands, producer.getInputs());
   assert(producer->getNumResults() == 1 && "expected single result producer");
-  AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
-  for (auto &inputOpOperand : producer.getInputOpOperands()) {
+  AffineMap producerResultIndexMap =
+      producer.getTiedIndexingMap(producer.getOutputOperand(0));
+  for (OpOperand *opOperand : producer.getInputOperands()) {
+    fusedOperands.push_back(opOperand->get());
     // Compute indexing maps for the producer args in the fused operation.
     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
-        inputOpOperand, producerResultIndexMap,
-        consumer.getInputIndexingMap(consumerIdx));
+        opOperand, producerResultIndexMap,
+        consumer.getTiedIndexingMap(consumerOpOperand));
     fusedIndexMaps.push_back(map);
   }
   // 4.b. Producer output operand/map that is fused needs to be passed if it is
   // an "initTensor" (i.e. its value is actually read).
   assert(producer->getNumResults() == 1 && "expected single result producer");
-  if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
-    llvm::append_range(fusedOperands, producer.getOutputs().take_front());
+  if (producer.isInitTensor(producer.getOutputOperand(0))) {
+    fusedOperands.push_back(producer.getOutputOperand(0)->get());
     // Compute indexing maps for the producer args in the fused operation.
     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
-        producer.getOutputOpOperands().front(), producerResultIndexMap,
-        consumer.getOutputIndexingMap(0));
+        producer.getOutputOperand(0), producerResultIndexMap,
+        consumer.getTiedIndexingMap(consumerOpOperand));
     fusedIndexMaps.push_back(map);
   }
   // 5. Remaining consumer's input operands/maps (drop past index
   // `consumerIdx`).
-  llvm::append_range(fusedOperands,
-                     consumer.getInputs().drop_front(consumerIdx + 1));
-  llvm::append_range(
-      fusedIndexMaps,
-      ArrayRef<AffineMap>{consumer.getInputIndexingMaps()}.drop_front(
-          consumerIdx + 1));
+  for (OpOperand *opOperand :
+       llvm::make_range(std::next(it), consumerInputs.end())) {
+    fusedOperands.push_back(opOperand->get());
+    fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
+  }
   // 6. All of consumer's output operands (skip operands: added by the builder).
-  // llvm::append_range(fusedOperands, consumer.getOutputs());
-  llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps());
+  for (OpOperand *opOperand : consumer.getOutputOperands())
+    fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
   // 7. All of producer's output operands/maps except the one fused.
   // TODO: allow fusion of multi-result producers.
   assert(producer->getNumResults() == 1 && "expected single result producer");
 
   // Generate the fused op.
+  SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
   auto fusedOp = rewriter.create<GenericOp>(
       consumer.getLoc(), consumer->getResultTypes(),
       /*inputs=*/fusedOperands,
       // TODO: handle outputs.
-      consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+      consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
       consumer.iterator_types(),
       /*doc=*/nullptr,
       /*library_call=*/nullptr);
 
   // Construct an AffineMap from consumer loops to producer loops.
   // consumer loop -> tensor index
-  AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
+  AffineMap consumerResultIndexMap =
+      consumer.getTiedIndexingMap(consumerOpOperand);
   // tensor index -> producer loop
   AffineMap invProducerResultIndexMap =
       inversePermutation(producerResultIndexMap);
@@ -289,9 +296,9 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
   AffineMap consumerToProducerLoopsMap =
       invProducerResultIndexMap.compose(consumerResultIndexMap);
 
-  generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
-                                   consumerToProducerLoopsMap, consumerIdx,
-                                   consumer.getNumLoops());
+  generateFusedElementwiseOpRegion(rewriter, fusedOp,
+                                   consumerToProducerLoopsMap,
+                                   consumerOpOperand, consumer.getNumLoops());
   return SmallVector<Value>(fusedOp->getResults());
 }
 
@@ -449,7 +456,7 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
 ///  The added reshapes are again expanding patterns, so they will get fused
 ///  with its producers if possible.
 static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
-                                               unsigned fusedTensorIndex) {
+                                               OpOperand *fusableOpOperand) {
   // Is fusable only if:
   // - All the indexing maps for operands and results are projected
   //   permutations.
@@ -462,7 +469,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         genericOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
+         genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
          llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
            return attr.cast<StringAttr>().getValue() ==
                   getParallelIteratorTypeName();
@@ -478,7 +485,7 @@ class ExpansionInfo {
   // of the expanded op given the `indexingMap` of the fused operand/result of
   // the generic op, the `reassocationMaps` of the reshape op and the shape of
   // the expanded op.
-  LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex,
+  LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
                         ArrayRef<AffineMap> reassociationMaps,
                         ArrayRef<int64_t> expandedShape,
                         PatternRewriter &rewriter);
@@ -503,13 +510,13 @@ class ExpansionInfo {
 } // namespace
 
 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
-                                     unsigned fusedTensorIndex,
+                                     OpOperand *fusableOpOperand,
                                      ArrayRef<AffineMap> reassociationMaps,
                                      ArrayRef<int64_t> expandedShape,
                                      PatternRewriter &rewriter) {
   if (reassociationMaps.empty())
     return failure();
-  AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
+  AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
 
   Optional<SmallVector<int64_t, 4>> originalLoopRange =
       linalgOp.getStaticLoopRanges();
@@ -676,9 +683,9 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
 /// been satisfied.
 static Optional<SmallVector<Value>>
 fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
-                           unsigned fusedTensorIndex,
+                           OpOperand *fusableOpOperand,
                            PatternRewriter &rewriter) {
-  assert(isFusableWithReshapeByDimExpansion(genericOp, fusedTensorIndex) &&
+  assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
          "preconditions for fuse operation failed");
   // Check if reshape is expanding or collapsing.
   bool isExpanding =
@@ -687,7 +694,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
       isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
 
   ExpansionInfo expansionInfo;
-  if (failed(expansionInfo.compute(genericOp, fusedTensorIndex,
+  if (failed(expansionInfo.compute(genericOp, fusableOpOperand,
                                    reshapeOp.getReassociationMaps(),
                                    expandedType.getShape(), rewriter)))
     return llvm::None;
@@ -701,39 +708,39 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
       }));
 
   SmallVector<Value> expandedOpOperands;
-  for (auto operand : llvm::enumerate(genericOp.getInputs())) {
-    if (operand.index() == fusedTensorIndex) {
+  for (OpOperand *opOperand : genericOp.getInputOperands()) {
+    if (opOperand == fusableOpOperand) {
       expandedOpOperands.push_back(reshapeOp.src());
       continue;
     }
-    AffineMap indexingMap = genericOp.getInputIndexingMap(operand.index());
+    AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
     RankedTensorType expandedOperandType =
-        getExpandedType(operand.value().getType().cast<RankedTensorType>(),
+        getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
                         indexingMap, expansionInfo);
-    if (expandedOperandType != operand.value().getType()) {
+    if (expandedOperandType != opOperand->get().getType()) {
       // Reshape the operand to get the right type.
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(indexingMap, expansionInfo);
       expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
-          genericOp.getLoc(), expandedOperandType, operand.value(),
+          genericOp.getLoc(), expandedOperandType, opOperand->get(),
           reassociation));
       continue;
     }
-    expandedOpOperands.push_back(operand.value());
+    expandedOpOperands.push_back(opOperand->get());
   }
 
   Location loc = genericOp.getLoc();
   SmallVector<Value> outputs;
-  for (auto result : llvm::enumerate(genericOp.getOutputs())) {
-    AffineMap indexingMap = genericOp.getOutputIndexingMap(result.index());
+  for (OpOperand *opOperand : genericOp.getOutputOperands()) {
+    AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
     RankedTensorType expandedOutputType =
-        getExpandedType(result.value().getType().cast<RankedTensorType>(),
+        getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
                         indexingMap, expansionInfo);
-    if (expandedOutputType != result.value().getType()) {
+    if (expandedOutputType != opOperand->get().getType()) {
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(indexingMap, expansionInfo);
       outputs.push_back(rewriter.create<TensorReshapeOp>(
-          genericOp.getLoc(), expandedOutputType, result.value(),
+          genericOp.getLoc(), expandedOutputType, opOperand->get(),
           reassociation));
     }
   }
@@ -757,17 +764,19 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;
-  for (auto result : llvm::enumerate(genericOp->getResults())) {
-    if (!isExpanding &&
-        resultTypes[result.index()] != result.value().getType()) {
+  for (OpResult opResult : genericOp->getOpResults()) {
+    int64_t resultNumber = opResult.getResultNumber();
+    if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(
-              genericOp.getOutputIndexingMap(result.index()), expansionInfo);
+              genericOp.getTiedIndexingMap(
+                  genericOp.getOutputOperand(resultNumber)),
+              expansionInfo);
       resultVals.push_back(rewriter.create<TensorReshapeOp>(
-          genericOp.getLoc(), result.value().getType(),
-          fusedOp->getResult(result.index()), reassociation));
+          genericOp.getLoc(), opResult.getType(),
+          fusedOp->getResult(resultNumber), reassociation));
     } else {
-      resultVals.push_back(fusedOp->getResult(result.index()));
+      resultVals.push_back(fusedOp->getResult(resultNumber));
     }
   }
   // Assuming a single result.
@@ -809,12 +818,13 @@ struct FoldProducerReshapeOpByLinearization
                                 PatternRewriter &rewriter) const override {
     if (!genericOp.hasTensorSemantics())
       return failure();
-    for (auto operand : llvm::enumerate(genericOp.getInputs())) {
+    SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
+    for (auto en : llvm::enumerate(inputOperands)) {
       TensorReshapeOp reshapeOp =
-          operand.value().getDefiningOp<TensorReshapeOp>();
+          en.value()->get().getDefiningOp<TensorReshapeOp>();
       if (!reshapeOp ||
           !isTensorReshapeOpFoldableByLinearization(
-              reshapeOp, genericOp.getInputIndexingMap(operand.index()),
+              reshapeOp, genericOp.getTiedIndexingMap(en.value()),
               /*asProducer =*/true) ||
           (foldUnitDimReshapesOnly &&
            !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
@@ -822,18 +832,17 @@ struct FoldProducerReshapeOpByLinearization
         continue;
 
       // Compute the fused operands list,
-      SmallVector<Value> fusedOperands(genericOp.getInputs());
-      fusedOperands[operand.index()] = reshapeOp.src();
-      fusedOperands.append(genericOp.getOutputs().begin(),
-                           genericOp.getOutputs().end());
+      SmallVector<Value> fusedOperands = genericOp.getInputOperands();
+      fusedOperands[en.index()] = reshapeOp.src();
+      SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+      llvm::append_range(fusedOperands, outputOperands);
 
       // Compute indexing_maps for the fused operation. The indexing_maps for
       // the operands of the consumers that arent fused are the same.
-      SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
-          genericOp.indexing_maps().template getAsValueRange<AffineMapAttr>());
+      SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
 
       // Accepted consumer maps are either identity or permutation.
-      auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
+      auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
 
       // Compute the indexing map to use for the result of the producer.
       AffineMap modifiedMap =
@@ -843,7 +852,7 @@ struct FoldProducerReshapeOpByLinearization
         if (!expr.isPureAffine())
           return failure();
       }
-      fusedIndexMaps[operand.index()] = modifiedMap;
+      fusedIndexMaps[en.index()] = modifiedMap;
 
       // Further check that the resulting index maps can be fused and
       // inverted. Without this the resultant op is not legal.
@@ -917,35 +926,36 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
       return failure();
     // Only support identity output maps. It could be extended to permuations if
     // needed.
-    if (llvm::any_of(genericOp.getOutputIndexingMaps(),
-                     [](AffineMap map) { return !map.isIdentity(); }))
+    if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
+          return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
+        }))
       return failure();
     int64_t destRank = genericOp.getNumParallelLoops();
-    SmallVector<Value, 4> newOperands =
-        llvm::to_vector<4>(genericOp.getInputs());
+    SmallVector<Value> newOperands = genericOp.getInputOperands();
     TensorReshapeOp reshapeFound;
     // 1. Look for tensor_reshape operands and figure out save the dimensions
     // merged.
-    for (auto operand : llvm::enumerate(genericOp.getInputs())) {
+    SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
+    for (auto en : llvm::enumerate(inputOperands)) {
       TensorReshapeOp reshapeOp =
-          operand.value().template getDefiningOp<TensorReshapeOp>();
+          en.value()->get().template getDefiningOp<TensorReshapeOp>();
       if (!reshapeOp || reshapeOp.getSrcType().getRank() >
                             reshapeOp.getResultType().getRank()) {
         continue;
       }
       // TODO: We could support non-identity map as long as the merged
       // dimensions are still contiguous.
-      if (!genericOp.getIndexingMaps()[operand.index()].isIdentity())
+      if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
         continue;
       if (reshapeFound) {
         // Only support a second reshape op if it has the same reassociate maps.
         if (reshapeFound.getReassociationMaps() ==
             reshapeOp.getReassociationMaps())
-          newOperands[operand.index()] = reshapeOp.src();
+          newOperands[en.index()] = reshapeOp.src();
         continue;
       }
       reshapeFound = reshapeOp;
-      newOperands[operand.index()] = reshapeOp.src();
+      newOperands[en.index()] = reshapeOp.src();
     }
     if (!reshapeFound)
       return failure();
@@ -962,9 +972,9 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
     // 2. Verify that we can merge the dimensions in the linalg and that we
     // don't need to create new reshapes operands. Inserting new reshape
     // operands would defeat the purpose of the transformation.
-    for (auto operand : llvm::enumerate(genericOp.getInputs())) {
-      if (operand.value() == newOperands[operand.index()]) {
-        AffineMap map = genericOp.getIndexingMaps()[operand.index()];
+    for (auto en : llvm::enumerate(inputOperands)) {
+      if (en.value()->get() == newOperands[en.index()]) {
+        AffineMap map = genericOp.getTiedIndexingMap(en.value());
         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
             return failure();
@@ -1036,9 +1046,9 @@ class FoldWithProducerReshapeOpByExpansion
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    for (auto operand : llvm::enumerate(genericOp.getInputs())) {
+    for (OpOperand *opOperand : genericOp.getInputOperands()) {
       TensorReshapeOp reshapeOp =
-          operand.value().getDefiningOp<TensorReshapeOp>();
+          opOperand->get().getDefiningOp<TensorReshapeOp>();
       if (!reshapeOp)
         continue;
       // Fold only if
@@ -1046,15 +1056,12 @@ class FoldWithProducerReshapeOpByExpansion
       // - All constraints of fusing with reshape by expansion are met.
       if (reshapeOp.getSrcType().getRank() <
               reshapeOp.getResultType().getRank() ||
-          !isFusableWithReshapeByDimExpansion(genericOp, operand.index()) ||
-          (!controlFoldingReshapes(
-              reshapeOp->getResult(0),
-              genericOp.getInputOpOperands()[operand.index()])))
+          !isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
+          (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
         continue;
 
       Optional<SmallVector<Value>> replacementValues =
-          fuseWithReshapeByExpansion(genericOp, reshapeOp, operand.index(),
-                                     rewriter);
+          fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
       if (!replacementValues)
         return failure();
       rewriter.replaceOp(genericOp, replacementValues.getValue());
@@ -1080,7 +1087,8 @@ struct FoldConsumerReshapeOpByLinearization
     if (!producer || !producer.hasTensorSemantics() ||
         producer.getNumOutputs() != 1 ||
         !isTensorReshapeOpFoldableByLinearization(
-            reshapeOp, producer.getOutputIndexingMap(0),
+            reshapeOp,
+            producer.getTiedIndexingMap(producer.getOutputOperand(0)),
             /*asProducer =*/false) ||
         (foldUnitDimReshapesOnly &&
          !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
@@ -1088,10 +1096,10 @@ struct FoldConsumerReshapeOpByLinearization
       return failure();
     // The indexing_maps for the operands of the fused operation are same as
     // those for the operands of the producer.
-    SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
-        producer.indexing_maps().getAsValueRange<AffineMapAttr>());
+    SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
 
-    auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
+    auto invMap = inversePermutation(
+        producer.getTiedIndexingMap(producer.getOutputOperand(0)));
 
     // Compute the indexing map to use for the operand of the producer.
     AffineMap modifiedMap =
@@ -1113,11 +1121,13 @@ struct FoldConsumerReshapeOpByLinearization
     }
 
     Location loc = producer.getLoc();
+    SmallVector<Value> inputOperands = producer.getInputOperands();
     Value output = rewriter.create<TensorReshapeOp>(
-        loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
+        loc, producer.getOutputOperand(0)->get(),
+        reshapeOp.getReassociationExprs());
     auto fusedOp = rewriter.create<GenericOp>(
         loc, reshapeOp.getResultType(),
-        /*inputs=*/producer.getInputs(),
+        /*inputs=*/inputOperands,
         // TODO: handle outputs.
         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
         producer.iterator_types(),
@@ -1147,12 +1157,12 @@ struct FoldReshapeWithGenericOpByExpansion
     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
     if (!producer || producer.getNumOutputs() != 1 ||
         !isFusableWithReshapeByDimExpansion(producer,
-                                            producer.getNumInputs()) ||
+                                            producer.getOutputOperand(0)) ||
         isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
                                reshapeOp.getReassociationMaps()))
       return failure();
     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
-        producer, reshapeOp, producer.getNumInputs(), rewriter);
+        producer, reshapeOp, producer.getOutputOperand(0), rewriter);
     if (!replacementValues)
       return failure();
     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
@@ -1171,21 +1181,29 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
                                 PatternRewriter &rewriter) const override {
     if (!genericOp.hasTensorSemantics())
       return failure();
-    for (auto operand : llvm::enumerate(genericOp.getInputOpOperands())) {
-      Operation *def = operand.value().get().getDefiningOp();
+    for (OpOperand *opOperand : genericOp.getInputOperands()) {
+      Operation *def = opOperand->get().getDefiningOp();
       DenseElementsAttr constantAttr;
       if (!def ||
           !matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
-          !constantAttr.isSplat() ||
-          !controlFn(def->getResult(0), operand.value()))
+          !constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand))
         continue;
 
-      // The indexing_maps for the operands of the fused operation are same as
-      // those for the operands of the genericOp without the indexing map at
-      // operand.index()
-      SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
-          genericOp.indexing_maps().getAsValueRange<AffineMapAttr>());
-      fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
+      // The operands and the indexing_maps of the fused operation the same as
+      // the operands and indexing_maps of the generic operations with the
+      // values at the constant index dropped.
+      SmallVector<AffineMap> fusedIndexMaps;
+      SmallVector<Value> fusedOperands;
+      fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
+      fusedOperands.reserve(genericOp.getNumInputs());
+      for (OpOperand *inputOperand : genericOp.getInputOperands()) {
+        if (inputOperand == opOperand)
+          continue;
+        fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
+        fusedOperands.push_back(inputOperand->get());
+      }
+      for (OpOperand *outputOperand : genericOp.getOutputOperands())
+        fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
 
       // Check if the operation shapes to loops map is computable.
       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
@@ -1193,20 +1211,16 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
             genericOp, "fused op loop bound computation failed");
       }
 
-      // The operands list is same as the genericOp with the argument for
-      // constant index dropped.
-      SmallVector<Value> fusedOperands(genericOp.getInputs());
-      fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
-
       // Create a constant scalar value from the splat constant.
       Value scalarConstant = rewriter.create<ConstantOp>(
           def->getLoc(), constantAttr.getSplatValue(),
           constantAttr.getType().getElementType());
 
+      SmallVector<Value> outputOperands = genericOp.getOutputOperands();
       auto fusedOp = rewriter.create<GenericOp>(
           rewriter.getUnknownLoc(), genericOp->getResultTypes(),
           /*inputs=*/fusedOperands,
-          /*outputs=*/genericOp.getOutputs(),
+          /*outputs=*/outputOperands,
           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
           genericOp.iterator_types(),
           /*doc=*/nullptr,
@@ -1217,7 +1231,8 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
       Region &region = genericOp->getRegion(0);
       Block &entryBlock = *region.begin();
       BlockAndValueMapping mapping;
-      mapping.map(entryBlock.getArgument(operand.index()), scalarConstant);
+      mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
+                  scalarConstant);
       Region &fusedRegion = fusedOp->getRegion(0);
       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
                                  mapping);
@@ -1233,7 +1248,7 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
 } // namespace
 
 static Optional<SmallVector<Value>>
-fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
+fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
                    GenericOp producer,
                    const ControlElementwiseOpsFusionFn &controlFn) {
   if (producer->getNumResults() != 1)
@@ -1261,9 +1276,9 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
-    for (OpOperand &opOperand : genericOp.getShapedOpOperands()) {
+    for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
       auto producer =
-          dyn_cast_or_null<GenericOp>(opOperand.get().getDefiningOp());
+          dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
       if (!producer || !producer.hasTensorSemantics())
         continue;
       Optional<SmallVector<Value>> fusedOpResults =
@@ -1322,9 +1337,9 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
     rewriter.startRootUpdate(op);
     bool modifiedOutput = false;
     Location loc = op.getLoc();
-    for (OpOperand &opOperand : op.getOutputOpOperands()) {
-      if (!op.payloadUsesValueFromOpOperand(&opOperand)) {
-        Value operandVal = opOperand.get();
+    for (OpOperand *opOperand : op.getOutputOperands()) {
+      if (!op.payloadUsesValueFromOperand(opOperand)) {
+        Value operandVal = opOperand->get();
         auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
         if (!operandType)
           continue;
@@ -1344,7 +1359,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
         Value initTensor = rewriter.create<InitTensorOp>(
             loc, dynamicDims, operandType.getShape(),
             operandType.getElementType());
-        op->setOperand(opOperand.getOperandNumber(), initTensor);
+        op->setOperand(opOperand->getOperandNumber(), initTensor);
       }
     }
     if (!modifiedOutput) {


        


More information about the Mlir-commits mailing list