[Mlir-commits] [mlir] d571639 - [mlir][Linalg] SplitReduction implementation without tensor::ExpandShapeOp
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jun 22 12:08:40 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-22T12:06:58-07:00
New Revision: d5716395792696f2b56a0d4debadd040ee385143
URL: https://github.com/llvm/llvm-project/commit/d5716395792696f2b56a0d4debadd040ee385143
DIFF: https://github.com/llvm/llvm-project/commit/d5716395792696f2b56a0d4debadd040ee385143.diff
LOG: [mlir][Linalg] SplitReduction implementation without tensor::ExpandShapeOp
This revision proposes a different implementation of the SplitReductoin transformation that does
not rely on tensor::ExpandShapeOp.
Previously, a dimension `[k]` would be split into `[k][kk]` via an ExpandShapeOp.
Instead, this revision proposes to rewrite `[k]` into `[factor * k + kk]`.
There are different tradeoffs involved but the proposed implementation is more general because
the affine rewrite is well-defined. In particular, it works naturally with `?` parallel dimensions and
non-trivial indexing maps.
A further rewrite of `[factor * k + kk]` + ExpandShapeOp is possible as a followup.
Differential Revision: https://reviews.llvm.org/D128266
Added:
mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/IR/AffineMap.h
mlir/include/mlir/IR/BuiltinTypes.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2d8a4986e09d6..8f0dc16d35ab7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -222,6 +222,89 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
}];
}
+def SplitReductionByScalingOp :
+ Op<Transform_Dialect, "structured.split_reduction_by_scaling",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Indicates that the given `target` op should be transformed with the
+ `splitReductionByScaling` transformation and split factor provided as
+ attribute.
+
+ Instead of introducing an ExpandShapeOp, this scaling-based implementation
+ rewrites a reduction dimension `k` into `k * split_factor + kk`.
+ The dimension `kk` is added as an extra parallel dimension to the
+ intermediate output tensor at position `insert_split_dimension`.
+
+ Consider a minimal example where `k` is reduced:
+ O(i, j) += I(i, j, k)
+ Assume i=3, j=5, k=128, split_factor=16 and insert_split_dimension=0.
+ The compute is rewritten as:
+ a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
+ b. O(i, j) += O_i(kk, i, j)
+ The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
+
+ Example:
+
+ ```
+ %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ ```
+
+ Is transformed to:
+
+ ```
+ #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)>
+ #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)>
+ #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+ #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+ #map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+ %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) ->
+ tensor<16x32x64xf32>
+ %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1>
+
+ %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
+ outs(%1 : tensor<16x32x64xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32):
+ %5 = arith.mulf %arg3, %arg4 : f32
+ %6 = arith.addf %arg6, %5 : f32
+ linalg.yield %6 : f32
+ } -> tensor<16x32x64xf32>
+
+ %4 = linalg.generic {indexing_maps = [#map4, #map5],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%3 : tensor<16x32x64xf32>)
+ outs(%C : tensor<16x32xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %5 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %5 : f32
+ } -> tensor<16x32xf32>
+
+ return %4 : tensor<16x32xf32>
+ ```
+
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64Attr, "{}">:$split_factor,
+ DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
+ let results = (outs PDL_Operation:$fill_op,
+ PDL_Operation:$split_linalg_op,
+ PDL_Operation:$combining_linalg_op);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+ ::mlir::linalg::LinalgOp target, TransformState &state);
+ }];
+}
+
def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7e2d58939da1c..78f17c1620ba9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1532,6 +1532,56 @@ FailureOr<SplitReductionResult>
splitReduction(PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn);
+/// Scaling-based implementation of the split reduction transformation.
+/// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension
+/// `k` into `k * scale + kk`.
+///
+/// Example:
+/// ```
+/// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
+/// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+/// ```
+///
+/// Is transformed to:
+///
+/// ```
+/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)>
+/// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)>
+/// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+/// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+/// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+/// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+/// %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32>
+/// %cst = arith.constant 0.000000e+00 : f32
+/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) ->
+/// tensor<16x32x64xf32>
+/// %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1>
+///
+/// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3],
+/// iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+/// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
+/// outs(%1 : tensor<16x32x64xf32>) {
+/// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32):
+/// %5 = arith.mulf %arg3, %arg4 : f32
+/// %6 = arith.addf %arg6, %5 : f32
+/// linalg.yield %6 : f32
+/// } -> tensor<16x32x64xf32>
+///
+/// %4 = linalg.generic {indexing_maps = [#map4, #map5],
+/// iterator_types = ["parallel", "parallel", "reduction"]}
+// ins(%3 : tensor<16x32x64xf32>)
+/// outs(%C : tensor<16x32xf32>) {
+/// ^bb0(%arg3: f32, %arg4: f32):
+/// %5 = arith.addf %arg3, %arg4 : f32
+/// linalg.yield %5 : f32
+/// } -> tensor<16x32xf32>
+///
+/// return %4 : tensor<16x32xf32>
+/// ```
+FailureOr<SplitReductionResult>
+splitReductionByScaling(PatternRewriter &b, LinalgOp op,
+ const ControlSplitReductionFn &controlSplitReductionFn);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 87ac693492113..de94f43708fad 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -240,6 +240,22 @@ class AffineMap {
getContext());
}
+ /// Returns a new AffineMap with the same number of dims and symbols and one
+ /// less result at `pos`, dropped.
+ AffineMap dropResult(unsigned pos) {
+ auto exprs = llvm::to_vector<4>(getResults());
+ exprs.erase(exprs.begin() + pos);
+ return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
+ }
+
+ /// Returns a new AffineMap with the same number of dims and symbols and an
+ /// extra result inserted at `pos`.
+ AffineMap insertResult(AffineExpr expr, unsigned pos) {
+ auto exprs = llvm::to_vector<4>(getResults());
+ exprs.insert(exprs.begin() + pos, expr);
+ return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
+ }
+
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible.
LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 60c61cdd56a76..4bdc10a25023f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -249,6 +249,16 @@ class RankedTensorType::Builder {
return *this;
}
+ /// Insert a val into shape @pos.
+ Builder &insertDim(int64_t val, unsigned pos) {
+ assert(pos <= shape.size() && "overflow");
+ if (storage.empty())
+ storage.append(shape.begin(), shape.end());
+ storage.insert(storage.begin() + pos, val);
+ shape = {storage.data(), storage.size()};
+ return *this;
+ }
+
operator RankedTensorType() {
return RankedTensorType::get(shape, elementType, encoding);
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f8ce4701ab74d..e495a3ddfd483 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -421,6 +421,28 @@ transform::SplitReductionOp::applyToOne(LinalgOp target,
splitResult->resultCombiningLinalgOp};
}
+//===----------------------------------------------------------------------===//
+// SplitReductionByScalingOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<SmallVector<Operation *>>
+transform::SplitReductionByScalingOp::applyToOne(LinalgOp target,
+ TransformState &state) {
+ ControlSplitReductionFn splitFn = [&](LinalgOp) {
+ return std::pair<int64_t, unsigned>(getSplitFactor(),
+ getInsertSplitDimension());
+ };
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<SplitReductionResult> splitResult =
+ splitReductionByScaling(rewriter, target, splitFn);
+ if (failed(splitResult))
+ return getOperation()->emitError("failed to apply");
+ return SmallVector<Operation *>{splitResult->fillOp,
+ splitResult->splitLinalgOp,
+ splitResult->resultCombiningLinalgOp};
+}
+
//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 226b35d4495ce..8834000edd69b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -19,13 +19,14 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::linalg;
/// Return the identity numeric value associated to the give op.
-static Optional<Attribute> getIdentity(Operation *op) {
+static Attribute getNeutralElement(Operation *op) {
// Builder only used as helper for attribute creation.
OpBuilder b(op->getContext());
Type resultType = op->getResult(0).getType();
@@ -41,7 +42,7 @@ static Optional<Attribute> getIdentity(Operation *op) {
if (isa<arith::MinFOp>(op))
return b.getFloatAttr(resultType,
llvm::APFloat::getLargest(semantic, true));
- return llvm::None;
+ return Attribute();
}
if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
return b.getIntegerAttr(resultType, 0);
@@ -53,7 +54,7 @@ static Optional<Attribute> getIdentity(Operation *op) {
return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
if (isa<arith::MulIOp>(op))
return b.getIntegerAttr(resultType, 1);
- return llvm::None;
+ return Attribute();
}
FailureOr<LinalgOp> mlir::linalg::splitReduction(
@@ -84,7 +85,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
int64_t ratio = control.first;
- unsigned insertDimIndex = control.second;
+ unsigned insertSplitDimension = control.second;
if (ratio <= 1)
return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
@@ -95,7 +96,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
int64_t reductionDimSize = loopRanges[reductionDim];
if (reductionDimSize == ShapedType::kDynamicSize ||
- reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size())
+ reductionDimSize % ratio != 0 ||
+ insertSplitDimension >= loopRanges.size())
return b.notifyMatchFailure(
op, "Reduction dimension not divisible by split ratio");
@@ -105,7 +107,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
Operation *reductionOp = combinerOps[0];
- Optional<Attribute> identity = getIdentity(reductionOp);
+ Attribute identity = getNeutralElement(reductionOp);
if (!identity)
return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
@@ -125,13 +127,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
newShape.push_back(ratio);
newShape.push_back(op.getShape(operand)[idx] / ratio);
reassociation.push_back({index++, index++});
- exprs.push_back(b.getAffineDimExpr(insertDimIndex));
+ exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
exprs.push_back(
- b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+ b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
continue;
}
newShape.push_back(op.getShape(operand)[idx]);
- exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+ exprs.push_back(
+ b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
reassociation.push_back({index++});
}
newMaps.push_back(
@@ -157,20 +160,20 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
SmallVector<AffineExpr> outputExpr;
for (unsigned idx :
llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
- if (idx == insertDimIndex) {
+ if (idx == insertSplitDimension) {
newOutputShape.push_back(ratio);
- outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
+ outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
continue;
}
- unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
+ unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1;
newOutputShape.push_back(oldShape[oldDim]);
unsigned dim = oldOutputMap.getDimPosition(oldDim);
outputExpr.push_back(
- b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+ b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
}
Value initTensor = b.create<linalg::InitTensorOp>(
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
- Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
+ Value constantOp = b.create<arith::ConstantOp>(loc, identity);
Value identityTensor =
b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
.getResult(0);
@@ -179,7 +182,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
op.getContext()));
SmallVector<StringRef> newIteratorTypes;
for (auto &it : llvm::enumerate(op.iterator_types())) {
- if (insertDimIndex == it.index())
+ if (insertSplitDimension == it.index())
newIteratorTypes.push_back(getParallelIteratorTypeName());
newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
}
@@ -199,7 +202,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
SmallVector<StringRef> reductionIteratorTypes;
SmallVector<AffineExpr> exprs;
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
- if (insertDimIndex == i) {
+ if (insertSplitDimension == i) {
reductionIteratorTypes.push_back(getReductionIteratorTypeName());
} else {
exprs.push_back(b.getAffineDimExpr(i));
@@ -225,6 +228,206 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
reduction};
}
+/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
+/// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
+/// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
+/// done as a transform to enable better vectorization.
+static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
+ unsigned reductionDimPos,
+ int64_t reductionRatio) {
+ auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
+ auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
+ AffineMap map = op.getTiedIndexingMap(&opOperand);
+ AffineMap idMap =
+ AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
+ AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
+ AffineMap composeMap = shiftedIdMap.replace(
+ reductionDim, reductionDim * reductionRatio + reductionDimP1,
+ shiftedIdMap.getNumDims(), /*numSymbols=*/0);
+ return map.compose(composeMap);
+}
+
+static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
+ unsigned reductionDimPos, int64_t size) {
+ auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
+ AffineMap map = op.getTiedIndexingMap(&opOperand);
+ AffineMap idMap =
+ AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
+ AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
+ return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
+}
+
+/// Core rewrite implementation.
+FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
+ PatternRewriter &b, LinalgOp op,
+ const ControlSplitReductionFn &controlSplitReductionFn) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPoint(op);
+
+ // Matcher part, enforce preconditions.
+ std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
+ int64_t splitFactor = control.first;
+ unsigned insertSplitDimension = control.second;
+ if (splitFactor <= 1)
+ return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
+
+ SmallVector<unsigned> dims;
+ op.getReductionDims(dims);
+ if (dims.empty())
+ return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
+
+ unsigned reductionDimPos = dims[0];
+ SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
+ int64_t reductionDimSize = loopRanges[reductionDimPos];
+ if (reductionDimSize == ShapedType::kDynamicSize ||
+ reductionDimSize % splitFactor != 0 ||
+ insertSplitDimension >= loopRanges.size())
+ return b.notifyMatchFailure(
+ op, "first reduction dimension not divisible by split factor");
+
+ SmallVector<Operation *> combinerOps;
+ if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
+ return b.notifyMatchFailure(op, "cannot match a reduction pattern");
+
+ SmallVector<Attribute> neutralElements = llvm::to_vector<4>(
+ llvm::map_range(combinerOps, [&](Operation *reductionOp) {
+ return getNeutralElement(reductionOp);
+ }));
+ if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
+ return b.notifyMatchFailure(op, "unknown reduction neutral");
+
+ // TODO: relax this when multi-reduction support is available.
+ if (op.getNumOutputs() != neutralElements.size())
+ return b.notifyMatchFailure(op, "expect one reduction per output");
+
+ // Rewrite part.
+ // Step 1. Build the intermediate outputs filled with the proper
+ // neutralElements. Such outputs are of the same shape with an extra dimension
+ // inserted at `insertSplitDimension`.
+ //
+ // Consider a minimal example where `k` is reduced:
+ // O(i, j) += I(i, j, k)
+ // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
+ // The compute is rewritten as:
+ // a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
+ // b. O(i, j) += O_i(kk, i, j)
+ // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
+ Location loc = op->getLoc();
+ MLIRContext *context = op.getContext();
+ // For now assume outputs are 1-1 with reduction neutralElements.
+ // TODO: generalize when multi-reduction support is available.
+ SmallVector<Value> newOutputs;
+ newOutputs.reserve(op.getNumOutputs());
+ SmallVector<linalg::FillOp> fillOps;
+ fillOps.reserve(op.getNumOutputs());
+ for (auto it : llvm::zip(op.outputs(), neutralElements)) {
+ Value rankedTensor = std::get<0>(it);
+ auto t = rankedTensor.getType().cast<RankedTensorType>();
+ RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
+ reductionDimSize / splitFactor, insertSplitDimension);
+ SmallVector<Value> dims =
+ tensor::createDynamicDimValues(b, loc, rankedTensor);
+ Value initTensor = b.create<linalg::InitTensorOp>(
+ loc, dims, newT.getShape(), t.getElementType());
+ Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
+ fillOps.push_back(
+ b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor));
+ newOutputs.push_back(fillOps.back().getResult(0));
+ }
+
+ // Step 2. Reindex / expand indexing maps.
+ // Reindex existing input indexings: k -> k * splitFactor + k'.
+ SmallVector<AffineMap> newMaps;
+ newMaps.reserve(op.getNumInputsAndOutputs() + 1);
+ for (OpOperand *o : op.getInputOperands())
+ newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
+ // Provision a new indexing for the shape-only tensor.
+ auto nDims = op.getNumLoops() + 1;
+ auto redDim = getAffineDimExpr(reductionDimPos, context);
+ auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
+ newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
+ // Expand existing output indexings.
+ // TODO: a subset of these may not reduce along reducePos and should be
+ // reindexed: k -> k * splitFactor + k', when multi-reduction support is
+ // available.
+ for (OpOperand *o : op.getOutputOperands())
+ newMaps.push_back(insertParallelDim(op, *o, reductionDimPos,
+ reductionDimSize / splitFactor));
+
+ // Step 3. Handle operands.
+ // Compute the new input tensors.
+ auto newInputs = llvm::to_vector<4>(op.inputs());
+ // Add a single shape-only tensor to carry the dimensions without resorting to
+ // more complex inversions.
+ newInputs.push_back(b.create<linalg::InitTensorOp>(
+ loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
+ b.getIntegerType(1)));
+ // Output tensors are already good to go.
+
+ // Step 4. Create the new op matching the original op with an extra parallel
+ // dimension.
+ SmallVector<StringRef> iteratorTypes =
+ llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>());
+ iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
+ getParallelIteratorTypeName());
+ GenericOp genericOp =
+ b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
+ newOutputs, newMaps, iteratorTypes);
+ b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
+ genericOp.region().begin());
+ genericOp.region().front().insertArgument(reductionDimPos,
+ b.getIntegerType(1), loc);
+
+ // Step 5. Create new reduction ops that only reduce the newly added
+ // dimensions from the previous op.
+ // For now assume outputs are 1-1 with reduction ops.
+ // TODO: a subset of these may not reduce in the first place and do not
+ // require a new op, when multi-reduction support is available.
+ // TODO: all results can be handled in a single GenericOp, when
+ // multi-reduction support is available.
+ SmallVector<LinalgOp> results;
+ for (auto it :
+ llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) {
+ Value reindexedOutput = std::get<0>(it);
+ Value originalOutput = std::get<1>(it);
+ auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
+ Operation *combinerOp = std::get<2>(it);
+
+ AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
+ SmallVector<AffineMap> indexingMaps = {
+ map, map.dropResult(insertSplitDimension)};
+ SmallVector<StringRef> reductionIteratorTypes(
+ originalOutputType.getRank() + 1, getParallelIteratorTypeName());
+ reductionIteratorTypes[insertSplitDimension] =
+ getReductionIteratorTypeName();
+
+ // clang-format off
+ auto reductionOp = b.create<GenericOp>(
+ loc,
+ originalOutputType,
+ reindexedOutput,
+ originalOutput,
+ indexingMaps,
+ reductionIteratorTypes,
+ [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
+ Operation *clonedReductionOp = b.clone(*combinerOp);
+ clonedReductionOp->setOperand(0, bbArgs[0]);
+ clonedReductionOp->setOperand(1, bbArgs[1]);
+ b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+ });
+ // clang-format on
+
+ results.push_back(reductionOp);
+ }
+
+ // TODO: extend when multi-reduction support is available.
+ assert(fillOps.size() == results.size() && results.size() == 1);
+ b.replaceOp(op, results.front()->getResults());
+ return SplitReductionResult{fillOps.front(),
+ cast<LinalgOp>(genericOp.getOperation()),
+ results.front()};
+}
+
namespace {
struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
new file mode 100644
index 0000000000000..85ab597d61c18
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
+
+// CHECK-LABEL: func.func @matmul_split
+func.func @matmul_split(%A : tensor<?x256xf32>, %B: tensor<256x32xf32>, %C: tensor<?x32xf32>) -> tensor<?x32xf32> {
+
+ // CHECK: linalg.generic
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+ // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<?x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
+ // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<?x32x64xf32>) {
+
+ // CHECK: linalg.generic
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+ // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<?x32x64xf32>)
+ // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<?x32xf32>) {
+ %0 = linalg.matmul ins(%A, %B: tensor<?x256xf32>, tensor<256x32xf32>)
+ outs(%C: tensor<?x32xf32>) -> tensor<?x32xf32>
+ return %0: tensor<?x32xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1:3 = transform.structured.split_reduction_by_scaling %0 { split_factor = 4, insert_split_dimension = 2}
+ }
+}
More information about the Mlir-commits
mailing list