[Mlir-commits] [mlir] [mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface (PR #92624)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 17 17:58:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>



---

Patch is 29.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92624.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2-2) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2-2) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+2-2) 
- (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+3-2) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+4-2) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp (+6-7) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+17-15) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+96-68) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+27-24) 
- (modified) mlir/test/Dialect/Linalg/transform-tile-reduction.mlir (+40-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5585ba27fdad8..93e2c2db729da 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1681,7 +1681,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
   // TODO: support mixed static-dynamic (see TileUsingForallOp).
   let arguments = (ins TransformHandleTypeInterface:$target,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
-  let results = (outs TransformHandleTypeInterface:$fill_op,
+  let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_linalg_op,
                       TransformHandleTypeInterface:$combining_linalg_op,
                       TransformHandleTypeInterface:$for_op);
@@ -1787,7 +1787,7 @@ def TileReductionUsingForallOp :
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
                    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
-  let results = (outs TransformHandleTypeInterface:$fill_op,
+  let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
                       TransformHandleTypeInterface:$split_linalg_op,
                       TransformHandleTypeInterface:$combining_linalg_op,
                       TransformHandleTypeInterface:$forall_op);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f77c19ed0fcce..308ce92e35520 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -876,8 +876,8 @@ struct ForallReductionTilingResult {
   Operation *parallelTiledOp;
   /// The final reduction operation merging all the partial reductions.
   Operation *mergeOp;
-  /// The op initializing the tensor used for partial reductions.
-  Operation *initialOp;
+  /// Initial values used for partial reductions.
+  SmallVector<Value> initialValues;
   /// The `scf.forall` operation that iterate over the tiles.
   scf::ForallOp loops;
 };
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..6d567171e185a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -250,8 +250,8 @@ struct SCFReductionTilingResult {
   Operation *parallelTiledOp;
   /// The final reduction operation merging all the partial reductions.
   Operation *mergeOp;
-  /// Initial op
-  Operation *initialOp;
+  /// Initial values used for reduction.
+  SmallVector<Value> initialValues;
   /// The loop operations that iterate over the tiles.
   SmallVector<LoopLikeOpInterface> loops;
 };
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c2424..6fff7e0da538a 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -170,11 +170,12 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
           operation reduction. The tensor shape is equal to operation result
           shape with new dimension for each non zero tile size.
         }],
-        /*retType=*/"FailureOr<Operation*>",
+        /*retType=*/"FailureOr<Value>",
         /*methodName=*/"generateInitialTensorForPartialReduction",
         /*args=*/(ins
             "OpBuilder &":$b,
-            "Location ":$loc,
+            "Location":$loc,
+            "int64_t":$resultNumber,
             "ArrayRef<OpFoldResult>":$sizes,
             "ArrayRef<int>":$reductionDim),
         /*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..9b3121774ab3a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2523,7 +2523,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
 
   if (failed(result))
     return emitDefaultSilenceableFailure(target);
-  results.push_back(result->initialOp);
+  for (Value initValue : result->initialValues)
+    results.push_back(initValue.getDefiningOp());
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
   results.push_back(result->loops.front());
@@ -2574,7 +2575,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
     diag.attachNote(target.getLoc()) << "target operation";
     return diag;
   }
-  results.push_back(result->initialOp);
+  for (Value initValue : result->initialValues)
+    results.push_back(initValue.getDefiningOp());
   results.push_back(result->parallelTiledOp);
   results.push_back(result->mergeOp);
   results.push_back(result->loops);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 146e880765668..e0394f852fcc3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -155,12 +155,12 @@ static Value createDestinationPassingStyleInitOperand(
         tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
     PartialReductionOpInterface partialReductionIface =
         llvm::cast<PartialReductionOpInterface>(op.getOperation());
-    FailureOr<Operation *> reductionNeutralTensorOp =
+    assert(op->getNumResults() == 1 && "Multiple results not supported.");
+    FailureOr<Value> reductionNeutralTensor =
         partialReductionIface.generateInitialTensorForPartialReduction(
-            builder, builder.getLoc(), shape, {});
-    assert(succeeded(reductionNeutralTensorOp));
-    builder.create<scf::YieldOp>(
-        reductionNeutralTensorOp.value()->getResult(0));
+            builder, builder.getLoc(), 0, shape, {});
+    assert(succeeded(reductionNeutralTensor));
+    builder.create<scf::YieldOp>(reductionNeutralTensor.value());
   }
   return ifOp.getResult(0);
 }
@@ -173,8 +173,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
     ImplicitLocOpBuilder &builder) {
   // TODO: add support for multiple destination passing style initial value
   // operands.
-  // PartialReductionOpInterface::generateInitialTensorForPartialReduction
-  // needs to also support multiple DPS initial operands.
+  assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
   SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
   auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
   Value spmdizedInitOperand =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index df4089d61bfd7..d4805611a68b2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -657,9 +657,9 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
 
   // Ops implementing PartialReductionOpInterface are not necessarily expected
-  // to implement TilingInterface.. This cast is unsafe atm.
+  // to implement DestinationStyleOpInterface. This cast is unsafe atm.
   // TODO: proper core mechanism to tie interfaces together.
-  // TODO: this function requires a pair of interfaces ..
+  // TODO: this function requires a pair of interfaces.
   auto destinationStyleOp =
       dyn_cast<DestinationStyleOpInterface>(op.getOperation());
   if (!destinationStyleOp)
@@ -671,9 +671,6 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
     return b.notifyMatchFailure(op, "not a linalg op");
 
   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
-  if (op->getNumResults() != 1)
-    return b.notifyMatchFailure(
-        op, "don't support ops with multiple results for now");
 
   SmallVector<utils::IteratorType> iterators =
       tilingInterfaceOp.getLoopIteratorTypes();
@@ -692,12 +689,17 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
         op, "reduction dimension must be mapped to threads");
 
   // 1. Create the inital tensor value.
-  FailureOr<Operation *> identityTensor =
-      op.generateInitialTensorForPartialReduction(b, loc, numThreads,
-                                                  reductionDim);
-  if (failed(identityTensor))
-    return b.notifyMatchFailure(op,
-                                "cannot create a tensor of identity value.");
+  SmallVector<Value> initTensors;
+  initTensors.reserve(op->getNumResults());
+  for (int idx : llvm::seq<int>(0, op->getNumResults())) {
+    FailureOr<Value> initValue = op.generateInitialTensorForPartialReduction(
+        b, loc, idx, numThreads, reductionDim);
+    if (failed(initValue))
+      return b.notifyMatchFailure(
+          op, "cannot create a tensor of identity value for result " +
+                  std::to_string(idx));
+    initTensors.push_back(initValue.value());
+  }
 
   // Gather destination tensors.
   SmallVector<Value> dest;
@@ -715,8 +717,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
 
   // 2. Create the ForallOp with an empty region.
   scf::ForallOp forallOp = b.create<scf::ForallOp>(
-      loc, getAsOpFoldResult(materializedNonZeroNumThreads),
-      (*identityTensor)->getResults(), mapping);
+      loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
+      mapping);
 
   // 3. Calculate the tile offsets and sizes for the subsequent loop that will
   // be nested under `forallOp`.
@@ -726,7 +728,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
                                /*nominalTileSizes=*/std::nullopt, tiledOffsets,
                                tiledSizes);
 
