[Mlir-commits] [mlir] [mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface (PR #92624)

Kunwar Grover llvmlistbot at llvm.org
Mon May 20 06:49:11 PDT 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/92624

>From 8e890c8908b3267c41bb616de5749dca3fc43f17 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 20 May 2024 13:17:01 +0000
Subject: [PATCH] save

---
 .../Linalg/TransformOps/LinalgTransformOps.td |   4 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +-
 .../SCF/Transforms/TileUsingInterface.h       |   4 +-
 .../mlir/Interfaces/TilingInterface.td        |   5 +-
 .../TransformOps/LinalgTransformOps.cpp       |   6 +-
 .../Transforms/MeshShardingInterfaceImpl.cpp  |  13 +-
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp |  25 +--
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 170 +++++++++++-------
 .../SCF/Transforms/TileUsingInterface.cpp     |  51 +++---
 .../Linalg/transform-tile-reduction.mlir      |  50 +++++-
 10 files changed, 211 insertions(+), 121 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5585ba27fdad8..93e2c2db729da 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1681,7 +1681,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   let arguments = (ins TransformHandleTypeInterface:$target,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
-  let results = (outs TransformHandleTypeInterface:$fill_op,
+  let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_linalg_op,
                       TransformHandleTypeInterface:$combining_linalg_op,
                       TransformHandleTypeInterface:$for_op);
@@ -1787,7 +1787,7 @@ def TileReductionUsingForallOp :
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
-  let results = (outs TransformHandleTypeInterface:$fill_op,
+  let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_linalg_op,
                       TransformHandleTypeInterface:$combining_linalg_op,
                       TransformHandleTypeInterface:$forall_op);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f77c19ed0fcce..308ce92e35520 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,8 +876,8 @@ struct ForallReductionTilingResult {
   Operation *parallelTiledOp;
   /// The final reduction operation merging all the partial reductions.
   Operation *mergeOp;
-  /// The op initializing the tensor used for partial reductions.
-  Operation *initialOp;
+  /// Initial values used for partial reductions.
+  SmallVector<Value> initialValues;
   /// The `scf.forall` operation that iterate over the tiles.
   scf::ForallOp loops;
 };
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..6d567171e185a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -250,8 +250,8 @@ struct SCFReductionTilingResult {
   Operation *parallelTiledOp;
   /// The final reduction operation merging all the partial reductions.
   Operation *mergeOp;
-  /// Initial op
-  Operation *initialOp;
+  /// Initial values used for reduction.
+  SmallVector<Value> initialValues;
   /// The loop operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
 };
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c2424..6fff7e0da538a 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -170,11 +170,12 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
           operation reduction. The tensor shape is equal to operation result
           shape with new dimension for each non zero tile size.
         }],
-        /*retType=*/"FailureOr<Operation*>",
+        /*retType=*/"FailureOr<Value>",
         /*methodName=*/"generateInitialTensorForPartialReduction",
         /*args=*/(ins
             "OpBuilder &":$b,
-            "Location ":$loc,
+            "Location":$loc,
+            "int64_t":$resultNumber,
             "ArrayRef<OpFoldResult>":$sizes,
             "ArrayRef<int>":$reductionDim),
         /*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..9b3121774ab3a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2523,7 +2523,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
 
   if (failed(result))
     return emitDefaultSilenceableFailure(target);
-  results.push_back(result->initialOp);
+  for (Value initValue : result->initialValues)
+    results.push_back(initValue.getDefiningOp());
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
   results.push_back(result->loops.front());
@@ -2574,7 +2575,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
     diag.attachNote(target.getLoc()) << "target operation";
     return diag;
   }
-  results.push_back(result->initialOp);
+  for (Value initValue : result->initialValues)
+    results.push_back(initValue.getDefiningOp());
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
   results.push_back(result->loops);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 146e880765668..e0394f852fcc3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -155,12 +155,12 @@ static Value createDestinationPassingStyleInitOperand(
         tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
     PartialReductionOpInterface partialReductionIface =
         llvm::cast<PartialReductionOpInterface>(op.getOperation());
-    FailureOr<Operation *> reductionNeutralTensorOp =
+    assert(op->getNumResults() == 1 && "Multiple results not supported.");
+    FailureOr<Value> reductionNeutralTensor =
         partialReductionIface.generateInitialTensorForPartialReduction(
-            builder, builder.getLoc(), shape, {});
-    assert(succeeded(reductionNeutralTensorOp));
-    builder.create<scf::YieldOp>(
-        reductionNeutralTensorOp.value()->getResult(0));
+            builder, builder.getLoc(), 0, shape, {});
+    assert(succeeded(reductionNeutralTensor));
+    builder.create<scf::YieldOp>(reductionNeutralTensor.value());
   }
   return ifOp.getResult(0);
 }
@@ -173,8 +173,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
     ImplicitLocOpBuilder &builder) {
   // TODO: add support for multiple destination passing style initial value
   // operands.
-  // PartialReductionOpInterface::generateInitialTensorForPartialReduction
-  // needs to also support multiple DPS initial operands.
+  assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
   SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
   auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
   Value spmdizedInitOperand =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index df4089d61bfd7..d51c4d0af0366 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -692,12 +692,17 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
         op, "reduction dimension must be mapped to threads");
 
   // 1. Create the inital tensor value.
-  FailureOr<Operation *> identityTensor =
-      op.generateInitialTensorForPartialReduction(b, loc, numThreads,
-                                                  reductionDim);
-  if (failed(identityTensor))
-    return b.notifyMatchFailure(op,
-                                "cannot create a tensor of identity value.");
+  SmallVector<Value> initTensors;
+  initTensors.reserve(op->getNumResults());
+  for (int idx : llvm::seq<int>(0, op->getNumResults())) {
+    FailureOr<Value> initValue = op.generateInitialTensorForPartialReduction(
+        b, loc, idx, numThreads, reductionDim);
+    if (failed(initValue))
+      return b.notifyMatchFailure(
+          op, "cannot create a tensor of identity value for result " +
+                  std::to_string(idx));
+    initTensors.push_back(initValue.value());
+  }
 
   // Gather destination tensors.
   SmallVector<Value> dest;
@@ -715,8 +720,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
 
   // 2. Create the ForallOp with an empty region.
   scf::ForallOp forallOp = b.create<scf::ForallOp>(
-      loc, getAsOpFoldResult(materializedNonZeroNumThreads),
-      (*identityTensor)->getResults(), mapping);
+      loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
+      mapping);
 
   // 3. Calculate the tile offsets and sizes for the subsequent loop that will
   // be nested under `forallOp`.
@@ -726,7 +731,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
                                /*nominalTileSizes=*/std::nullopt, tiledOffsets,
                                tiledSizes);
 
