[Mlir-commits] [mlir] [mlir][Linalg] Drop unit extent dim in non-trivial expressions (PR #173873)
Lukas Sommer
llvmlistbot at llvm.org
Wed Jan 7 14:04:45 PST 2026
https://github.com/sommerlukas updated https://github.com/llvm/llvm-project/pull/173873
>From 416113fe8aee354383add4b6c7d68d8752b733a0 Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at amd.com>
Date: Mon, 29 Dec 2025 15:34:32 +0000
Subject: [PATCH] [mlir][Linalg] Drop unit extent dim in non-trivial expr
Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>
---
.../Linalg/Transforms/DropUnitDims.cpp | 7 +++-
.../Dialect/Linalg/drop-unit-extent-dims.mlir | 39 +++++++++++++++++++
2 files changed, 45 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index ef58197bbfd0d..c3dca148b7f94 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -362,7 +362,12 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
// Handle the other case where the shape is 1, and is accessed using a
// constant 0.
if (operandShape[dim] == 1) {
- auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
+ // Use the new expression after replacing dimensions that will be dropped
+ // here to handle cases where an affine expression with multiple
+ // dimensions (e.g., `d0 + d2`) can be simplified to 0 after dropping all
+ // dimensions used in the expression (`d0` and `d2` in this example).
+ AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
+ auto constAffineExpr = dyn_cast<AffineConstantExpr>(newExpr);
return constAffineExpr && constAffineExpr.getValue() == 0;
}
return false;
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 55b47bc2e9714..841d0e5f56512 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1136,6 +1136,45 @@ module {
// -----
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+module {
+ func.func @drop_unit_dim_binary_expr(%arg0: tensor<1x61x1x1xf32>, %arg1: tensor<48x61x1x1xf32>, %arg2: tensor<1x48x1x1xf32>) -> tensor<1x48x1x1xf32> {
+ %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<1x61x1x1xf32>, tensor<48x61x1x1xf32>) outs(%arg2 : tensor<1x48x1x1xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %3 = arith.mulf %in, %in_0 : f32
+ %4 = arith.addf %out, %3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<1x48x1x1xf32>
+ return %2 : tensor<1x48x1x1xf32>
+ }
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-LABEL: func.func @drop_unit_dim_binary_expr
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x61x1x1xf32>, %[[ARG1:.*]]: tensor<48x61x1x1xf32>
+// CHECK-SAME: %[[ARG2:.*]]: tensor<1x48x1x1xf32>) -> tensor<1x48x1x1xf32>
+// CHECK: %[[COLLAPSE_SHAPE_0:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2, 3]] : tensor<1x61x1x1xf32> into tensor<61xf32>
+// CHECK: %[[COLLAPSE_SHAPE_1:.*]] = tensor.collapse_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] : tensor<48x61x1x1xf32> into tensor<48x61xf32>
+// CHECK: %[[COLLAPSE_SHAPE_2:.*]] = tensor.collapse_shape %[[ARG2]] {{\[\[}}0, 1, 2, 3]] : tensor<1x48x1x1xf32> into tensor<48xf32>
+// CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK-SAME: iterator_types = ["parallel", "reduction"]
+// CHECK-SAME: ins(%[[COLLAPSE_SHAPE_0]], %[[COLLAPSE_SHAPE_1]] : tensor<61xf32>, tensor<48x61xf32>)
+// CHECK-SAME: outs(%[[COLLAPSE_SHAPE_2]] : tensor<48xf32>)
+// CHECK: ^bb0(%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32):
+// CHECK: %[[MULF_0:.*]] = arith.mulf %[[VAL_0]], %[[VAL_1]] : f32
+// CHECK: %[[ADDF_0:.*]] = arith.addf %[[VAL_2]], %[[MULF_0]] : f32
+// CHECK: linalg.yield %[[ADDF_0]] : f32
+// CHECK: } -> tensor<48xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = tensor.expand_shape %[[GENERIC_0]] {{\[\[}}0, 1, 2, 3]] output_shape [1, 48, 1, 1]
+// CHECK-SAME: : tensor<48xf32> into tensor<1x48x1x1xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : tensor<1x48x1x1xf32>
+
+// -----
+
func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
%cst = arith.constant 1.000000e+00 : f32
%cst7 = arith.constant 7 : index
More information about the Mlir-commits
mailing list