[Mlir-commits] [mlir] [mlir][Linalg] Drop unit extent dim in non-trivial expressions (PR #173873)

Lukas Sommer llvmlistbot at llvm.org
Wed Jan 7 14:04:45 PST 2026


https://github.com/sommerlukas updated https://github.com/llvm/llvm-project/pull/173873

>From 416113fe8aee354383add4b6c7d68d8752b733a0 Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Mon, 29 Dec 2025 15:34:32 +0000
Subject: [PATCH] [mlir][Linalg] Drop unit extent dim in non-trivial expr

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
 .../Linalg/Transforms/DropUnitDims.cpp        |  7 +++-
 .../Dialect/Linalg/drop-unit-extent-dims.mlir | 39 +++++++++++++++++++
 2 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index ef58197bbfd0d..c3dca148b7f94 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -362,7 +362,12 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
     // Handle the other case where the shape is 1, and is accessed using a
     // constant 0.
     if (operandShape[dim] == 1) {
-      auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
+      // Use the new expression after replacing dimensions that will be dropped
+      // here to handle cases where an affine expression with multiple
+      // dimensions (e.g., `d0 + d2`) can be simplified to 0 after dropping all
+      // dimensions used in the expression (`d0` and `d2` in this example).
+      AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
+      auto constAffineExpr = dyn_cast<AffineConstantExpr>(newExpr);
       return constAffineExpr && constAffineExpr.getValue() == 0;
     }
     return false;
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 55b47bc2e9714..841d0e5f56512 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1136,6 +1136,45 @@ module {
 
 // -----
 
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+module {
+  func.func @drop_unit_dim_binary_expr(%arg0: tensor<1x61x1x1xf32>, %arg1: tensor<48x61x1x1xf32>, %arg2: tensor<1x48x1x1xf32>) -> tensor<1x48x1x1xf32> {
+    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x61x1x1xf32>, tensor<48x61x1x1xf32>) outs(%arg2 : tensor<1x48x1x1xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %3 = arith.mulf %in, %in_0 : f32
+      %4 = arith.addf %out, %3 : f32
+      linalg.yield %4 : f32
+    } -> tensor<1x48x1x1xf32>
+    return %2 : tensor<1x48x1x1xf32>
+  }
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-LABEL:   func.func @drop_unit_dim_binary_expr
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<1x61x1x1xf32>, %[[ARG1:.*]]: tensor<48x61x1x1xf32>
+// CHECK-SAME:      %[[ARG2:.*]]: tensor<1x48x1x1xf32>) -> tensor<1x48x1x1xf32>
+// CHECK:           %[[COLLAPSE_SHAPE_0:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2, 3]] : tensor<1x61x1x1xf32> into tensor<61xf32>
+// CHECK:           %[[COLLAPSE_SHAPE_1:.*]] = tensor.collapse_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] : tensor<48x61x1x1xf32> into tensor<48x61xf32>
+// CHECK:           %[[COLLAPSE_SHAPE_2:.*]] = tensor.collapse_shape %[[ARG2]] {{\[\[}}0, 1, 2, 3]] : tensor<1x48x1x1xf32> into tensor<48xf32>
+// CHECK:           %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK-SAME:        iterator_types = ["parallel", "reduction"]
+// CHECK-SAME:        ins(%[[COLLAPSE_SHAPE_0]], %[[COLLAPSE_SHAPE_1]] : tensor<61xf32>, tensor<48x61xf32>)
+// CHECK-SAME:        outs(%[[COLLAPSE_SHAPE_2]] : tensor<48xf32>)
+// CHECK:           ^bb0(%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32):
+// CHECK:             %[[MULF_0:.*]] = arith.mulf %[[VAL_0]], %[[VAL_1]] : f32
+// CHECK:             %[[ADDF_0:.*]] = arith.addf %[[VAL_2]], %[[MULF_0]] : f32
+// CHECK:             linalg.yield %[[ADDF_0]] : f32
+// CHECK:           } -> tensor<48xf32>
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = tensor.expand_shape %[[GENERIC_0]] {{\[\[}}0, 1, 2, 3]] output_shape [1, 48, 1, 1]
+// CHECK-SAME:        : tensor<48xf32> into tensor<1x48x1x1xf32>
+// CHECK:           return %[[EXPAND_SHAPE_0]] : tensor<1x48x1x1xf32>
+
+// -----
+
 func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
   %cst = arith.constant 1.000000e+00 : f32
   %cst7 = arith.constant 7 : index



More information about the Mlir-commits mailing list