[Mlir-commits] [mlir] [mlir][Linalg] Fix linalg.generic iteration domain collapse for dynamic dims (PR #118208)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 1 02:58:48 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

This pr fixes how iteration domain of linalg.generic is collapsed when fusing with tensor.expand_shape. Previously, the output_shape for tensor.expand shape was infered, which doesn't always work except some special cases.

This patch makes the logic explicitly set the bounds of the new collapsed iteration domain, because we already know them.

---

Patch is 24.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118208.diff


8 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+14-9) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp (+23-96) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+6-26) 
- (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+45-18) 
- (modified) mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (+3-4) 
- (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-5) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+2-6) 
- (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+2-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index c44194a1231588..fa730241203039 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1549,7 +1549,7 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
 /// value in the collapsed operation.
 void generateCollapsedIndexingRegion(Location loc, Block *block,
                                      const CollapsingInfo &collapsingInfo,
-                                     ValueRange loopRange,
+                                     ArrayRef<OpFoldResult> loopRange,
                                      RewriterBase &rewriter) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointToStart(block);
@@ -1571,10 +1571,12 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
     Value newIndexVal =
         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
+      Value loopDim =
+          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
       indexReplacementVals[dim] =
-          rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
+          rewriter.createOrFold<arith::RemUIOp>(loc, newIndexVal, loopDim);
       newIndexVal =
-          rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
+          rewriter.createOrFold<arith::DivUIOp>(loc, newIndexVal, loopDim);
     }
     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
   }
@@ -1721,14 +1723,13 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
   LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
 
   Location loc = op->getLoc();
+  SmallVector<OpFoldResult> loopBound =
+      llvm::map_to_vector(loopRanges, [&](Range range) { return range.size; });
+
   if (collapsedOp.hasIndexSemantics()) {
     // Collect the loop range of the generic op.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(collapsedOp);
-    SmallVector<Value> loopBound =
-        llvm::map_to_vector(loopRanges, [&](Range range) {
-          return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
-        });
     generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
                                     collapsingInfo, loopBound, rewriter);
   }
@@ -1746,15 +1747,19 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
+      SmallVector<OpFoldResult> resultShape =
+          applyPermutationMap(indexingMap, ArrayRef(loopBound));
       Value result;
       if (isa<MemRefType>(collapsedOpResult.getType())) {
         MemRefType expandShapeResultType = MemRefType::get(
             originalResultType.getShape(), originalResultType.getElementType());
         result = rewriter.create<memref::ExpandShapeOp>(
-            loc, expandShapeResultType, collapsedOpResult, reassociation);
+            loc, expandShapeResultType, collapsedOpResult, reassociation,
+            resultShape);
       } else {
         result = rewriter.create<tensor::ExpandShapeOp>(
-            loc, originalResultType, collapsedOpResult, reassociation);
+            loc, originalResultType, collapsedOpResult, reassociation,
+            resultShape);
       }
       results.push_back(result);
     } else {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebb88bf695d4c2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -16,24 +16,6 @@
 using namespace mlir;
 using namespace mlir::tensor;
 
-/// Compute a map that for a given dimension of the expanded type gives the
-/// dimension in the collapsed type it maps to. Essentially its the inverse of
-/// the `reassocation` maps.
-static llvm::DenseMap<int64_t, int64_t>
-getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
-  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
-  for (const auto &map : enumerate(reassociation)) {
-    unsigned startPos =
-        cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
-    unsigned endPos =
-        cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
-    for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
-      expandedDimToCollapsedDim[dim] = map.index();
-    }
-  }
-  return expandedDimToCollapsedDim;
-}
-
 /// For reshape op compute the shape at dimension `dimIndex` of the output in
 /// terms of shape of the `src`, when the reshape op is a collapsing
 /// operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,86 +58,33 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
       }));
 }
 
-/// For an expanding reshape op, compute the value for a dimension of the output
-/// from the shape of the input.
-static OpFoldResult getExpandedOutputDimFromInputShape(
-    OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
-    llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
-  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
-    // Static dimension: return Attribute.
-    return builder.getIndexAttr(dstStaticShape[dimIndex]);
-  }
-  unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
-  unsigned startPos =
-      cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
-          .getPosition();
-  unsigned endPos =
-      cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
-          .getPosition();
-  int64_t linearizedStaticDim = 1;
-  for (auto d :
-       llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
-    if (d.index() + startPos == static_cast<unsigned>(dimIndex))
-      continue;
-    assert(!ShapedType::isDynamic(d.value()) &&
-           "single dimension cannot be expanded into multiple dynamic "
-           "dimensions");
-    linearizedStaticDim *= d.value();
+struct ReifyCollapseShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<
+          ReifyCollapseShapeOp, CollapseShapeOp> {
+  LogicalResult
+  reifyResultShapes(Operation *op, OpBuilder &b,
+                    ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
+    auto loc = op->getLoc();
+    auto collapseShape = cast<CollapseShapeOp>(op);
+    reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
+        b, loc, collapseShape.getSrc(),
+        collapseShape.getResultType().getShape(),
+        collapseShape.getReassociationMaps()));
+    return success();
   }
