[Mlir-commits] [mlir] [mlir][linalg] Fix tiling with constants in indexing maps (PR #173038)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 19 08:37:35 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrey Pavlenko (AndreyPavlenko)
<details>
<summary>Changes</summary>
Fixes #<!-- -->173025
---
Full diff: https://github.com/llvm/llvm-project/pull/173038.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+42-21)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 50a84ace09258..78d124a3ccd4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -436,17 +436,25 @@ static InitSliceInfo getInitSliceInfoForOuterReduction(
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes;
- Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
- Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+ Type idxType = IndexType::get(context);
+ Attribute zero = IntegerAttr::get(idxType, 0);
+ Attribute one = IntegerAttr::get(idxType, 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
- for (AffineExpr dimExpr : partialReductionMap.getResults()) {
- unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
- if (reductionDims.contains(dim)) {
- initOffsets.push_back(zero);
+ for (AffineExpr expr : partialReductionMap.getResults()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ unsigned dim = dimExpr.getPosition();
+ if (reductionDims.contains(dim)) {
+ initOffsets.push_back(zero);
+ } else {
+ initOffsets.push_back(offsets[dim]);
+ }
+ initSizes.push_back(sizes[dim]);
+ } else if (auto cstExpr = dyn_cast<AffineConstantExpr>(expr)) {
+ initOffsets.push_back(IntegerAttr::get(idxType, cstExpr.getValue()));
+ initSizes.push_back(one);
} else {
- initOffsets.push_back(offsets[dim]);
+ llvm_unreachable("Unsupported affine expression type");
}
- initSizes.push_back(sizes[dim]);
}
SmallVector<int64_t> resultShape;
std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
@@ -462,18 +470,27 @@ static InitSliceInfo getInitSliceInfoForOuterParallel(
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes;
- Attribute one = IntegerAttr::get(IndexType::get(context), 1);
+ Type idxType = IndexType::get(context);
+ Attribute one = IntegerAttr::get(idxType, 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
SmallVector<OpFoldResult> resultShape;
- for (AffineExpr dimExpr : partialReductionMap.getResults()) {
- unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
- if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
- initOffsets.push_back(splitReductionIvs[dimPos.value()]);
+ for (AffineExpr expr : partialReductionMap.getResults()) {
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ unsigned dim = dimExpr.getPosition();
+ if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
+ initOffsets.push_back(splitReductionIvs[dimPos.value()]);
+ initSizes.push_back(one);
+ } else {
+ initOffsets.push_back(offsets[dim]);
+ initSizes.push_back(sizes[dim]);
+ resultShape.push_back(sizes[dim]);
+ }
+ } else if (auto cstExpr = dyn_cast<AffineConstantExpr>(expr)) {
+ initOffsets.push_back(IntegerAttr::get(idxType, cstExpr.getValue()));
initSizes.push_back(one);
+ resultShape.push_back(one);
} else {
- initOffsets.push_back(offsets[dim]);
- initSizes.push_back(sizes[dim]);
- resultShape.push_back(sizes[dim]);
+ llvm_unreachable("Unsupported affine expression type");
}
}
SmallVector<int64_t> staticShapes;
@@ -538,8 +555,11 @@ struct LinalgOpPartialReductionInterface
// Append the new partial result dimensions.
SmallVector<OpFoldResult> partialResultShape;
for (AffineExpr dimExpr : partialMap.getResults()) {
- auto dim = cast<AffineDimExpr>(dimExpr);
- partialResultShape.push_back(sizes[dim.getPosition()]);
+ if (auto dim = dyn_cast<AffineDimExpr>(dimExpr)) {
+ partialResultShape.push_back(sizes[dim.getPosition()]);
+ } else {
+ partialResultShape.push_back(b.getIndexAttr(1));
+ }
}
Type elType = getElementTypeOrSelf(result.getType());
@@ -667,9 +687,10 @@ struct LinalgOpPartialReductionInterface
SmallVector<int64_t> partialReductionDims;
for (auto [resultNum, dimExpr] :
llvm::enumerate(partialMap.getResults())) {
- unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
- if (llvm::is_contained(reductionDims, dim)) {
- partialReductionDims.push_back(resultNum);
+ if (auto dim = dyn_cast<AffineDimExpr>(dimExpr)) {
+ if (llvm::is_contained(reductionDims, dim.getPosition())) {
+ partialReductionDims.push_back(resultNum);
+ }
}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/173038
More information about the Mlir-commits
mailing list