[Mlir-commits] [mlir] [mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. (PR #127943)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 12 22:11:29 PDT 2025


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/127943

>From 1b675e9199d107ec091ec725fd0ef820d40807b1 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mravisha at amd.com>
Date: Mon, 17 Feb 2025 21:03:56 -0600
Subject: [PATCH 1/2] [mlir][Linalg] Allow expand shape propagation across
 linalg ops with dynamic shapes.

With `tensor.expand_shape` allowing expanding dynamic dimension into
multiple dynamic dimension, adapt the reshape propagation through
expansion to handle cases where one dynamic dimension is expanded into
multiple dynamic dimension.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 186 +++++------
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 296 ++++++------------
 2 files changed, 177 insertions(+), 305 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 33667e7ab0c5c..cfc5b25fa87a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include <optional>
 #include <utility>
 
@@ -590,18 +591,17 @@ class ExpansionInfo {
   // the expanded op.
   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
                         ArrayRef<AffineMap> reassociationMaps,
-                        ArrayRef<int64_t> expandedShape,
-                        ArrayRef<int64_t> collapsedShape,
+                        ArrayRef<OpFoldResult> expandedShape,
                         PatternRewriter &rewriter);
   unsigned getOrigOpNumDims() const { return reassociation.size(); }
   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
   ReassociationIndicesRef getExpandedDims(unsigned i) const {
     return reassociation[i];
   }
-  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
+  ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
     return expandedShapeMap[i];
   }
-  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
+  ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
 
 private:
   /// Reassociation from the dimensions in the original operation to the
@@ -609,9 +609,9 @@ class ExpansionInfo {
   SmallVector<ReassociationIndices> reassociation;
   /// Mapping from extent of loops in the original operation, to the extent of
   /// loops in the expanded operation.
-  SmallVector<SmallVector<int64_t>> expandedShapeMap;
+  SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
   /// Extent of the loop in the original operation.
-  SmallVector<int64_t> originalLoopExtent;
+  SmallVector<OpFoldResult> originalLoopExtent;
   unsigned expandedOpNumDims;
 };
 } // namespace
@@ -619,15 +619,17 @@ class ExpansionInfo {
 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
                                      OpOperand *fusableOpOperand,
                                      ArrayRef<AffineMap> reassociationMaps,
-                                     ArrayRef<int64_t> expandedShape,
-                                     ArrayRef<int64_t> collapsedShape,
+                                     ArrayRef<OpFoldResult> expandedShape,
                                      PatternRewriter &rewriter) {
   if (reassociationMaps.empty())
     return failure();
   AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
 
-  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
-  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(linalgOp);
+  originalLoopExtent = llvm::map_to_vector(
+      linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
+      [](Range r) { return r.size; });
 
   reassociation.clear();
   expandedShapeMap.clear();
@@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
     unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
     numExpandedDims[pos] = foldedDims.getNumResults();
-    ArrayRef<int64_t> shape =
+    ArrayRef<OpFoldResult> shape =
         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
     expandedShapeMap[pos].assign(shape.begin(), shape.end());
   }
@@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
   return success();
 }
 
-/// Expanding the body of a linalg operation requires adaptations of the
-/// accessed loop indices. Specifically, access of indices in the original
-/// operation need to be replaced with linearizations of indices in the expanded
-/// op. That requires the shape of the expanded dimensions to be static (at
-/// least all but the most significant). For now check that these are all
-/// statically sized. Note that this could be extended to handle dynamic case,
-/// but the implementation below uses `affine.apply` which seems to have issues
-/// when the shapes are not static.
-static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
-                                          const ExpansionInfo &expansionInfo,
-                                          PatternRewriter &rewriter) {
-  if (!linalgOp.hasIndexSemantics())
-    return success();
-  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
-    ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
-    if (expandedShape.size() == 1)
-      continue;
-    for (int64_t shape : expandedShape.drop_front()) {
-      if (ShapedType::isDynamic(shape)) {
-        return rewriter.notifyMatchFailure(
-            linalgOp, "cannot expand due to index semantics and dynamic dims");
-      }
-    }
-  }
-  return success();
-}
-
 /// Return the indexing map to use in the expanded op for a given the
 /// `indexingMap` of the original operation.
 static AffineMap