-  // 4. Clone the tileable op and update its destination operands to use the
+  // 4b. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
   SmallVector<Value> tilingResults;
   ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
@@ -838,7 +840,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
 
   // 8. Return.
   ForallReductionTilingResult results;
-  results.initialOp = *identityTensor;
+  results.initialValues = initTensors;
   results.loops = forallOp;
   results.parallelTiledOp = tiledOp;
   results.mergeOp = mergeOp;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5..edae03fec72a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -250,9 +250,9 @@ template <typename LinalgOpTy>
 struct LinalgOpPartialReductionInterface
     : public PartialReductionOpInterface::ExternalModel<
           LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
-  FailureOr<Operation *> generateInitialTensorForPartialReduction(
-      Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
-      ArrayRef<int> reductionDims) const {
+  FailureOr<Value> generateInitialTensorForPartialReduction(
+      Operation *op, OpBuilder &b, Location loc, int64_t resultNumber,
+      ArrayRef<OpFoldResult> sizes, ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
     OpBuilder::InsertionGuard guard(b);
 
@@ -262,7 +262,8 @@ struct LinalgOpPartialReductionInterface
     // loops. This could be controlled by user for more flexibility.
 
     SmallVector<Operation *, 4> combinerOps;
-    if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
+    if (!matchReduction(linalgOp.getRegionOutputArgs(), resultNumber,
+                        combinerOps) ||
         combinerOps.size() != 1)
       return op->emitOpError("Failed to anaysis the reduction operation.");
 
@@ -273,7 +274,7 @@ struct LinalgOpPartialReductionInterface
           "Failed to get an identity value for the reduction operation.");
 
     ArrayRef<int64_t> oldShape =
-        linalgOp.getShape(linalgOp.getDpsInitOperand(0));
+        linalgOp.getShape(linalgOp.getDpsInitOperand(resultNumber));
 
     // Calculate the new shape, we insert the new dimensions based on the index
     // of the reduction dimensions.
@@ -293,15 +294,15 @@ struct LinalgOpPartialReductionInterface
       newOutputShape.push_back(dim);
       if (ShapedType::isDynamic(dim))
         dynamicDims.push_back(b.create<tensor::DimOp>(
-            loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx));
+            loc, linalgOp.getDpsInitOperand(resultNumber)->get(), oldIdx));
     }
     Value emptyTensor = b.create<tensor::EmptyOp>(
-        loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(),
-        dynamicDims);
+        loc, newOutputShape,
+        linalgOp.getRegionOutputArgs()[resultNumber].getType(), dynamicDims);
     Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
     auto identityTensor =
         b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
-    return identityTensor.getOperation();
+    return identityTensor.getResult(0);
   }
 
   Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
@@ -312,44 +313,58 @@ struct LinalgOpPartialReductionInterface
     OpBuilder::InsertionGuard guard(b);
     auto linalgOp = cast<LinalgOp>(op);
 
