[Mlir-commits] [mlir] 2cc5f5d - [mlir][Linalg] Implement tileReductionUsingScf for multiple reductions

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 16 13:47:15 PDT 2023

Author: Groverkss
Date: 2023-08-17T02:17:03+05:30
New Revision: 2cc5f5d43c20428c9b9799176ae2d1da0bf9330d

URL: https://github.com/llvm/llvm-project/commit/2cc5f5d43c20428c9b9799176ae2d1da0bf9330d
DIFF: https://github.com/llvm/llvm-project/commit/2cc5f5d43c20428c9b9799176ae2d1da0bf9330d.diff

LOG: [mlir][Linalg] Implement tileReductionUsingScf for multiple reductions

This patch improves the reduction tiling for linalg to support multiple
reduction dimensions.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D158005




diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index a1ab4d9fd03801..cedaa4344a2958 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -255,15 +255,11 @@ struct LinalgOpPartialReductionInterface
       ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
     OpBuilder::InsertionGuard guard(b);
-    assert(reductionDims.size() == 1 &&
-           "only support single reduction right now.");
     if (linalgOp.hasBufferSemantics())
       return op->emitOpError("expected operation to have tensor semantics");
     // Insert the new parallel dimension based on the index of the reduction
-    // loop. This could be controlled by user for more flexibility.
-    int64_t insertSplitDimension = reductionDims[0];
-    assert(sizes.size() >= static_cast<size_t>(insertSplitDimension) &&
-           "reduction dimension must be tiled");
+    // loops. This could be controlled by user for more flexibility.
     SmallVector<Operation *, 4> combinerOps;
     if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
@@ -276,18 +272,31 @@ struct LinalgOpPartialReductionInterface
       return op->emitOpError(
           "Failed to get an identity value for the reduction operation.");
-    // Calculate the new shape, we insert the new dimension based on the index
-    // of the reduction dimension.
-    SmallVector<int64_t> newOutputShape;
     ArrayRef<int64_t> oldShape =
+    // Extend tile size vector to the rank of the output tensor.
+    SmallVector<Value> tileSizeVector =
+        getValueOrCreateConstantIndexOp(b, loc, sizes);
+    if (tileSizeVector.size() < oldShape.size()) {
+      auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
+      tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero);
+    }
+    // Calculate the new shape, we insert the new dimensions based on the index
+    // of the reduction dimensions.
+    SmallVector<int64_t> newOutputShape;
     SmallVector<Value> dynamicDims;
