[Mlir-commits] [mlir] a7bfdc2 - [mlir][Linalg] Handle multi-result operations in Elementwise op fusion.

Mahesh Ravishankar llvmlistbot at llvm.org
Tue Aug 23 22:58:00 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-08-24T05:57:30Z
New Revision: a7bfdc23ab3ade54da99f0f59dababe4d71ae75b

URL: https://github.com/llvm/llvm-project/commit/a7bfdc23ab3ade54da99f0f59dababe4d71ae75b
DIFF: https://github.com/llvm/llvm-project/commit/a7bfdc23ab3ade54da99f0f59dababe4d71ae75b.diff

LOG: [mlir][Linalg] Handle multi-result operations in Elementwise op fusion.

This drops the artificial requirement of producers having a single
result value to be able to fuse with consumers.

The current default also only fuses producer with consumer when the
producer has a single use. This is a simplifying assumption. There are
legitimate use cases where a producer can be fused with consumer and
the fused o pcould be used to replace the uses of the producer as
well. This needs to be done with care to avoid use-def violations. To
allow for downstream users to explore more fusion opportunities, the
core transformation method is exposed as a utility function.

This patch also modifies the control function to take just the fused
operand as the argument. This is enough information for the callers to
get the producer and the consumer operations being considered to
fuse. It also provides information of which producer result is used.

Differential Revision: https://reviews.llvm.org/D132301

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
    mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
    mlir/test/Dialect/Linalg/reshape_fusion.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8f53d5a8fea8c..abd0243c7119a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -66,8 +66,7 @@ void populateSparseTensorRewriting(RewritePatternSet &patterns);
 /// Function type which is used to control when to stop fusion. It is expected
 /// that OpOperand is not modified in the callback. The OpOperand is not marked
 /// as const to allow callers to use non-const methods.
-using ControlFusionFn =
-    std::function<bool(const OpResult &producer, OpOperand &consumer)>;
+using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
 
 /// Patterns for fusing linalg operation on tensors.
 
@@ -111,6 +110,17 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
 /// Patterns that are used to bubble up extract slice op above linalg op.
 void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
 
+/// Return true if two `linalg.generic` operations with producer/consumer
+/// relationship through `fusedOperand` can be fused using elementwise op
+/// fusion.
+bool areElementwiseOpsFusable(OpOperand *fusedOperand);
+
+/// Fuse two `linalg.generic` operations that have a producer-consumer
+/// relationship captured through `fusedOperand`. The method expects
+/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
+FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
+                                          OpOperand *fusedOperand);
+
 /// Split the given `op` into two parts along the given iteration space
 /// `dimension` at the specified `splitPoint`, and return the two parts.
 ///

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 05421bf0a0ac0..e46f7aefde1a6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -122,10 +122,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
     // Identified this as a potential candidate for folding. Now check the
     // policy to see whether we are allowed to proceed.
-    for (int i = 0; i < numInputs; ++i) {
-      OpOperand *consumer = genericOp.getInputOperand(i);
-      OpResult producer = consumer->get().cast<OpResult>();
-      if (!controlFn(producer, *consumer))
+    for (auto operand : genericOp.getInputOperands()) {
+      if (!controlFn(operand))
         return failure();
     }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index d19f926da8660..e3d121c32c1c2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -65,8 +65,14 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
 }
 
 /// Conditions for elementwise fusion of generic operations.