-  // 4. Clone the tileable op and update its destination operands to use the
+  // 4b. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
   SmallVector<Value> tilingResults;
   ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
@@ -838,7 +843,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
 
   // 8. Return.
   ForallReductionTilingResult results;
-  results.initialOp = *identityTensor;
+  results.initialValues = initTensors;
   results.loops = forallOp;
   results.parallelTiledOp = tiledOp;
   results.mergeOp = mergeOp;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5..528f120287181 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -250,9 +250,9 @@ template <typename LinalgOpTy>
 struct LinalgOpPartialReductionInterface
     : public PartialReductionOpInterface::ExternalModel<
           LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
-  FailureOr<Operation *> generateInitialTensorForPartialReduction(
-      Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
-      ArrayRef<int> reductionDims) const {
+  FailureOr<Value> generateInitialTensorForPartialReduction(
+      Operation *op, OpBuilder &b, Location loc, int64_t resultNumber,
+      ArrayRef<OpFoldResult> sizes, ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
     OpBuilder::InsertionGuard guard(b);
 
@@ -262,7 +262,8 @@ struct LinalgOpPartialReductionInterface
     // loops. This could be controlled by user for more flexibility.
 
     SmallVector<Operation *, 4> combinerOps;
-    if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
+    if (!matchReduction(linalgOp.getRegionOutputArgs(), resultNumber,
+                        combinerOps) ||
         combinerOps.size() != 1)
       return op->emitOpError("Failed to anaysis the reduction operation.");
 
@@ -273,7 +274,7 @@ struct LinalgOpPartialReductionInterface
           "Failed to get an identity value for the reduction operation.");
 
     ArrayRef<int64_t> oldShape =
-        linalgOp.getShape(linalgOp.getDpsInitOperand(0));
+        linalgOp.getShape(linalgOp.getDpsInitOperand(resultNumber));
 
     // Calculate the new shape, we insert the new dimensions based on the index
     // of the reduction dimensions.
@@ -293,15 +294,15 @@ struct LinalgOpPartialReductionInterface
       newOutputShape.push_back(dim);
       if (ShapedType::isDynamic(dim))
         dynamicDims.push_back(b.create<tensor::DimOp>(
-            loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx));
+            loc, linalgOp.getDpsInitOperand(resultNumber)->get(), oldIdx));
     }
     Value emptyTensor = b.create<tensor::EmptyOp>(
-        loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(),
-        dynamicDims);
+        loc, newOutputShape,
+        linalgOp.getRegionOutputArgs()[resultNumber].getType(), dynamicDims);
     Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
     auto identityTensor =
         b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