@@ -708,16 +683,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
 
 /// Return the type of the operand/result to use in the expanded op given the
 /// type in the original op.
-static RankedTensorType getExpandedType(RankedTensorType originalType,
-                                        AffineMap indexingMap,
-                                        const ExpansionInfo &expansionInfo) {
-  SmallVector<int64_t> expandedShape;
+static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
+getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
+                        const ExpansionInfo &expansionInfo) {
+  SmallVector<int64_t> expandedStaticShape;
+  SmallVector<OpFoldResult> expandedShape;
   for (AffineExpr expr : indexingMap.getResults()) {
     unsigned dim = cast<AffineDimExpr>(expr).getPosition();
-    auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
+    ArrayRef<OpFoldResult> dimExpansion =
+        expansionInfo.getExpandedShapeOfDim(dim);
+    llvm::append_range(expandedStaticShape,
+                       llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
+                         std::optional<int64_t> staticShape =
+                             getConstantIntValue(ofr);
+                         if (staticShape) {
+                           return staticShape.value();
+                         }
+                         return ShapedType::kDynamic;
+                       }));
     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
   }
-  return RankedTensorType::get(expandedShape, originalType.getElementType());
+  return {expandedShape, RankedTensorType::get(expandedStaticShape,
+                                               originalType.getElementType())};
 }
 
 /// Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -765,49 +752,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
     // Linearize the expanded indices of the original index dimension.
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.setInsertionPointAfter(indexOp);
-    ArrayRef<int64_t> expandedDimsShape =
+    ArrayRef<OpFoldResult> expandedDimsShape =
         expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
     SmallVector<Value> expandedIndices;
     expandedIndices.reserve(expandedDims.size() - 1);
     llvm::transform(
         expandedDims.drop_front(), std::back_inserter(expandedIndices),
         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
-    Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
+    OpFoldResult newIndex =
+        rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
-      assert(!ShapedType::isDynamic(std::get<0>(it)));
-      AffineExpr idx, acc;
+      AffineExpr idx, acc, shape;
       bindDims(rewriter.getContext(), idx, acc);
-      newIndex = rewriter.create<affine::AffineApplyOp>(
-          indexOp.getLoc(), idx + acc * std::get<0>(it),
-          ValueRange{std::get<1>(it), newIndex});
-    }
-    rewriter.replaceOp(indexOp, newIndex);
-  }
-}
-
-/// Checks if a single dynamic dimension expanded into multiple dynamic
-/// dimensions.
-static LogicalResult
-validateDynamicDimExpansion(LinalgOp linalgOp,
-                            const ExpansionInfo &expansionInfo,
-                            PatternRewriter &rewriter) {
-  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
-    ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
-    if (expandedShape.size() == 1)
-      continue;
-    bool foundDynamic = false;
-    for (int64_t shape : expandedShape) {
-      if (!ShapedType::isDynamic(shape))
-        continue;
-      if (foundDynamic) {
-        return rewriter.notifyMatchFailure(
-            linalgOp, "cannot infer expanded shape with multiple dynamic "
-                      "dims in the same reassociation group");
-      }
-      foundDynamic = true;
+      bindSymbols(rewriter.getContext(), shape);
+      newIndex = affine::makeComposedFoldedAffineApply(
+          rewriter, indexOp.getLoc(), idx + acc * shape,
+          ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
     }
+    Value newIndexVal =
+        getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
+    rewriter.replaceOp(indexOp, newIndexVal);
   }
-  return success();
 }
 
 // Create an expanded transpose op.
@@ -910,31 +875,31 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
          "preconditions for fuse operation failed");
 
   Location loc = linalgOp.getLoc();