-    for (int64_t idx : llvm::seq<int64_t>(0, oldShape.size() + 1)) {
-      if (idx == insertSplitDimension) {
+    int64_t currReductionDims = 0;
+    DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
+    for (int64_t idx :
+         llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
+      if (reductionDimsSet.contains(idx)) {
         dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
+        currReductionDims++;
-      int64_t oldIdx = idx < insertSplitDimension ? idx : idx - 1;
+      int64_t oldIdx = idx - currReductionDims;
       int64_t dim = oldShape[oldIdx];
       if (ShapedType::isDynamic(dim))
@@ -310,21 +319,20 @@ struct LinalgOpPartialReductionInterface
                                     ArrayRef<int> reductionDims) const {
     OpBuilder::InsertionGuard guard(b);
     auto linalgOp = cast<LinalgOp>(op);
-    assert(reductionDims.size() == 1 &&
-           "only support single reduction right now.");
-    int64_t insertSplitDimension = reductionDims[0];
     AffineMap oldOutputMap =
-    SmallVector<AffineExpr> outputExpr;
-    for (auto [idx, expr] : llvm::enumerate(oldOutputMap.getResults())) {
-      if (static_cast<int64_t>(idx) == insertSplitDimension) {
-        outputExpr.push_back(b.getAffineDimExpr(reductionDims[0]));
-      }
-      outputExpr.push_back(expr);
+    SmallVector<AffineExpr> outputExpr(oldOutputMap.getNumResults() +
+                                       reductionDims.size());
+    for (int idx : reductionDims)
+      outputExpr[idx] = b.getAffineDimExpr(idx);
+    int currExpr = 0;
+    for (int idx : llvm::seq<int>(0, outputExpr.size())) {
+      if (outputExpr[idx])
+        continue;
+      outputExpr[idx] = oldOutputMap.getResult(currExpr++);
-    if (insertSplitDimension == oldOutputMap.getNumResults())
-      outputExpr.push_back(b.getAffineDimExpr(reductionDims[0]));
     // Step 1: Extract a slice of the input operands.
     SmallVector<Value> valuesToTile = linalgOp.getDpsInputOperands();
@@ -338,11 +346,12 @@ struct LinalgOpPartialReductionInterface
     Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
                                                  sizes, strides);
-    // Step3. create a generic op where the reduction dimension is replaced by a
-    // parallel dimension of the size of reduction.
+    // Step3. Create a generic op where the reduction dimensions are replaced
+    // by a parallel dimension of the size of reduction.
     SmallVector<utils::IteratorType> newIteratorTypes =
-    newIteratorTypes[reductionDims[0]] = utils::IteratorType::parallel;
+    for (int dim : reductionDims)
+      newIteratorTypes[dim] = utils::IteratorType::parallel;
     SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
     newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
@@ -359,24 +368,25 @@ struct LinalgOpPartialReductionInterface
                              ValueRange partialReduce,
                              ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
-    assert(reductionDims.size() == 1 &&
-           "only support single reduction right now.");
-    int64_t dimToMerge = reductionDims[0];
-    // Then create a new reduction that only reduce the newly added dimension
+    DenseSet<int> reductionDimsSet(reductionDims.begin(), reductionDims.end());
+    // Then create a new reduction that only reduce the newly added dimensions
     // from the previous op.
     int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
     AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
     SmallVector<utils::IteratorType> reductionIteratorTypes;
     SmallVector<AffineExpr> exprs;
     for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
-      if (dimToMerge == i) {
+      if (reductionDimsSet.contains(i)) {
       } else {
     AffineMap outputMap =
         AffineMap::get(intermRank, 0, exprs, op->getContext());
     SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 3f8fe9a23f55e5..597676a017bf48 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -419,26 +419,18 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
         op, "don't support ops with multiple results for now");
   SmallVector<utils::IteratorType> iterators =
-  int64_t numReductionDims = llvm::count(
-      tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
-  if (numReductionDims != 1)
-    return b.notifyMatchFailure(
-        op, "only support ops with one reduction dimension.");
-  int reductionDim;
+  SmallVector<int> reductionDims;
   for (auto [idx, iteratorType] :
        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
-    if (iteratorType == utils::IteratorType::reduction) {
-      reductionDim = idx;
-      break;
-    }
+    if (iteratorType == utils::IteratorType::reduction)
+      reductionDims.push_back(idx);
-  if (static_cast<size_t>(reductionDim) >= tileSize.size())
-    return b.notifyMatchFailure(op, "reduction dimension must be tiled");
   // 1. create the inital tensor value.
   FailureOr<Operation *> identityTensor =
       op.generateInitialTensorForPartialReduction(b, loc, tileSize,
-                                                  reductionDim);
+                                                  reductionDims);
   if (failed(identityTensor))
     return b.notifyMatchFailure(op,
                                 "cannot create a tensor of identity value.");
@@ -450,7 +442,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   // 3. Generate the tiled implementation within the inner most loop.
   Operation *parallelOp = op.tileToPartialReduction(
-      b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim);
+      b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
   SmallVector<OpFoldResult> resultSizesList;
   for (size_t i = 0; i < offsets.size(); i++)
@@ -472,7 +464,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   // 4. Apply the merge reduction to combine all the partial values.
-  Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
+  Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
   b.replaceOp(op, mergeOp->getResults());
   SCFReductionTilingResult results;

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 8c67b80c63f463..70e535b74f055b 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -353,3 +353,35 @@ module {
     %for_op, %fill_op, %split_linalg_op, %combining_linalg_op = transform.structured.tile_reduction_using_scf %0 by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+// -----
+#map = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+module {
+  func.func @reduction_tile_multiple_reduction(%arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
+    %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %1 = arith.mulf %in, %in_0 : f32
+      %2 = arith.addf %1, %out : f32
+      linalg.yield %2 : f32
+    } -> tensor<4096xf32>
+    return %0 : tensor<4096xf32>
+  }
+  transform.sequence  failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %for_op, %fill_op, %split_linalg_op, %combining_linalg_op = transform.structured.tile_reduction_using_scf %0 by tile_sizes = [0, 2, 64] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+  }
+// CHECK: func @reduction_tile_multiple_reduction(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>
+// CHECK:   %[[F:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<4096x2x64xf32>) -> tensor<4096x2x64xf32>
+// CHECK:   %[[L0:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<4096x2x64xf32>)
+// CHECK:     %[[L1:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]) -> (tensor<4096x2x64xf32>)
+// CHECK:       %[[OUT:.*]] = linalg.generic  {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>)
+// CHECK:       scf.yield %[[OUT]] : tensor<4096x2x64xf32>
+// CHECK:     scf.yield %[[L1]] : tensor<4096x2x64xf32>
+// CHECK:   %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
+// CHECK:  return %[[OUT2]] : tensor<4096xf32>


More information about the Mlir-commits mailing list