[Mlir-commits] [mlir] 3310fe5 - [mlir][linalg] Add reduction tiling transformation

Thomas Raoux llvmlistbot at llvm.org
Thu Nov 3 16:07:34 PDT 2022

Author: Thomas Raoux
Date: 2022-11-03T23:07:12Z
New Revision: 3310fe55d9480ef3c27037043a5c3db8c7003914

URL: https://github.com/llvm/llvm-project/commit/3310fe55d9480ef3c27037043a5c3db8c7003914
DIFF: https://github.com/llvm/llvm-project/commit/3310fe55d9480ef3c27037043a5c3db8c7003914.diff

LOG: [mlir][linalg] Add reduction tiling transformation

Add a transformation to tile reduction ops into a parallel operation
followed by a merge operation. This is equivalent to the existing
reduction spliting transformation but using loops instead of using
higher dimensions linalg.

Differential Revision: https://reviews.llvm.org/D136586




diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5c304f5efb6ea..6cb14acb1b089 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -608,6 +608,94 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
+def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
+       [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+        TransformEachOpTrait, TransformOpInterface]> {
+  let description = [{
+    Indicates that the given `target` op should be transformed with the 
+    `tileReduction` transformation with the tile size provided as attribute.
+    This transformation tiles the `target` along the reduction dimensions. It
+    creates a tensor initialized with the identity value. Then it creates nested
+    loops with a parallel version of `target` op inside. The parallel op
+    dimensions are less or equal to the tile size passed by user.
+    After the loop a merge operation is created to do a final reduction with the
+    partial reductions.
+    The initial tensor always uses the tile size dimension. This may overallocate
+    if the tile size is greater than the reduction dimension.
+    #### Return modes
+    This 3 returned handles point to:
+      - the fill op used to initialize the neutral element, 
+      - the parallel tiled op and 
+      - the result-combining op.
+    #### Example:
+    ```
+      %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                              affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%arg0 : tensor<?x?xf32>)
+      outs(%out : tensor<?xf32>) {
+        ^bb0(%arg7: f32, %arg9: f32):
+        %1 = arith.addf %arg7, %arg9 : f32
+        linalg.yield %1 : f32
+      } -> tensor<?xf32>
+      return %red : tensor<?xf32>
+    ```
+    is transformed into:
+    ```
+      %0 = tensor.empty(%dim_1) : tensor<?x5xf32>
+      %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x5xf32>) -> tensor<?x5xf32>
+      %2 = scf.for %arg2 = %c0 to %dim_0 step %c5 iter_args(%arg3 = %1) -> (tensor<?x5xf32>) {
+        %extracted_slice = tensor.extract_slice %1[0, 0] [%dim, 5] [1, 1] : tensor<?x5xf32> to tensor<?x5xf32>
+        %extracted_slice_2 = tensor.extract_slice %arg0[0, %arg2] [%dim, 5] [1, 1] : tensor<?x?xf32> to tensor<?x5xf32>
+        %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 
+                                              affine_map<(d0, d1) -> (d0, d1)>],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%extracted_slice_2 : tensor<?x5xf32>)
+        outs(%extracted_slice : tensor<?x5xf32>) {
+        ^bb0(%in: f32, %out: f32):
+          %5 = arith.addf %in, %out : f32
+          linalg.yield %5 : f32
+        } -> tensor<?x5xf32>
+        %dim_3 = tensor.dim %1, %c0 : tensor<?x5xf32>
+        %inserted_slice = tensor.insert_slice %4 into %arg3[0, 0] [%dim_3, 5] [1, 1] : tensor<?x5xf32> into tensor<?x5xf32>
+        scf.yield %inserted_slice : tensor<?x5xf32>
+      }
+      %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                            affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%2 : tensor<?x5xf32>)
+      outs(%arg1 : tensor<?xf32>) {
+      ^bb0(%in: f32, %out: f32):
+        %4 = arith.addf %in, %out : f32
+        linalg.yield %4 : f32
+      } -> tensor<?xf32>
+    ```
+  }];
+  let arguments = (ins PDL_Operation:$target,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
+  let results = (outs PDL_Operation:$fill_op,
+                      PDL_Operation:$split_linalg_op,
+                      PDL_Operation:$combining_linalg_op);
+  let assemblyFormat = "$target attr-dict";
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::linalg::LinalgOp target, 
+        ::llvm::SmallVectorImpl<::mlir::Operation *> &results, 
+        ::mlir::transform::TransformState &state);
+  }];
 def TileOp : Op<Transform_Dialect, "structured.tile",
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 6a10d4332e7eb..5fc7938e0dd2f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -137,6 +137,10 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
+/// Return the identity numeric value associated to the give op. Return
+/// llvm::None if there is no known neutral element.
+Optional<Attribute> getNeutralElement(Operation *op);
 // Fusion / Tiling utilities

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 6cdef2512f607..9fa4114c77b11 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -136,6 +136,46 @@ tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
 lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
