[Mlir-commits] [mlir] [MLIR][Tensor] Remove FoldDimOf[Expand|Collapse]Shape Pattern (PR #134219)

Vivek Khandelwal llvmlistbot at llvm.org
Thu Apr 3 07:17:49 PDT 2025


https://github.com/vivekkhandelwal1 updated https://github.com/llvm/llvm-project/pull/134219

>From dad4284571c7bc233f4deeae5825d03ddb9a8d75 Mon Sep 17 00:00:00 2001
From: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
Date: Thu, 3 Apr 2025 13:59:47 +0530
Subject: [PATCH 1/3] [MLIR][Tensor] Remove FoldDimOf[Expand|Collapse]Shape
 Pattern

This commit removes the FoldDimOfExpandShape and FoldDimOfCollapseShape
pattern from the list of ExpandShapeOp's canonicalization patterns.

The above pattern were resulting in crash while folding the dims of
an expanded tensor. The issue can be reproduced by undoing the changes
done in this commit and by running the command:

```
mlir-opt --linalg-fuse-elementwise-ops repro.mlir
```

over the IR: https://gist.github.com/vivekkhandelwal1/56a1a398c21d739df77a67ce372b9366.

Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d589f627d896e..35265ac49fcbf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2158,8 +2158,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
       ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
       FoldReshapeWithSplat<ExpandShapeOp>,
-      FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
-      FoldDimOfCollapseShape>(context);
+      FoldReshapeWithFromElements<ExpandShapeOp>>(context);
 }
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

>From 532a05323ebac79d10f552e4792cdd390713f1d9 Mon Sep 17 00:00:00 2001
From: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
Date: Thu, 3 Apr 2025 15:11:50 +0530
Subject: [PATCH 2/3] Fix regression tests

---
 .../Dialect/Linalg/drop-unit-extent-dims.mlir | 29 ++++---------------
 .../Linalg/rank-reduce-contraction-ops.mlir   | 10 +++----
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 28 +++++++-----------
 3 files changed, 21 insertions(+), 46 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 3256daa8e0b59..a00c798197e5a 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -25,10 +25,8 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 //   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
 //   CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//   CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
 // CHECK-LABEL: func @drop_one_trip_loops
 //       CHECK: %[[C2:.*]] = arith.constant 2 : index
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: %[[C0:.*]] = arith.constant 0 : index
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]]
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]]
@@ -36,11 +34,9 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
 //  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
 //       CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C0]]
-//       CHECK: %[[VAL_1:.*]] = affine.apply #[[$MAP4]]()[%[[DIM]], %[[C1]]]
 //       CHECK: %[[DIM_1:.*]] = tensor.dim %{{.*}}, %[[C2]]
-//       CHECK: %[[VAL_2:.*]] = affine.apply #[[$MAP4]]()[%[[DIM_1]], %[[C1]]]
 //       CHECK: %[[DIM_2:.*]] = tensor.dim %{{.*}}, %[[C2]]
-//       CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[VAL_1]], 1, %[[VAL_2]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
+//       CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[DIM]], 1, %[[DIM_1]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
 
 //   CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 //   CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
@@ -79,18 +75,15 @@ func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32,
 }
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
 //   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
-//   CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 * s1) * s2) * s3) * s4)>
 // CHECK-LABEL: func @drop_one_trip_loops_all_ones
 //       CHECK: %[[C2:.*]] = arith.constant 2 : index
-//       CHECK: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: tensor.collapse_shape %{{.*}} []
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
 //       CHECK: linalg.generic
 //  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME:   iterator_types = ["parallel"]
 //       CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x1x?x1x1xf32>
-//       CHECK: %[[SZ:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[C1]], %[[DIM]], %[[C1]], %[[C1]]]
-//       CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[SZ]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>
+//       CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[DIM]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>
 
 // -----
 
@@ -406,7 +399,6 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32>
 }
 //  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
-//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)>
 //      CHECK: func @unit_dim_for_reduction
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x?xf32>
 //      CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -422,8 +414,7 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32>
 // CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
 //      CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x1x?xf32>
