[Mlir-commits] [mlir] [mlir][Linalg] Drop unit extent dim in non-trivial expressions (PR #173873)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 29 07:44:52 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Lukas Sommer (sommerlukas)

<details>
<summary>Changes</summary>

The current implementation does not drop unit extent dimension if that dimension is indexed by a non-trivial affine expression (i.e., not a single dimension or constant 0) on the first application of the transformation. However, it is possible to drop such dimensions if all dimensions involved in the affine expression are going to be dropped. So far, this required repeated application of the transformation, with the changes in this PR, the dimensions are dropped with a single application of the transformation.

---
Full diff: https://github.com/llvm/llvm-project/pull/173873.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+6-1) 
- (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+39) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index ef58197bbfd0d..c3dca148b7f94 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -362,7 +362,12 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
     // Handle the other case where the shape is 1, and is accessed using a
     // constant 0.
     if (operandShape[dim] == 1) {
-      auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
+      // Use the new expression after replacing dimensions that will be dropped
+      // here to handle cases where an affine expression with multiple
+      // dimensions (e.g., `d0 + d2`) can be simplified to 0 after dropping all
+      // dimensions used in the expression (`d0` and `d2` in this example).
+      AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
+      auto constAffineExpr = dyn_cast<AffineConstantExpr>(newExpr);
       return constAffineExpr && constAffineExpr.getValue() == 0;
     }
     return false;
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 55b47bc2e9714..841d0e5f56512 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1136,6 +1136,45 @@ module {
 
 // -----
 
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+module {
+  func.func @drop_unit_dim_binary_expr(%arg0: tensor<1x61x1x1xf32>, %arg1: tensor<48x61x1x1xf32>, %arg2: tensor<1x48x1x1xf32>) -> tensor<1x48x1x1xf32> {
+    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x61x1x1xf32>, tensor<48x61x1x1xf32>) outs(%arg2 : tensor<1x48x1x1xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %3 = arith.mulf %in, %in_0 : f32
+      %4 = arith.addf %out, %3 : f32
+      linalg.yield %4 : f32
+    } -> tensor<1x48x1x1xf32>
+    return %2 : tensor<1x48x1x1xf32>
+  }
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-LABEL:   func.func @drop_unit_dim_binary_expr
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<1x61x1x1xf32>, %[[ARG1:.*]]: tensor<48x61x1x1xf32>
+// CHECK-SAME:      %[[ARG2:.*]]: tensor<1x48x1x1xf32>) -> tensor<1x48x1x1xf32>
+// CHECK:           %[[COLLAPSE_SHAPE_0:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2, 3]] : tensor<1x61x1x1xf32> into tensor<61xf32>
+// CHECK:           %[[COLLAPSE_SHAPE_1:.*]] = tensor.collapse_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] : tensor<48x61x1x1xf32> into tensor<48x61xf32>
+// CHECK:           %[[COLLAPSE_SHAPE_2:.*]] = tensor.collapse_shape %[[ARG2]] {{\[\[}}0, 1, 2, 3]] : tensor<1x48x1x1xf32> into tensor<48xf32>
+// CHECK:           %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK-SAME:        iterator_types = ["parallel", "reduction"]
+// CHECK-SAME:        ins(%[[COLLAPSE_SHAPE_0]], %[[COLLAPSE_SHAPE_1]] : tensor<61xf32>, tensor<48x61xf32>)
+// CHECK-SAME:        outs(%[[COLLAPSE_SHAPE_2]] : tensor<48xf32>)
+// CHECK:           ^bb0(%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32):
+// CHECK:             %[[MULF_0:.*]] = arith.mulf %[[VAL_0]], %[[VAL_1]] : f32
+// CHECK:             %[[ADDF_0:.*]] = arith.addf %[[VAL_2]], %[[MULF_0]] : f32
+// CHECK:             linalg.yield %[[ADDF_0]] : f32
+// CHECK:           } -> tensor<48xf32>
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = tensor.expand_shape %[[GENERIC_0]] {{\[\[}}0, 1, 2, 3]] output_shape [1, 48, 1, 1]
+// CHECK-SAME:        : tensor<48xf32> into tensor<1x48x1x1xf32>
+// CHECK:           return %[[EXPAND_SHAPE_0]] : tensor<1x48x1x1xf32>
+
+// -----
+
 func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
   %cst = arith.constant 1.000000e+00 : f32
   %cst7 = arith.constant 7 : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/173873


More information about the Mlir-commits mailing list