-static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
-                                     OpOperand *consumerOpOperand) {
+bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
+  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
+
+  // Check producer and consumer are generic ops.
+  if (!producer || !consumer)
+    return false;
+
   // Producer and consumer must have tensor semantics.
   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
     return false;
@@ -78,19 +84,15 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
 
   // Only allow fusing the producer of an input operand for now.
   // TODO: allow fusing the producer of an output operand.
-  if (!consumer.isInputTensor(consumerOpOperand))
+  if (!consumer.isInputTensor(fusedOperand))
     return false;
 
   // Get the consumer index map. The number of results of the consumer index
   // map must match the number of loops of the producer.
-  AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
+  AffineMap consumerIndexMap = consumer.getTiedIndexingMap(fusedOperand);
   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
     return false;
 
-  // Currently support only operations with single result.
-  if (producer.getNumOutputs() != 1)
-    return false;
-
   // Finally the index_map for the result must be invertible. For now just
   // verify it is a permutation.
   AffineMap producerResultIndexMap =
@@ -114,7 +116,7 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
     for (auto pair :
          llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
       Value operand = std::get<0>(pair);
-      if (operand == consumerOpOperand->get())
+      if (operand == fusedOperand->get())
         continue;
       AffineMap operandMap = std::get<1>(pair);
       addToCoveredDims(operandMap);
@@ -136,12 +138,11 @@ static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
 /// Generate the region of the fused tensor operation. The region of the fused
 /// op must be empty.
 static void
-generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
+generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
                                  AffineMap consumerToProducerLoopsMap,
-                                 OpOperand *consumerOpOperand,
-                                 unsigned nloops) {
-  auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
-  auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
+                                 OpOperand *fusedOperand, unsigned nloops) {
+  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
+  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
   // Build the region of the fused op.
   Block &producerBlock = producer->getRegion(0).front();
   Block &consumerBlock = consumer->getRegion(0).front();
@@ -172,11 +173,11 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
     }
   }
   // TODO: allow fusing the producer of an output operand.
-  assert(consumer.isInputTensor(consumerOpOperand) &&
+  assert(consumer.isInputTensor(fusedOperand) &&
          "expected producer of input operand");
   // 3. Consumer input operands up to consumerIdx (exclusive).
   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
-           consumerOpOperand->getOperandNumber())) // input assumption.
+           fusedOperand->getOperandNumber())) // input assumption.
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
 
   // Replacing consumerIdx requires getting the cloned, yielded, value from
@@ -187,29 +188,22 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
        producerBlock.getArguments().take_front(producer.getNumInputs()))
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
 
-  // 4.b. Producer output operand/map that is fused needs to be mapped to the
-  // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
-  assert(producer->getNumResults() == 1 && "expected single result producer");
-  if (producer.isInitTensor(producer.getOutputOperand(0))) {
-    BlockArgument bbArg = producerBlock.getArguments()
-                              .drop_front(producer.getNumInputs())
-                              // TODO: bbArg index of
-                              .front();
-    mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
-  }
   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
   for (BlockArgument bbArg :
        consumerBlock.getArguments()
            .take_front(consumer.getNumInputs())
-           .drop_front(consumerOpOperand->getOperandNumber() + 1))
+           .drop_front(fusedOperand->getOperandNumber() + 1))
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
-  // 6. All of consumer's output operands.
+
+  // 6. All of the producer's output operands
+  for (BlockArgument bbArg :
+       producerBlock.getArguments().take_back(producer.getNumOutputs()))
+    mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
+
+  // 7. All of consumer's output operands.
   for (BlockArgument bbArg :
        consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
-  // 7. All of producer's output operands except the one fused.
-  // TODO: allow fusion of multi-result producers.
-  assert(producer->getNumResults() == 1 && "expected single result producer");
 
   // 8. Clone all producer operations except for the yield and index operations
   // to the fused operation.
@@ -219,15 +213,15 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
   }
   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
   // forward the yield operand.
-  auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
-  // TODO: allow fusion of multi-result producers.
-  assert(producer->getNumResults() == 1 && "expected single result producer");
-  unsigned producerResultNumber = 0;
+  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
+  unsigned producerResultNumber =
+      fusedOperand->get().cast<OpResult>().getResultNumber();
   Value replacement =
-      mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
+      mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
+
   // Sanity checks, if replacement is not already in the mapper then it must be
   // produced outside.
-  if (replacement == yieldOp.getOperand(producerResultNumber)) {
+  if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
     if (auto bb = replacement.dyn_cast<BlockArgument>())
       assert(bb.getOwner() != &producerBlock &&
              "yielded block argument must have been mapped");
@@ -235,91 +229,101 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
              "yielded value must have been mapped");
   }
-  mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
+  mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
              replacement);
   // 10. Clone operations from the consumer to the fused op.
-  for (auto &op : consumerBlock.getOperations())
+  for (auto &op : consumerBlock.without_terminator())
     rewriter.clone(op, mapper);
 