-//      CHECK: %[[VAL_3:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[DIM_0]], %[[C1]]]
-//      CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_3]]] : tensor<?xf32> into tensor<1x?xf32>
+//      CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[DIM_0]]] : tensor<?xf32> into tensor<1x?xf32>
 //      CHECK: return %[[EXPANDED]] : tensor<1x?xf32>
 
 // -----
@@ -482,10 +473,8 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x
 }
 //  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
-//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
 //      CHECK: func @unit_dim_for_reduction_inner
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<?x1x?x1xf32>
-//      CHECK: %[[C1:.*]] = arith.constant 1 : index
 //      CHECK: %[[C0:.*]] = arith.constant 0 : index
 //      CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
 //      CHECK: %[[C2:.*]] = arith.constant 2 : index
@@ -499,8 +488,7 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x
 // CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
 //      CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x1x?x1xf32>
-//      CHECK: %[[VAL_3:.+]] = affine.apply #[[$MAP3]]()[%[[DIM_0]], %[[C1]]]
-//      CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[VAL_3]], 1] : tensor<?xf32> into tensor<?x1xf32>
+//      CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[DIM_0]], 1] : tensor<?xf32> into tensor<?x1xf32>
 //      CHECK: return %[[RESULT_RESHAPE]]
 
 // -----
@@ -1017,7 +1005,6 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
   return %0 : tensor<1x?xf32>
 }
 
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
 // CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 11)>
 // CHECK-LABEL: func @drop_unit_pad_dynamic_dims
 //       CHECK:   %[[C1:.*]] = arith.constant 1 : index
@@ -1027,8 +1014,7 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
 //       CHECK:   %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[5] high[6]
 //       CHECK:   } : tensor<?xf32> to tensor<?xf32>
 //       CHECK:   %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?xf32>
-//       CHECK:   %[[VAL_0:.+]] = affine.apply #[[$MAP]]()[%[[C1]], %[[DIM]]]
-//       CHECK:   %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[VAL_0]]]
+//       CHECK:   %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[DIM]]]
 //       CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_1]]] : tensor<?xf32> into tensor<1x?xf32>
 
 // CHECK-SLICES: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 11)>
@@ -1090,7 +1076,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
 
 // -----
 
-// CHECK: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
 
@@ -1098,12 +1083,10 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
 // CHECK-SAME:                    %[[ARG0:.*]]: tensor<1x?x?x1xf32>,
 // CHECK-SAME:                    %[[ARG1:.*]]: index) -> tensor<?x1x61x1xf32> {
 // CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant dense<1.000000e+00> : tensor<f32>
 // CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] : tensor<1x?x?x1xf32> into tensor<?x?xf32>
 // CHECK:           %[[VAL_4:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
-// CHECK:           %[[VAL_5:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[ARG1]], %[[VAL_1]]]
-// CHECK:           %[[VAL_6:.*]] = tensor.empty(%[[VAL_5]]) : tensor<?x61xf32>
+// CHECK:           %[[VAL_6:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
 // CHECK:           %[[VAL_7:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[VAL_3]], %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>, tensor<f32>, tensor<?x61xf32>) outs(%[[VAL_6]] : tensor<?x61xf32>) {
 // CHECK:           ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
 // CHECK:             %[[VAL_12:.*]] = arith.mulf %[[VAL_8]], %[[VAL_9]] : f32
diff --git a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
index c68a6362f52c5..43bddb075e649 100644
--- a/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
+++ b/mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
@@ -76,13 +76,13 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
   //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
   //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
   //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-DAG:    %[[C0:.*]] = arith.constant 0
   //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
   //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.vecmat 
   //  CHECK-SAME:   ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
-  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
   //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
   //  CHECK-NEXT:   return %[[RES]]
   %1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
@@ -134,7 +134,7 @@ func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32
   //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[MATMUL:.+]] = linalg.matvec 
   //  CHECK-SAME:   ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
-  //  CHECK-NEXT:   %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
+  //  CHECK-NEXT:   %[[DIM0:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
   //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
   //  CHECK-NEXT:   return %[[RES]]
     %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
@@ -171,12 +171,12 @@ func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32
   //  CHECK-SAME:     %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
   //  CHECK-SAME:     %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
   //  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
-  //  CHECK-DAG:    %[[C1:.*]] = arith.constant 1
+  //  CHECK-DAG:    %[[C0:.*]] = arith.constant 0
   //  CHECK-NEXT:   %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
   //  CHECK-NEXT:   %[[RESULT:.*]] = linalg.vecmat 
   //  CHECK-SAME:   ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