-    return identityTensor.getOperation();
+    return identityTensor.getResult(0);
   }
 
   Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
@@ -312,44 +313,64 @@ struct LinalgOpPartialReductionInterface
     OpBuilder::InsertionGuard guard(b);
     auto linalgOp = cast<LinalgOp>(op);
 
-    AffineMap oldOutputMap =
-        linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
-    SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() +
-                                       reductionDims.size());
-
-    for (int idx : reductionDims)
-      outputExpr[idx] = b.getAffineDimExpr(idx);
-    int currExpr = 0;
-    for (int idx : llvm::seq<int>(0, outputExpr.size())) {
-      if (outputExpr[idx])
-        continue;
-      outputExpr[idx] = oldOutputMap.getResult(currExpr++);
+    // Step 1. Extend init maps to have reduction dimension dims, since we
+    // are converting them to parallel dimensions.
+    SmallVector<AffineMap> newInitMaps;
+    newInitMaps.reserve(linalgOp.getNumDpsInits());
+    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+      // this with a for range loop when we have it.
+      AffineMap newMap =
+          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+      for (int redPos : reductionDims) {
+        newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
+                                     newMap.getNumResults());
+      }
+      newInitMaps.push_back(newMap);
     }
 
-    // Step 1: Extract a slice of the input operands.
-    SmallVector<Value> valuesToTile = linalgOp.getDpsInputs();
-    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
-        b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+    // Step 2a: Extract a slice of the input operands.
+    SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+        b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+
+    // Step 2b: Extract a slice of the init operands.
+    SmallVector<Value, 1> tiledInits;
+    for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) {
+      int64_t initRank = valueMap.getNumResults();
+      SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0));
+      SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1));
+      SmallVector<OpFoldResult> initSizes;
+      for (AffineExpr dimExpr : valueMap.getResults()) {
+        auto dim = cast<AffineDimExpr>(dimExpr);
+        initSizes.push_back(sizes[dim.getPosition()]);
+      }
+      // TODO: Use SubsetExtractOpInterface here once available.
+      auto extractSlice = b.create<tensor::ExtractSliceOp>(
+          loc, valueToTile, initOffset, initSizes, initStride);
+      tiledInits.push_back(extractSlice);
+    }
 
-    // Step 2: Extract the accumulator operands
-    SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
-    SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-    // TODO: use SubsetExtractOpInterface once it is available.
-    Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
-                                                 sizes, strides);
+    // Update the indexing maps.
+    SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
+    // Change the init maps.
+    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+      // this with a for range loop when we have it.
+      OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
+      int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
+      newMaps[mapIdx] = newInitMaps[idx];
+    }
 
-    // Step3. Create a generic op where the reduction dimensions are replaced
-    // by a parallel dimension of the size of reduction.
+    // Step 3. Change the reduction dim iterator types.
     SmallVector<utils::IteratorType> newIteratorTypes =
         linalgOp.getIteratorTypesArray();
     for (int dim : reductionDims)
       newIteratorTypes[dim] = utils::IteratorType::parallel;
-    SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
-    newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
-                                    linalgOp.getContext());
+
+    // Step 4. Create the new generic op.
     auto genericOp =
-        b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands,
-                            ValueRange({out}), newMaps, newIteratorTypes);
+        b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
+                            tiledInits, newMaps, newIteratorTypes);
     IRMapping mapping;
     op->getRegion(0).cloneInto(&genericOp.getRegion(),
                                genericOp.getRegion().begin(), mapping);
@@ -361,40 +382,53 @@ struct LinalgOpPartialReductionInterface
                              ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
 