-  // Check if reshape is expanding or collapsing.
-  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
-  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
-  bool isExpanding = (expandingReshapeOp != nullptr);
-  RankedTensorType expandedType = isExpanding
-                                      ? expandingReshapeOp.getResultType()
-                                      : collapsingReshapeOp.getSrcType();
-  RankedTensorType collapsedType = isExpanding
-                                       ? expandingReshapeOp.getSrcType()
-                                       : collapsingReshapeOp.getResultType();
+  SmallVector<OpFoldResult> expandedShape, collapsedShape;
+  SmallVector<AffineMap, 4> reassociationIndices;
+  Value src;
+  if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
+    // Try to move the dynamic dimensions in output shape before the `linalgOp`
+    // to maintain SSA validity
+    if (failed(moveValueDefinitions(
+            rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
+      return std::nullopt;
+
+    expandedShape = expandingReshapeOp.getMixedOutputShape();
+    reassociationIndices = expandingReshapeOp.getReassociationMaps();
+    src = expandingReshapeOp.getSrc();
+  } else {
+    auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
+    expandedShape = tensor::getMixedSizes(
+        rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
+    reassociationIndices = collapsingReshapeOp.getReassociationMaps();
+    src = collapsingReshapeOp.getSrc();
+  }
 
   ExpansionInfo expansionInfo;
-  if (failed(expansionInfo.compute(
-          linalgOp, fusableOpOperand,
-          isExpanding ? expandingReshapeOp.getReassociationMaps()
-                      : collapsingReshapeOp.getReassociationMaps(),
-          expandedType.getShape(), collapsedType.getShape(), rewriter)))
-    return std::nullopt;
-
-  // TODO: With the support of multiple dynamic dims expansion in
-  // tensor.expand_shape op, this case can be handled.
-  if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
-    return std::nullopt;
-
-  if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
+  if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
+                                   reassociationIndices, expandedShape,
+                                   rewriter)))
     return std::nullopt;
 
   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -950,15 +915,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
   expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
   for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
     if (opOperand == fusableOpOperand) {
-      expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
-                                               : collapsingReshapeOp.getSrc());
+      expandedOpOperands.push_back(src);
       continue;
     }
     if (auto opOperandType =
             dyn_cast<RankedTensorType>(opOperand->get().getType())) {
       AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
-      RankedTensorType expandedOperandType =
-          getExpandedType(opOperandType, indexingMap, expansionInfo);
+      SmallVector<OpFoldResult> expandedOperandShape;
+      RankedTensorType expandedOperandType;
+      std::tie(expandedOperandShape, expandedOperandType) =
+          getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
       if (expandedOperandType != opOperand->get().getType()) {
         // Reshape the operand to get the right type.
         SmallVector<ReassociationIndices> reassociation =
@@ -972,7 +938,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
                 /*isExpandingReshape=*/true)))
           return std::nullopt;
         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
-            loc, expandedOperandType, opOperand->get(), reassociation));
+            loc, expandedOperandType, opOperand->get(), reassociation,
+            expandedOperandShape));
         continue;
       }
     }
@@ -983,8 +950,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
   for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
     auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
-    RankedTensorType expandedOutputType =
-        getExpandedType(opOperandType, indexingMap, expansionInfo);
+    SmallVector<OpFoldResult> expandedOutputShape;
+    RankedTensorType expandedOutputType;
+    std::tie(expandedOutputShape, expandedOutputType) =
+        getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
     if (expandedOutputType != opOperand.get().getType()) {
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(indexingMap, expansionInfo);
@@ -997,7 +966,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
               /*isExpandingReshape=*/true)))
         return std::nullopt;
       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
-          loc, expandedOutputType, opOperand.get(), reassociation));
+          loc, expandedOutputType, opOperand.get(), reassociation,
+          expandedOutputShape));
     } else {
       outputs.push_back(opOperand.get());
     }
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 3244418d445b7..67b4f2b32bad5 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -30,20 +30,14 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
-//      CHECK:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:   %[[C3:.+]] = arith.constant 3 : index
 //      CHECK:   %[[C1:.+]] = arith.constant 1 : index
 //      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM_1]], %[[C4]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_4]], %[[C4]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x4x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x4x?xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x4x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -88,21 +82,9 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
 // CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