-  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
+  //  CHECK-NEXT:   %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
   //  CHECK-NEXT:   %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
   //  CHECK-NEXT:   return %[[RES]]
     %0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index fd96328c6033d..85bf6fba52aa4 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1105,15 +1105,13 @@ func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -
   %expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor<?xf32> into tensor<?x384xf32>
   return %expanded : tensor<?x384xf32>
 }
-//       CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)>
 // CHECK-LABEL: @compose_expand_of_collapse_last_two_dims
 //  CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32>
-//       CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
 //       CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index
+//       CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
 //       CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
-//       CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32>
-//       CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-//       CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index
+//       CHECK: %[[DIM:.+]] = tensor.dim %[[COLLAPSE]], %[[CONSTANT0]] : tensor<?xf32>
+//       CHECK: %[[DIVUI:.+]] = arith.divui %[[DIM]], %[[CONSTANT384]] : index
 //       CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor<?xf32> into tensor<?x384xf32>
 //       CHECK: return %[[RESULT]]
 
@@ -2137,13 +2135,12 @@ func.func @empty_tensor_canonicalize(%i : index) {
 
 // -----
 
-//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
 // CHECK-LABEL: func @dim_of_expand_shape(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
-//       CHECK:   %[[c1:.*]] = arith.constant 1 : index
-//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
-//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
-//       CHECK:   return %[[apply]]
+//       CHECK:   %[[c2:.*]] = arith.constant 2 : index
+//       CHECK:   %[[expanded:.*]] = tensor.expand_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4, 5]] output_shape [%arg1, 1, %arg2, 5, 1, 8] : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
+//       CHECK:   %[[dim:.*]] = tensor.dim %[[expanded]], %[[c2]] : tensor<?x1x?x5x1x8xf32>
+//       CHECK:   return %[[dim]]
 func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
   %c2 = arith.constant 2 : index
   %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
@@ -2154,17 +2151,12 @@ func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) ->
 
 // -----
 
-//       CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
 // CHECK-LABEL: func @dim_of_collapse_shape(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x7x?xf32>
 //   CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
-//   CHECK-DAG:   %[[c2:.*]] = arith.constant 2 : index
-//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
-//   CHECK-DAG:   %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
-//   CHECK-DAG:   %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
-//   CHECK-DAG:   %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
-//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
-//       CHECK:   return %[[apply]]
+//   CHECK-DAG:   %[[collapsed:.*]] = tensor.collapse_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4]] : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
+//   CHECK-DAG:   %[[dim:.*]] = tensor.dim %[[collapsed]], %[[c1]]
+//       CHECK:   return %[[dim]]
 func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
   %c1 = arith.constant 1 : index
   %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]

>From 5c6b295e80af89e367637bc72c3ce879eb7d657a Mon Sep 17 00:00:00 2001
From: Vivek Khandelwal <vivekkhandelwal1424 at gmail.com>
Date: Thu, 3 Apr 2025 19:47:37 +0530
Subject: [PATCH 3/3] Add test for reproducing the crash