+  // 11. Include the final yield (which is the remapped values for all the
+  // yield)
+  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
+  SmallVector<Value> fusedYieldValues;
+  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
+                           consumerYieldOp.getNumOperands());
+  for (auto producerYieldVal : producerYieldOp.getOperands())
+    fusedYieldValues.push_back(mapper.lookupOrDefault(producerYieldVal));
+  for (auto consumerYieldVal : consumerYieldOp.getOperands())
+    fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
+  rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
+
   // Sanity checks.
   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
          "Ill-formed GenericOp region");
 }
 
-static Optional<SmallVector<Value>>
-fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
-                       const ControlFusionFn &controlFn,
-                       PatternRewriter &rewriter) {
-  auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
-  if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
-      !controlFn(producer->getResult(0), *consumerOpOperand))
-    return llvm::None;
-
+FailureOr<Operation *>
+mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
+                                 OpOperand *fusedOperand) {
+  assert(areElementwiseOpsFusable(fusedOperand) &&
+         "expected elementwise operation pre-conditions to pass");
+  auto producerResult = fusedOperand->get().cast<OpResult>();
+  auto producer = cast<GenericOp>(producerResult.getOwner());
+  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
   // TODO: allow fusing the producer of an output operand.
-  assert(consumer.isInputTensor(consumerOpOperand) &&
+  assert(consumer.isInputTensor(fusedOperand) &&
          "expected producer of input operand");
 
   // Compute the fused operands list and indexing maps.
-  SmallVector<Value> fusedOperands;
+  SmallVector<Value> fusedInputOperands, fusedOutputOperands;
+  SmallVector<Type> fusedResultTypes;
   SmallVector<AffineMap> fusedIndexMaps;
-  fusedOperands.reserve(producer->getNumOperands() +
-                        consumer->getNumOperands());
-  fusedIndexMaps.reserve(producer->getNumOperands() +
-                         consumer->getNumOperands());
+  fusedInputOperands.reserve(producer.getNumInputs() + consumer.getNumInputs());
+  fusedOutputOperands.reserve(producer.getNumOutputs() +
+                              consumer.getNumOutputs());
+  fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs());
+  fusedIndexMaps.reserve(producer.getNumInputsAndOutputs() +
+                         consumer.getNumInputsAndOutputs());
   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
   SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
   SmallVector<OpOperand *>::iterator it =
-      llvm::find(consumerInputs, consumerOpOperand);
+      llvm::find(consumerInputs, fusedOperand);
   assert(it != consumerInputs.end() && "expected to find the consumer operand");
   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
-    fusedOperands.push_back(opOperand->get());
+    fusedInputOperands.push_back(opOperand->get());
     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
   }
   // 4. Splice in producer's input operands/maps.
-  assert(producer->getNumResults() == 1 && "expected single result producer");
   AffineMap producerResultIndexMap =
-      producer.getTiedIndexingMap(producer.getOutputOperand(0));
+      producer.getTiedIndexingMapForResult(producerResult);
   for (OpOperand *opOperand : producer.getInputOperands()) {
-    fusedOperands.push_back(opOperand->get());
+    fusedInputOperands.push_back(opOperand->get());
     // Compute indexing maps for the producer args in the fused operation.
     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
         opOperand, producerResultIndexMap,
-        consumer.getTiedIndexingMap(consumerOpOperand));
-    fusedIndexMaps.push_back(map);
-  }
-  // 4.b. Producer output operand/map that is fused needs to be passed if it is
-  // an "initTensor" (i.e. its value is actually read).
-  assert(producer->getNumResults() == 1 && "expected single result producer");
-  if (producer.isInitTensor(producer.getOutputOperand(0))) {
-    fusedOperands.push_back(producer.getOutputOperand(0)->get());
-    // Compute indexing maps for the producer args in the fused operation.
-    AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
-        producer.getOutputOperand(0), producerResultIndexMap,
-        consumer.getTiedIndexingMap(consumerOpOperand));
+        consumer.getTiedIndexingMap(fusedOperand));
     fusedIndexMaps.push_back(map);
   }
   // 5. Remaining consumer's input operands/maps (drop past index
   // `consumerIdx`).
   for (OpOperand *opOperand :
        llvm::make_range(std::next(it), consumerInputs.end())) {
-    fusedOperands.push_back(opOperand->get());
+    fusedInputOperands.push_back(opOperand->get());
     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
   }