-//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -137,26 +119,9 @@ func.func @reshape_as_consumer_permutation
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index
-//      CHECK:   %[[C12:.+]] = arith.constant 12 : index
-//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index
-//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
-//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-//      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index
-//      CHECK:   %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[SZ2]], %[[SZ0]], 2, %[[SZ1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[SZ2]], %[[SZ1]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[SZ0]], 2, %[[SZ1]], 3, 4, %[[SZ2]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
@@ -258,7 +223,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
 }
 
 // Only check the body in the indexed version of the test.
-//       CHECK: #[[MAP:.+]] =  affine_map<(d0, d1) -> (d0 + d1 * 4)>
+//       CHECK: #[[MAP:.+]] =  affine_map<()[s0, s1] -> (s0 + s1 * 4)>
 //       CHECK: func @indexed_consumer_reshape_producer_fusion
 //       CHECK:   linalg.generic
 //       CHECK:   ^{{.*}}(
@@ -268,7 +233,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
 //   CHECK-DAG:     %[[IDX1:.+]] = linalg.index 1 : index
 //   CHECK-DAG:     %[[IDX2:.+]] = linalg.index 2 : index
 //   CHECK-DAG:     %[[IDX3:.+]] = linalg.index 3 : index
-//   CHECK-DAG:     %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX1]], %[[IDX0]])
+//   CHECK-DAG:     %[[T3:.+]] = affine.apply #[[MAP]]()[%[[IDX1]], %[[IDX0]]]
 //       CHECK:     %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
 //       CHECK:     %[[T5:.+]] = arith.index_cast %[[T3]]
 //       CHECK:     %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
@@ -307,8 +272,7 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
 }
 
 // Only check the body in the indexed version of the test.
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
-//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 5 + s1 * 20 + s2)>
 //       CHECK: func @indexed_producer_reshape_consumer_fusion
 //       CHECK:   linalg.generic
 //       CHECK:   ^{{.*}}(
@@ -318,12 +282,11 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
 //   CHECK-DAG:     %[[IDX1:.+]] = linalg.index 1 : index
 //   CHECK-DAG:     %[[IDX2:.+]] = linalg.index 2 : index
 //   CHECK-DAG:     %[[IDX3:.+]] = linalg.index 3 : index
-//       CHECK:     %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]])
-//       CHECK:     %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]])
+//       CHECK:     %[[T1:.+]] = affine.apply #[[MAP1]]()[%[[IDX2]], %[[IDX1]], %[[IDX3]]]
 //       CHECK:     %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
 //       CHECK:     %[[T5:.+]] = arith.index_cast %[[IDX0]]
 //       CHECK:     %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
-//       CHECK:     %[[T7:.+]] = arith.index_cast %[[T2]]
+//       CHECK:     %[[T7:.+]] = arith.index_cast %[[T1]]
 //       CHECK:     %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
 //       CHECK:     linalg.yield %[[T8]]
 
@@ -362,16 +325,15 @@ func.func @reshape_as_consumer_permutation
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
 //   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
 //   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
-//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
-//   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)>
-//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
+//   CHECK-DAG: #[[MAP4:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 7 + s1 * 42 + s2)>
 //       CHECK: func @reshape_as_consumer_permutation
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<210x6x4xi32>
 //  CHECK-SAME:   %[[ARG1:.+]]: tensor<210x4xi32>
 //   CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
 //       CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [5, 6, 7, 2, 3, 4] : tensor<210x6x4xi32> into tensor<5x6x7x2x3x4xi32>
 //       CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [5, 6, 7, 4] : tensor<210x4xi32> into tensor<5x6x7x4xi32>
-//       CHECK:   %[[T3:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
+//       CHECK:   %[[T3:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
 //       CHECK:   %[[T4:.+]] = linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 //  CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
@@ -385,13 +347,12 @@ func.func @reshape_as_consumer_permutation
 //   CHECK-DAG:       %[[IDX3:.+]] = linalg.index 3 : index
 //   CHECK-DAG:       %[[IDX4:.+]] = linalg.index 4 : index
 //   CHECK-DAG:       %[[IDX5:.+]] = linalg.index 5 : index