---
 .../Linalg/fusion-elementwise-ops.mlir        | 90 +++++++++++++++++++
 1 file changed, 90 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 28e1291bce1fa..0405913093411 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -977,3 +977,93 @@ module {
 //   CHECK-DAG:     %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
 //       CHECK:     linalg.yield %[[T3]] : f32
 //       CHECK:   return %[[GENERIC]]
+
+// -----
+
+#map = affine_map<()[s0, s1] -> (s0 * s1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+module {
+  func.func @no_fusio(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi64>, %arg2: tensor<?x?xi64>, %arg3: tensor<?x?xi64>) -> tensor<?x?x?x?xf32> {
+    %c1 = arith.constant 1 : index
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %dim = tensor.dim %arg1, %c0 : tensor<?x?xi64>
+    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?xi64>
+    %0 = arith.index_cast %dim : index to i64
+    %1 = arith.index_cast %dim_0 : index to i64
+    %collapsed = tensor.collapse_shape %arg3 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
+    %dim_1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+    %dim_2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+    %2 = affine.apply #map()[%dim, %dim_0]
+    %3 = tensor.empty(%2, %dim_1, %dim_2) : tensor<?x?x?xf32>
+    %4 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%3 : tensor<?x?x?xf32>) {
+    ^bb0(%in: i64, %out: f32):
+      %7 = arith.index_cast %in : i64 to index
+      %8 = linalg.index 1 : index
+      %9 = linalg.index 2 : index
+      %extracted = tensor.extract %arg0[%7, %8, %9] : tensor<?x?x?xf32>
+      linalg.yield %extracted : f32
+    } -> tensor<?x?x?xf32>
+    %5 = arith.index_cast %dim_1 : index to i64
+    %6 = arith.index_cast %dim_2 : index to i64
+    %from_elements = tensor.from_elements %0, %1, %5, %6 : tensor<4xi64>
+    %reshape = tensor.reshape %4(%from_elements) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
+    return %reshape : tensor<?x?x?x?xf32>
+  }
+}
+
+// -----
+
+#map = affine_map<()[s0, s1] -> (s0 * s1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL:   func.func @no_fuse_expand_collapsed_generic_input(
+// CHECK-SAME:                                              %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?x?xf32>,
+// CHECK-SAME:                                              %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
+// CHECK-SAME:                                              %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
+// CHECK-SAME:                                              %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>)
+func.func @no_fuse_expand_collapsed_generic_input(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi64>, %arg2: tensor<?x?xi64>, %arg3: tensor<?x?xi64>) -> tensor<?x?x?x?xf32> {
+  // CHECK:           %[[EXPANDED:.*]] = tensor.expand_shape %{{.+}} {{\[\[}}0, 1], [2], [3]] output_shape {{\[}}%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  // CHECK:           %[[OUT:.*]] = tensor.empty(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}) : tensor<?x?x?x?xf32>
+  // CHECK:           %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_3]] : tensor<?x?xi64>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) {
+  // CHECK:           ^bb0(%[[VAL_5:.*]]: i64, %[[VAL_6:.*]]: f32):
+  // CHECK:             %[[OFFSETS:.*]] = arith.index_cast %[[VAL_5]] : i64 to index
+  // CHECK:             %[[SIZES:.*]] = linalg.index 2 : index
+  // CHECK:             %[[STRIDES:.*]] = linalg.index 3 : index
+  // CHECK:             %[[EXTRACT:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[OFFSETS]], %[[SIZES]], %[[STRIDES]]] : tensor<?x?x?xf32>
+  // CHECK:             linalg.yield %[[EXTRACT]] : f32
+  // CHECK:           } -> tensor<?x?x?x?xf32>
+  // CHECK:           %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAL_4]] {{\[\[}}0, 1], [2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  // CHECK:           %[[SHAPE:.*]] = tensor.from_elements
+  // CHECK:           %[[RESULT:.*]] = tensor.reshape %[[COLLAPSED]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
+  // CHECK:           return %[[RESULT]] : tensor<?x?x?x?xf32>
+  // CHECK:         }
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %dim = tensor.dim %arg1, %c0 : tensor<?x?xi64>
+  %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?xi64>
+  %0 = arith.index_cast %dim : index to i64
+  %1 = arith.index_cast %dim_0 : index to i64
+  %collapsed = tensor.collapse_shape %arg3 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
+  %dim_1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %dim_2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %2 = affine.apply #map()[%dim, %dim_0]
+  %3 = tensor.empty(%2, %dim_1, %dim_2) : tensor<?x?x?xf32>
+  %4 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%3 : tensor<?x?x?xf32>) {
+  ^bb0(%in: i64, %out: f32):
+    %7 = arith.index_cast %in : i64 to index
+    %8 = linalg.index 1 : index
+    %9 = linalg.index 2 : index
+    %extracted = tensor.extract %arg0[%7, %8, %9] : tensor<?x?x?xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<?x?x?xf32>
+  %5 = arith.index_cast %dim_1 : index to i64
+  %6 = arith.index_cast %dim_2 : index to i64
+  %from_elements = tensor.from_elements %0, %1, %5, %6 : tensor<4xi64>
+  %reshape = tensor.reshape %4(%from_elements) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
+  return %reshape : tensor<?x?x?x?xf32>
+}



More information about the Mlir-commits mailing list