[Mlir-commits] [mlir] 178f9bd - [mlir][Linalg] Uniformize SplitReduction transforms and add option to use Bufferization::AllocTensor
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jun 30 03:32:28 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-30T03:32:23-07:00
New Revision: 178f9bd63c9e0a207acc3ac2461ba53b99576e69
URL: https://github.com/llvm/llvm-project/commit/178f9bd63c9e0a207acc3ac2461ba53b99576e69
DIFF: https://github.com/llvm/llvm-project/commit/178f9bd63c9e0a207acc3ac2461ba53b99576e69.diff
LOG: [mlir][Linalg] Uniformize SplitReduction transforms and add option to use Bufferization::AllocTensor
This revision merges the 2 split_reduction transforms and adds extra control by using attributes.
SplitReduction is known to require a concrete additional buffer to store tempoaray information.
Add an option to introduce a `bufferization.alloc_tensor` instead of `linalg.init_tensor`.
This behaves better with subset-based tiling and bufferization.
Differential Revision: https://reviews.llvm.org/D128722
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f3e42cefb2d45..461388dd61af3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -164,8 +164,24 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
reduction into a parallel and reduction dimension.
A new `linalg.generic` op is created to perform the rest of the reduction.
- Example:
-
+ The transformation supports
diff erent configurations attributes:
+ - split_factor: the factor by which to split (i.e. the size of the
+ remaining reduction after splitting).
+ - insert_split_dimension: the dimension in the temporary tensor into
+ which the new parallel dimension is inserted.
+ - use_scaling_algorithm: whether to use a scaling based formulation that
+ does not create an ExpandShapeOp (default: do not use scaling)
+ - use_alloc: whether to use an alloc op to allocate the temporary
+ tensor (default: do not use alloc op)
+
+ This op returns 4 handles to:
+ - the init op (or tensor_alloc op if use_alloc = true),
+ - the fill op used to initialize the neutral element,
+ - the split op and
+ - the result-combining op.
+
+ Example (default: use_scaling_algorithm = false, use_alloc = false):
+ ====================================================================
```
%r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> ()>],
@@ -178,7 +194,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
} -> tensor<f32>
```
- To:
+ is split into:
```
%cst = arith.constant 0.000000e+00 : f32
@@ -203,34 +219,8 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
} -> tensor<f32>
```
- This op returns handles to the fill op used to initialize the neutral
- element, the split op and the result-combining op.
- }];
-
- let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64Attr, "{}">:$split_factor,
- DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
- let results = (outs PDL_Operation:$fill_op,
- PDL_Operation:$split_linalg_op,
- PDL_Operation:$combining_linalg_op);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
- ::mlir::linalg::LinalgOp target, TransformState &state);
- }];
-}
-
-def SplitReductionByScalingOp :
- Op<Transform_Dialect, "structured.split_reduction_by_scaling",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformEachOpTrait, TransformOpInterface]> {
- let description = [{
- Indicates that the given `target` op should be transformed with the
- `splitReductionByScaling` transformation and split factor provided as
- attribute.
-
+ Example (use_scaling_algorithm = true, use_alloc = true):
+ =========================================================
Instead of introducing an ExpandShapeOp, this scaling-based implementation
rewrites a reduction dimension `k` into `k * split_factor + kk`.
The dimension `kk` is added as an extra parallel dimension to the
@@ -287,12 +277,13 @@ def SplitReductionByScalingOp :
return %4 : tensor<16x32xf32>
```
-
}];
let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<I64Attr, "{}">:$split_factor,
- DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
+ DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension,
+ UnitAttr:$use_scaling_algorithm,
+ UnitAttr:$use_alloc);
let results = (outs PDL_Operation:$fill_op,
PDL_Operation:$split_linalg_op,
PDL_Operation:$combining_linalg_op);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 78f17c1620ba9..6b3230ade0033 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1474,7 +1474,8 @@ using ControlSplitReductionFn =
void populateSplitReductionPattern(
RewritePatternSet &patterns,
const ControlSplitReductionFn &controlSplitReductionFn,
- const LinalgTransformationFilter &f = LinalgTransformationFilter());
+ const LinalgTransformationFilter &f = LinalgTransformationFilter(),
+ bool useAlloc = false);
/// Apply transformation to split the single linalg op reduction into a parallel
/// and reduction dimension. Then create a new linalg.generic op doing the rest
@@ -1518,19 +1519,21 @@ void populateSplitReductionPattern(
FailureOr<LinalgOp>
splitReduction(PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn,
- const LinalgTransformationFilter &f);
+ const LinalgTransformationFilter &f, bool useAlloc = false);
/// Filterless version of the above.
/// Returns both the new linalg ops as well as the fillOp needed to initialize
/// the temporary expanded tensor with the proper neutral element.
struct SplitReductionResult {
+ Operation *initOrAlloc;
FillOp fillOp;
LinalgOp splitLinalgOp;
LinalgOp resultCombiningLinalgOp;
};
FailureOr<SplitReductionResult>
splitReduction(PatternRewriter &b, LinalgOp op,
- const ControlSplitReductionFn &controlSplitReductionFn);
+ const ControlSplitReductionFn &controlSplitReductionFn,
+ bool useAlloc = false);
/// Scaling-based implementation of the split reduction transformation.
/// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension
@@ -1580,7 +1583,8 @@ splitReduction(PatternRewriter &b, LinalgOp op,
/// ```
FailureOr<SplitReductionResult>
splitReductionByScaling(PatternRewriter &b, LinalgOp op,
- const ControlSplitReductionFn &controlSplitReductionFn);
+ const ControlSplitReductionFn &controlSplitReductionFn,
+ bool useAlloc = false);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e495a3ddfd483..b644848c53172 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -413,29 +413,9 @@ transform::SplitReductionOp::applyToOne(LinalgOp target,
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<SplitReductionResult> splitResult =
- splitReduction(rewriter, target, splitFn);
- if (failed(splitResult))
- return getOperation()->emitError("failed to apply");
- return SmallVector<Operation *>{splitResult->fillOp,
- splitResult->splitLinalgOp,
- splitResult->resultCombiningLinalgOp};
-}
-
-//===----------------------------------------------------------------------===//
-// SplitReductionByScalingOp
-//===----------------------------------------------------------------------===//
-
-FailureOr<SmallVector<Operation *>>
-transform::SplitReductionByScalingOp::applyToOne(LinalgOp target,
- TransformState &state) {
- ControlSplitReductionFn splitFn = [&](LinalgOp) {
- return std::pair<int64_t, unsigned>(getSplitFactor(),
- getInsertSplitDimension());
- };
- SimpleRewriter rewriter(getContext());
- rewriter.setInsertionPoint(target);
- FailureOr<SplitReductionResult> splitResult =
- splitReductionByScaling(rewriter, target, splitFn);
+ (getUseScalingAlgorithm())
+ ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
+ : splitReduction(rewriter, target, splitFn, getUseAlloc());
if (failed(splitResult))
return getOperation()->emitError("failed to apply");
return SmallVector<Operation *>{splitResult->fillOp,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 61989f1eb7b29..6eb263a0bf156 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -15,6 +15,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -60,14 +61,14 @@ static Attribute getNeutralElement(Operation *op) {
FailureOr<LinalgOp> mlir::linalg::splitReduction(
PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn,
- const LinalgTransformationFilter &filter) {
+ const LinalgTransformationFilter &filter, bool useAlloc) {
if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
!op.hasOnlyProjectedPermutations())
return b.notifyMatchFailure(op, "precondition not met");
FailureOr<SplitReductionResult> res =
- splitReduction(b, op, controlSplitReductionFn);
+ splitReduction(b, op, controlSplitReductionFn, useAlloc);
if (failed(res))
return failure();
@@ -79,7 +80,7 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
PatternRewriter &b, LinalgOp op,
- const ControlSplitReductionFn &controlSplitReductionFn) {
+ const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
@@ -171,11 +172,20 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
outputExpr.push_back(
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
}
- Value initTensor = b.create<linalg::InitTensorOp>(
- loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
+ Value initOrAllocTensor;
+ if (useAlloc) {
+ initOrAllocTensor = b.create<bufferization::AllocTensorOp>(
+ loc,
+ RankedTensorType::get(newOutputShape,
+ op.getRegionOutputArgs()[0].getType()),
+ ValueRange{});
+ } else {
+ initOrAllocTensor = b.create<linalg::InitTensorOp>(
+ loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
+ }
Value constantOp = b.create<arith::ConstantOp>(loc, identity);
Value identityTensor =
- b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
+ b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor)
.getResult(0);
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
@@ -189,7 +199,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
// Create the new op matching the original op with an extra parallel
// dimension.
GenericOp genericOp = b.create<GenericOp>(
- loc, TypeRange({initTensor.getType()}), newInputs,
+ loc, TypeRange({initOrAllocTensor.getType()}), newInputs,
ValueRange({identityTensor}), newMaps, newIteratorTypes);
b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
genericOp.region().begin());
@@ -223,9 +233,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
});
b.replaceOp(op, reduction.getResults());
- return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(),
- cast<LinalgOp>(genericOp.getOperation()),
- reduction};
+ return SplitReductionResult{
+ initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp<FillOp>(),
+ cast<LinalgOp>(genericOp.getOperation()), reduction};
}
/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
@@ -260,7 +270,7 @@ static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
/// Core rewrite implementation.
FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
PatternRewriter &b, LinalgOp op,
- const ControlSplitReductionFn &controlSplitReductionFn) {
+ const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
@@ -297,7 +307,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
return b.notifyMatchFailure(op, "unknown reduction neutral");
// TODO: relax this when multi-reduction support is available.
- if (op.getNumOutputs() != (int)neutralElements.size())
+ if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size()))
return b.notifyMatchFailure(op, "expect one reduction per output");
// Rewrite part.
@@ -318,6 +328,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// TODO: generalize when multi-reduction support is available.
SmallVector<Value> newOutputs;
newOutputs.reserve(op.getNumOutputs());
+ SmallVector<Operation *> initOrAllocTensorOps;
SmallVector<linalg::FillOp> fillOps;
fillOps.reserve(op.getNumOutputs());
for (auto it : llvm::zip(op.outputs(), neutralElements)) {
@@ -327,12 +338,19 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
reductionDimSize / splitFactor, insertSplitDimension);
SmallVector<Value> dims =
tensor::createDynamicDimValues(b, loc, rankedTensor);
- Value initTensor = b.create<linalg::InitTensorOp>(
- loc, dims, newT.getShape(), t.getElementType());
+ Value initOrAllocTensor;
+ if (useAlloc) {
+ initOrAllocTensor =
+ b.create<bufferization::AllocTensorOp>(loc, newT, dims);
+ } else {
+ initOrAllocTensor = b.create<linalg::InitTensorOp>(
+ loc, dims, newT.getShape(), t.getElementType());
+ }
Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
fillOps.push_back(
- b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor));
+ b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor));
newOutputs.push_back(fillOps.back().getResult(0));
+ initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp());
}
// Step 2. Reindex / expand indexing maps.
@@ -423,7 +441,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// TODO: extend when multi-reduction support is available.
assert(fillOps.size() == results.size() && results.size() == 1);
b.replaceOp(op, results.front()->getResults());
- return SplitReductionResult{fillOps.front(),
+ return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(),
cast<LinalgOp>(genericOp.getOperation()),
results.front()};
}
@@ -434,18 +452,21 @@ struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction(MLIRContext *context,
ControlSplitReductionFn controlSplitReductionFn,
- LinalgTransformationFilter f, PatternBenefit benefit = 1)
+ LinalgTransformationFilter f, bool useAlloc = false,
+ PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlSplitReductionFn(std::move(controlSplitReductionFn)),
- filter(std::move(f)) {}
+ useAlloc(useAlloc), filter(std::move(f)) {}
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
- return splitReduction(rewriter, op, controlSplitReductionFn, filter);
+ return splitReduction(rewriter, op, controlSplitReductionFn, filter,
+ useAlloc);
}
private:
ControlSplitReductionFn controlSplitReductionFn;
+ bool useAlloc;
LinalgTransformationFilter filter;
};
@@ -454,7 +475,7 @@ struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
void linalg::populateSplitReductionPattern(
RewritePatternSet &patterns,
const ControlSplitReductionFn &controlSplitReductionFn,
- const LinalgTransformationFilter &f) {
+ const LinalgTransformationFilter &f, bool useAlloc) {
patterns.add<LinalgSplitReduction>(patterns.getContext(),
- controlSplitReductionFn, f);
+ controlSplitReductionFn, f, useAlloc);
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
index 85ab597d61c18..572c746d583e5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
@@ -3,6 +3,7 @@
// CHECK-LABEL: func.func @matmul_split
func.func @matmul_split(%A : tensor<?x256xf32>, %B: tensor<256x32xf32>, %C: tensor<?x32xf32>) -> tensor<?x32xf32> {
+ // CHECK: bufferization.alloc_tensor({{.*}}) : tensor<?x32x64xf32>
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<?x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
@@ -30,6 +31,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
- %1:3 = transform.structured.split_reduction_by_scaling %0 { split_factor = 4, insert_split_dimension = 2}
+ %1:3 = transform.structured.split_reduction %0
+ { split_factor = 4, insert_split_dimension = 2, use_scaling_algorithm, use_alloc}
}
}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4c8a9484eb4ee..2c44d2ff83ee2 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -41,6 +42,7 @@ struct TestLinalgTransforms
void getDependentDialects(DialectRegistry ®istry) const override {
// clang-format off
registry.insert<AffineDialect,
+ bufferization::BufferizationDialect,
memref::MemRefDialect,
scf::SCFDialect,
linalg::LinalgDialect,
More information about the Mlir-commits
mailing list