+/// Transformation information returned after reduction tiling.
+struct SCFReductionTilingResult {
+  /// The partial reduction tiled op generated.
+  Operation *parallelTiledOp;
+  /// The final reduction operation merging all the partial reductions.
+  Operation *mergeOp;
+  /// Initial op
+  Operation *initialOp;
+  /// The `scf.for` operations that iterate over the tiles.
+  SmallVector<scf::ForOp> loops;
+/// Method to tile a reduction and generate a parallel op within a serial loop.
+/// Each of the partial reductions are calculated in parallel. Then after the
+/// loop all the partial reduction are merged into a final reduction.
+/// For example for the following sequence
+/// ```mlir
+/// %0 = linalg.generic %in ["parallel", "reduction"]
+///   : tensor<7x9xf32> -> tensor<7xf32>
+/// ```
+/// into:
+/// ```mlir
+/// %0 = linalg.fill ... : tensor<7x4xf32>
+/// %1 = scf.for ... iter_args(%arg0 = %0)
+///   %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32>
+///   %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
+///   %4 = linalg.generic %2, %3 ["parallel", "parallel"]
+///     : tensor<7x?xf32> -> tensor<7x?xf32>
+///   %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32>
+/// }
+/// %6 = linalg.generic %1 ["parallel", "reduction"]
+///   : tensor<7x4xf32> -> tensor<7xf32>
+/// ```
+tileReductionUsingScf(PatternRewriter &b, PartialReductionOpInterface op,
+                      ArrayRef<OpFoldResult> tileSize);
 } // namespace scf
 } // namespace mlir

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 0cdf7a8eb649a..dc6ffcbb7accc 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -155,4 +155,72 @@ def TilingInterface : OpInterface<"TilingInterface"> {
+def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
+  let description = [{
+    Interface for allowing operations to expose information needed to
+    tile reductions using partial reduction followed by merge. This is
+    complementary to TilingInterface to tile reductions.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate a tensor initalized with the identity value of the
+          operation reduction. The tensor shape is equal to operation result
+          shape with new dimension for each non zero tile size.
+        }],
+        /*retType=*/"FailureOr<Operation*>",
+        /*methodName=*/"generateInitialTensorForPartialReduction",
+        /*args=*/(ins
+            "OpBuilder &":$b,
+            "Location ":$loc,
+            "ArrayRef<OpFoldResult>":$sizes,
+            "ArrayRef<int>":$reductionDim),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate a tiled version of the operation where the tiled
+          reduction dimension are converted to parallel dimensions with a size
+          less or equal to the tile size. This is meant to be used with
+          `mergeReductions` method which will combine the partial reductions.
+        }],
+        /*retType=*/"Operation*",
+        /*methodName=*/"tileToPartialReduction",
+        /*args=*/(ins
+            "OpBuilder &":$b,
+            "Location ":$loc,
+            "ValueRange":$init,
+            "ArrayRef<OpFoldResult>":$offsets,
+            "ArrayRef<OpFoldResult>":$sizes,
+            "ArrayRef<int>":$reductionDims),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return nullptr;
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to merge partial reductions for an operation that has been
+          tiled along the reduction dimensions. This will only apply the
+          reduction the operation.
+        }],
+        /*retType=*/"Operation*",
+        /*methodName=*/"mergeReductions",
+        /*args=*/(ins
+            "OpBuilder &":$b,
+            "Location ":$loc,
+            "ValueRange":$partialReduce,
+            "ArrayRef<int>":$reductionDim),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return nullptr;
+        }]
+      >
+  ];

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 513882ec91260..c8a3cb6946e3d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1094,6 +1094,33 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
   return DiagnosedSilenceableFailure(success());
