[Mlir-commits] [mlir] [mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_for (PR #120118)
Kunwar Grover
llvmlistbot at llvm.org
Fri Dec 27 05:11:59 PST 2024
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/120118
>From 564cc9bc1e6437539e0b8bb6c9fc38d7498b1736 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 16 Dec 2024 17:31:15 +0000
Subject: [PATCH] [mlir][Linalg] Allow PartialReductionOpInterface ops in
tile_reduction_using_for
---
.../Dialect/Linalg/TransformOps/LinalgTransformOps.td | 6 +++---
.../Linalg/TransformOps/LinalgTransformOps.cpp | 11 +++++++++--
2 files changed, 12 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2e713bca24efc5..081bf9b6d3b239 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1765,8 +1765,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
- TransformHandleTypeInterface:$split_linalg_op,
- TransformHandleTypeInterface:$combining_linalg_op,
+ TransformHandleTypeInterface:$split_op,
+ TransformHandleTypeInterface:$combining_op,
TransformHandleTypeInterface:$for_op);
let builders = [
@@ -1784,7 +1784,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
+ Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 221ca27b80fdd0..a1d619c8cd19dc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2626,12 +2626,19 @@ void transform::TileReductionUsingForOp::build(
}
DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
- transform::TransformRewriter &rewriter, LinalgOp target,
+ transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
+
+ auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
+ if (!partialReductionOp) {
+ return emitSilenceableFailure(
+ target->getLoc(),
+ "Operation should implement PartialReductionOpInterface");
+ }
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
- rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
+ rewriter, partialReductionOp,
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
if (failed(result))
More information about the Mlir-commits
mailing list