[Mlir-commits] [mlir] c223521 - [mlir][TilingInterface] Allow tile and fuse to work with `ReductionTilingStrategy::PartialReductionOuterParallelStrategy`. (#147593)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 9 08:50:05 PDT 2025
Author: MaheshRavishankar
Date: 2025-07-09T08:50:01-07:00
New Revision: c22352175ef29c141de27485286275434c58e88a
URL: https://github.com/llvm/llvm-project/commit/c22352175ef29c141de27485286275434c58e88a
DIFF: https://github.com/llvm/llvm-project/commit/c22352175ef29c141de27485286275434c58e88a.diff
LOG: [mlir][TilingInterface] Allow tile and fuse to work with `ReductionTilingStrategy::PartialReductionOuterParallelStrategy`. (#147593)
Since `scf::tileUsingSCF` is the core method used for tiling the root
operation within the `scf::tileConsumersAndFuseProducersUsingSCF`, the
latter can fuse into any tiled loop generated using `scf::tileUsingSCF`.
This patch adds a test for tiling a root operation using
`ReductionTilingStrategy::PartialReductionOuterParallelStrategy` and
fusing producers with it.
Since this strategy generates a rank-reducing extract slice
`tensor::replaceExtractSliceWithTiledProducer` which is the core method
used for the fusion was extended to handle the rank-reducing slices.
Also fix a small bug in the computation of the reduction induction
variable (which needs to use `floorDiv` instead of `ceilDiv`)
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Added:
mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
Modified:
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 995120ad8680e..c7d634283fd4e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -681,7 +681,7 @@ getSplitReductionIvs(RewriterBase &rewriter, Location loc,
splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0));
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
- AffineExpr divExpr = s0.ceilDiv(s1);
+ AffineExpr divExpr = s0.floorDiv(s1);
int ivIndex = 0;
if (reductionStrategy ==
ReductionTilingStrategy::PartialReductionOuterParallel) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 4392a2c0eb839..6df401d4c6962 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -39,6 +39,23 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
if (failed(tiledResult))
return failure();
+ // For cases where the slice was rank-reducing, create a rank-reducing slice
+ // to get the same type back.
+ llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
+ if (droppedDims.any()) {
+ assert(tiledResult->tiledValues.size() == 1 &&
+ "expected only a single tiled result value to replace the extract "
+ "slice");
+ SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(),
+ builder.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(),
+ builder.getIndexAttr(1));
+ auto newSliceOp = builder.create<tensor::ExtractSliceOp>(
+ sliceOp.getLoc(), sliceOp.getType(), tiledResult->tiledValues[0],
+ offsets, sliceOp.getMixedSizes(), strides);
+ tiledResult->tiledValues[0] = newSliceOp;
+ }
+
return *tiledResult;
}
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 075d02ab75ad1..4cc58668944fe 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -555,6 +555,7 @@ func.func @reduction_tile_parallel_using_tile_sizes(
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 floordiv 5)>
// CHECK: func @reduction_tile_parallel_using_tile_sizes(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -566,7 +567,7 @@ func.func @reduction_tile_parallel_using_tile_sizes(
// CHECK-SAME: outs(%[[E]] :
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (%[[D1]]) step (5) shared_outs(%[[ARG3:.+]] = %[[F]])
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[D1]]]
-// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
+// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP2]]()[%[[IV]]]
// CHECK-DAG: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [%[[D0]], %[[TS0]]] [1, 1]
// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [%[[D0]], 1] [1, 1]
// CHECK: %[[PARTIAL:.+]] = linalg.generic
@@ -619,7 +620,7 @@ module {
}
}
}
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 64)>
// CHECK: func @reduction_using_forall_tile_single_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>)
// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32>
// CHECK: %[[F:.*]] = linalg.fill
@@ -671,7 +672,7 @@ module {
}
}
}
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 64)>
// CHECK: func @reduction_using_forall_tilesize_0_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>)
// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32>
// CHECK: %[[F:.*]] = linalg.fill
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
new file mode 100644
index 0000000000000..8cace28d441c6
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt -transform-interpreter -cse -mlir-print-local-scope -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// Check tile+ fuse works with partial reduction outer parallel strategy.
+
+module{
+ func.func @tile_and_fuse_with_partial_reduction_outer_parallel(
+ %arg0 : tensor<?x?xf32>) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f32
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %empty = tensor.empty(%d0) : tensor<?xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?xf32>) -> tensor<?xf32>
+ %generic = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>) outs(%fill : tensor<?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %0 = arith.addf %b0, %b1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %generic : tensor<?xf32>
+ }
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %loop = transform.test.tile_and_fuse_outer_parallel_partial_reduction
+ %generic tile_sizes = [128]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @tile_and_fuse_with_partial_reduction_outer_parallel(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[REDUCTION_NUM:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%[[D1]]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[REDUCTION_NUM]])
+// CHECK: %[[FORALL:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]]) =
+// CHECK-SAME: shared_outs(%[[ITER_ARG:.+]] = %[[EMPTY]])
+// CHECK-DAG: %[[TILESIZE:.+]] = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 128)>(%[[IV0]])[%[[D1]]]
+// CHECK-DAG: %[[REDUCTION_IV:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 128)>()[%[[IV0]]]
+// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV0]]] [%[[D0]], %[[TILESIZE]]] [1, 1]
+// CHECK: %[[ITER_ARG_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, %[[REDUCTION_IV]]] [%[[D0]], 1] [1, 1]
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[ITER_ARG_SLICE]] : tensor<?x1xf32>)
+// CHECK: %[[REDUCING_SLICE:.+]] = tensor.extract_slice %[[FILL]][0, 0] [%[[D0]], 1] [1, 1] : tensor<?x1xf32> to tensor<?xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0_SLICE]] :
+// CHECK-SAME: outs(%[[REDUCING_SLICE]] :
+// CHECK: tensor.parallel_insert_slice %[[GENERIC]] into %[[ITER_ARG]]
+// CHECK-SAME: [0, %[[REDUCTION_IV]]] [%[[D0]], 1] [1, 1]
+// CHECK: %[[REDUCE:.+]] = linalg.reduce
+// CHECK-SAME: ins(%[[FORALL]] :
+// CHECK: return %[[REDUCE]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index ee3eb9522db7e..3d24d4ecc4d0d 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/TilingInterface.h"
@@ -60,8 +61,7 @@ template <typename Range>
static LogicalResult
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps, unsigned numLoops,
- ArrayRef<OpFoldResult> tileSizes,
- ArrayRef<int64_t> interchange, bool useForall,
+ scf::SCFTilingOptions tilingOptions,
TransformResults &transformResults) {
SmallVector<Operation *> tiledOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
@@ -83,12 +83,6 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
}
}
- scf::SCFTilingOptions tilingOptions;
- tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
- if (useForall) {
- tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
- }
-
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);
@@ -157,10 +151,16 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
+ scf::SCFTilingOptions tilingOptions;
+ tilingOptions.setTileSizes(tileSizesOfr).setInterchange(tileInterchange);
+ if (getUseForall()) {
+ tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
+ }
+
LogicalResult result = applyTileAndFuseToAll(
rewriter, getOperation(), state.getPayloadOps(getTarget()),
- tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
- tileInterchange, getUseForall(), transformResults);
+ tileSizes.size() - llvm::count(tileSizes, 0), tilingOptions,
+ transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
@@ -399,6 +399,75 @@ void transform::TestFuseUsingForallOp::getEffects(
modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// TestTileAndFuseOuterParallelPartialReduction
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
+ TransformRewriter &rewriter, TransformResults &transformResults,
+ TransformState &state) {
+ auto target =
+ dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
+ if (!target) {
+ emitOpError("expected root operation to implement `TilingInterface`");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ SmallVector<unsigned> reductionDims =
+ extractFromIntegerArrayAttr<unsigned>(getReductionDims());
+ if (reductionDims.empty()) {
+ for (auto [index, iterator] :
+ llvm::enumerate(target.getLoopIteratorTypes()))
+ if (iterator == utils::IteratorType::reduction)
+ reductionDims.push_back(index);
+ }
+
+ if (reductionDims.empty()) {
+ emitOpError(
+ "no reduction dimension specified or found in the target operation");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ SmallVector<int64_t> reductionTileSizes =
+ extractFromIntegerArrayAttr<int64_t>(getTileSizes());
+ if (reductionTileSizes.size() != reductionDims.size()) {
+ emitOpError(
+ "missing tile sizes for reduction dimensions that are to be tiled");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ // Adjust tile sizes so that it corresponds to the reduction iterator types.
+ SmallVector<OpFoldResult> tileSizes;
+ int reductionTileSizeNum = 0;
+ OpFoldResult zero = rewriter.getIndexAttr(0);
+ for (auto iterator : target.getLoopIteratorTypes()) {
+ if (iterator == utils::IteratorType::parallel) {
+ tileSizes.push_back(zero);
+ continue;
+ }
+ tileSizes.push_back(
+ rewriter.getIndexAttr(reductionTileSizes[reductionTileSizeNum++]));
+ }
+
+ scf::SCFTilingOptions tilingOptions;
+ tilingOptions.setTileSizes(tileSizes)
+ .setLoopType(scf::SCFTilingOptions::LoopType::ForallOp)
+ .setReductionTilingStrategy(
+ ReductionTilingStrategy::PartialReductionOuterParallel)
+ .setReductionDims(reductionDims);
+ if (auto mapping = getMapping()) {
+ tilingOptions.setMapping(getMapping().value());
+ }
+
+ LogicalResult result = applyTileAndFuseToAll(
+ rewriter, getOperation(), state.getPayloadOps(getRootOp()),
+ /*numLoops =*/1, tilingOptions, transformResults);
+
+ return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+ : DiagnosedSilenceableFailure::success();
+}
+
#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.cpp.inc"
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 3c09082e192ea..58ccd30bb99a2 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -126,4 +126,28 @@ def TestFuseUsingForallOp : Op<Transform_Dialect, "test.fuse_using_forall",
}];
}
+def TestTileAndFuseOuterParallelPartialReductionOp : Op<
+ Transform_Dialect, "test.tile_and_fuse_outer_parallel_partial_reduction",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Test operation to tile an operation using partial reduction with
+ outer parallel strategy, and to fuse its producers.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$root_op,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+ OptionalAttr<DeviceMappingArrayAttr>:$mapping);
+
+ let results = (outs TransformHandleTypeInterface:$tiled_ops,
+ Variadic<TransformHandleTypeInterface>:$loops);
+ let assemblyFormat = [{
+ $root_op (`reduction_dims` `=` $reduction_dims^)?
+ (`tile_sizes` `=` $tile_sizes^)? (`mapping` `=` $mapping^)?
+ attr-dict `:` functional-type(operands, results)
+ }];
+}
+
#endif // TEST_TILINGINTERFACE_TRANSFORM_OPS
More information about the Mlir-commits
mailing list