+// SplitReductionOp
+DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
+    linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
+    transform::TransformState &state) {
+  SimpleRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
+  SmallVector<OpFoldResult> sizes;
+  for (int64_t size : tileSizes) {
+    sizes.push_back(rewriter.getIndexAttr(size));
+  }
+  FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
+      rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
+      sizes);
+  if (failed(result))
+    return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+  results.push_back(result->initialOp);
+  results.push_back(result->parallelTiledOp);
+  results.push_back(result->mergeOp);
+  return DiagnosedSilenceableFailure(success());
 // TileOp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 32d05c5acbe6c..0608c361e774b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -26,38 +26,6 @@
 using namespace mlir;
 using namespace mlir::linalg;
-/// Return the identity numeric value associated to the give op.
-static Attribute getNeutralElement(Operation *op) {
-  // Builder only used as helper for attribute creation.
-  OpBuilder b(op->getContext());
-  Type resultType = op->getResult(0).getType();
-  if (auto floatType = resultType.dyn_cast<FloatType>()) {
-    const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
-    if (isa<arith::AddFOp>(op))
-      return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
-    if (isa<arith::MulFOp>(op))
-      return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
-    if (isa<arith::MaxFOp>(op))
-      return b.getFloatAttr(resultType,
-                            llvm::APFloat::getLargest(semantic, true));
-    if (isa<arith::MinFOp>(op))
-      return b.getFloatAttr(resultType,
-                            llvm::APFloat::getLargest(semantic, true));
-    return Attribute();
-  }
-  if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
-    return b.getIntegerAttr(resultType, 0);
-  if (isa<arith::AndIOp>(op))
-    return b.getIntegerAttr(resultType, -1);
-  if (isa<arith::MaxSIOp>(op))
-    return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
-  if (isa<arith::MinSIOp>(op))
-    return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
-  if (isa<arith::MulIOp>(op))
-    return b.getIntegerAttr(resultType, 1);
-  return Attribute();
 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     PatternRewriter &b, LinalgOp op,
     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
@@ -88,8 +56,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
   Operation *reductionOp = combinerOps[0];
-  Attribute identity = getNeutralElement(reductionOp);
-  if (!identity)
+  Optional<Attribute> identity = getNeutralElement(reductionOp);
+  if (!identity.has_value())
     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
   Location loc = op->getLoc();
@@ -187,7 +155,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     emptyOrAllocTensor = b.create<tensor::EmptyOp>(
         loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
-  Value constantOp = b.create<arith::ConstantOp>(loc, identity);
+  Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
   Value identityTensor =
       b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
@@ -309,10 +277,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
     return b.notifyMatchFailure(op, "cannot match a reduction pattern");
-  SmallVector<Attribute> neutralElements = llvm::to_vector<4>(
-      llvm::map_range(combinerOps, [&](Operation *reductionOp) {
-        return getNeutralElement(reductionOp);
-      }));
+  SmallVector<Attribute> neutralElements;
+  for (Operation *reductionOp : combinerOps) {
+    Optional<Attribute> neutralElement = getNeutralElement(reductionOp);
+    if (!neutralElement.has_value())
+      return b.notifyMatchFailure(op, "cannot find neutral element.");
+    neutralElements.push_back(*neutralElement);
+  }
   if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
     return b.notifyMatchFailure(op, "unknown reduction neutral");

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c843f0f400793..d1fcc01ca853d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -8,6 +8,7 @@
 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -240,11 +241,170 @@ struct LinalgOpTilingInterface
+// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
+/// External model implementation of PartialReductionInterface for LinalgOps.
+template <typename LinalgOpTy>
+struct LinalgOpPartialReductionInterface
+    : public PartialReductionOpInterface::ExternalModel<
+          LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
+  FailureOr<Operation *> generateInitialTensorForPartialReduction(
+      Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
+      ArrayRef<int> reductionDims) const {
+    auto linalgOp = cast<LinalgOp>(op);
+    OpBuilder::InsertionGuard guard(b);
+    assert(reductionDims.size() == 1 &&
+           "only support single reduction right now.");
+    if (linalgOp.hasBufferSemantics())
+      return op->emitOpError("expected operation to have tensor semantics");
+    // Insert the new parallel dimension based on the index of the reduction
+    // loop. This could be controlled by user for more flexibility.
+    int64_t insertSplitDimension = reductionDims[0];
+    SmallVector<Operation *, 4> combinerOps;
+    if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
+        combinerOps.size() != 1)
+      return op->emitOpError("Failed to anaysis the reduction operation.");
+    Operation *reductionOp = combinerOps[0];
+    Optional<Attribute> identity = getNeutralElement(reductionOp);
+    if (!identity.has_value())
+      return op->emitOpError(
+          "Failed to get an identity value for the reduction operation.");
+    // Calculate the new shape, we insert the new dimension based on the index
+    // of the reduction dimension.
+    SmallVector<int64_t> newOutputShape;
+    ArrayRef<int64_t> oldShape =
+        linalgOp.getShape(linalgOp.getDpsInitOperand(0));
+    SmallVector<Value> dynamicDims;
+    for (int64_t idx : llvm::seq<int64_t>(0, oldShape.size() + 1)) {
+      if (idx == insertSplitDimension) {
+        dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape,
+                                   ShapedType::kDynamicStrideOrOffset);
+        continue;
+      }
+      int64_t oldIdx = idx < insertSplitDimension ? idx : idx - 1;
+      int64_t dim = oldShape[oldIdx];
+      newOutputShape.push_back(dim);
+      if (ShapedType::isDynamic(dim))
+        dynamicDims.push_back(b.createOrFold<tensor::DimOp>(
+            loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx));
+    }
+    Value emptyTensor = b.create<tensor::EmptyOp>(
+        loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(),
+        dynamicDims);
+    Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
+    auto identityTensor =
+        b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
+    return identityTensor.getOperation();
+  }
+  Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
+                                    ValueRange init,
+                                    ArrayRef<OpFoldResult> offsets,
+                                    ArrayRef<OpFoldResult> sizes,
+                                    ArrayRef<int> reductionDims) const {
+    OpBuilder::InsertionGuard guard(b);
+    auto linalgOp = cast<LinalgOp>(op);
+    assert(reductionDims.size() == 1 &&
+           "only support single reduction right now.");
+    int64_t insertSplitDimension = reductionDims[0];
+    AffineMap oldOutputMap =
+        linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
+    SmallVector<AffineExpr> outputExpr;
+    for (auto &[idx, expr] : llvm::enumerate(oldOutputMap.getResults())) {
+      if (static_cast<int64_t>(idx) == insertSplitDimension) {
+        outputExpr.push_back(b.getAffineDimExpr(reductionDims[0]));
+      }
+      outputExpr.push_back(expr);
+    }
+    if (insertSplitDimension == oldOutputMap.getNumResults())
+      outputExpr.push_back(b.getAffineDimExpr(reductionDims[0]));
+    // Step 1: Extract a slice of the input operands.
+    SmallVector<Value> valuesToTile = linalgOp.getDpsInputOperands();
+    SmallVector<Value, 4> tiledOperands =
+        makeTiledShapes(b, loc, op, valuesToTile, offsets, sizes, {}, true);
+    // Step 2: Extract the accumulator operands
+    SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
+    SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+    // TODO: use SubsetExtractOpInterface once it is available.
+    Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
+                                                 sizes, strides);
+    // Step3. create a generic op where the reduction dimension is replaced by a
+    // parallel dimension of the size of reduction.
+    SmallVector<StringRef> newIteratorTypes = linalgOp.getIteratorTypesArray();
+    newIteratorTypes[reductionDims[0]] = getParallelIteratorTypeName();
+    SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
+    newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
+                                    linalgOp.getContext());
+    auto genericOp =
+        b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands,
+                            ValueRange({out}), newMaps, newIteratorTypes);
+    BlockAndValueMapping mapping;
+    op->getRegion(0).cloneInto(&genericOp.getRegion(),
+                               genericOp.getRegion().begin(), mapping);
+    return genericOp.getOperation();
+  }
+  Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
+                             ValueRange partialReduce,
+                             ArrayRef<int> reductionDims) const {
+    auto linalgOp = cast<LinalgOp>(op);
+    assert(reductionDims.size() == 1 &&
+           "only support single reduction right now.");
+    int64_t dimToMerge = reductionDims[0];
+    // Then create a new reduction that only reduce the newly added dimension
+    // from the previous op.
+    int64_t intermRank =
+        partialReduce[0].getType().cast<ShapedType>().getRank();
+    AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
+    SmallVector<StringRef> reductionIteratorTypes;
+    SmallVector<AffineExpr> exprs;
+    for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
+      if (dimToMerge == i) {
+        reductionIteratorTypes.push_back(getReductionIteratorTypeName());
+      } else {
+        exprs.push_back(b.getAffineDimExpr(i));
+        reductionIteratorTypes.push_back(getParallelIteratorTypeName());
+      }
+    }
+    AffineMap outputMap =
+        AffineMap::get(intermRank, 0, exprs, op->getContext());
+    SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
+    SmallVector<Operation *, 4> combinerOps;
+    matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps);
+    Operation *reductionOp = combinerOps[0];
+    auto reduction = b.create<GenericOp>(
+        loc, op->getResultTypes(), ValueRange({partialReduce[0]}),
+        SmallVector<Value>{linalgOp.getDpsInitOperands()}, reductionMaps,
+        reductionIteratorTypes,
+        [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
+          Operation *clonedReductionOp = b.clone(*reductionOp);
+          clonedReductionOp->setOperand(0, inputs[0]);
+          clonedReductionOp->setOperand(1, inputs[1]);
+          b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+        });
+    return reduction.getOperation();
+  }
 } // namespace
 template <typename OpType>
 static void registerOne(MLIRContext *ctx) {
   OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
+  OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
+      *ctx);
 /// Variadic helper function.

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index ce15c6767b24b..04cbed0c4e135 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -948,13 +948,14 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
   SmallVector<OpFoldResult> subShapeSizes =
       computeTileSizes(builder, loc, tileSizes, sizeBounds);