-  // 6. All of consumer's output operands (skip operands: added by the builder).
-  for (OpOperand *opOperand : consumer.getOutputOperands())
+
+  // 6. Collect all of the producer outputs.
+  for (OpOperand *opOperand : producer.getOutputOperands()) {
+    fusedOutputOperands.push_back(opOperand->get());
+    AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
+        opOperand, producerResultIndexMap,
+        consumer.getTiedIndexingMap(fusedOperand));
+    fusedIndexMaps.push_back(map);
+    fusedResultTypes.push_back(opOperand->get().getType());
+  }
+
+  // 7. All of consumer's output operands (skip operands: added by the builder).
+  for (OpOperand *opOperand : consumer.getOutputOperands()) {
+    fusedOutputOperands.push_back(opOperand->get());
     fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
-  // 7. All of producer's output operands/maps except the one fused.
-  // TODO: allow fusion of multi-result producers.
-  assert(producer->getNumResults() == 1 && "expected single result producer");
+    fusedResultTypes.push_back(opOperand->get().getType());
+  }
 
   // Generate the fused op.
-  SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
   auto fusedOp = rewriter.create<GenericOp>(
-      consumer.getLoc(), consumer->getResultTypes(),
-      /*inputs=*/fusedOperands,
-      // TODO: handle outputs.
-      consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+      consumer.getLoc(), fusedResultTypes, fusedInputOperands,
+      fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
       consumer.getIteratorTypes(),
       /*doc=*/nullptr,
       /*library_call=*/nullptr);
@@ -328,13 +332,13 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
     // in the input, but going ahead here would result in verification errors.
     // So cleanup and abort.
     rewriter.eraseOp(fusedOp);
-    return llvm::None;
+    return rewriter.notifyMatchFailure(
+        fusedOp, "fused op failed loop bound computation check");
   }
 
   // Construct an AffineMap from consumer loops to producer loops.
   // consumer loop -> tensor index
-  AffineMap consumerResultIndexMap =
-      consumer.getTiedIndexingMap(consumerOpOperand);
+  AffineMap consumerResultIndexMap = consumer.getTiedIndexingMap(fusedOperand);
   // tensor index -> producer loop
   AffineMap invProducerResultIndexMap =
       inversePermutation(producerResultIndexMap);
@@ -345,19 +349,9 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
       invProducerResultIndexMap.compose(consumerResultIndexMap);
 
   generateFusedElementwiseOpRegion(rewriter, fusedOp,
-                                   consumerToProducerLoopsMap,
-                                   consumerOpOperand, consumer.getNumLoops());
-  return SmallVector<Value>(fusedOp->getResults());
-}
-
-static Optional<SmallVector<Value>>
-fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
-                   GenericOp producer, const ControlFusionFn &controlFn) {
-  if (producer->getNumResults() != 1)
-    return llvm::None;
-
-  return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
-                                rewriter);
+                                   consumerToProducerLoopsMap, fusedOperand,
+                                   consumer.getNumLoops());
+  return fusedOp.getOperation();
 }
 
 namespace {
@@ -373,14 +367,16 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
-      auto producer =
-          dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
-      if (!producer || !producer.hasTensorSemantics())
+      if (!areElementwiseOpsFusable(opOperand))
+        continue;
+      if (!controlFn(opOperand))
         continue;
-      Optional<SmallVector<Value>> fusedOpResults =
-          fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
-      if (fusedOpResults) {
-        rewriter.replaceOp(genericOp, *fusedOpResults);
+
+      FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, opOperand);
+      if (succeeded(fusedOp)) {
+        auto replacements = fusedOp.getValue()->getResults().take_back(
+            genericOp.getNumResults());
+        rewriter.replaceOp(genericOp, replacements);
         return success();
       }
     }
@@ -713,6 +709,10 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
       }));
 
