[Mlir-commits] [mlir] [mlir][linalg] Fix tiling with constants in indexing maps (PR #173038)

Andrey Pavlenko llvmlistbot at llvm.org
Mon Jan 12 05:35:13 PST 2026


https://github.com/AndreyPavlenko updated https://github.com/llvm/llvm-project/pull/173038

>From 9bf61ac75956621c43e047d4f1209a1a20fb17c7 Mon Sep 17 00:00:00 2001
From: Andrey Pavlenko <andrey.a.pavlenko at gmail.com>
Date: Fri, 19 Dec 2025 16:33:41 +0000
Subject: [PATCH 1/3] [mlir][linalg] Fix tiling with constants in indexing maps

Fixes #173025
---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 63 ++++++++++++-------
 1 file changed, 42 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 50a84ace09258..78d124a3ccd4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -436,17 +436,25 @@ static InitSliceInfo getInitSliceInfoForOuterReduction(
     ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
   int64_t initRank = partialReductionMap.getNumResults();
   SmallVector<OpFoldResult> initOffsets, initSizes;
-  Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
-  Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+  Type idxType = IndexType::get(context);
+  Attribute zero = IntegerAttr::get(idxType, 0);
+  Attribute one = IntegerAttr::get(idxType, 1);
   SmallVector<OpFoldResult> initStrides(initRank, one);
-  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
-    unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
-    if (reductionDims.contains(dim)) {
-      initOffsets.push_back(zero);
+  for (AffineExpr expr : partialReductionMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      unsigned dim = dimExpr.getPosition();
+      if (reductionDims.contains(dim)) {
+        initOffsets.push_back(zero);
+      } else {
+        initOffsets.push_back(offsets[dim]);
+      }
+      initSizes.push_back(sizes[dim]);
+    } else if (auto cstExpr = dyn_cast<AffineConstantExpr>(expr)) {
+      initOffsets.push_back(IntegerAttr::get(idxType, cstExpr.getValue()));
+      initSizes.push_back(one);
     } else {
-      initOffsets.push_back(offsets[dim]);
+      llvm_unreachable("Unsupported affine expression type");
     }
-    initSizes.push_back(sizes[dim]);
   }
   SmallVector<int64_t> resultShape;
   std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
@@ -462,18 +470,27 @@ static InitSliceInfo getInitSliceInfoForOuterParallel(
     ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
   int64_t initRank = partialReductionMap.getNumResults();
   SmallVector<OpFoldResult> initOffsets, initSizes;
-  Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+  Type idxType = IndexType::get(context);
+  Attribute one = IntegerAttr::get(idxType, 1);
   SmallVector<OpFoldResult> initStrides(initRank, one);
   SmallVector<OpFoldResult> resultShape;
-  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
-    unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
-    if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
-      initOffsets.push_back(splitReductionIvs[dimPos.value()]);
+  for (AffineExpr expr : partialReductionMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      unsigned dim = dimExpr.getPosition();
+      if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
+        initOffsets.push_back(splitReductionIvs[dimPos.value()]);
+        initSizes.push_back(one);
+      } else {
+        initOffsets.push_back(offsets[dim]);
+        initSizes.push_back(sizes[dim]);
+        resultShape.push_back(sizes[dim]);
+      }
+    } else if (auto cstExpr = dyn_cast<AffineConstantExpr>(expr)) {
+      initOffsets.push_back(IntegerAttr::get(idxType, cstExpr.getValue()));
       initSizes.push_back(one);
+      resultShape.push_back(one);
     } else {
-      initOffsets.push_back(offsets[dim]);
-      initSizes.push_back(sizes[dim]);
-      resultShape.push_back(sizes[dim]);
+      llvm_unreachable("Unsupported affine expression type");
     }
   }
   SmallVector<int64_t> staticShapes;
@@ -538,8 +555,11 @@ struct LinalgOpPartialReductionInterface
       // Append the new partial result dimensions.
       SmallVector<OpFoldResult> partialResultShape;
       for (AffineExpr dimExpr : partialMap.getResults()) {
-        auto dim = cast<AffineDimExpr>(dimExpr);
-        partialResultShape.push_back(sizes[dim.getPosition()]);
+        if (auto dim = dyn_cast<AffineDimExpr>(dimExpr)) {
+          partialResultShape.push_back(sizes[dim.getPosition()]);
+        } else {
+          partialResultShape.push_back(b.getIndexAttr(1));
+        }
       }
 
       Type elType = getElementTypeOrSelf(result.getType());
@@ -667,9 +687,10 @@ struct LinalgOpPartialReductionInterface
       SmallVector<int64_t> partialReductionDims;
       for (auto [resultNum, dimExpr] :
            llvm::enumerate(partialMap.getResults())) {
-        unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
-        if (llvm::is_contained(reductionDims, dim)) {
-          partialReductionDims.push_back(resultNum);
+        if (auto dim = dyn_cast<AffineDimExpr>(dimExpr)) {
+          if (llvm::is_contained(reductionDims, dim.getPosition())) {
+            partialReductionDims.push_back(resultNum);
+          }
         }
       }
 

>From b113b84b1eb5ac1a0acfc1b09961a4800447691c Mon Sep 17 00:00:00 2001
From: Andrey Pavlenko <andrey.a.pavlenko at gmail.com>
Date: Wed, 7 Jan 2026 21:07:25 +0000
Subject: [PATCH 2/3] Added test

---
 .../Linalg/transform-tile-reduction.mlir      | 45 +++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 4cc58668944fe..6b5161f4e9e5b 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -692,3 +692,48 @@ module {
 //      CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]]
 // CHECK-SAME:       outs(%[[ARG2]] :
 //      CHECK:   return %[[R]]
+
+// -----
+
+// Check reduction that has constants in indexing maps. Issue #173025.
+
+module {
+  func.func @test(%arg0: tensor<1x4096x64xf32>, %arg1: tensor<1x1x64xf32>) -> tensor<1x1x64xf32> {
+    %0 = linalg.generic {
+      indexing_maps = [
+        affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+        affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+      ],
+      iterator_types = ["parallel", "reduction", "parallel"]
+    } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x1x64xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %5 = arith.addf %in, %out : f32
+      linalg.yield %5 : f32
+    } -> tensor<1x1x64xf32>
+    return %0 : tensor<1x1x64xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %arg : (!transform.any_op) -> !transform.any_op
+      %1:4 = transform.structured.tile_reduction_using_forall %0 by tile_sizes = [0, 4, 0]
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+// CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//     CHECK:   %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x64x1024xf32>
+//     CHECK:   %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<1x1x64x1024xf32>) -> tensor<1x1x64x1024xf32>
+//     CHECK:   %[[FORALL:.*]] = scf.forall (%[[IV:.*]]) = (0) to (4096) step (4) shared_outs(%[[ARG:.*]] = %[[FILL]]) -> (tensor<1x1x64x1024xf32>)
+//     CHECK:     %[[OFFSET:.*]] = affine.apply #[[MAP]]()[%[[IV]]]
+//     CHECK:     %[[SLICE0:.*]] = tensor.extract_slice %{{.*}}[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1] : tensor<1x4096x64xf32> to tensor<1x4x64xf32>
+//     CHECK:     %[[SLICE1:.*]] = tensor.extract_slice %[[ARG]][0, 0, 0, %[[OFFSET]]] [1, 1, 64, 1] [1, 1, 1, 1] : tensor<1x1x64x1024xf32> to tensor<1x1x64xf32>
+//     CHECK:     %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[SLICE0]] : tensor<1x4x64xf32>) outs(%[[SLICE1]] : tensor<1x1x64xf32>)
+//     CHECK:       tensor.parallel_insert_slice %[[GENERIC]] into %[[ARG]][0, 0, 0, %[[OFFSET]]] [1, 1, 64, 1] [1, 1, 1, 1] : tensor<1x1x64xf32> into tensor<1x1x64x1024xf32>
+//     CHECK:   %[[REDUCE:.*]] = linalg.reduce ins(%[[FORALL]] : tensor<1x1x64x1024xf32>) outs(%{{.*}} : tensor<1x1x64xf32>) dimensions = [3]
+//     CHECK:   return %[[REDUCE]] : tensor<1x1x64xf32>

>From 4c1a1889bd25a9e4b5927d914a6e1da2a75b879e Mon Sep 17 00:00:00 2001
From: Andrey Pavlenko <andrey.a.pavlenko at gmail.com>
Date: Mon, 12 Jan 2026 13:34:56 +0000
Subject: [PATCH 3/3] Added error propagation

---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 58 ++++++++++---------
 1 file changed, 32 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 78d124a3ccd4f..a6e024b98da57 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -430,8 +430,8 @@ struct InitSliceInfo {
 /// Return the result shape, offsets, sizes and strides of the slice of the
 /// `initValue` to use as the destination of the partial reduction op generated
 /// with outer reduction strategy.
-static InitSliceInfo getInitSliceInfoForOuterReduction(
-    MLIRContext *context, ArrayRef<OpFoldResult> offsets,
+static FailureOr<InitSliceInfo> getInitSliceInfoForOuterReduction(
+    Operation *op, MLIRContext *context, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
     ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
   int64_t initRank = partialReductionMap.getNumResults();
@@ -453,19 +453,21 @@ static InitSliceInfo getInitSliceInfoForOuterReduction(
       initOffsets.push_back(IntegerAttr::get(idxType, cstExpr.getValue()));
       initSizes.push_back(one);
     } else {
-      llvm_unreachable("Unsupported affine expression type");
+      return op->emitOpError(
+          "Unexpected affine expression type: only dimension and constant "
+          "expressions are supported");
     }
   }
   SmallVector<int64_t> resultShape;
   std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
-  return {resultShape, initOffsets, initSizes, initStrides};
+  return InitSliceInfo{resultShape, initOffsets, initSizes, initStrides};
 }
 
 /// Return the result shape, offsets, sizes and strides of the slice of the
 /// `initValue` to use as destination of the partial reduction op generated with
 /// outer parallel strategy.
-static InitSliceInfo getInitSliceInfoForOuterParallel(
-    MLIRContext *context, ArrayRef<OpFoldResult> offsets,
+static FailureOr<InitSliceInfo> getInitSliceInfoForOuterParallel(
+    Operation *op, MLIRContext *context, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
     ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
   int64_t initRank = partialReductionMap.getNumResults();
@@ -490,31 +492,31 @@ static InitSliceInfo getInitSliceInfoForOuterParallel(
       initSizes.push_back(one);
       resultShape.push_back(one);
     } else {
-      llvm_unreachable("Unsupported affine expression type");
+      return op->emitOpError(
+          "Unexpected affine expression type: only dimension and constant "
+          "expressions are supported");
     }
   }
   SmallVector<int64_t> staticShapes;
   std::tie(staticShapes, std::ignore) = decomposeMixedValues(resultShape);
-  return {staticShapes, initOffsets, initSizes, initStrides};
+  return InitSliceInfo{staticShapes, initOffsets, initSizes, initStrides};
 }
 
 /// Return the result shape, offsets, sizes and strides of the slice of the
 /// `initValue` to use as destination of the partial reduction op.
-static InitSliceInfo getInitSliceInfo(MLIRContext *context,
-                                      ReductionTilingStrategy strategy,
-                                      ArrayRef<OpFoldResult> offsets,
-                                      ArrayRef<OpFoldResult> sizes,
-                                      const SetVector<unsigned> &reductionDims,
-                                      ArrayRef<OpFoldResult> splitReductionIvs,
-                                      AffineMap partialReductionMap) {
+static FailureOr<InitSliceInfo> getInitSliceInfo(
+    Operation *op, MLIRContext *context, ReductionTilingStrategy strategy,
+    ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+    const SetVector<unsigned> &reductionDims,
+    ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
   if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
-    return getInitSliceInfoForOuterReduction(context, offsets, sizes,
+    return getInitSliceInfoForOuterReduction(op, context, offsets, sizes,
                                              reductionDims, splitReductionIvs,
                                              partialReductionMap);
   }
   assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
          "unexpected ReductionTilingStrategy");
-  return getInitSliceInfoForOuterParallel(context, offsets, sizes,
+  return getInitSliceInfoForOuterParallel(op, context, offsets, sizes,
                                           reductionDims, splitReductionIvs,
                                           partialReductionMap);
 }
@@ -612,16 +614,18 @@ struct LinalgOpPartialReductionInterface
     SmallVector<Value, 1> tiledInits;
     for (auto [partialReductionMap, valueToTile] :
          llvm::zip_equal(partialReductionMaps, init)) {
-      InitSliceInfo sliceInfo = getInitSliceInfo(
-          b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
+      FailureOr<InitSliceInfo> sliceInfo = getInitSliceInfo(
+          op, b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
           splitReductionIvs, partialReductionMap);
+      if (failed(sliceInfo))
+        return failure();
       auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
       RankedTensorType sliceResultType = RankedTensorType::get(
-          sliceInfo.resultShape, valueToTileType.getElementType(),
+          sliceInfo->resultShape, valueToTileType.getElementType(),
           valueToTileType.getEncoding());
       auto sliceOp = tensor::ExtractSliceOp::create(
-          b, loc, sliceResultType, valueToTile, sliceInfo.offsets,
-          sliceInfo.sizes, sliceInfo.strides);
+          b, loc, sliceResultType, valueToTile, sliceInfo->offsets,
+          sliceInfo->sizes, sliceInfo->strides);
       tiledInits.push_back(sliceOp.getResult());
       generatedSlices.push_back(sliceOp);
     }
@@ -725,11 +729,13 @@ struct LinalgOpPartialReductionInterface
     auto linalgOp = cast<LinalgOp>(op);
     SmallVector<AffineMap> partialReductionMaps =
         getPartialResultAffineMaps(linalgOp, reductionDims);
-    InitSliceInfo sliceInfo = getInitSliceInfo(
-        b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
+    FailureOr<InitSliceInfo> sliceInfo = getInitSliceInfo(
+        op, b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
         splitReductionIvs, partialReductionMaps[resultNumber]);
-    std::swap(resultOffsets, sliceInfo.offsets);
-    std::swap(resultSizes, sliceInfo.sizes);
+    if (failed(sliceInfo))
+      return failure();
+    std::swap(resultOffsets, sliceInfo->offsets);
+    std::swap(resultSizes, sliceInfo->sizes);
 
     return success();
   }



More information about the Mlir-commits mailing list