[Mlir-commits] [mlir] c223521 - [mlir][TilingInterface] Allow tile and fuse to work with `ReductionTilingStrategy::PartialReductionOuterParallelStrategy`. (#147593)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 9 08:50:05 PDT 2025


Author: MaheshRavishankar
Date: 2025-07-09T08:50:01-07:00
New Revision: c22352175ef29c141de27485286275434c58e88a

URL: https://github.com/llvm/llvm-project/commit/c22352175ef29c141de27485286275434c58e88a
DIFF: https://github.com/llvm/llvm-project/commit/c22352175ef29c141de27485286275434c58e88a.diff

LOG: [mlir][TilingInterface] Allow tile and fuse to work with `ReductionTilingStrategy::PartialReductionOuterParallelStrategy`. (#147593)

Since `scf::tileUsingSCF` is the core method used for tiling the root
operation within the `scf::tileConsumersAndFuseProducersUsingSCF`, the
latter can fuse into any tiled loop generated using `scf::tileUsingSCF`.
This patch adds a test for tiling a root operation using
`ReductionTilingStrategy::PartialReductionOuterParallelStrategy` and
fusing producers with it.

Since this strategy generates a rank-reducing extract slice
`tensor::replaceExtractSliceWithTiledProducer` which is the core method
used for the fusion was extended to handle the rank-reducing slices.

Also fix a small bug in the computation of the reduction induction
variable (which needs to use `floorDiv` instead of `ceilDiv`)

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>

Added: 
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir

Modified: 
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 995120ad8680e..c7d634283fd4e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -681,7 +681,7 @@ getSplitReductionIvs(RewriterBase &rewriter, Location loc,
   splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0));
   AffineExpr s0, s1;
   bindSymbols(rewriter.getContext(), s0, s1);
-  AffineExpr divExpr = s0.ceilDiv(s1);
+  AffineExpr divExpr = s0.floorDiv(s1);
   int ivIndex = 0;
   if (reductionStrategy ==
       ReductionTilingStrategy::PartialReductionOuterParallel) {

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 4392a2c0eb839..6df401d4c6962 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -39,6 +39,23 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
   if (failed(tiledResult))
     return failure();
 
+  // For cases where the slice was rank-reducing, create a rank-reducing slice
+  // to get the same type back.
+  llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
+  if (droppedDims.any()) {
+    assert(tiledResult->tiledValues.size() == 1 &&
+           "expected only a single tiled result value to replace the extract "
+           "slice");
+    SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(),
+                                      builder.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(),
+                                      builder.getIndexAttr(1));
+    auto newSliceOp = builder.create<tensor::ExtractSliceOp>(
+        sliceOp.getLoc(), sliceOp.getType(), tiledResult->tiledValues[0],
+        offsets, sliceOp.getMixedSizes(), strides);
+    tiledResult->tiledValues[0] = newSliceOp;
+  }
+
   return *tiledResult;
 }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 075d02ab75ad1..4cc58668944fe 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -555,6 +555,7 @@ func.func @reduction_tile_parallel_using_tile_sizes(
 }
 //  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 5)>
 //  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+//  CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 floordiv 5)>
 //      CHECK: func @reduction_tile_parallel_using_tile_sizes(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
 //  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
@@ -566,7 +567,7 @@ func.func @reduction_tile_parallel_using_tile_sizes(
 // CHECK-SAME:      outs(%[[E]] :
 //      CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (%[[D1]]) step (5) shared_outs(%[[ARG3:.+]] = %[[F]])
 //  CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[D1]]]
-//  CHECK-DAG:     %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
+//  CHECK-DAG:     %[[INIT_OFFSET:.+]] = affine.apply #[[MAP2]]()[%[[IV]]]
 //  CHECK-DAG:     %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [%[[D0]], %[[TS0]]] [1, 1]
 //  CHECK-DAG:     %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [%[[D0]], 1] [1, 1]
 //      CHECK:     %[[PARTIAL:.+]] = linalg.generic
@@ -619,7 +620,7 @@ module {
     }
   }
 }