+  // Set insertion point to the generic op.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(genericOp);
+
   SmallVector<Value> expandedOpOperands;
   expandedOpOperands.reserve(genericOp.getNumInputs());
   for (OpOperand *opOperand : genericOp.getInputOperands()) {
@@ -792,7 +792,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   SmallVector<Value> resultVals;
   for (OpResult opResult : genericOp->getOpResults()) {
     int64_t resultNumber = opResult.getResultNumber();
-    if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
+    if (resultTypes[resultNumber] != opResult.getType()) {
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(
               genericOp.getTiedIndexingMap(
@@ -834,7 +834,7 @@ class FoldWithProducerReshapeOpByExpansion
       // - The tensor reshape op is folding.
       // - All constraints of fusing with reshape by expansion are met.
       if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
-          (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
+          (!controlFoldingReshapes(opOperand)))
         continue;
 
       Optional<SmallVector<Value>> replacementValues =
@@ -865,18 +865,50 @@ struct FoldReshapeWithGenericOpByExpansion
   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
     // Fold only if all constraints of fusing with reshape by expansion are met.
-    GenericOp producer = reshapeOp.getSrc().getDefiningOp<GenericOp>();
-    if (!producer || producer.getNumOutputs() != 1 ||
-        !isFusableWithReshapeByDimExpansion(producer,
-                                            producer.getOutputOperand(0)) ||
-        !controlFoldingReshapes(producer->getResult(0),
-                                reshapeOp->getOpOperand(0)))
-      return failure();
+    auto producerResult = reshapeOp.getSrc().dyn_cast<OpResult>();
+    if (!producerResult) {
+      return rewriter.notifyMatchFailure(reshapeOp,
+                                         "source not produced by an operation");
+    }
+
+    auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
+    if (!producer) {
+      return rewriter.notifyMatchFailure(reshapeOp,
+                                         "producer not a generic op");
+    }
+
+    if (!isFusableWithReshapeByDimExpansion(
+            producer,
+            producer.getOutputOperand(producerResult.getResultNumber()))) {
+      return rewriter.notifyMatchFailure(
+          reshapeOp, "failed preconditions of fusion with producer generic op");
+    }
+
+    if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) {
+      return rewriter.notifyMatchFailure(reshapeOp,
+                                         "fusion blocked by control function");
+    }
+
     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
-        producer, reshapeOp, producer.getOutputOperand(0), rewriter);
-    if (!replacementValues)
-      return failure();
-    rewriter.replaceOp(reshapeOp, *replacementValues);
+        producer, reshapeOp,
+        producer.getOutputOperand(producerResult.getResultNumber()), rewriter);
+    if (!replacementValues) {
+      return rewriter.notifyMatchFailure(reshapeOp,
+                                         "fusion by expansion failed");
+    }
+
+    // Find the replacement for the reshape op. Since the replacements have the
+    // same type as the returns of the original generic op, the consumer reshape
+    // op can be replaced by the source of the collapse_shape op that defines
+    // the replacement.
+    Value reshapeReplacement = (*replacementValues)
+        [reshapeOp.getSrc().cast<OpResult>().getResultNumber()];
+    if (auto collapseOp =
+            reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
+      reshapeReplacement = collapseOp.getSrc();
+    }
+    rewriter.replaceOp(reshapeOp, reshapeReplacement);
+    rewriter.replaceOp(producer, *replacementValues);
     return success();
   }
 
@@ -1469,7 +1501,7 @@ class FoldWithProducerReshapeOpByCollapsing
           getCollapsableIterationSpaceDims(genericOp, opOperand,
                                            reshapeOp.getReassociationIndices());
       if (collapsableIterationDims.empty() ||
-          !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
+          !controlFoldingReshapes(opOperand)) {
         continue;
       }
 
@@ -1726,9 +1758,9 @@ struct LinalgElementwiseOpFusionPass
     RewritePatternSet patterns(context);
 
     // Add folding with reshape by expansion patterns.
-    ControlFusionFn defaultControlFn = [](const OpResult &producer,
-                                          const OpOperand &consumer) {
-      return producer.hasOneUse();
+    ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
+      Operation *producer = fusedOperand->get().getDefiningOp();
+      return producer && producer->hasOneUse();
     };
 
     // Add elementwise op fusion patterns.

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index b7c91c0aebad0..3002d447fb68c 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -662,12 +662,12 @@ func.func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) ->
     linalg.yield %r : f64
   } -> tensor<1x8xf64>
 