-//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]])
-//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]])
-//   CHECK-DAG:       %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]])
+//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP3]]()[%[[IDX1]], %[[IDX0]]]
+//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP4]]()[%[[IDX3]], %[[IDX2]], %[[IDX4]]]
 //   CHECK-DAG:       %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
 //       CHECK:       %[[T9:.+]] = arith.index_cast %[[T5]]
 //       CHECK:       %[[T10:.+]] = arith.addi %[[T8]], %[[T9]]
-//       CHECK:       %[[T11:.+]] = arith.index_cast %[[T7]]
+//       CHECK:       %[[T11:.+]] = arith.index_cast %[[T6]]
 //       CHECK:       %[[T12:.+]] = arith.addi %[[T10]], %[[T11]]
 //       CHECK:       %[[T13:.+]] = arith.index_cast %[[IDX5]]
 //       CHECK:       %[[T14:.+]] = arith.addi %[[T12]], %[[T13]]
@@ -426,7 +387,7 @@ func.func @reshape_as_producer_projected_permutation(
 
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 //   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 8)>
 //       CHECK: @reshape_as_producer_projected_permutation
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<33x8x?xi32>
 //       CHECK:   %[[RES:.+]] = linalg.generic
@@ -439,7 +400,7 @@ func.func @reshape_as_producer_projected_permutation(
 //   CHECK-DAG:       %[[IDX1:.+]] = linalg.index 1 : index
 //   CHECK-DAG:       %[[IDX2:.+]] = linalg.index 2 : index
 //   CHECK-DAG:       %[[IDX3:.+]] = linalg.index 3 : index
-//   CHECK-DAG:       %[[T0:.+]] = affine.apply #[[MAP2]](%[[IDX1]], %[[IDX0]])
+//   CHECK-DAG:       %[[T0:.+]] = affine.apply #[[MAP2]]()[%[[IDX1]], %[[IDX0]]]
 //       CHECK:       %[[T1:.+]] = arith.index_cast %[[T0]] : index to i32
 //       CHECK:       %[[T2:.+]] = arith.addi %[[ARG1]], %[[T1]] : i32
 //       CHECK:       %[[T3:.+]] = arith.index_cast %[[IDX2]] : index to i32
@@ -481,21 +442,9 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index
-//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_1]], %[[C20]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_1]], 4, 5, %[[DIM_2]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
-//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[SZ0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[SZ0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -528,9 +477,10 @@ func.func @fuse_collapse_reduction(%arg0: tensor<10x10x20xf32>) -> tensor<100xf3
 // CHECK-SAME:       ins(%[[ARG0]] : tensor<10x10x20xf32>)
 //      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
 //      CHECK:   return %[[COLLAPSE]]
+
 // -----
 
-func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
+func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
   %c0 = arith.constant 0 : index
   %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
   %1 = tensor.dim %0, %c0 : tensor<?xf32>
@@ -546,39 +496,21 @@ func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
   return %3 : tensor<?xf32>
 }
 
-//      CHECK: func @no_fuse_dynamic_dims
+//      CHECK: func @fuse_dynamic_dims
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
+//      CHECK:   %[[EMPTY:.+]] = tensor.empty
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//      CHECK:   %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[EMPTY]] {{\[}}[0, 1]{{\]}}
+// CHECK-SAME:       output_shape [%[[D0]], %[[D1]]]
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME:       ins(%[[RESHAPE]] : tensor<?xf32>)
-//      CHECK:   return %[[GENERIC]]
-
-// -----
-
-func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi64>) -> tensor<2xi64> {
-  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64>
-  %1 = tensor.empty() : tensor<2xi64>
-  %2 = linalg.generic
-    {indexing_maps = [affine_map<(d0) -> (d0)>,
-                      affine_map<(d0) -> (d0)>,
-                      affine_map<(d0) -> (d0)>],
-     iterator_types = ["parallel"]}
-    ins(%0, %arg1 : tensor<2xi64>, tensor<?xi64>)
-    outs(%1 : tensor<2xi64>) {
-  ^bb0(%arg4: i64, %arg5: i64, %arg6: i64):
-    %3 = arith.addi %arg4, %arg5 : i64
-    linalg.yield %3 : i64
-  } -> tensor<2xi64>
-  return %2 : tensor<2xi64>
-}
-
-//      CHECK: func @no_fuse_mismatched_dynamism
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<2x1xi64>
-// CHECK-SAME:     %[[ARG1:.+]]: tensor<?xi64>
-//      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
-//      CHECK:   %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
-//      CHECK:   return %[[GENERIC]]
+// CHECK-SAME:       ins(%[[ARG0]] :
+// CHECK-SAME:       outs(%[[EXPAND_SHAPE]] :
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]{{\]}}
+//      CHECK:   return %[[COLLAPSE]]
 
 // -----
 
