[Mlir-commits] [mlir] 3310fe5 - [mlir][linalg] Add reduction tiling transformation
Thomas Raoux
llvmlistbot at llvm.org
Thu Nov 3 16:07:34 PDT 2022
Author: Thomas Raoux
Date: 2022-11-03T23:07:12Z
New Revision: 3310fe55d9480ef3c27037043a5c3db8c7003914
URL: https://github.com/llvm/llvm-project/commit/3310fe55d9480ef3c27037043a5c3db8c7003914
DIFF: https://github.com/llvm/llvm-project/commit/3310fe55d9480ef3c27037043a5c3db8c7003914.diff
LOG: [mlir][linalg] Add reduction tiling transformation
Add a transformation to tile reduction ops into a parallel operation
followed by a merge operation. This is equivalent to the existing
reduction spliting transformation but using loops instead of using
higher dimensions linalg.
Differential Revision: https://reviews.llvm.org/D136586
Added:
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/include/mlir/Interfaces/TilingInterface.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5c304f5efb6ea..6cb14acb1b089 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -608,6 +608,94 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
}];
}
+
+def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Indicates that the given `target` op should be transformed with the
+ `tileReduction` transformation with the tile size provided as attribute.
+
+ This transformation tiles the `target` along the reduction dimensions. It
+ creates a tensor initialized with the identity value. Then it creates nested
+ loops with a parallel version of `target` op inside. The parallel op
+ dimensions are less or equal to the tile size passed by user.
+ After the loop a merge operation is created to do a final reduction with the
+ partial reductions.
+ The initial tensor always uses the tile size dimension. This may overallocate
+ if the tile size is greater than the reduction dimension.
+
+ #### Return modes
+
+ This 3 returned handles point to:
+ - the fill op used to initialize the neutral element,
+ - the parallel tiled op and
+ - the result-combining op.
+
+ #### Example:
+
+ ```
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%out : tensor<?xf32>) {
+ ^bb0(%arg7: f32, %arg9: f32):
+ %1 = arith.addf %arg7, %arg9 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?xf32>
+ return %red : tensor<?xf32>
+ ```
+
+ is transformed into:
+
+ ```
+ %0 = tensor.empty(%dim_1) : tensor<?x5xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x5xf32>) -> tensor<?x5xf32>
+ %2 = scf.for %arg2 = %c0 to %dim_0 step %c5 iter_args(%arg3 = %1) -> (tensor<?x5xf32>) {
+ %extracted_slice = tensor.extract_slice %1[0, 0] [%dim, 5] [1, 1] : tensor<?x5xf32> to tensor<?x5xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[0, %arg2] [%dim, 5] [1, 1] : tensor<?x?xf32> to tensor<?x5xf32>
+ %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%extracted_slice_2 : tensor<?x5xf32>)
+ outs(%extracted_slice : tensor<?x5xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %5 = arith.addf %in, %out : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x5xf32>
+ %dim_3 = tensor.dim %1, %c0 : tensor<?x5xf32>
+ %inserted_slice = tensor.insert_slice %4 into %arg3[0, 0] [%dim_3, 5] [1, 1] : tensor<?x5xf32> into tensor<?x5xf32>
+ scf.yield %inserted_slice : tensor<?x5xf32>
+ }
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%2 : tensor<?x5xf32>)
+ outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %4 = arith.addf %in, %out : f32
+ linalg.yield %4 : f32
+ } -> tensor<?xf32>
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
+ let results = (outs PDL_Operation:$fill_op,
+ PDL_Operation:$split_linalg_op,
+ PDL_Operation:$combining_linalg_op);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::linalg::LinalgOp target,
+ ::llvm::SmallVectorImpl<::mlir::Operation *> &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 6a10d4332e7eb..5fc7938e0dd2f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -137,6 +137,10 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
Optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
+/// Return the identity numeric value associated to the give op. Return
+/// llvm::None if there is no known neutral element.
+Optional<Attribute> getNeutralElement(Operation *op);
+
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 6cdef2512f607..9fa4114c77b11 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -136,6 +136,46 @@ tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
FailureOr<SmallVector<scf::ForOp>>
lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
+/// Transformation information returned after reduction tiling.
+struct SCFReductionTilingResult {
+ /// The partial reduction tiled op generated.
+ Operation *parallelTiledOp;
+ /// The final reduction operation merging all the partial reductions.
+ Operation *mergeOp;
+ /// Initial op
+ Operation *initialOp;
+ /// The `scf.for` operations that iterate over the tiles.
+ SmallVector<scf::ForOp> loops;
+};
+
+/// Method to tile a reduction and generate a parallel op within a serial loop.
+/// Each of the partial reductions are calculated in parallel. Then after the
+/// loop all the partial reduction are merged into a final reduction.
+/// For example for the following sequence
+///
+/// ```mlir
+/// %0 = linalg.generic %in ["parallel", "reduction"]
+/// : tensor<7x9xf32> -> tensor<7xf32>
+/// ```
+///
+/// into:
+///
+/// ```mlir
+/// %0 = linalg.fill ... : tensor<7x4xf32>
+/// %1 = scf.for ... iter_args(%arg0 = %0)
+/// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32>
+/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
+/// %4 = linalg.generic %2, %3 ["parallel", "parallel"]
+/// : tensor<7x?xf32> -> tensor<7x?xf32>
+/// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32>
+/// }
+/// %6 = linalg.generic %1 ["parallel", "reduction"]
+/// : tensor<7x4xf32> -> tensor<7xf32>
+/// ```
+FailureOr<scf::SCFReductionTilingResult>
+tileReductionUsingScf(PatternRewriter &b, PartialReductionOpInterface op,
+ ArrayRef<OpFoldResult> tileSize);
+
} // namespace scf
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 0cdf7a8eb649a..dc6ffcbb7accc 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -155,4 +155,72 @@ def TilingInterface : OpInterface<"TilingInterface"> {
>
];
}
+
+def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ tile reductions using partial reduction followed by merge. This is
+ complementary to TilingInterface to tile reductions.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate a tensor initalized with the identity value of the
+ operation reduction. The tensor shape is equal to operation result
+ shape with new dimension for each non zero tile size.
+ }],
+ /*retType=*/"FailureOr<Operation*>",
+ /*methodName=*/"generateInitialTensorForPartialReduction",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "Location ":$loc,
+ "ArrayRef<OpFoldResult>":$sizes,
+ "ArrayRef<int>":$reductionDim),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate a tiled version of the operation where the tiled
+ reduction dimension are converted to parallel dimensions with a size
+ less or equal to the tile size. This is meant to be used with
+ `mergeReductions` method which will combine the partial reductions.
+ }],
+ /*retType=*/"Operation*",
+ /*methodName=*/"tileToPartialReduction",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "Location ":$loc,
+ "ValueRange":$init,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes,
+ "ArrayRef<int>":$reductionDims),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return nullptr;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to merge partial reductions for an operation that has been
+ tiled along the reduction dimensions. This will only apply the
+ reduction the operation.
+ }],
+ /*retType=*/"Operation*",
+ /*methodName=*/"mergeReductions",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "Location ":$loc,
+ "ValueRange":$partialReduce,
+ "ArrayRef<int>":$reductionDim),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return nullptr;
+ }]
+ >
+ ];
+}
#endif // MLIR_TILINGINTERFACE
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 513882ec91260..c8a3cb6946e3d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1094,6 +1094,33 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
return DiagnosedSilenceableFailure(success());
}
+//===----------------------------------------------------------------------===//
+// SplitReductionOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
+ linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
+ transform::TransformState &state) {
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(target);
+ SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
+ SmallVector<OpFoldResult> sizes;
+ for (int64_t size : tileSizes) {
+ sizes.push_back(rewriter.getIndexAttr(size));
+ }
+
+ FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
+ rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
+ sizes);
+
+ if (failed(result))
+ return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+ results.push_back(result->initialOp);
+ results.push_back(result->parallelTiledOp);
+ results.push_back(result->mergeOp);
+ return DiagnosedSilenceableFailure(success());
+}
+
//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 32d05c5acbe6c..0608c361e774b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -26,38 +26,6 @@
using namespace mlir;
using namespace mlir::linalg;
-/// Return the identity numeric value associated to the give op.
-static Attribute getNeutralElement(Operation *op) {
- // Builder only used as helper for attribute creation.
- OpBuilder b(op->getContext());
- Type resultType = op->getResult(0).getType();
- if (auto floatType = resultType.dyn_cast<FloatType>()) {
- const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
- if (isa<arith::AddFOp>(op))
- return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
- if (isa<arith::MulFOp>(op))
- return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
- if (isa<arith::MaxFOp>(op))
- return b.getFloatAttr(resultType,
- llvm::APFloat::getLargest(semantic, true));
- if (isa<arith::MinFOp>(op))
- return b.getFloatAttr(resultType,
- llvm::APFloat::getLargest(semantic, true));
- return Attribute();
- }
- if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
- return b.getIntegerAttr(resultType, 0);
- if (isa<arith::AndIOp>(op))
- return b.getIntegerAttr(resultType, -1);
- if (isa<arith::MaxSIOp>(op))
- return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
- if (isa<arith::MinSIOp>(op))
- return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
- if (isa<arith::MulIOp>(op))
- return b.getIntegerAttr(resultType, 1);
- return Attribute();
-}
-
FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
@@ -88,8 +56,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
Operation *reductionOp = combinerOps[0];
- Attribute identity = getNeutralElement(reductionOp);
- if (!identity)
+ Optional<Attribute> identity = getNeutralElement(reductionOp);
+ if (!identity.has_value())
return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
Location loc = op->getLoc();
@@ -187,7 +155,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
emptyOrAllocTensor = b.create<tensor::EmptyOp>(
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
}
- Value constantOp = b.create<arith::ConstantOp>(loc, identity);
+ Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
Value identityTensor =
b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
.getResult(0);
@@ -309,10 +277,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
return b.notifyMatchFailure(op, "cannot match a reduction pattern");
- SmallVector<Attribute> neutralElements = llvm::to_vector<4>(
- llvm::map_range(combinerOps, [&](Operation *reductionOp) {
- return getNeutralElement(reductionOp);
- }));
+ SmallVector<Attribute> neutralElements;
+ for (Operation *reductionOp : combinerOps) {
+ Optional<Attribute> neutralElement = getNeutralElement(reductionOp);
+ if (!neutralElement.has_value())
+ return b.notifyMatchFailure(op, "cannot find neutral element.");
+ neutralElements.push_back(*neutralElement);
+ }
if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
return b.notifyMatchFailure(op, "unknown reduction neutral");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c843f0f400793..d1fcc01ca853d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -240,11 +241,170 @@ struct LinalgOpTilingInterface
}
};
+//===----------------------------------------------------------------------===//
+// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
+//===----------------------------------------------------------------------===//
+
+/// External model implementation of PartialReductionInterface for LinalgOps.
+template <typename LinalgOpTy>
+struct LinalgOpPartialReductionInterface
+ : public PartialReductionOpInterface::ExternalModel<
+ LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
+ FailureOr<Operation *> generateInitialTensorForPartialReduction(
+ Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
+ ArrayRef<int> reductionDims) const {
+ auto linalgOp = cast<LinalgOp>(op);
+ OpBuilder::InsertionGuard guard(b);
+ assert(reductionDims.size() == 1 &&
+ "only support single reduction right now.");
+ if (linalgOp.hasBufferSemantics())
+ return op->emitOpError("expected operation to have tensor semantics");
+ // Insert the new parallel dimension based on the index of the reduction
+ // loop. This could be controlled by user for more flexibility.
+ int64_t insertSplitDimension = reductionDims[0];
+
+ SmallVector<Operation *, 4> combinerOps;
+ if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) ||
+ combinerOps.size() != 1)
+ return op->emitOpError("Failed to anaysis the reduction operation.");
+
+ Operation *reductionOp = combinerOps[0];
+ Optional<Attribute> identity = getNeutralElement(reductionOp);
+ if (!identity.has_value())
+ return op->emitOpError(
+ "Failed to get an identity value for the reduction operation.");
+
+ // Calculate the new shape, we insert the new dimension based on the index
+ // of the reduction dimension.
+ SmallVector<int64_t> newOutputShape;
+ ArrayRef<int64_t> oldShape =
+ linalgOp.getShape(linalgOp.getDpsInitOperand(0));
+ SmallVector<Value> dynamicDims;
+ for (int64_t idx : llvm::seq<int64_t>(0, oldShape.size() + 1)) {
+ if (idx == insertSplitDimension) {
+ dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape,
+ ShapedType::kDynamicStrideOrOffset);
+ continue;
+ }
+ int64_t oldIdx = idx < insertSplitDimension ? idx : idx - 1;
+ int64_t dim = oldShape[oldIdx];
+ newOutputShape.push_back(dim);
+ if (ShapedType::isDynamic(dim))
+ dynamicDims.push_back(b.createOrFold<tensor::DimOp>(
+ loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx));
+ }
+ Value emptyTensor = b.create<tensor::EmptyOp>(
+ loc, newOutputShape, linalgOp.getRegionOutputArgs()[0].getType(),
+ dynamicDims);
+ Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
+ auto identityTensor =
+ b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
+ return identityTensor.getOperation();
+ }
+
+ Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc,
+ ValueRange init,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<int> reductionDims) const {
+ OpBuilder::InsertionGuard guard(b);
+ auto linalgOp = cast<LinalgOp>(op);
+ assert(reductionDims.size() == 1 &&
+ "only support single reduction right now.");
+ int64_t insertSplitDimension = reductionDims[0];
+
+ AffineMap oldOutputMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
+ SmallVector<AffineExpr> outputExpr;
+ for (auto &[idx, expr] : llvm::enumerate(oldOutputMap.getResults())) {
+ if (static_cast<int64_t>(idx) == insertSplitDimension) {
+ outputExpr.push_back(b.getAffineDimExpr(reductionDims[0]));
+ }
+ outputExpr.push_back(expr);
+ }
+ if (insertSplitDimension == oldOutputMap.getNumResults())
+ outputExpr.push_back(b.getAffineDimExpr(reductionDims[0]));
+
+ // Step 1: Extract a slice of the input operands.
+ SmallVector<Value> valuesToTile = linalgOp.getDpsInputOperands();
+ SmallVector<Value, 4> tiledOperands =
+ makeTiledShapes(b, loc, op, valuesToTile, offsets, sizes, {}, true);
+
+ // Step 2: Extract the accumulator operands
+ SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
+ SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+ // TODO: use SubsetExtractOpInterface once it is available.
+ Value out = b.create<tensor::ExtractSliceOp>(loc, init[0], outOffsets,
+ sizes, strides);
+
+ // Step3. create a generic op where the reduction dimension is replaced by a
+ // parallel dimension of the size of reduction.
+ SmallVector<StringRef> newIteratorTypes = linalgOp.getIteratorTypesArray();
+ newIteratorTypes[reductionDims[0]] = getParallelIteratorTypeName();
+ SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
+ newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr,
+ linalgOp.getContext());
+ auto genericOp =
+ b.create<GenericOp>(loc, TypeRange({out.getType()}), tiledOperands,
+ ValueRange({out}), newMaps, newIteratorTypes);
+ BlockAndValueMapping mapping;
+ op->getRegion(0).cloneInto(&genericOp.getRegion(),
+ genericOp.getRegion().begin(), mapping);
+ return genericOp.getOperation();
+ }
+
+ Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
+ ValueRange partialReduce,
+ ArrayRef<int> reductionDims) const {
+ auto linalgOp = cast<LinalgOp>(op);
+ assert(reductionDims.size() == 1 &&
+ "only support single reduction right now.");
+ int64_t dimToMerge = reductionDims[0];
+
+ // Then create a new reduction that only reduce the newly added dimension
+ // from the previous op.
+ int64_t intermRank =
+ partialReduce[0].getType().cast<ShapedType>().getRank();
+ AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
+ SmallVector<StringRef> reductionIteratorTypes;
+ SmallVector<AffineExpr> exprs;
+ for (int64_t i : llvm::seq<int64_t>(0, intermRank)) {
+ if (dimToMerge == i) {
+ reductionIteratorTypes.push_back(getReductionIteratorTypeName());
+ } else {
+ exprs.push_back(b.getAffineDimExpr(i));
+ reductionIteratorTypes.push_back(getParallelIteratorTypeName());
+ }
+ }
+ AffineMap outputMap =
+ AffineMap::get(intermRank, 0, exprs, op->getContext());
+ SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
+
+ SmallVector<Operation *, 4> combinerOps;
+ matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps);
+ Operation *reductionOp = combinerOps[0];
+
+ auto reduction = b.create<GenericOp>(
+ loc, op->getResultTypes(), ValueRange({partialReduce[0]}),
+ SmallVector<Value>{linalgOp.getDpsInitOperands()}, reductionMaps,
+ reductionIteratorTypes,
+ [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
+ Operation *clonedReductionOp = b.clone(*reductionOp);
+ clonedReductionOp->setOperand(0, inputs[0]);
+ clonedReductionOp->setOperand(1, inputs[1]);
+ b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+ });
+ return reduction.getOperation();
+ }
+};
+
} // namespace
template <typename OpType>
static void registerOne(MLIRContext *ctx) {
OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
+ OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
+ *ctx);
}
/// Variadic helper function.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index ce15c6767b24b..04cbed0c4e135 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -948,13 +948,14 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
SmallVector<OpFoldResult> subShapeSizes =
computeTileSizes(builder, loc, tileSizes, sizeBounds);
- assert(static_cast<int64_t>(valuesToTile.size()) ==
+ assert(static_cast<int64_t>(valuesToTile.size()) <=
linalgOp->getNumOperands() &&
- "expected one value to tile for every operand");
+ "more value to tile than operands.");
SmallVector<Optional<SliceParameters>> allSliceParams;
allSliceParams.reserve(valuesToTile.size());
- for (OpOperand &opOperand : linalgOp->getOpOperands()) {
- Value shapedOp = valuesToTile[opOperand.getOperandNumber()];
+ for (auto [opOperand, val] :
+ llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
+ Value shapedOp = val;
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
// Use `opOperand` as is if it is not tiled and not an output tensor. Having
@@ -1059,5 +1060,37 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
return reassociation;
}
+/// Return the identity numeric value associated to the give op.
+Optional<Attribute> getNeutralElement(Operation *op) {
+ // Builder only used as helper for attribute creation.
+ OpBuilder b(op->getContext());
+ Type resultType = op->getResult(0).getType();
+ if (auto floatType = resultType.dyn_cast<FloatType>()) {
+ const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
+ if (isa<arith::AddFOp>(op))
+ return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
+ if (isa<arith::MulFOp>(op))
+ return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
+ if (isa<arith::MaxFOp>(op))
+ return b.getFloatAttr(resultType,
+ llvm::APFloat::getLargest(semantic, true));
+ if (isa<arith::MinFOp>(op))
+ return b.getFloatAttr(resultType,
+ llvm::APFloat::getLargest(semantic, true));
+ return Attribute();
+ }
+ if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
+ return b.getIntegerAttr(resultType, 0);
+ if (isa<arith::AndIOp>(op))
+ return b.getIntegerAttr(resultType, -1);
+ if (isa<arith::MaxSIOp>(op))
+ return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
+ if (isa<arith::MinSIOp>(op))
+ return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
+ if (isa<arith::MulIOp>(op))
+ return b.getIntegerAttr(resultType, 1);
+ return llvm::None;
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2d6edb7332ac8..0c86bd4d1262a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -424,6 +424,90 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
return tilingResult;
}
+FailureOr<scf::SCFReductionTilingResult>
+mlir::scf::tileReductionUsingScf(PatternRewriter &b,
+ PartialReductionOpInterface op,
+ ArrayRef<OpFoldResult> tileSize) {
+ Location loc = op.getLoc();
+ // Ops implementing PartialReductionOpInterface are expected to implement
+ // TilingInterface.
+ auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
+ SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
+ SmallVector<Value> tileSizeVector =
+ getValueOrCreateConstantIndexOp(b, loc, tileSize);
+ if (tileSizeVector.size() < iterationDomain.size()) {
+ auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
+ }
+ if (op->getNumResults() != 1)
+ return b.notifyMatchFailure(
+ op, "don't support ops with multiple results for now");
+ SmallVector<utils::IteratorType> iterators =
+ tilingInterfaceOp.getLoopIteratorTypes();
+ int64_t numReductionDims = llvm::count(
+ tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
+ if (numReductionDims != 1)
+ return b.notifyMatchFailure(
+ op, "only support ops with one reduction dimension.");
+ int reductionDim;
+ for (auto &[idx, iteratorType] :
+ llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction) {
+ reductionDim = idx;
+ break;
+ }
+ }
+ // 1. create the inital tensor value.
+ FailureOr<Operation *> identityTensor =
+ op.generateInitialTensorForPartialReduction(b, loc, tileSize,
+ reductionDim);
+ if (failed(identityTensor))
+ return b.notifyMatchFailure(op,
+ "cannot create a tensor of identity value.");
+ // 2. Create the nested loops.
+ SmallVector<OpFoldResult> offsets, sizes;
+ SmallVector<scf::ForOp> loops = generateTileLoopNest(
+ b, loc, iterationDomain, tileSizeVector, offsets, sizes);
+
+ // 3. Generate the tiled implementation within the inner most loop.
+ b.setInsertionPoint(loops.back().getBody()->getTerminator());
+ Operation *parallelOp =
+ op.tileToPartialReduction(b, loc, identityTensor.value()->getResults(),
+ offsets, sizes, reductionDim);
+
+ SmallVector<OpFoldResult> resultSizesList;
+ for (size_t i = 0; i < offsets.size(); i++)
+ resultSizesList.push_back(
+ b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
+ SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
+ FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
+ b, identityTensor.value()->getResults(), parallelOp->getResults(),
+ outOffsets, resultSizesList, loops);
+ if (failed(replacementOr))
+ return b.notifyMatchFailure(op, "failed to yield replacement");
+
+ auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
+ auto innerMostLoop = loops.back();
+ SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
+ assert(destinationTensors.size() ==
+ innerMostLoop.getRegionIterArgs().size() &&
+ "unexpected number of outputs");
+ updateDestinationOperandsForTiledOp(b, destinationTensors,
+ innerMostLoop.getRegionIterArgs());
+
+ // 4. Apply the merge reduction to combine all the partial values.
+ b.setInsertionPointAfter(*loops.begin());
+ Operation *mergeOp =
+ op.mergeReductions(b, loc, replacementOr.value(), reductionDim);
+ b.replaceOp(op, mergeOp->getResults());
+
+ SCFReductionTilingResult results;
+ results.initialOp = identityTensor.value();
+ results.loops = std::move(loops);
+ results.parallelTiledOp = parallelOp;
+ results.mergeOp = mergeOp;
+ return results;
+}
//===----------------------------------------------------------------------===//
// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
new file mode 100644
index 0000000000000..dad2f8476d1ff
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -canonicalize | FileCheck %s
+
+func.func @reduction_tile(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%out : tensor<?xf32>) {
+ ^bb0(%arg7: f32, %arg9: f32):
+ %1 = arith.mulf %arg7, %arg7 : f32
+ %2 = arith.addf %1, %arg9 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ return %red : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 5] }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
+// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
+// CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
+// CHECK: %[[PS:.*]] = affine.min #[[MAP2]](%[[K]])[%[[D1]]]
+// CHECK: %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
+// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<?x?xf32>
+// CHECK: %[[D3:.*]] = tensor.dim %[[PR]], %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[D4:.*]] = tensor.dim %[[PR]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
+// CHECK: scf.yield %[[INS]] : tensor<?x5xf32>
+// CHECK: }
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<?xf32>
+// CHECK: return %[[R]] : tensor<?xf32>
+
+// -----
+
+func.func @reduction_tile_transpose(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>],
+ iterator_types = ["reduction", "parallel"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%out : tensor<?xf32>) {
+ ^bb0(%arg7: f32, %arg9: f32):
+ %42 = arith.addf %arg7, %arg9 : f32
+ linalg.yield %42 : f32
+ } -> tensor<?xf32>
+ return %red : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [5, 0] }
+}
+
+// CHECK: func @reduction_tile_transpose
+// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
+// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
+// CHECK: scf.for
+// CHECK: linalg.generic
+// CHECK: %[[D3:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[D4:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
+// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
+// CHECK: }
+// CHECK: linalg.generic
+// CHECK: return
More information about the Mlir-commits
mailing list