-  // CHECK-NEXT:   %[[R:.*]] = linalg.generic
+  // CHECK-NEXT:   %[[R:.*]]:2 = linalg.generic
   //      CHECK:     bb0(%[[BBA:[0-9a-z]*]]: f64, %[[BBB:[0-9a-z]*]]: i32):
   // CHECK-NEXT:       %[[A:.*]] = func.call @compute1(%[[BBA]]) : (f64) -> f64
   // CHECK-NEXT:       %[[B:.*]] = func.call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32
-  // CHECK-NEXT:       linalg.yield %[[B]] : i32
-  // CHECK-NEXT:   } -> tensor<1x8xi32>
+  // CHECK-NEXT:       linalg.yield %[[A]], %[[B]] : f64, i32
+  // CHECK-NEXT:   } -> (tensor<1x8xf64>, tensor<1x8xi32>)
   %1 = linalg.generic {
     indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
     iterator_types = ["parallel", "parallel"]}
@@ -678,7 +678,7 @@ func.func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) ->
     linalg.yield %r : i32
   } -> tensor<1x8xi32>
 
-  // CHECK-NEXT:   return %[[R]] : tensor<1x8xi32>
+  // CHECK-NEXT:   return %[[R]]#1 : tensor<1x8xi32>
   return %1 : tensor<1x8xi32>
 }
 
@@ -948,7 +948,7 @@ func.func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -
 
 // -----
 
-func.func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> {
+func.func @fusion_
diff erent_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> {
   %c1_i32 = arith.constant 1 : i32
   %0 = linalg.generic {
         indexing_maps = [affine_map<(d0) -> (d0)>],
@@ -971,10 +971,25 @@ func.func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) ->
         } -> tensor<5000xi32>
   return %2 : tensor<5000xi32>
 }
-// CHECK-LABEL: func @illegal_fusion(
-//       CHECK:   %[[PRODUCER:.+]] = linalg.generic
-//       CHECK:    linalg.generic
-//   CHECK-SAME:       ins(%[[PRODUCER]]
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
+//      CHECK: func @fusion_
diff erent_axes(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<5000xi64>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<5000xi32>
+//  CHECK-DAG:   %[[INIT0:.+]] = linalg.init_tensor [5000] : tensor<5000xi64>
+//  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [5000] : tensor<5000xi32>
+//      CHECK:   %[[RESULT:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:       outs(%[[INIT0]], %[[INIT1]] :
+// CHECK-NEXT:   ^bb0(
+// CHECK-SAME:       %[[B0:.+]]: i64
+// CHECK-SAME:       %[[B1:.+]]: i32
+//  CHECK-DAG:     %[[T0:.+]] = linalg.index 0
+//  CHECK-DAG:     %[[CAST1:.+]] = arith.index_cast %[[T0]] : index to i64
+//  CHECK-DAG:     %[[CAST2:.+]] = arith.index_cast %[[CAST1]] : i64 to index
+//      CHECK:     %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[CAST2]]]
+//      CHECK:     linalg.yield %[[CAST1]], %[[EXTRACT]]
+//      CHECK:   return %[[RESULT]]#1
 
 // -----
 
@@ -995,7 +1010,7 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
   %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
   ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
     %5 = arith.addf  %arg1, %arg2 : f32
-	linalg.yield %5 : f32
+        linalg.yield %5 : f32
   } -> tensor<?xf32>
   return %4 : tensor<?xf32>
 }
@@ -1024,7 +1039,50 @@ func.func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?x
   %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
   ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
     %8 = arith.divf  %arg1, %arg2 : f32
-	linalg.yield %8 : f32
+        linalg.yield %8 : f32
   } -> tensor<?x?xf32>
   return %7 : tensor<?x?xf32>
 }