-  OpFoldResult sourceDim =
-      builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
-
-  // Dynamic dimension: return Value.
-  return affine::makeComposedAffineApply(
-             builder, loc,
-             AffineMap::get(
-                 0, 1,
-                 builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
-             sourceDim)
-      ->getResult(0);
-}
-
-/// Given the `src` of an expanding reshape op, the reassociation maps and the
-/// result type, compute the shape of the result of the reshape.
-static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
-    OpBuilder &builder, Location loc, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
-  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
-      getExpandedDimToCollapsedDimMap(reassociation);
-  return llvm::to_vector<4>(llvm::map_range(
-      llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
-        return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
-                                                  dstStaticShape, reassociation,
-                                                  expandedDimToCollapsedDim);
-      }));
-}
-
-static SmallVector<OpFoldResult, 4>
-getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
-                                    ArrayRef<int64_t> dstStaticShape,
-                                    ArrayRef<AffineMap> reassocation) {
-  return dstStaticShape.size() >
-                 static_cast<size_t>(
-                     llvm::cast<ShapedType>(src.getType()).getRank())
-             ? getExpandedOutputShapeFromInputShape(
-                   builder, loc, src, dstStaticShape, reassocation)
-             : getCollapsedOutputShapeFromInputShape(
-                   builder, loc, src, dstStaticShape, reassocation);
-}
+};
 
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
-    : public ReifyRankedShapedTypeOpInterface::ExternalModel<
-          ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+struct ReifyExpandShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                             ExpandShapeOp> {
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
     auto loc = op->getLoc();
-    auto reshapeOp = cast<OpTy>(op);
-    reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
-        b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
-        reshapeOp.getReassociationMaps()));
+    auto expandShape = cast<ExpandShapeOp>(op);
+    SmallVector<OpFoldResult> outputShape = getMixedValues(
+        expandShape.getStaticOutputShape(), expandShape.getOutputShape(), b);
+    reifiedReturnShapes.push_back(outputShape);
     return success();
   }
 };
@@ -202,10 +131,8 @@ struct ReifyPadOp
 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    ExpandShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
-    CollapseShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+    ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+    CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
     PadOp::attachInterface<ReifyPadOp>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 616d4a7d0a0ab5..a6ae728b20fa47 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
     if (!dim.has_value())
       return failure();
 
-    // Skip static dims. These are folded to constant ops.
-    RankedTensorType resultType = expandShapeOp.getResultType();
-    if (!resultType.isDynamicDim(*dim))
-      return failure();
-
-    // Find reassociation group that contains this result dimension.
-    int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
-
-    // `dim` is the only dynamic dimension in `group`. (Otherwise, the
-    // ExpandShapeOp would be ambiguous.)
-    int64_t product = 1;
-    ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
-    for (int64_t d : grp) {
-      if (d != dim) {
-        assert(!resultType.isDynamicDim(d) && "expected static dim");
-        product *= resultType.getDimSize(d);
-      }
-    }
-
-    // result dim size = src dim size / (product(other dims in reassoc group))
-    Value srcDimSz =
-        rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
-    AffineExpr expr;
-    bindSymbols(dimOp.getContext(), expr);
-    rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
-        dimOp, expr.floorDiv(product), srcDimSz);
+    SmallVector<OpFoldResult> outputShape =
+        getMixedValues(expandShapeOp.getStaticOutputShape(),
+                       expandShapeOp.getOutputShape(), rewriter);
+    OpFoldResult outputDim = outputShape[dim.value()];
+    rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
+                                  rewriter, dimOp.getLoc(), outputDim));
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index f17881d59a266e..f29f231cdeca87 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1: index) -> tensor<?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, %sz1] : tensor<?xf32> into tensor<?x?xf32>
+  %init = tensor.empty(%sz1, %sz0) : tensor<?x?xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map0],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%0 : tensor<?x?xf32>) 
+      outs(%init : tensor<?x?xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32):
+          %out = arith.negf %b0 : f32
+          linalg.yield %out : f32
+      } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @fuse_by_collapsing_dynamic_2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK:     %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]]
+// CHECK:     %[[OUT:.+]] = linalg.generic
+// CHECK-SAME:   ins(%[[ARG0]] : tensor<?xf32>)
+// CHECK-SAME:   outs(%{{.*}} : tensor<?xf32>)
+// CHECK:     %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]]
+// CHECK-SAME:    output_shape [%[[DIM0]], %[[DIM1]]]
+// CHECK:      return %[[EXPANDED_1]]
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
 func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> {
@@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK: func @fuse_only_one_reassociation
 // CHECK-SAME:     (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-//  CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : index
 //  CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-//  CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
+//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor<?x4x?x8xf32>
+//  CHECK-DAG:   %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor<?x4x?x8xf32>
 //  CHECK-DAG:   %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel"]
 // CHECK-SAME:       ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
 // CHECK-SAME:       outs(%[[COLLAPSE_ARG1_1]] :
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C8]] : index
-//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
 //      CHECK:   return %[[EXPANDED_3]]
 
 // -----
@@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
 //      CHECK: func @fold_non_consecutive_dims(
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//      CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//      CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
+//      CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+//      CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
 //      CHECK:   %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
+//      CHECK-DAG:   %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+//      CHECK-DAG:   %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
 //      CHECK:   %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
 //  CHECK-DAG:       %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
 //  CHECK-DAG:       %[[T7:.+]] = arith.index_cast %[[T6]]
 //      CHECK:       linalg.yield %[[T7]]
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_1]], %[[C8]] : index
-//      CHECK:   %[[VAL_3:.+]] = arith.divui %[[DIM_2]], %[[C4]] : index
-//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
 //      CHECK:   return %[[EXPANDED_3]]
 
 // -----
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index 751ece37bc094f..fd3c3217225086 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -5,15 +5,14 @@
 
 // CHECK-LABEL: func @reshape
 // CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
-//      CHECK: %[[C112:.*]] = arith.constant 112 : index
 //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[A]]
+//      CHECK: %[[DIM:.*]] = tensor.dim %[[EXPANDED]], %[[C0]]
 //      CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
 //      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
 // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/118208


More information about the Mlir-commits mailing list