[Mlir-commits] [mlir] [mlir][Linalg] Drop unit extent dim in non-trivial expressions (PR #173873)

Lukas Sommer llvmlistbot at llvm.org
Mon Dec 29 07:44:22 PST 2025


https://github.com/sommerlukas created https://github.com/llvm/llvm-project/pull/173873

The current implementation does not drop unit extent dimension if that dimension is indexed by a non-trivial affine expression (i.e., not a single dimension or constant 0) on the first application of the transformation. However, it is possible to drop such dimensions if all dimensions involved in the affine expression are going to be dropped. So far, this required repeated application of the transformation, with the changes in this PR, the dimensions are dropped with a single application of the transformation.

>From f5c6e6906bb7d21ea0aa0c7a0ce0f63d127d7d3e Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Mon, 29 Dec 2025 15:34:32 +0000
Subject: [PATCH] [mlir][Linalg] Drop unit extent dim in non-trivial expr

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
 .../Linalg/Transforms/DropUnitDims.cpp        |  7 +++-
 .../Dialect/Linalg/drop-unit-extent-dims.mlir | 39 +++++++++++++++++++
 2 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index ef58197bbfd0d..c3dca148b7f94 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -362,7 +362,12 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
     // Handle the other case where the shape is 1, and is accessed using a
     // constant 0.
     if (operandShape[dim] == 1) {
-      auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
+      // Use the new expression after replacing dimensions that will be dropped
+      // here to handle cases where an affine expression with multiple
+      // dimensions (e.g., `d0 + d2`) can be simplified to 0 after dropping all
+      // dimensions used in the expression (`d0` and `d2` in this example).
+      AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
+      auto constAffineExpr = dyn_cast<AffineConstantExpr>(newExpr);
       return constAffineExpr && constAffineExpr.getValue() == 0;
     }
     return false;
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 55b47bc2e9714..841d0e5f56512 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1136,6 +1136,45 @@ module {
 
 // -----
 
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+module {
+  func.func @drop_unit_dim_binary_expr(%arg0: tensor<1x61x1x1xf32>, %arg1: tensor<48x61x1x1xf32>, %arg2: tensor<1x48x1x1xf32>) -> tensor<1x48x1x1xf32> {
+    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x61x1x1xf32>, tensor<48x61x1x1xf32>) outs(%arg2 : tensor<1x48x1x1xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %3 = arith.mulf %in, %in_0 : f32
+      %4 = arith.addf %out, %3 : f32
+      linalg.yield %4 : f32
+    } -> tensor<1x48x1x1xf32>
+    return %2 : tensor<1x48x1x1xf32>
+  }
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-LABEL:   func.func @drop_unit_dim_binary_expr
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<1x61x1x1xf32>, %[[ARG1:.*]]: tensor<48x61x1x1xf32>
+// CHECK-SAME:      %[[ARG2:.*]]: tensor<1x48x1x1xf32>) -> tensor<1x48x1x1xf32>
+// CHECK:           %[[COLLAPSE_SHAPE_0:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2, 3]] : tensor<1x61x1x1xf32> into tensor<61xf32>
+// CHECK:           %[[COLLAPSE_SHAPE_1:.*]] = tensor.collapse_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] : tensor<48x61x1x1xf32> into tensor<48x61xf32>
+// CHECK:           %[[COLLAPSE_SHAPE_2:.*]] = tensor.collapse_shape %[[ARG2]] {{\[\[}}0, 1, 2, 3]] : tensor<1x48x1x1xf32> into tensor<48xf32>
+// CHECK:           %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK-SAME:        iterator_types = ["parallel", "reduction"]
+// CHECK-SAME:        ins(%[[COLLAPSE_SHAPE_0]], %[[COLLAPSE_SHAPE_1]] : tensor<61xf32>, tensor<48x61xf32>)
+// CHECK-SAME:        outs(%[[COLLAPSE_SHAPE_2]] : tensor<48xf32>)
+// CHECK:           ^bb0(%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32):
+// CHECK:             %[[MULF_0:.*]] = arith.mulf %[[VAL_0]], %[[VAL_1]] : f32
+// CHECK:             %[[ADDF_0:.*]] = arith.addf %[[VAL_2]], %[[MULF_0]] : f32
+// CHECK:             linalg.yield %[[ADDF_0]] : f32
+// CHECK:           } -> tensor<48xf32>
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = tensor.expand_shape %[[GENERIC_0]] {{\[\[}}0, 1, 2, 3]] output_shape [1, 48, 1, 1]
+// CHECK-SAME:        : tensor<48xf32> into tensor<1x48x1x1xf32>
+// CHECK:           return %[[EXPAND_SHAPE_0]] : tensor<1x48x1x1xf32>
+
+// -----
+
 func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
   %cst = arith.constant 1.000000e+00 : f32
   %cst7 = arith.constant 7 : index



More information about the Mlir-commits mailing list