[Mlir-commits] [mlir] e377a5d - [MLIR][Tensor] Remove tensor.dim canonicalization patterns registered on tensor.expand_shape/tensor.collapse_shape (#134219)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 11 03:57:38 PDT 2025
Author: Vivek Khandelwal
Date: 2025-04-11T06:57:34-04:00
New Revision: e377a5d1682d0410b1fd38b011c29b7d81d6b53b
URL: https://github.com/llvm/llvm-project/commit/e377a5d1682d0410b1fd38b011c29b7d81d6b53b
DIFF: https://github.com/llvm/llvm-project/commit/e377a5d1682d0410b1fd38b011c29b7d81d6b53b.diff
LOG: [MLIR][Tensor] Remove tensor.dim canonicalization patterns registered on tensor.expand_shape/tensor.collapse_shape (#134219)
These are problematic because the iterative application that locally
resolves the tensor.dim operation introduces
intermediate floor_div, which is losing the information about the exact
division that was carried out in the original
IR, and the iterative algorithm can't converge towards the simplest
form.
Information loss is not acceptable for canonicalization.
Resolving the dimOp can be achieved through
resolve-ranked-shaped-type-result-dims and
resolve-shaped-type-result-dims passes.
---------
Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d589f627d896e..b42e60d5cebd7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1986,90 +1986,6 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
}
};
-struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dimOp,
- PatternRewriter &rewriter) const override {
- auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
- if (!expandShapeOp)
- return failure();
-
- // Only constant dimension values are supported.
- std::optional<int64_t> dim = dimOp.getConstantIndex();
- if (!dim.has_value())
- return failure();
-
- // Skip static dims. These are folded to constant ops.
- RankedTensorType resultType = expandShapeOp.getResultType();
- if (!resultType.isDynamicDim(*dim))
- return failure();
-
- // Find reassociation group that contains this result dimension.
- int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
-
- // `dim` is the only dynamic dimension in `group`. (Otherwise, the
- // ExpandShapeOp would be ambiguous.)
- int64_t product = 1;
- ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
- for (int64_t d : grp) {
- if (d != dim) {
- assert(!resultType.isDynamicDim(d) && "expected static dim");
- product *= resultType.getDimSize(d);
- }
- }
-
- // result dim size = src dim size / (product(other dims in reassoc group))
- Value srcDimSz =
- rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
- AffineExpr expr;
- bindSymbols(dimOp.getContext(), expr);
- rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
- dimOp, expr.floorDiv(product), srcDimSz);
- return success();
- }
-};
-
-struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dimOp,
- PatternRewriter &rewriter) const override {
- auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
- if (!collapseShapeOp)
- return failure();
-
- // Only constant dimension values are supported.
- std::optional<int64_t> dim = dimOp.getConstantIndex();
- if (!dim.has_value() ||
- dim.value() >= collapseShapeOp.getResultType().getRank())
- return failure();
-
- // Skip static dims. These are folded to constant ops.
- RankedTensorType resultType = collapseShapeOp.getResultType();
- if (!resultType.isDynamicDim(*dim))
- return failure();
-
- // Get reassociation group of the result dimension.
- ReassociationIndices group =
- collapseShapeOp.getReassociationIndices()[*dim];
-
- // result dim size = product(dims in reassoc group)
- SmallVector<Value> srcDimSizes;
- SmallVector<AffineExpr> syms;
- AffineExpr product;
- for (const auto &it : llvm::enumerate(group)) {
- srcDimSizes.push_back(rewriter.create<DimOp>(
- dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
- syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
- product = product ? product * syms.back() : syms.back();
- }
- rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
- srcDimSizes);
- return success();
- }
-};
-
/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
/// matching constant output_shape operands of the expand. This makes the
/// `tensor.expand_shape` more static and creates a consumer cast that can be
@@ -2158,8 +2074,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
- FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
- FoldDimOfCollapseShape>(context);
+ FoldReshapeWithFromElements<ExpandShapeOp>>(context);
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 3256daa8e0b59..a00c798197e5a 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -25,10 +25,8 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK-LABEL: func @drop_one_trip_loops
// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]]
// CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]]
@@ -36,11 +34,9 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C0]]
-// CHECK: %[[VAL_1:.*]] = affine.apply #[[$MAP4]]()[%[[DIM]], %[[C1]]]
// CHECK: %[[DIM_1:.*]] = tensor.dim %{{.*}}, %[[C2]]
-// CHECK: %[[VAL_2:.*]] = affine.apply #[[$MAP4]]()[%[[DIM_1]], %[[C1]]]
// CHECK: %[[DIM_2:.*]] = tensor.dim %{{.*}}, %[[C2]]
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[VAL_1]], 1, %[[VAL_2]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[DIM]], 1, %[[DIM_1]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
// CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
@@ -79,18 +75,15 @@ func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32,
}
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
-// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 * s1) * s2) * s3) * s4)>
// CHECK-LABEL: func @drop_one_trip_loops_all_ones
// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: tensor.collapse_shape %{{.*}} []
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel"]
// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x1x?x1x1xf32>
-// CHECK: %[[SZ:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[C1]], %[[DIM]], %[[C1]], %[[C1]]]
-// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[SZ]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>
+// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[DIM]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>
// -----
@@ -406,7 +399,6 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)>
// CHECK: func @unit_dim_for_reduction
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -422,8 +414,7 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32>
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x1x?xf32>
-// CHECK: %[[VAL_3:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[DIM_0]], %[[C1]]]
-// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_3]]] : tensor<?xf32> into tensor<1x?xf32>
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[DIM_0]]] : tensor<?xf32> into tensor<1x?xf32>
// CHECK: return %[[EXPANDED]] : tensor<1x?xf32>
// -----
@@ -482,10 +473,8 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK: func @unit_dim_for_reduction_inner
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x1xf32>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[C2:.*]] = arith.constant 2 : index
@@ -499,8 +488,7 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x1x?x1xf32>
-// CHECK: %[[VAL_3:.+]] = affine.apply #[[$MAP3]]()[%[[DIM_0]], %[[C1]]]
-// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[VAL_3]], 1] : tensor<?xf32> into tensor<?x1xf32>
+// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[DIM_0]], 1] : tensor<?xf32> into tensor<?x1xf32>
// CHECK: return %[[RESULT_RESHAPE]]
// -----
@@ -1017,7 +1005,6 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
return %0 : tensor<1x?xf32>
}
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 11)>
// CHECK-LABEL: func @drop_unit_pad_dynamic_dims
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -1027,8 +1014,7 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
// CHECK: %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[5] high[6]
// CHECK: } : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?xf32>
-// CHECK: %[[VAL_0:.+]] = affine.apply #[[$MAP]]()[%[[C1]], %[[DIM]]]
-// CHECK: %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[VAL_0]]]
+// CHECK: %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[DIM]]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_1]]] : tensor<?xf32> into tensor<1x?xf32>
// CHECK-SLICES: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 11)>
@@ -1090,7 +1076,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
// -----
-// CHECK: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
@@ -1098,12 +1083,10 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x?x1xf32>,
// CHECK-SAME: %[[ARG1:.*]]: index) -> tensor<?x1x61x1xf32> {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] : tensor<1x?x?x1xf32> into tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
-// CHECK: %[[VAL_5:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[ARG1]], %[[VAL_1]]]
-// CHECK: %[[VAL_6:.*]] = tensor.empty(%[[VAL_5]]) : tensor<?x61xf32>
+// CHECK: %[[VAL_6:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
// CHECK: %[[VAL_7:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[VAL_3]], %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>, tensor<f32>, tensor<?x61xf32>) outs(%[[VAL_6]] : tensor<?x61xf32>) {
// CHECK: ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
// CHECK: %[[VAL_12:.*]] = arith.mulf %[[VAL_8]], %[[VAL_9]] : f32
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index c68a6362f52c5..43bddb075e649 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -76,13 +76,13 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
- // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
// CHECK-NEXT: return %[[RES]]
%1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
@@ -134,7 +134,7 @@ func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec
// CHECK-SAME: ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
- // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
+ // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
// CHECK-NEXT: return %[[RES]]
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
@@ -171,12 +171,12 @@ func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
- // CHECK-DAG: %[[C1:.*]] = arith.constant 1
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
- // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+ // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
// CHECK-NEXT: return %[[RES]]
%0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index fd96328c6033d..85bf6fba52aa4 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1105,15 +1105,13 @@ func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -
%expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor<?xf32> into tensor<?x384xf32>
return %expanded : tensor<?x384xf32>
}
-// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)>
// CHECK-LABEL: @compose_expand_of_collapse_last_two_dims
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32>
-// CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
// CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index
+// CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32>
-// CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[COLLAPSE]], %[[CONSTANT0]] : tensor<?xf32>
+// CHECK: %[[DIVUI:.+]] = arith.divui %[[DIM]], %[[CONSTANT384]] : index
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor<?xf32> into tensor<?x384xf32>
// CHECK: return %[[RESULT]]
@@ -2137,13 +2135,12 @@ func.func @empty_tensor_canonicalize(%i : index) {
// -----
-// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
// CHECK-LABEL: func @dim_of_expand_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
-// CHECK: %[[c1:.*]] = arith.constant 1 : index
-// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
-// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
-// CHECK: return %[[apply]]
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[expanded:.*]] = tensor.expand_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4, 5]] output_shape [%arg1, 1, %arg2, 5, 1, 8] : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
+// CHECK: %[[dim:.*]] = tensor.dim %[[expanded]], %[[c2]] : tensor<?x1x?x5x1x8xf32>
+// CHECK: return %[[dim]]
func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
%c2 = arith.constant 2 : index
%0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
@@ -2154,17 +2151,12 @@ func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) ->
// -----
-// CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
// CHECK-LABEL: func @dim_of_collapse_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x7x?xf32>
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
-// CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
-// CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
-// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
-// CHECK: return %[[apply]]
+// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4]] : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
+// CHECK-DAG: %[[dim:.*]] = tensor.dim %[[collapsed]], %[[c1]]
+// CHECK: return %[[dim]]
func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
%c1 = arith.constant 1 : index
%0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
More information about the Mlir-commits
mailing list