[Mlir-commits] [mlir] 17723e4 - [mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_forall (#157932)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 10 12:49:28 PDT 2025
Author: Bangtian Liu
Date: 2025-09-10T15:49:24-04:00
New Revision: 17723e472e228be5404ab4377498b52a0c5db03b
URL: https://github.com/llvm/llvm-project/commit/17723e472e228be5404ab4377498b52a0c5db03b
DIFF: https://github.com/llvm/llvm-project/commit/17723e472e228be5404ab4377498b52a0c5db03b.diff
LOG: [mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_forall (#157932)
Following [PR
#120118](https://github.com/llvm/llvm-project/pull/120118), this PR
extends transform.structured.tile_reduction_using_forall so that it can
be applied to any operation implementing `PartialReductionOpInterface`,
rather than being restricted to LinalgOp.
Existing tests relevant to linalg ops remain valid:
https://github.com/llvm/llvm-project/blob/2a2296b1aab4614bf6c95c3003000832c9d43de5/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir#L114
Additional tests for non-Linalg operations (e.g., IREE custom ops that
implement `PartialReductionOpInterface`) will be added on the IREE side.
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a19cce4b919a8..8f3232f01544f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2017,8 +2017,8 @@ def TileReductionUsingForallOp :
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
- TransformHandleTypeInterface:$split_linalg_op,
- TransformHandleTypeInterface:$combining_linalg_op,
+ TransformHandleTypeInterface:$split_op,
+ TransformHandleTypeInterface:$combining_op,
TransformHandleTypeInterface:$forall_op);
let builders = [
@@ -2042,7 +2042,7 @@ def TileReductionUsingForallOp :
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
+ Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..f3db8f7ccfaa1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3033,10 +3033,17 @@ void transform::TileReductionUsingForallOp::build(
}
DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
- transform::TransformRewriter &rewriter, LinalgOp target,
+ transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
+
+ auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
+ if (!partialReductionOp) {
+ return emitSilenceableFailure(
+ target->getLoc(),
+ "Operation should implement PartialReductionOpInterface");
+ }
SmallVector<OpFoldResult> numThreads =
getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
SmallVector<OpFoldResult> tileSizes =
@@ -3058,14 +3065,14 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
if (reductionDims.empty()) {
for (auto [idx, iteratorType] :
- llvm::enumerate(target.getIteratorTypesArray())) {
+ llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
}
options.setReductionDims(reductionDims);
- FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
- rewriter, cast<TilingInterface>(target.getOperation()), options);
+ FailureOr<scf::SCFTilingResult> result =
+ scf::tileUsingSCF(rewriter, partialReductionOp, options);
if (failed(result)) {
auto diag = emitSilenceableError() << "could not tile reduction";
More information about the Mlir-commits
mailing list