[Mlir-commits] [mlir] [mlir][linalg] Fix tiling with constants in indexing maps (PR #173038)
Andrey Pavlenko
llvmlistbot at llvm.org
Wed Jan 7 13:10:09 PST 2026
https://github.com/AndreyPavlenko updated https://github.com/llvm/llvm-project/pull/173038
>From 5d6d20a80aa5415a630f3561dba4f6a92465e660 Mon Sep 17 00:00:00 2001
From: Andrey Pavlenko <andrey.a.pavlenko at gmail.com>
Date: Fri, 19 Dec 2025 16:33:41 +0000
Subject: [PATCH 1/2] [mlir][linalg] Fix tiling with constants in indexing maps
Fixes #173025
---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 63 ++++++++++++-------
1 file changed, 42 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 50a84ace09258..78d124a3ccd4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -436,17 +436,25 @@ static InitSliceInfo getInitSliceInfoForOuterReduction(
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes;
- Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
- Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+ Type idxType = IndexType::get(context);
+ Attribute zero = IntegerAttr::get(idxType, 0);
+ Attribute one = IntegerAttr::get(idxType, 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
- for (AffineExpr dimExpr : partialReductionMap.getResults()) {
- unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
- if (reductionDims.contains(dim)) {
- initOffsets.push_back(zero);
+ for (AffineExpr expr : partialReductionMap.getResults()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ unsigned dim = dimExpr.getPosition();
+ if (reductionDims.contains(dim)) {
+ initOffsets.push_back(zero);
+ } else {
+ initOffsets.push_back(offsets[dim]);
+ }
+ initSizes.push_back(sizes[dim]);
+ } else if (auto cstExpr = dyn_cast<AffineConstantExpr>(expr)) {
+ initOffsets.push_back(IntegerAttr::get(idxType, cstExpr.getValue()));
+ initSizes.push_back(one);
} else {
- initOffsets.push_back(offsets[dim]);
+ llvm_unreachable("Unsupported affine expression type");
}
- initSizes.push_back(sizes[dim]);
}
SmallVector<int64_t> resultShape;
std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
@@ -462,18 +470,27 @@ static InitSliceInfo getInitSliceInfoForOuterParallel(
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes;
- Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+ Type idxType = IndexType::get(context);
+ Attribute one = IntegerAttr::get(idxType, 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
SmallVector<OpFoldResult> resultShape;
- for (AffineExpr dimExpr : partialReductionMap.getResults()) {
- unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
- if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
- initOffsets.push_back(splitReductionIvs[dimPos.value()]);
+ for (AffineExpr expr : partialReductionMap.getResults()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ unsigned dim = dimExpr.getPosition();
+ if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
+ initOffsets.push_back(splitReductionIvs[dimPos.value()]);
+ initSizes.push_back(one);
+ } else {
+ initOffsets.push_back(offsets[dim]);
+ initSizes.push_back(sizes[dim]);
+ resultShape.push_back(sizes[dim]);
+ }
+ } else if (auto cstExpr = dyn_cast<AffineConstantExpr>(expr)) {
+ initOffsets.push_back(IntegerAttr::get(idxType, cstExpr.getValue()));
initSizes.push_back(one);
+ resultShape.push_back(one);
} else {
- initOffsets.push_back(offsets[dim]);
- initSizes.push_back(sizes[dim]);
- resultShape.push_back(sizes[dim]);
+ llvm_unreachable("Unsupported affine expression type");
}
}
SmallVector<int64_t> staticShapes;
@@ -538,8 +555,11 @@ struct LinalgOpPartialReductionInterface
// Append the new partial result dimensions.
SmallVector<OpFoldResult> partialResultShape;
for (AffineExpr dimExpr : partialMap.getResults()) {
- auto dim = cast<AffineDimExpr>(dimExpr);
- partialResultShape.push_back(sizes[dim.getPosition()]);
+ if (auto dim = dyn_cast<AffineDimExpr>(dimExpr)) {
+ partialResultShape.push_back(sizes[dim.getPosition()]);
+ } else {
+ partialResultShape.push_back(b.getIndexAttr(1));
+ }
}
Type elType = getElementTypeOrSelf(result.getType());
@@ -667,9 +687,10 @@ struct LinalgOpPartialReductionInterface
SmallVector<int64_t> partialReductionDims;
for (auto [resultNum, dimExpr] :
llvm::enumerate(partialMap.getResults())) {
- unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
- if (llvm::is_contained(reductionDims, dim)) {
- partialReductionDims.push_back(resultNum);
+ if (auto dim = dyn_cast<AffineDimExpr>(dimExpr)) {
+ if (llvm::is_contained(reductionDims, dim.getPosition())) {
+ partialReductionDims.push_back(resultNum);
+ }
}
}
>From 8f9fef1668de22b5247e69acaa62a3c92adaa300 Mon Sep 17 00:00:00 2001
From: Andrey Pavlenko <andrey.a.pavlenko at gmail.com>
Date: Wed, 7 Jan 2026 21:07:25 +0000
Subject: [PATCH 2/2] Added test
---
.../Linalg/transform-tile-reduction.mlir | 45 +++++++++++++++++++
1 file changed, 45 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 4cc58668944fe..6b5161f4e9e5b 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -692,3 +692,48 @@ module {
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]]
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK: return %[[R]]
+
+// -----
+
+// Check reduction that has constants in indexing maps. Issue #173025.
+
+module {
+ func.func @test(%arg0: tensor<1x4096x64xf32>, %arg1: tensor<1x1x64xf32>) -> tensor<1x1x64xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+ ],
+ iterator_types = ["parallel", "reduction", "parallel"]
+ } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x1x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %5 = arith.addf %in, %out : f32
+ linalg.yield %5 : f32
+ } -> tensor<1x1x64xf32>
+ return %0 : tensor<1x1x64xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg : (!transform.any_op) -> !transform.any_op
+ %1:4 = transform.structured.tile_reduction_using_forall %0 by tile_sizes = [0, 4, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x64x1024xf32>
+// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<1x1x64x1024xf32>) -> tensor<1x1x64x1024xf32>
+// CHECK: %[[FORALL:.*]] = scf.forall (%[[IV:.*]]) = (0) to (4096) step (4) shared_outs(%[[ARG:.*]] = %[[FILL]]) -> (tensor<1x1x64x1024xf32>)
+// CHECK: %[[OFFSET:.*]] = affine.apply #[[MAP]]()[%[[IV]]]
+// CHECK: %[[SLICE0:.*]] = tensor.extract_slice %{{.*}}[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1] : tensor<1x4096x64xf32> to tensor<1x4x64xf32>
+// CHECK: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG]][0, 0, 0, %[[OFFSET]]] [1, 1, 64, 1] [1, 1, 1, 1] : tensor<1x1x64x1024xf32> to tensor<1x1x64xf32>
+// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[SLICE0]] : tensor<1x4x64xf32>) outs(%[[SLICE1]] : tensor<1x1x64xf32>)
+// CHECK: tensor.parallel_insert_slice %[[GENERIC]] into %[[ARG]][0, 0, 0, %[[OFFSET]]] [1, 1, 64, 1] [1, 1, 1, 1] : tensor<1x1x64xf32> into tensor<1x1x64x1024xf32>
+// CHECK: %[[REDUCE:.*]] = linalg.reduce ins(%[[FORALL]] : tensor<1x1x64x1024xf32>) outs(%{{.*}} : tensor<1x1x64xf32>) dimensions = [3]
+// CHECK: return %[[REDUCE]] : tensor<1x1x64xf32>
More information about the Mlir-commits
mailing list