[Mlir-commits] [mlir] [mlir][linalg] Fix crash in tile_reduction when output map has constant exprs (PR #189166)
Mehdi Amini
llvmlistbot at llvm.org
Thu Apr 2 01:53:49 PDT 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/189166
>From 2d1ea407534e0bd9736556abd4a6f4672d527820 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 03:46:27 -0700
Subject: [PATCH] [mlir][linalg] Fix crash in tile_reduction when output map
has constant exprs
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
`generateInitialTensorForPartialReduction` and the `getInitSliceInfo*`
helpers unconditionally cast every result expression of the partial
result AffineMap to `AffineDimExpr`. When the original output indexing
map contains a constant (e.g. `affine_map<(d0,d1,d2)->(d0,0,d2)>`),
the constant expression propagates into the partial map and the cast
triggers an assertion.
Fix all four call sites:
- `generateInitialTensorForPartialReduction`: use the actual output
operand dimension size for constant result expressions (not a
hardcoded 1), so the partial init tensor shape matches the original
output shape at constant-indexed positions.
- `getInitSliceInfoForOuterReduction`: pass and use the actual output
operand shape for constant result expressions instead of size 1.
- `getInitSliceInfoForOuterParallel`: same, and include the actual dim
size in `resultShape` (it is a non-reduction parallel dim retained in
the tiled op's output).
- `mergeReductions`: skip constant result expressions (they can never be
reduction dims).
A constant index in the output map only means the op always accesses
that fixed index in the dimension — it does not imply the dimension has
size 1. Using the actual output operand size ensures that `linalg.reduce`
in `mergeReductions` receives matching shapes for its inputs and outputs.
Add a regression test in `transform-tile-reduction.mlir`.
Fixes #173025
Assisted-by: Claude Code
---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 101 ++++++++++--
.../Linalg/transform-tile-reduction.mlir | 153 ++++++++++++++++++
2 files changed, 237 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 558ebdebd65c5..9b455a0d4e460 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -434,13 +434,22 @@ struct InitSliceInfo {
static InitSliceInfo getInitSliceInfoForOuterReduction(
MLIRContext *context, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
- ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
+ ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap,
+ ArrayRef<OpFoldResult> initOperandShape) {
int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes;
Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
- for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+ for (auto [resultIdx, dimExpr] :
+ llvm::enumerate(partialReductionMap.getResults())) {
+ if (isa<AffineConstantExpr>(dimExpr)) {
+ // A constant index in the output map accesses a fixed position; keep
+ // the full output dimension to match the original output operand shape.
+ initOffsets.push_back(zero);
+ initSizes.push_back(initOperandShape[resultIdx]);
+ continue;
+ }
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
if (reductionDims.contains(dim)) {
initOffsets.push_back(zero);
@@ -460,13 +469,24 @@ static InitSliceInfo getInitSliceInfoForOuterReduction(
static InitSliceInfo getInitSliceInfoForOuterParallel(
MLIRContext *context, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
- ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
+ ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap,
+ ArrayRef<OpFoldResult> initOperandShape) {
int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes;
+ Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
SmallVector<OpFoldResult> resultShape;
- for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+ for (auto [resultIdx, dimExpr] :
+ llvm::enumerate(partialReductionMap.getResults())) {
+ if (isa<AffineConstantExpr>(dimExpr)) {
+ // A constant index in the output map accesses a fixed position; keep
+ // the full output dimension to match the original output operand shape.
+ initOffsets.push_back(zero);
+ initSizes.push_back(initOperandShape[resultIdx]);
+ resultShape.push_back(initOperandShape[resultIdx]);
+ continue;
+ }
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
initOffsets.push_back(splitReductionIvs[dimPos.value()]);
@@ -490,17 +510,18 @@ static InitSliceInfo getInitSliceInfo(MLIRContext *context,
ArrayRef<OpFoldResult> sizes,
const SetVector<unsigned> &reductionDims,
ArrayRef<OpFoldResult> splitReductionIvs,
- AffineMap partialReductionMap) {
+ AffineMap partialReductionMap,
+ ArrayRef<OpFoldResult> initOperandShape) {
if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
- return getInitSliceInfoForOuterReduction(context, offsets, sizes,
- reductionDims, splitReductionIvs,
- partialReductionMap);
+ return getInitSliceInfoForOuterReduction(
+ context, offsets, sizes, reductionDims, splitReductionIvs,
+ partialReductionMap, initOperandShape);
}
assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
"unexpected ReductionTilingStrategy");
- return getInitSliceInfoForOuterParallel(context, offsets, sizes,
- reductionDims, splitReductionIvs,
- partialReductionMap);
+ return getInitSliceInfoForOuterParallel(
+ context, offsets, sizes, reductionDims, splitReductionIvs,
+ partialReductionMap, initOperandShape);
}
/// External model implementation of PartialReductionInterface for
@@ -538,9 +559,26 @@ struct LinalgOpPartialReductionInterface
// Append the new partial result dimensions.
SmallVector<OpFoldResult> partialResultShape;
+ Value initValue = linalgOp.getDpsInits()[initIdx];
+ auto initType = cast<RankedTensorType>(initValue.getType());
+ int64_t resultIdx = 0;
for (AffineExpr dimExpr : partialMap.getResults()) {
+ if (isa<AffineConstantExpr>(dimExpr)) {
+ // A constant index in the output map accesses a fixed position; use
+ // the actual output dimension size (not a hardcoded 1).
+ int64_t dimSize = initType.getDimSize(resultIdx);
+ if (ShapedType::isDynamic(dimSize))
+ partialResultShape.push_back(
+ tensor::DimOp::create(b, loc, initValue, resultIdx)
+ .getResult());
+ else
+ partialResultShape.push_back(b.getIndexAttr(dimSize));
+ ++resultIdx;
+ continue;
+ }
auto dim = cast<AffineDimExpr>(dimExpr);
partialResultShape.push_back(sizes[dim.getPosition()]);
+ ++resultIdx;
}
Type elType = getElementTypeOrSelf(result.getType());
@@ -591,11 +629,23 @@ struct LinalgOpPartialReductionInterface
// Step 2b: Extract a slice of the init operands.
SmallVector<Value, 1> tiledInits;
- for (auto [partialReductionMap, valueToTile] :
- llvm::zip_equal(partialReductionMaps, init)) {
+ for (auto [partialReductionMap, valueToTile, initOperandValue] :
+ llvm::zip_equal(partialReductionMaps, init, linalgOp.getDpsInits())) {
+ // Compute the actual shape of the original init operand for handling
+ // constant expressions in the partial reduction map.
+ auto initOperandType = cast<RankedTensorType>(initOperandValue.getType());
+ SmallVector<OpFoldResult> initOperandShape;
+ for (int64_t dim = 0; dim < initOperandType.getRank(); ++dim) {
+ int64_t dimSize = initOperandType.getDimSize(dim);
+ if (ShapedType::isDynamic(dimSize))
+ initOperandShape.push_back(
+ tensor::DimOp::create(b, loc, initOperandValue, dim).getResult());
+ else
+ initOperandShape.push_back(b.getIndexAttr(dimSize));
+ }
InitSliceInfo sliceInfo = getInitSliceInfo(
b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
- splitReductionIvs, partialReductionMap);
+ splitReductionIvs, partialReductionMap, initOperandShape);
auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
RankedTensorType sliceResultType = RankedTensorType::get(
sliceInfo.resultShape, valueToTileType.getElementType(),
@@ -670,6 +720,8 @@ struct LinalgOpPartialReductionInterface
SmallVector<int64_t> partialReductionDims;
for (auto [resultNum, dimExpr] :
llvm::enumerate(partialMap.getResults())) {
+ if (isa<AffineConstantExpr>(dimExpr))
+ continue; // Constant dims are never reduction dims.
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
if (llvm::is_contained(reductionDims, dim)) {
partialReductionDims.push_back(resultNum);
@@ -707,9 +759,24 @@ struct LinalgOpPartialReductionInterface
auto linalgOp = cast<LinalgOp>(op);
SmallVector<AffineMap> partialReductionMaps =
getPartialResultAffineMaps(linalgOp, reductionDims);
- InitSliceInfo sliceInfo = getInitSliceInfo(
- b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
- splitReductionIvs, partialReductionMaps[resultNumber]);
+ // Compute the actual shape of the init operand for handling constant
+ // expressions in the partial reduction map.
+ Value initOperandValue = linalgOp.getDpsInits()[resultNumber];
+ auto initOperandType = cast<RankedTensorType>(initOperandValue.getType());
+ Location loc = op->getLoc();
+ SmallVector<OpFoldResult> initOperandShape;
+ for (int64_t dim = 0; dim < initOperandType.getRank(); ++dim) {
+ int64_t dimSize = initOperandType.getDimSize(dim);
+ if (ShapedType::isDynamic(dimSize))
+ initOperandShape.push_back(
+ tensor::DimOp::create(b, loc, initOperandValue, dim).getResult());
+ else
+ initOperandShape.push_back(b.getIndexAttr(dimSize));
+ }
+ InitSliceInfo sliceInfo =
+ getInitSliceInfo(b.getContext(), tilingStrategy, offsets, sizes,
+ reductionDims, splitReductionIvs,
+ partialReductionMaps[resultNumber], initOperandShape);
std::swap(resultOffsets, sliceInfo.offsets);
std::swap(resultSizes, sliceInfo.sizes);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index e31d4f333557c..43941128eee4f 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -728,3 +728,156 @@ module attributes {transform.with_named_sequence} {
// CHECK: linalg.generic
// CHECK: %[[LOCAL_IDX:.+]] = linalg.index 1 : index
// CHECK: affine.apply #[[$INDEX_MAP]](%[[IV]])[%[[LOCAL_IDX]]]
+
+// -----
+
+// Verify that tile_reduction_using_forall handles output indexing maps that
+// contain constant expressions (e.g. `affine_map<(d0,d1,d2)->(d0,0,d2)>`)
+// without crashing. Previously, generateInitialTensorForPartialReduction
+// unconditionally cast every map result to AffineDimExpr, triggering an
+// assertion when a constant expression was present (issue #173025).
+
+func.func @reduction_tile_with_constant_in_output_map(
+ %arg0: tensor<1x4096x64xf32>,
+ %arg1: tensor<1x1x64xf32>) -> tensor<1x1x64xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+ ],
+ iterator_types = ["parallel", "reduction", "parallel"]
+ } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x1x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = arith.addf %in, %out : f32
+ linalg.yield %1 : f32
+ } -> tensor<1x1x64xf32>
+ return %0 : tensor<1x1x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg
+ : (!transform.any_op) -> !transform.any_op
+ %1:4 = transform.structured.tile_reduction_using_forall %0
+ by tile_sizes = [0, 4, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op,
+ !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @reduction_tile_with_constant_in_output_map
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<1x1x64x1024xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x1x64x1024xf32>) -> tensor<1x1x64x1024xf32>
+// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x1x64x1024xf32>) {
+// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1]
+// CHECK: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [1, 1, 64, 1] [1, 1, 1, 1]
+// CHECK: %[[PARTIAL:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[IN_SLICE]] :
+// CHECK-SAME: outs(%[[INIT_SLICE]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, 0, {{.*}}] [1, 1, 64, 1] [1, 1, 1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: linalg.reduce ins(%[[L]] : tensor<1x1x64x1024xf32>) outs(%arg1 : tensor<1x1x64xf32>) dimensions = [3]
+// CHECK: return %{{.*}} : tensor<1x1x64xf32>
+
+// -----
+
+// Verify tile_reduction_using_forall with a constant in the output map when
+// the constant-indexed dimension size is greater than 1 (K=3). The partial
+// init tensor must use the actual output dim size (3), not a hardcoded 1.
+
+func.func @reduction_tile_forall_constant_dim_k_gt_1(
+ %arg0: tensor<1x4096x64xf32>,
+ %arg1: tensor<1x3x64xf32>) -> tensor<1x3x64xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+ ],
+ iterator_types = ["parallel", "reduction", "parallel"]
+ } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x3x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = arith.addf %in, %out : f32
+ linalg.yield %1 : f32
+ } -> tensor<1x3x64xf32>
+ return %0 : tensor<1x3x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg
+ : (!transform.any_op) -> !transform.any_op
+ %1:4 = transform.structured.tile_reduction_using_forall %0
+ by tile_sizes = [0, 4, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op,
+ !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @reduction_tile_forall_constant_dim_k_gt_1
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<1x3x64x1024xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x3x64x1024xf32>) -> tensor<1x3x64x1024xf32>
+// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x3x64x1024xf32>) {
+// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1]
+// CHECK: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [1, 3, 64, 1] [1, 1, 1, 1]
+// CHECK: %[[PARTIAL:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[IN_SLICE]] :
+// CHECK-SAME: outs(%[[INIT_SLICE]] :
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, 0, {{.*}}] [1, 3, 64, 1] [1, 1, 1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: linalg.reduce ins(%[[L]] : tensor<1x3x64x1024xf32>) outs(%arg1 : tensor<1x3x64xf32>) dimensions = [3]
+// CHECK: return %{{.*}} : tensor<1x3x64xf32>
+
+// -----
+
+// Verify tile_reduction_using_for with a constant in the output map when
+// the constant-indexed dimension size is greater than 1 (K=3). The partial
+// init tensor must use the actual output dim size (3), not a hardcoded 1.
+
+func.func @reduction_tile_for_constant_dim_k_gt_1(
+ %arg0: tensor<1x4096x64xf32>,
+ %arg1: tensor<1x3x64xf32>) -> tensor<1x3x64xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+ ],
+ iterator_types = ["parallel", "reduction", "parallel"]
+ } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x3x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = arith.addf %in, %out : f32
+ linalg.yield %1 : f32
+ } -> tensor<1x3x64xf32>
+ return %0 : tensor<1x3x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg
+ : (!transform.any_op) -> !transform.any_op
+ %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+ by tile_sizes = [0, 4, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op,
+ !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @reduction_tile_for_constant_dim_k_gt_1
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<1x3x64x4xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x3x64x4xf32>) -> tensor<1x3x64x4xf32>
+// CHECK: %[[L:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x3x64x4xf32>) {
+// CHECK: %[[PARTIAL:.+]] = linalg.generic
+// CHECK-SAME: outs(%[[ARG3]] :
+// CHECK: scf.yield %[[PARTIAL]]
+// CHECK: }
+// CHECK: linalg.reduce ins(%[[L]] : tensor<1x3x64x4xf32>) outs(%arg1 : tensor<1x3x64xf32>) dimensions = [3]
+// CHECK: return %{{.*}} : tensor<1x3x64xf32>
More information about the Mlir-commits
mailing list