-    DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
-
-    // Then create a new reduction that only reduce the newly added dimensions
-    // from the previous op.
-    int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
-    AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
-    SmallVector<utils::IteratorType> reductionIteratorTypes;
-    SmallVector<AffineExpr> exprs;
-
-    for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
-      if (reductionDimsSet.contains(i)) {
-        reductionIteratorTypes.push_back(utils::IteratorType::reduction);
-      } else {
-        exprs.push_back(b.getAffineDimExpr(i));
-        reductionIteratorTypes.push_back(utils::IteratorType::parallel);
+    // Step 1. Recover the dims that actually need to be merged from the
+    // original operation. We can classify the original iterators as follows:
+    //
+    // parallel                         --> parallel
+    // reduction + not in reductionDims --> parallel (already reduced)
+    // reduction + in reductionDims     --> reduction (will reduce now)
+    SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
+                                               utils::IteratorType::parallel);
+    for (int redIdx : reductionDims)
+      iterators[redIdx] = utils::IteratorType::reduction;
+
+    // Step 2. For each partial result, create a map to index it. This map
+    // is simply the indexing map for the original result with reductionDims
+    // appended (as produced in tileToPartialReduction).
+    int64_t numInits = linalgOp.getNumDpsInits();
+    SmallVector<AffineMap> indexingMaps(numInits * 2);
+    for (int idx : llvm::seq<int>(0, numInits)) {
+      AffineMap &inputMap = indexingMaps[idx];
+      AffineMap &outputMap = indexingMaps[numInits + idx];
+
+      outputMap =
+          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+      inputMap = outputMap;
+      for (int redPos : reductionDims) {
+        inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos),
+                                         inputMap.getNumResults());
       }
     }
 
-    AffineMap outputMap =
-        AffineMap::get(intermRank, 0, exprs, op->getContext());
-    SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
-
-    SmallVector<Operation *, 4> combinerOps;
-    matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps);
-    Operation *reductionOp = combinerOps[0];
-
     auto reduction = b.create<GenericOp>(
-        loc, op->getResultTypes(), ValueRange({partialReduce[0]}),
-        linalgOp.getDpsInits(), reductionMaps, reductionIteratorTypes,
-        [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
-          Operation *clonedReductionOp = b.clone(*reductionOp);
-          clonedReductionOp->setOperand(0, inputs[0]);
-          clonedReductionOp->setOperand(1, inputs[1]);
-          b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+        loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(),
+        indexingMaps, iterators,
+        [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
+          int64_t numInits = linalgOp.getNumDpsInits();
+          SmallVector<Value> yieldedValues;
+          for (int idx : llvm::seq<int>(0, numInits)) {
+            // Get the combiner op.
+            SmallVector<Operation *, 4> combinerOps;
+            matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
+            Operation *clonedReductionOp = b.clone(*combinerOps[0]);
+            // Combine the input at idx and output at numInits + idx.
+            clonedReductionOp->setOperand(0, inputs[idx]);
+            clonedReductionOp->setOperand(1, inputs[numInits + idx]);
+            // Yield.
+            yieldedValues.push_back(clonedReductionOp->getResult(0));
+          }
+          b.create<linalg::YieldOp>(loc, yieldedValues);
         });
     return reduction.getOperation();
   }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69d..32a683eff8cdb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -182,6 +182,9 @@ static LogicalResult generateLoopNestUsingForOp(
   if (loops.empty())
     return success();
 
+  assert(tiledResults.size() == destinationTensors.size() &&
+         "Number of results of body should be equal to number of iter args");
+
   // 6. Yield all the results of the tiled operation.
   SmallVector<Value> yieldedValues;
   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
@@ -694,9 +697,6 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
     tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
                            zero);
   }
-  if (op->getNumResults() != 1)
-    return b.notifyMatchFailure(
-        op, "don't support ops with multiple results for now");
   SmallVector<utils::IteratorType> iterators =
       tilingInterfaceOp.getLoopIteratorTypes();
 
@@ -708,12 +708,18 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   }
 
   // 2. create the inital tensor value.
-  FailureOr<Operation *> identityTensor =
-      op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
-                                                  reductionDims);
-  if (failed(identityTensor))
-    return b.notifyMatchFailure(op,
-                                "cannot create a tensor of identity value.");
+  SmallVector<Value> initTensors;
+  initTensors.reserve(op->getNumResults());
+  for (int idx : llvm::seq<int>(0, op->getNumResults())) {
+    FailureOr<Value> initTensor = op.generateInitialTensorForPartialReduction(
+        b, loc, idx, tileSizesVector, reductionDims);
+    if (failed(initTensor)) {
+      return b.notifyMatchFailure(
+          op, "cannot create a tensor of identity value for result: " +
+                  std::to_string(idx));
+    }
+    initTensors.push_back(initTensor.value());
+  }
 
   // 3. Define the callback to use for generating the inner most tile loop body.
   Operation *parallelOp = nullptr;
@@ -753,29 +759,26 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
     tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
     // 4d. Compute the offsets and sizes needed to insert the result of the
     // tiled value back into destination before yielding the destination.