+
+// -----
+
+#map = affine_map<() -> ()>
+module {
+  func.func @fuse_multi_result_producer(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<f32> {
+    %0 = linalg.init_tensor [] : tensor<f32>
+    %1 = linalg.init_tensor [] : tensor<f32>
+    %2:2 = linalg.generic {
+      indexing_maps = [#map, #map, #map, #map, #map], iterator_types = []}
+      ins(%arg0, %arg1, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>) outs(%0, %1 : tensor<f32>, tensor<f32>) {
+    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, %arg9: f32):
+      %4 = arith.addf %arg5, %arg6 : f32
+      %5 = arith.addf %4, %arg7 : f32
+      linalg.yield %4, %5 : f32, f32
+    } -> (tensor<f32>, tensor<f32>)
+    %3 = linalg.generic {
+      indexing_maps = [#map, #map, #map], iterator_types = []}
+      ins(%2#1, %arg1 : tensor<f32>, tensor<f32>) outs(%arg4 : tensor<f32>) {
+    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
+      %4 = arith.addf %arg5, %arg6 : f32
+      %5 = arith.addf %4, %arg6 : f32
+      linalg.yield %5 : f32
+    } -> tensor<f32>
+    return %3 : tensor<f32>
+  }
+}
+// CHECK-LABEL: func.func @fuse_multi_result_producer
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<f32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<f32>
+//       CHECK:   %[[INIT:.+]] = linalg.init_tensor
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
+//  CHECK-SAME:       outs(%[[INIT]] :
+//  CHECK-NEXT:     ^bb0
+//  CHECK-SAME:         %[[B0:[a-zA-Z0-9]+]]: f32
+//  CHECK-SAME:         %[[B1:[a-zA-Z0-9]+]]: f32
+//   CHECK-DAG:     %[[T0:.+]] = arith.addf %[[B0]], %[[B1]]
+//   CHECK-DAG:     %[[T1:.+]] = arith.addf %[[T0]], %[[B1]]
+//   CHECK-DAG:     %[[T2:.+]] = arith.addf %[[T1]], %[[B1]]
+//   CHECK-DAG:     %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
+//       CHECK:     linalg.yield %[[T3]] : f32
+//       CHECK:   return %[[GENERIC]]

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
index 1efaaa3b75a84..b0517288bcffc 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file -canonicalize | FileCheck %s --check-prefix=CANONICALIZE
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #binary2Dpointwise = {
@@ -58,5 +59,17 @@ func.func @test_fusion_limit(
 //  CHECK-SAME:   %[[ARG5:[a-zA-z0-9_]+]]: tensor<?x?xf32>
 //       CHECK:   %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]]
 //       CHECK:   %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]]
-//       CHECK:   %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
-//       CHECK:   return %[[OP3]]
+//       CHECK:   %[[OP3:.+]]:2 = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
+//       CHECK:   return %[[OP3]]#1
+
+// CANONICALIZE-LABEL: func @test_fusion_limit
+//  CANONICALIZE-SAME:   %[[ARG0:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CANONICALIZE-SAME:   %[[ARG1:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CANONICALIZE-SAME:   %[[ARG2:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CANONICALIZE-SAME:   %[[ARG3:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CANONICALIZE-SAME:   %[[ARG4:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CANONICALIZE-SAME:   %[[ARG5:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//       CANONICALIZE:   %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]]
+//       CANONICALIZE:   %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]]
+//       CANONICALIZE:   %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
+//       CANONICALIZE:   return %[[OP3]]

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 45e8721278e1c..e151b99b66ac2 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -499,3 +499,43 @@ func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+func.func @reshape_as_consumer_permutation_with_multiple_results
+  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>)
+    -> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
+  %c:2 = linalg.generic {
+         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
+                          affine_map<(d0, d1, d2) -> (d1, d2)>,
+                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>,
+                          affine_map<(d0, d1, d2) -> (d2, d0, d1)>],
+         iterator_types = ["parallel", "parallel", "parallel"]}
+          ins(%a, %b : tensor<?x?x?xf32>, tensor<?x?xf32>)
+         outs(%a, %a : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+       ^bb0(%arg0 : f32, %arg1: f32, %s: f32, %t : f32):
+         %1 = arith.addf %arg0, %arg1 : f32
+         linalg.yield %1, %1 : f32, f32
+       } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]]
+       : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+  %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]]
+       : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
+  return %d, %e : tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d1, d2, d3, d4)>
+//      CHECK: func @reshape_as_consumer_permutation_with_multiple_results
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-DAG:  %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3, 4], [5]{{\]}}
+//  CHECK-DAG:  %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1, 2], [3]{{\]}}
+//  CHECK-DAG:  %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3, 4, 5]{{\]}}
+//  CHECK-DAG:  %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4, 5]{{\]}}
+//      CHECK:  %[[GENERIC:.+]]:2 = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME:      ins(%[[RESHAPE0]], %[[RESHAPE1]] :
+// CHECK-SAME:      outs(%[[RESHAPE2]], %[[RESHAPE3]] :
+//      CHECK:  return %[[GENERIC]]#0, %[[GENERIC]]#1

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 1b046b9b2b981..41e46d0d69206 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -35,14 +35,18 @@ static void addOperands(Operation *op, SetVector<Value> &operandSet) {
 }
 
 template <int limit = 3>