-  assert(static_cast<int64_t>(valuesToTile.size()) ==
+  assert(static_cast<int64_t>(valuesToTile.size()) <=
              linalgOp->getNumOperands() &&
-         "expected one value to tile for every operand");
+         "more value to tile than operands.");
   SmallVector<Optional<SliceParameters>> allSliceParams;
-  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
-    Value shapedOp = valuesToTile[opOperand.getOperandNumber()];
+  for (auto [opOperand, val] :
+       llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
+    Value shapedOp = val;
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
     AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
     // Use `opOperand` as is if it is not tiled and not an output tensor. Having
@@ -1059,5 +1060,37 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
   return reassociation;
+/// Return the identity numeric value associated to the give op.
+Optional<Attribute> getNeutralElement(Operation *op) {
+  // Builder only used as helper for attribute creation.
+  OpBuilder b(op->getContext());
+  Type resultType = op->getResult(0).getType();
+  if (auto floatType = resultType.dyn_cast<FloatType>()) {
+    const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
+    if (isa<arith::AddFOp>(op))
+      return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
+    if (isa<arith::MulFOp>(op))
+      return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
+    if (isa<arith::MaxFOp>(op))
+      return b.getFloatAttr(resultType,
+                            llvm::APFloat::getLargest(semantic, true));
+    if (isa<arith::MinFOp>(op))
+      return b.getFloatAttr(resultType,
+                            llvm::APFloat::getLargest(semantic, true));
+    return Attribute();
+  }
+  if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
+    return b.getIntegerAttr(resultType, 0);
+  if (isa<arith::AndIOp>(op))
+    return b.getIntegerAttr(resultType, -1);
+  if (isa<arith::MaxSIOp>(op))
+    return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
+  if (isa<arith::MinSIOp>(op))
+    return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
+  if (isa<arith::MulIOp>(op))
+    return b.getIntegerAttr(resultType, 1);
+  return llvm::None;
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2d6edb7332ac8..0c86bd4d1262a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -424,6 +424,90 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   return tilingResult;
+mlir::scf::tileReductionUsingScf(PatternRewriter &b,
+                                 PartialReductionOpInterface op,
+                                 ArrayRef<OpFoldResult> tileSize) {
+  Location loc = op.getLoc();
+  // Ops implementing PartialReductionOpInterface are expected to implement
+  // TilingInterface.
+  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
+  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
+  SmallVector<Value> tileSizeVector =
+      getValueOrCreateConstantIndexOp(b, loc, tileSize);
+  if (tileSizeVector.size() < iterationDomain.size()) {
+    auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
+    tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
+  }
+  if (op->getNumResults() != 1)
+    return b.notifyMatchFailure(
+        op, "don't support ops with multiple results for now");
+  SmallVector<utils::IteratorType> iterators =
+      tilingInterfaceOp.getLoopIteratorTypes();
+  int64_t numReductionDims = llvm::count(
+      tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
+  if (numReductionDims != 1)
+    return b.notifyMatchFailure(
+        op, "only support ops with one reduction dimension.");
+  int reductionDim;
+  for (auto &[idx, iteratorType] :
+       llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
+    if (iteratorType == utils::IteratorType::reduction) {
+      reductionDim = idx;
+      break;
+    }
+  }
+  // 1. create the inital tensor value.
+  FailureOr<Operation *> identityTensor =
+      op.generateInitialTensorForPartialReduction(b, loc, tileSize,
+                                                  reductionDim);
+  if (failed(identityTensor))
+    return b.notifyMatchFailure(op,
+                                "cannot create a tensor of identity value.");
+  // 2. Create the nested loops.
+  SmallVector<OpFoldResult> offsets, sizes;
+  SmallVector<scf::ForOp> loops = generateTileLoopNest(
+      b, loc, iterationDomain, tileSizeVector, offsets, sizes);
+  // 3. Generate the tiled implementation within the inner most loop.
+  b.setInsertionPoint(loops.back().getBody()->getTerminator());
+  Operation *parallelOp =
+      op.tileToPartialReduction(b, loc, identityTensor.value()->getResults(),
+                                offsets, sizes, reductionDim);
+  SmallVector<OpFoldResult> resultSizesList;
+  for (size_t i = 0; i < offsets.size(); i++)
+    resultSizesList.push_back(
+        b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
+  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+  FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
+      b, identityTensor.value()->getResults(), parallelOp->getResults(),
+      outOffsets, resultSizesList, loops);
+  if (failed(replacementOr))
+    return b.notifyMatchFailure(op, "failed to yield replacement");
+  auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
+  auto innerMostLoop = loops.back();
+  SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
+  assert(destinationTensors.size() ==
+             innerMostLoop.getRegionIterArgs().size() &&
+         "unexpected number of outputs");
+  updateDestinationOperandsForTiledOp(b, destinationTensors,
+                                      innerMostLoop.getRegionIterArgs());
+  // 4. Apply the merge reduction to combine all the partial values.
+  b.setInsertionPointAfter(*loops.begin());
+  Operation *mergeOp =
+      op.mergeReductions(b, loc, replacementOr.value(), reductionDim);
+  b.replaceOp(op, mergeOp->getResults());
+  SCFReductionTilingResult results;
+  results.initialOp = identityTensor.value();
+  results.loops = std::move(loops);
+  results.parallelTiledOp = parallelOp;
+  results.mergeOp = mergeOp;
+  return results;
 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
new file mode 100644
index 0000000000000..dad2f8476d1ff
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -canonicalize | FileCheck %s
+func.func @reduction_tile(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0)>],
+   iterator_types = ["parallel", "reduction"]}
+   ins(%arg0 : tensor<?x?xf32>)
+   outs(%out : tensor<?xf32>) {
+    ^bb0(%arg7: f32, %arg9: f32):
+      %1 = arith.mulf %arg7, %arg7 : f32
+      %2 = arith.addf %1, %arg9 : f32
+      linalg.yield %2 : f32
+    } -> tensor<?xf32>
+  return %red : tensor<?xf32>
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 5] }
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+//     CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
+// CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+//     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
+//     CHECK:   %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
+//     CHECK:     %[[PS:.*]] = affine.min #[[MAP2]](%[[K]])[%[[D1]]]
+//     CHECK:     %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
+//     CHECK:     %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
+//     CHECK:       arith.mulf
+//     CHECK:       arith.addf
+//     CHECK:       linalg.yield
+//     CHECK:     } -> tensor<?x?xf32>
+//     CHECK:     %[[D3:.*]] = tensor.dim %[[PR]], %[[C0]] : tensor<?x?xf32>
+//     CHECK:     %[[D4:.*]] = tensor.dim %[[PR]], %[[C1]] : tensor<?x?xf32>
+//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
+//     CHECK:     scf.yield %[[INS]] : tensor<?x5xf32>
+//     CHECK:   }
+//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+//     CHECK:     arith.addf
+//     CHECK:     linalg.yield
+//     CHECK:   } -> tensor<?xf32>
+//     CHECK:   return %[[R]] : tensor<?xf32>
+// -----
+func.func @reduction_tile_transpose(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d1)>],
+   iterator_types = ["reduction", "parallel"]}
+   ins(%arg0 : tensor<?x?xf32>)
+   outs(%out : tensor<?xf32>) {
+    ^bb0(%arg7: f32, %arg9: f32):
+      %42 = arith.addf %arg7, %arg9 : f32
+      linalg.yield %42 : f32
+    } -> tensor<?xf32>
+  return %red : tensor<?xf32>
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [5, 0] }
+//     CHECK: func @reduction_tile_transpose
+//     CHECK:   tensor.empty(%{{.*}}) : tensor<5x?xf32>
+//     CHECK:   linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
+//     CHECK:   scf.for
+//     CHECK:     linalg.generic
+//     CHECK:     %[[D3:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x?xf32>
+//     CHECK:     %[[D4:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<?x?xf32>
+//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
+//     CHECK:     scf.yield {{.*}} : tensor<5x?xf32>
+//     CHECK:   }
+//     CHECK:   linalg.generic 
+//     CHECK:   return


More information about the Mlir-commits mailing list