-//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 64)>
 //      CHECK: func @reduction_using_forall_tile_single_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>)
 //      CHECK:   %[[E:.*]] = tensor.empty() : tensor<4096x2xf32>
 //      CHECK:   %[[F:.*]] = linalg.fill
@@ -671,7 +672,7 @@ module {
     }
   }
 }
-//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+//  CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 64)>
 //      CHECK: func @reduction_using_forall_tilesize_0_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>)
 //      CHECK:   %[[E:.*]] = tensor.empty() : tensor<4096x2xf32>
 //      CHECK:   %[[F:.*]] = linalg.fill

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
new file mode 100644
index 0000000000000..8cace28d441c6
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-with-reduction-tiling.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt -transform-interpreter -cse -mlir-print-local-scope -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// Check tile+ fuse works with partial reduction outer parallel strategy.
+
+module{
+  func.func @tile_and_fuse_with_partial_reduction_outer_parallel(
+      %arg0 : tensor<?x?xf32>) -> tensor<?xf32> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f32
+    %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+    %empty = tensor.empty(%d0) : tensor<?xf32>
+    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?xf32>) -> tensor<?xf32>
+    %generic = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+        iterator_types = ["parallel", "reduction"]}
+        ins(%arg0 : tensor<?x?xf32>) outs(%fill : tensor<?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %0 = arith.addf %b0, %b1 : f32
+        linalg.yield %0 : f32
+    } -> tensor<?xf32>
+    return %generic : tensor<?xf32>
+  }
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %loop = transform.test.tile_and_fuse_outer_parallel_partial_reduction
+      %generic tile_sizes = [128] 
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @tile_and_fuse_with_partial_reduction_outer_parallel(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//       CHECK:   %[[REDUCTION_NUM:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%[[D1]]]
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[REDUCTION_NUM]])
+//       CHECK:   %[[FORALL:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]]) =
+//  CHECK-SAME:       shared_outs(%[[ITER_ARG:.+]] = %[[EMPTY]])
+//   CHECK-DAG:     %[[TILESIZE:.+]] = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 128)>(%[[IV0]])[%[[D1]]]
+//   CHECK-DAG:     %[[REDUCTION_IV:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 128)>()[%[[IV0]]]
+//   CHECK-DAG:     %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV0]]] [%[[D0]], %[[TILESIZE]]] [1, 1]
+//       CHECK:     %[[ITER_ARG_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, %[[REDUCTION_IV]]] [%[[D0]], 1] [1, 1]
+//       CHECK:     %[[FILL:.+]] = linalg.fill
+//  CHECK-SAME:         outs(%[[ITER_ARG_SLICE]] : tensor<?x1xf32>)
+//       CHECK:     %[[REDUCING_SLICE:.+]] = tensor.extract_slice %[[FILL]][0, 0] [%[[D0]], 1] [1, 1] : tensor<?x1xf32> to tensor<?xf32>
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[ARG0_SLICE]] :
+//  CHECK-SAME:         outs(%[[REDUCING_SLICE]] :
+//       CHECK:     tensor.parallel_insert_slice %[[GENERIC]] into %[[ITER_ARG]]
+//  CHECK-SAME:         [0, %[[REDUCTION_IV]]] [%[[D0]], 1] [1, 1]
+//       CHECK:   %[[REDUCE:.+]] = linalg.reduce
+//  CHECK-SAME:       ins(%[[FORALL]] :
+//       CHECK:   return %[[REDUCE]]

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index ee3eb9522db7e..3d24d4ecc4d0d 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/TilingInterface.h"
@@ -60,8 +61,7 @@ template <typename Range>
 static LogicalResult
 applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
                       Range &&payloadOps, unsigned numLoops,
-                      ArrayRef<OpFoldResult> tileSizes,
-                      ArrayRef<int64_t> interchange, bool useForall,
+                      scf::SCFTilingOptions tilingOptions,
                       TransformResults &transformResults) {
   SmallVector<Operation *> tiledOps;
   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
@@ -83,12 +83,6 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
       }
     }
 
-    scf::SCFTilingOptions tilingOptions;
-    tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
-    if (useForall) {
-      tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
-    }
-
     scf::SCFTileAndFuseOptions tileAndFuseOptions;
     tileAndFuseOptions.setTilingOptions(tilingOptions);
 