-static bool setFusedOpOperandLimit(const OpResult &producer,
-                                   const OpOperand &consumer) {
+static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
+  Operation *producer = fusedOperand->get().getDefiningOp();
+  if (!producer)
+    return false;
+
+  Operation *consumer = fusedOperand->getOwner();
   SetVector<Value> fusedOpOperands;
-  if (producer.getOwner()->getNumResults() != 1)
+  if (producer->getNumResults() != 1)
     return false;
-  addOperands(consumer.getOwner(), fusedOpOperands);
-  fusedOpOperands.remove(producer);
-  addOperands(producer.getOwner(), fusedOpOperands);
+  addOperands(consumer, fusedOpOperands);
+  fusedOpOperands.remove(producer->getResult(0));
+  addOperands(producer, fusedOpOperands);
   return fusedOpOperands.size() <= limit;
 }
 
@@ -113,8 +117,7 @@ struct TestLinalgElementwiseFusion
     if (fuseWithReshapeByExpansion) {
       RewritePatternSet fusionPatterns(context);
       linalg::populateFoldReshapeOpsByExpansionPatterns(
-          fusionPatterns, [](const OpResult & /*producer*/,
-                             OpOperand & /*consumer*/) { return true; });
+          fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; });
       if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
                                               std::move(fusionPatterns))))
         return signalPassFailure();
@@ -125,15 +128,19 @@ struct TestLinalgElementwiseFusion
       RewritePatternSet fusionPatterns(context);
 
       linalg::ControlFusionFn controlReshapeFusionFn =
-          [](const OpResult &producer, OpOperand &consumer) {
-            if (auto collapseOp =
-                    producer.getDefiningOp<tensor::CollapseShapeOp>()) {
+          [](OpOperand *fusedOperand) {
+            Operation *producer = fusedOperand->get().getDefiningOp();
+            if (!producer)
+              return false;
+
+            if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(producer)) {
               if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) {
                 return false;
               }
             }
-            if (auto expandOp =
-                    dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
+
+            Operation *consumer = fusedOperand->getOwner();
+            if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(consumer)) {
               if (expandOp->hasOneUse()) {
                 OpOperand &use = *expandOp->getUses().begin();
                 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
@@ -155,18 +162,17 @@ struct TestLinalgElementwiseFusion
     if (fuseWithReshapeByCollapsing) {
       RewritePatternSet patterns(context);
       linalg::populateFoldReshapeOpsByCollapsingPatterns(
-          patterns, [](const OpResult & /*producer*/,
-                       OpOperand & /*consumer*/) { return true; });
+          patterns, [](OpOperand * /*fusedOperand */) { return true; });
       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
     }
 
     if (fuseWithReshapeByCollapsingWithControlFn) {
       RewritePatternSet patterns(context);
-      linalg::ControlFusionFn controlFn = [](const OpResult &producer,
-                                             OpOperand &consumer) -> bool {
-        if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) {
+      linalg::ControlFusionFn controlFn = [](OpOperand *fusedOperand) -> bool {
+        Operation *producer = fusedOperand->get().getDefiningOp();
+        if (isa<tensor::ExpandShapeOp>(producer)) {
           // Skip fusing the first operand.
-          return consumer.getOperandNumber();
+          return fusedOperand->getOperandNumber();
         }
         return true;
       };


        


More information about the Mlir-commits mailing list