-    AffineMap oldOutputMap =
-        linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
-    SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() +
-                                       reductionDims.size());
-
-    for (int idx : reductionDims)
-      outputExpr[idx] = b.getAffineDimExpr(idx);
-    int currExpr = 0;
-    for (int idx : llvm::seq<int>(0, outputExpr.size())) {
-      if (outputExpr[idx])
-        continue;
-      outputExpr[idx] = oldOutputMap.getResult(currExpr++);
+    // Step 1. Extend init maps to have reduction dimension dims, since we
+    // are converting them to parallel dimensions.
+    SmallVector<AffineMap> newInitMaps;
+    newInitMaps.reserve(linalgOp.getNumDpsInits());
+    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+      // this with a for range loop when we have it.
+      AffineMap newMap =
+          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+      for (int redPos : reductionDims) {
+        newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
+                                     newMap.getNumResults());
+      }
+      newInitMaps.push_back(newMap);
     }
 
-    // Step 1: Extract a slice of the input operands.
-    SmallVector<Value> valuesToTile = linalgOp.getDpsInputs();
-    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
-        b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
+    // Step 2a: Extract a slice of the input operands.
+    SmallVector<Value, 4> tiledInputs = makeTiledShapes(
+        b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true);
+
+    // Step 2b: Extract a slice of the init operands.
+    SmallVector<Value, 1> tiledInits;
+    for (Value valueToTile : init) {
+      SmallVector<OpFoldResult> initOffset(offsets.size(), b.getIndexAttr(0));
+      SmallVector<OpFoldResult> initStride(offsets.size(), b.getIndexAttr(1));
+      // TODO: Use SubsetExtractOpInterface here once available.
+      auto extractSlice = b.create<tensor::ExtractSliceOp>(
+          loc, valueToTile, initOffset, sizes, initStride);
+      tiledInits.push_back(extractSlice);
+    }
 
-    // Step 2: Extract the accumulator operands
-    SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
-    SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
-    // TODO: use SubsetExtractOpInterface once it is available.
-    Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
-                                                 sizes, strides);
+    // Update the indexing maps.
+    SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
+    // Change the init maps.
+    for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) {
+      // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
+      // this with a for range loop when we have it.
+      OpOperand *initOperand = linalgOp.getDpsInitOperand(idx);
+      int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand);
+      newMaps[mapIdx] = newInitMaps[idx];
+    }
 
-    // Step3. Create a generic op where the reduction dimensions are replaced
-    // by a parallel dimension of the size of reduction.
+    // Step 3. Change the reduction dim iterator types.
     SmallVector<utils::IteratorType> newIteratorTypes =
         linalgOp.getIteratorTypesArray();
     for (int dim : reductionDims)
       newIteratorTypes[dim] = utils::IteratorType::parallel;
-    SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
-    newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
-                                    linalgOp.getContext());
+
+    // Step 4. Create the new generic op.
     auto genericOp =
-        b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands,
-                            ValueRange({out}), newMaps, newIteratorTypes);
+        b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs,
+                            tiledInits, newMaps, newIteratorTypes);
     IRMapping mapping;
     op->getRegion(0).cloneInto(&genericOp.getRegion(),
                                genericOp.getRegion().begin(), mapping);
@@ -361,40 +376,53 @@ struct LinalgOpPartialReductionInterface
                              ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
 
-    DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
-
-    // Then create a new reduction that only reduce the newly added dimensions
-    // from the previous op.
-    int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
-    AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
-    SmallVector<utils::IteratorType> reductionIteratorTypes;
-    SmallVector<AffineExpr> exprs;
-
-    for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
-      if (reductionDimsSet.contains(i)) {
-        reductionIteratorTypes.push_back(utils::IteratorType::reduction);
-      } else {
-        exprs.push_back(b.getAffineDimExpr(i));
-        reductionIteratorTypes.push_back(utils::IteratorType::parallel);
+    // Step 1. Recover the dims that actually need to be merged from the
+    // original operation. We can classify the original iterators as follows:
+    //
+    // parallel                         --> parallel
+    // reduction + not in reductionDims --> parallel (already reduced)
+    // reduction + in reductionDims     --> reduction (will reduce now)
+    SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
+                                               utils::IteratorType::parallel);
+    for (int redIdx : reductionDims)
+      iterators[redIdx] = utils::IteratorType::reduction;
+
+    // Step 2. For each partial result, create a map to index it. This map
+    // is simply the indexing map for the original result with reductionDims
+    // appended (as produced in tileToPartialReduction).
+    int64_t numInits = linalgOp.getNumDpsInits();
+    SmallVector<AffineMap> indexingMaps(numInits * 2);
+    for (int idx : llvm::seq<int>(0, numInits)) {
+      AffineMap &inputMap = indexingMaps[idx];
+      AffineMap &outputMap = indexingMaps[numInits + idx];
+
+      outputMap =
+          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
+      inputMap = outputMap;
+      for (int redPos : reductionDims) {...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/92624


More information about the Mlir-commits mailing list