-    SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-    resultOffsets.emplace_back(std::move(outOffsets));
-
-    SmallVector<OpFoldResult> outSizes;
-    for (size_t i = 0; i < offsets.size(); i++) {
-      outSizes.push_back(
-          tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+    for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) {
+      SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+      resultOffsets.emplace_back(std::move(outOffsets));
+
+      SmallVector<OpFoldResult> outSizes;
+      for (size_t i = 0; i < offsets.size(); i++) {
+        outSizes.push_back(
+            tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i));
+      }
+      resultSizes.emplace_back(std::move(outSizes));
     }
-    resultSizes.emplace_back(std::move(outSizes));
     return success();
   };
 
   // 5. Generate the tiled implementation using the destination tensors.
-  SmallVector<Value> destinationTensors =
-      llvm::map_to_vector(identityTensor.value()->getResults(),
-                          [](OpResult res) -> Value { return res; });
-
   SmallVector<LoopLikeOpInterface> loops;
   scf::SCFTilingOptions options;
   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
   if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
-                              destinationTensors, innerYieldTiledValuesFn,
-                              loops)))
+                              initTensors, innerYieldTiledValuesFn, loops)))
     return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
 
   SmallVector<Value> replacements = llvm::map_to_vector(
@@ -787,7 +790,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   b.replaceOp(op, mergeOp->getResults());
 
   SCFReductionTilingResult results;
-  results.initialOp = *identityTensor;
+  results.initialValues = initTensors;
   results.loops = loops;
   results.parallelTiledOp = parallelOp;
   results.mergeOp = mergeOp;
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 0e1512717a22d..f3cf7c4dffa05 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -80,13 +80,14 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
 //     CHECK: func @reduction_tile_transpose
 //     CHECK:   tensor.empty(%{{.*}}) : tensor<5x?xf32>
 //     CHECK:   linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
 //     CHECK:   scf.for
 //     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
-//     CHECK:     %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
+//     CHECK:     %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
 //     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
 //     CHECK:     scf.yield {{.*}} : tensor<5x?xf32>
 //     CHECK:   }
@@ -403,3 +404,48 @@ module {
 // CHECK:     scf.yield %[[L1]] : tensor<4096x2x64xf32>
 // CHECK:   %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
 // CHECK:  return %[[OUT2]] : tensor<4096xf32>
+
+// -----
+
+func.func @reduction_tile_multiple_results(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>, %out2: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+  %red:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                            affine_map<(d0, d1) -> (d0)>,
+                                            affine_map<(d0, d1) -> (d0)>],
+   iterator_types = ["parallel", "reduction"]}
+   ins(%arg0 : tensor<?x?xf32>)
+   outs(%out, %out2 : tensor<?xf32>, tensor<?xf32>) {
+    ^bb0(%arg7: f32, %arg9: f32, %arg9_1: f32):
+      %1 = arith.mulf %arg7, %arg7 : f32
+      %2 = arith.addf %1, %arg9 : f32
+      %3 = arith.maximumf %1, %arg9_1 : f32
+      linalg.yield %2, %3 : f32, f32
+    } -> (tensor<?xf32>, tensor<?xf32>)
+  return %red#0, %red#1 : tensor<?xf32>, tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %12, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+      by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+  }
+}
+
+// CHECK: func @reduction_tile_multiple_results
+// CHECK-DAG:   %[[SUM_ID:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:   %[[MAX_ID:.+]] = arith.constant 0xFF800000 : f32
+// CHECK-DAG:   %[[SUM_INIT:.+]] = linalg.fill ins(%[[SUM_ID]] : f32) outs(%{{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
+// CHECK-DAG:   %[[MAX_INIT:.+]] = linalg.fill ins(%[[MAX_ID]] : f32) outs(%{{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
+// CHECK:       %[[OUT:.+]]:2 = scf.for
+// CHECK-SAME:            iter_args(%[[SUM:.+]] = %[[SUM_INIT]], %[[MAX:.+]] = %[[MAX_INIT]])
+// CHECK:         %[[UPDATED:.*]]:2 = linalg.generic
+// CHECK:         arith.mulf
+// CHECK:         arith.addf
+// CHECK:         arith.maximumf
+// CHECK:       %[[INSERT1:.+]] = tensor.insert_slice %[[UPDATED]]#0 into %[[SUM]]
+// CHECK:       %[[INSERT2:.+]] = tensor.insert_slice %[[UPDATED]]#1 into %[[MAX]]
+// CHECK:       scf.yield %[[INSERT1]], %[[INSERT1]]
+// CHECK:       linalg.generic
+// CHECK:         arith.addf
+// CHECK:         arith.maximumf



More information about the Mlir-commits mailing list