[Mlir-commits] [mlir] [mlir][linalg] Fix tiling with constants in indexing maps (PR #173038)

Andrey Pavlenko llvmlistbot at llvm.org
Fri Dec 19 08:36:44 PST 2025


https://github.com/AndreyPavlenko created https://github.com/llvm/llvm-project/pull/173038

Fixes #173025

>From d6b594e1aef84acaa2dec179de56cd3925132121 Mon Sep 17 00:00:00 2001
From: Andrey Pavlenko <andrey.a.pavlenko at gmail.com>
Date: Fri, 19 Dec 2025 16:33:41 +0000
Subject: [PATCH] [mlir][linalg] Fix tiling with constants in indexing maps

Fixes #173025
---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 63 ++++++++++++-------
 1 file changed, 42 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 50a84ace09258..78d124a3ccd4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -436,17 +436,25 @@ static InitSliceInfo getInitSliceInfoForOuterReduction(
     ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
   int64_t initRank = partialReductionMap.getNumResults();
   SmallVector<OpFoldResult> initOffsets, initSizes;
-  Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
-  Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+  Type idxType = IndexType::get(context);
+  Attribute zero = IntegerAttr::get(idxType, 0);
+  Attribute one = IntegerAttr::get(idxType, 1);
   SmallVector<OpFoldResult> initStrides(initRank, one);
-  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
-    unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
-    if (reductionDims.contains(dim)) {
-      initOffsets.push_back(zero);
+  for (AffineExpr expr : partialReductionMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      unsigned dim = dimExpr.getPosition();
+      if (reductionDims.contains(dim)) {
+        initOffsets.push_back(zero);
+      } else {
+        initOffsets.push_back(offsets[dim]);
+      }
+      initSizes.push_back(sizes[dim]);
+    } else if (auto cstExpr = dyn_cast<AffineConstantExpr>(expr)) {
+      initOffsets.push_back(IntegerAttr::get(idxType, cstExpr.getValue()));
+      initSizes.push_back(one);
     } else {
-      initOffsets.push_back(offsets[dim]);
+      llvm_unreachable("Unsupported affine expression type");
     }
-    initSizes.push_back(sizes[dim]);
   }
   SmallVector<int64_t> resultShape;
   std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
@@ -462,18 +470,27 @@ static InitSliceInfo getInitSliceInfoForOuterParallel(
     ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
   int64_t initRank = partialReductionMap.getNumResults();
   SmallVector<OpFoldResult> initOffsets, initSizes;
-  Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+  Type idxType = IndexType::get(context);
+  Attribute one = IntegerAttr::get(idxType, 1);
   SmallVector<OpFoldResult> initStrides(initRank, one);
   SmallVector<OpFoldResult> resultShape;
-  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
-    unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
-    if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
-      initOffsets.push_back(splitReductionIvs[dimPos.value()]);
+  for (AffineExpr expr : partialReductionMap.getResults()) {
+    if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+      unsigned dim = dimExpr.getPosition();
+      if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
+        initOffsets.push_back(splitReductionIvs[dimPos.value()]);
+        initSizes.push_back(one);
+      } else {
+        initOffsets.push_back(offsets[dim]);
+        initSizes.push_back(sizes[dim]);
+        resultShape.push_back(sizes[dim]);
+      }
+    } else if (auto cstExpr = dyn_cast<AffineConstantExpr>(expr)) {
+      initOffsets.push_back(IntegerAttr::get(idxType, cstExpr.getValue()));
       initSizes.push_back(one);
+      resultShape.push_back(one);
     } else {
-      initOffsets.push_back(offsets[dim]);
-      initSizes.push_back(sizes[dim]);
-      resultShape.push_back(sizes[dim]);
+      llvm_unreachable("Unsupported affine expression type");
     }
   }
   SmallVector<int64_t> staticShapes;
@@ -538,8 +555,11 @@ struct LinalgOpPartialReductionInterface
       // Append the new partial result dimensions.
       SmallVector<OpFoldResult> partialResultShape;
       for (AffineExpr dimExpr : partialMap.getResults()) {
-        auto dim = cast<AffineDimExpr>(dimExpr);
-        partialResultShape.push_back(sizes[dim.getPosition()]);
+        if (auto dim = dyn_cast<AffineDimExpr>(dimExpr)) {
+          partialResultShape.push_back(sizes[dim.getPosition()]);
+        } else {
+          partialResultShape.push_back(b.getIndexAttr(1));
+        }
       }
 
       Type elType = getElementTypeOrSelf(result.getType());
@@ -667,9 +687,10 @@ struct LinalgOpPartialReductionInterface
       SmallVector<int64_t> partialReductionDims;
       for (auto [resultNum, dimExpr] :
            llvm::enumerate(partialMap.getResults())) {
-        unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
-        if (llvm::is_contained(reductionDims, dim)) {
-          partialReductionDims.push_back(resultNum);
+        if (auto dim = dyn_cast<AffineDimExpr>(dimExpr)) {
+          if (llvm::is_contained(reductionDims, dim.getPosition())) {
+            partialReductionDims.push_back(resultNum);
+          }
         }
       }
 



More information about the Mlir-commits mailing list