@@ -157,10 +151,16 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
   SmallVector<OpFoldResult> tileSizesOfr =
       getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
 
+  scf::SCFTilingOptions tilingOptions;
+  tilingOptions.setTileSizes(tileSizesOfr).setInterchange(tileInterchange);
+  if (getUseForall()) {
+    tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
+  }
+
   LogicalResult result = applyTileAndFuseToAll(
       rewriter, getOperation(), state.getPayloadOps(getTarget()),
-      tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
-      tileInterchange, getUseForall(), transformResults);
+      tileSizes.size() - llvm::count(tileSizes, 0), tilingOptions,
+      transformResults);
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
 }
@@ -399,6 +399,75 @@ void transform::TestFuseUsingForallOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// TestTileAndFuseOuterParallelPartialReduction
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
+    TransformRewriter &rewriter, TransformResults &transformResults,
+    TransformState &state) {
+  auto target =
+      dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
+  if (!target) {
+    emitOpError("expected root operation to implement `TilingInterface`");
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  SmallVector<unsigned> reductionDims =
+      extractFromIntegerArrayAttr<unsigned>(getReductionDims());
+  if (reductionDims.empty()) {
+    for (auto [index, iterator] :
+         llvm::enumerate(target.getLoopIteratorTypes()))
+      if (iterator == utils::IteratorType::reduction)
+        reductionDims.push_back(index);
+  }
+
+  if (reductionDims.empty()) {
+    emitOpError(
+        "no reduction dimension specified or found in the target operation");
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  SmallVector<int64_t> reductionTileSizes =
+      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
+  if (reductionTileSizes.size() != reductionDims.size()) {
+    emitOpError(
+        "missing tile sizes for reduction dimensions that are to be tiled");
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  // Adjust tile sizes so that it corresponds to the reduction iterator types.
+  SmallVector<OpFoldResult> tileSizes;
+  int reductionTileSizeNum = 0;
+  OpFoldResult zero = rewriter.getIndexAttr(0);
+  for (auto iterator : target.getLoopIteratorTypes()) {
+    if (iterator == utils::IteratorType::parallel) {
+      tileSizes.push_back(zero);
+      continue;
+    }
+    tileSizes.push_back(
+        rewriter.getIndexAttr(reductionTileSizes[reductionTileSizeNum++]));
+  }
+
+  scf::SCFTilingOptions tilingOptions;
+  tilingOptions.setTileSizes(tileSizes)
+      .setLoopType(scf::SCFTilingOptions::LoopType::ForallOp)
+      .setReductionTilingStrategy(
+          ReductionTilingStrategy::PartialReductionOuterParallel)
+      .setReductionDims(reductionDims);
+  if (auto mapping = getMapping()) {
+    tilingOptions.setMapping(getMapping().value());
+  }
+
+  LogicalResult result = applyTileAndFuseToAll(
+      rewriter, getOperation(), state.getPayloadOps(getRootOp()),
+      /*numLoops =*/1, tilingOptions, transformResults);
+
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
+}
+
 #define GET_OP_CLASSES
 #include "TestTilingInterfaceTransformOps.cpp.inc"
 

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 3c09082e192ea..58ccd30bb99a2 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -126,4 +126,28 @@ def TestFuseUsingForallOp : Op<Transform_Dialect, "test.fuse_using_forall",
   }];
 }
 
+def TestTileAndFuseOuterParallelPartialReductionOp : Op<
+    Transform_Dialect, "test.tile_and_fuse_outer_parallel_partial_reduction",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Test operation to tile an operation using partial reduction with
+    outer parallel strategy, and to fuse its producers.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$root_op,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+                   OptionalAttr<DeviceMappingArrayAttr>:$mapping);
+
+  let results = (outs TransformHandleTypeInterface:$tiled_ops,
+                 Variadic<TransformHandleTypeInterface>:$loops);
+  let assemblyFormat = [{
+    $root_op (`reduction_dims` `=` $reduction_dims^)?
+    (`tile_sizes` `=` $tile_sizes^)? (`mapping` `=` $mapping^)?
+    attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS


        


More information about the Mlir-commits mailing list