@@ -610,32 +542,10 @@ func.func @reshape_as_consumer_permutation_with_multiple_results
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index
-//       CHECK:   %[[C12:.+]] = arith.constant 12 : index
-//       CHECK:   %[[C2:.+]] = arith.constant 2 : index
-//       CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//       CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-//       CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-//       CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-//       CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index
-//       CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index
-//       CHECK:   %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
-//       CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//       CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//       CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index
-//       CHECK:   %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
-//       CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-//       CHECK:   %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-//       CHECK:   %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-//       CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index
-//       CHECK:   %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index
-//       CHECK:   %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
-//       CHECK:   %[[DIM_9:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-//       CHECK:   %[[DIM_10:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-//       CHECK:   %[[DIM_11:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-//       CHECK:   %[[VAL_5:.+]] = arith.divsi %[[DIM_10]], %[[C2]] : index
-//       CHECK:   %[[VAL_6:.+]] = arith.divsi %[[DIM_11]], %[[C12]] : index
-//       CHECK:   %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[DIM_9]], %[[VAL_5]], 2, 3, 4, %[[VAL_6]]] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
+//       CHECK:   %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[SZ2]], %[[SZ4]], 2, %[[SZ3]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
+//       CHECK:   %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[SZ2]], %[[SZ3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
+//       CHECK:   %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[SZ4]], 2, %[[SZ3]], 3, 4, %[[SZ2]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+//       CHECK:   %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[SZ3]], %[[SZ4]], 2, 3, 4, %[[SZ2]]] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
 //       CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
 //  CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
 //  CHECK-SAME:      ins(%[[RESHAPE0]], %[[RESHAPE1]] :
@@ -710,17 +620,10 @@ func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[DIM]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "reduction"]
@@ -760,21 +663,12 @@ func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C7:.+]] = arith.constant 7 : index
-//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
 //      CHECK:   %[[C2:.+]] = arith.constant 2 : index
 //      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x4x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x4x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[VAL_0]], 8, 4, %[[VAL_1]], 7] : tensor<?x4x?xf32> into tensor<?x8x4x?x7xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
-//      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C7]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x7x?x8xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x7x?x8xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[DIM_0]], 8, 4, %[[DIM]], 7] : tensor<?x4x?xf32> into tensor<?x8x4x?x7xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_0]], 8, %[[DIM]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "reduction", "parallel", "parallel"]
@@ -807,21 +701,9 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], %[[VAL_0]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-//      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
 //      CHECK:   %[[T4:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -848,20 +730,12 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
-//      CHECK:   %[[C7:.+]] = arith.constant 7 : index
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
 //      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
-//      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x7x?x8xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x7x?x8xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -888,15 +762,11 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 //      CHECK: func @linalg_copy_reshape_producer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
-//      CHECK:   %[[C7:.+]] = arith.constant 7 : index
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
 //      CHECK:   %[[T2:.+]] = linalg.copy
 // CHECK-SAME:     ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
 // CHECK-SAME:     outs(%[[T1]] : tensor<?x7x?x8xf32>)
@@ -907,7 +777,6 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 
 // -----
 
-
 func.func @reshape_as_producer_transpose
   (%a :  tensor<4x5x6x7x2x3xf32>)
     -> tensor<6x4x210xf32> {
@@ -991,3 +860,36 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
 //      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
 // CHECK-SAME:       : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
 //      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
+    %arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> {
+  %c0 = arith.constant 0 : index
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%arg0 : tensor<?x128xf16>)
+      outs(%empty : tensor<4x?x32x128xf16>) {
+    ^bb0(%b0: f16, %b1 : f16) :
+      %iv0 = linalg.index 0 : index
+      %iv1 = linalg.index 1 : index
+      %iv2 = linalg.index 2 : index
+      %iv3 = linalg.index 3 : index
+      %1 = tensor.extract %arg1[%iv0, %iv1, %iv2, %iv3] : tensor<4x?x32x128xf16>
+      %2 = arith.addf %1, %b0 : f16
+      linalg.yield %2 : f16
+  } -> tensor<4x?x32x128xf16>
+  %1 = tensor.dim %arg0, %c0 : tensor<?x128xf16>
+  %2 = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [4, %1, 32, 8, 16]
+      : tensor<4x?x32x128xf16> into tensor<4x?x32x8x16xf16>
+  func.return %2 : tensor<4x?x32x8x16xf16>
+}
+//      CHECK: func @move_operand_deps(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x128xf16>
+//  CHECK-DAG:   %[[MOVED_OP:.+]] = tensor.dim %[[ARG0]]
+//  CHECK-DAG:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[EXPANDED]] :
+//      CHECK:   return %[[GENERIC]]

>From dbaa97a091a47a724812be410dc5798ee61c0fc2 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 12 Mar 2025 22:11:01 -0700
Subject: [PATCH 2/2] Address comments.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 25 ++++++++-----------
 1 file changed, 11 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index cfc5b25fa87a1..afeb162a71e31 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -681,28 +681,21 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
                         builder.getContext());
 }
 
-/// Return the type of the operand/result to use in the expanded op given the
-/// type in the original op.
+/// Return the shape and type of the operand/result to use in the expanded op
+/// given the type in the original op.
 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
 getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
                         const ExpansionInfo &expansionInfo) {
-  SmallVector<int64_t> expandedStaticShape;
   SmallVector<OpFoldResult> expandedShape;
   for (AffineExpr expr : indexingMap.getResults()) {
     unsigned dim = cast<AffineDimExpr>(expr).getPosition();
     ArrayRef<OpFoldResult> dimExpansion =
         expansionInfo.getExpandedShapeOfDim(dim);
-    llvm::append_range(expandedStaticShape,
-                       llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
-                         std::optional<int64_t> staticShape =
-                             getConstantIntValue(ofr);
-                         if (staticShape) {
-                           return staticShape.value();
-                         }
-                         return ShapedType::kDynamic;
-                       }));
     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
   }
+  SmallVector<int64_t> expandedStaticShape;
+  std::tie(expandedStaticShape, std::ignore) =
+      decomposeMixedValues(expandedShape);
   return {expandedShape, RankedTensorType::get(expandedStaticShape,
                                                originalType.getElementType())};
 }
@@ -761,13 +754,14 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
     OpFoldResult newIndex =
         rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
-    for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
+    for (auto [expandedShape, expandedIndex] :
+         llvm::zip(expandedDimsShape, expandedIndices)) {
       AffineExpr idx, acc, shape;
       bindDims(rewriter.getContext(), idx, acc);
       bindSymbols(rewriter.getContext(), shape);
       newIndex = affine::makeComposedFoldedAffineApply(
           rewriter, indexOp.getLoc(), idx + acc * shape,
-          ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
+          ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
     }
     Value newIndexVal =
         getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
@@ -890,6 +884,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
     src = expandingReshapeOp.getSrc();
   } else {
     auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
+    if (!collapsingReshapeOp)
+      return std::nullopt;
+
     expandedShape = tensor::getMixedSizes(
         rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
     reassociationIndices = collapsingReshapeOp.getReassociationMaps();



More information about the Mlir-commits mailing list