[Mlir-commits] [mlir] b6c58ec - [mlir] add producer fusion to structured transform ops
Alex Zinenko
llvmlistbot at llvm.org
Thu Jun 9 05:30:52 PDT 2022
Author: Alex Zinenko
Date: 2022-06-09T14:30:45+02:00
New Revision: b6c58ec486891312b3373a6f580318235421d918
URL: https://github.com/llvm/llvm-project/commit/b6c58ec486891312b3373a6f580318235421d918
DIFF: https://github.com/llvm/llvm-project/commit/b6c58ec486891312b3373a6f580318235421d918.diff
LOG: [mlir] add producer fusion to structured transform ops
This relies on the existing TileAndFuse pattern for tensor-based structured
ops. It complements pure tiling, from which some utilities are generalized.
Depends On D127300
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D127319
Added:
mlir/test/Dialect/Linalg/transform-op-fuse.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 205b0987ff98..fca389b438a3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -37,6 +37,25 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
}];
}
+def FuseOp : Op<Transform_Dialect, "structured.fuse",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{
+ Tiles the operations pointed to by the target handle and fuses their
+ producers greedily using the options provided as attributes.
+ }];
+
+ let arguments =
+ (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
+ let results = (outs PDL_Operation:$transformed,
+ Variadic<PDL_Operation>:$loops);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
@@ -136,7 +155,7 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
let description = [{
Indicates that the given `target` op should be tiled with the options
provided as attributes. This transform generates a loop nest with a smaller
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 61a01cb65808..1e7cfb53ca8e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -451,16 +451,15 @@ struct PayloadIRResource
StringRef getName() override { return "transform.payload_ir"; }
};
-/// Trait implementing the MemoryEffectOpInterface for single-operand zero- or
-/// single-result operations that "consume" their operand and produce a new
-/// result.
+/// Trait implementing the MemoryEffectOpInterface for single-operand operations
+/// that "consume" their operand and produce a new result.
template <typename OpTy>
class FunctionalStyleTransformOpTrait
: public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
public:
/// This op "consumes" the operand by reading and freeing it, "produces" the
- /// result by allocating and writing it and reads/writes the payload IR in the
- /// process.
+ /// results by allocating and writing it and reads/writes the payload IR in
+ /// the process.
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(MemoryEffects::Read::get(),
this->getOperation()->getOperand(0),
@@ -468,12 +467,10 @@ class FunctionalStyleTransformOpTrait
effects.emplace_back(MemoryEffects::Free::get(),
this->getOperation()->getOperand(0),
TransformMappingResource::get());
- if (this->getOperation()->getNumResults() == 1) {
- effects.emplace_back(MemoryEffects::Allocate::get(),
- this->getOperation()->getResult(0),
+ for (Value result : this->getOperation()->getResults()) {
+ effects.emplace_back(MemoryEffects::Allocate::get(), result,
TransformMappingResource::get());
- effects.emplace_back(MemoryEffects::Write::get(),
- this->getOperation()->getResult(0),
+ effects.emplace_back(MemoryEffects::Write::get(), result,
TransformMappingResource::get());
}
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
@@ -484,9 +481,6 @@ class FunctionalStyleTransformOpTrait
static LogicalResult verifyTrait(Operation *op) {
static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
"expected single-operand op");
- static_assert(OpTy::template hasTrait<OpTrait::ZeroResults>() ||
- OpTy::template hasTrait<OpTrait::OneResult>(),
- "expected zero- or single-result op");
if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
op->emitError()
<< "FunctionalStyleTransformOpTrait should only be attached to ops "
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b02f9337995a..b5cfb2ab58dd 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -92,6 +92,130 @@ FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
return reportUnknownTransformError(target);
}
+//===----------------------------------------------------------------------===//
+// FuseOp
+//===----------------------------------------------------------------------===//
+
+/// Apply a tiling transformation to all payload ops and store both the
+/// tiled operation as well as the created tile loops.
+static LogicalResult
+applyTilingToAll(Operation *transformOp, Value target,
+ ArrayRef<int64_t> tileSizes,
+ transform::TransformResults &transformResults,
+ transform::TransformState &state,
+ function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
+ // Number of loops: Number of tiles sizes that are not zero.
+ size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
+ // All payload ops. These should all be LinalgOps for now.
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
+
+ SmallVector<Operation *> tiledLinalgOps;
+ SmallVector<SmallVector<Operation *>> loopOps(numLoops);
+ for (unsigned int i = 0; i < numLoops; ++i)
+ loopOps[i].reserve(payloadOps.size());
+
+ for (Operation *target : payloadOps) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
+ if (!linalgOp)
+ return transformOp->emitError("only LinalgOps are supported");
+
+ FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
+ if (failed(tiled))
+ return failure();
+
+ tiledLinalgOps.push_back(tiled->op);
+ if (tiled->loops.size() != numLoops)
+ // Not enough loops were generated. This usually means that the input size
+ // was smaller than the tiling size.
+ // TODO: LinalgTilingPattern should return failure().
+ return failure();
+ for (unsigned int i = 0; i < numLoops; ++i)
+ loopOps[i].push_back(tiled->loops[i]);
+ }
+
+ transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
+ for (unsigned int i = 0; i < numLoops; ++i)
+ transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
+ return success();
+}
+
+/// Parse a tiling-like operation that returns the tiled op as well as the
+/// created tile loops. The function counts the non-zero tile sizes to compute
+/// the number of results.
+static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
+ StringRef sizesAttrName) {
+ OpAsmParser::UnresolvedOperand targetOperand;
+ SMLoc opLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(targetOperand) ||
+ parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ Attribute sizesAttr = result.attributes.get(sizesAttrName);
+ if (!sizesAttr)
+ return parser.emitError(opLoc)
+ << "expected '" << sizesAttrName << "' attribute";
+ auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
+ if (!sizesArrayAttr)
+ return parser.emitError(opLoc)
+ << "'" << sizesAttrName << "' attribute must be an array";
+ Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
+ size_t numExpectedLoops =
+ sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
+ result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
+ if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
+ return failure();
+ return success();
+}
+
+LogicalResult
+transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
+ mlir::transform::TransformState &state) {
+ LinalgTilingAndFusionOptions fusionOptions;
+ fusionOptions.tileSizes = extractI64Array(getTileSizes());
+ fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
+
+ return applyTilingToAll(
+ getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
+ state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
+ LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(linalgOp);
+ FailureOr<TileLoopNest> tileLoopNest =
+ pattern.returningMatchAndRewrite(linalgOp, rewriter);
+ if (failed(tileLoopNest))
+ return failure();
+
+ TiledLinalgOp tiledLinalgOp;
+ tiledLinalgOp.op = tileLoopNest->getRootOp();
+ tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
+ tileLoopNest->getLoopOps().end()};
+ return tiledLinalgOp;
+ });
+}
+
+ParseResult transform::FuseOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseTileLikeOp(
+ parser, result,
+ transform::FuseOp::getTileSizesAttrName(result.name).getValue());
+}
+
+void transform::FuseOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p << getTarget();
+ p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+LogicalResult transform::FuseOp::verify() {
+ SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
+ auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ permutation.begin(), permutation.end())) {
+ return emitOpError() << "expects interchange to be a permutation, found "
+ << getTileInterchange();
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// GeneralizeOp
//===----------------------------------------------------------------------===//
@@ -274,49 +398,6 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
// TileOp
//===----------------------------------------------------------------------===//
-/// Apply a tiling transformation to all payload ops and store both the
-/// tiled operation as well as the created tile loops.
-static LogicalResult
-applyTilingToAll(Operation *transformOp, Value target,
- ArrayRef<int64_t> tileSizes,
- transform::TransformResults &transformResults,
- transform::TransformState &state,
- function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
- // Number of loops: Number of tiles sizes that are not zero.
- size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
- // All payload ops. These should all be LinalgOps for now.
- ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
-
- SmallVector<Operation *> tiledLinalgOps;
- SmallVector<SmallVector<Operation *>> loopOps(numLoops);
- for (unsigned int i = 0; i < numLoops; ++i)
- loopOps[i].reserve(payloadOps.size());
-
- for (Operation *target : payloadOps) {
- auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
- if (!linalgOp)
- return transformOp->emitError("only LinalgOps are supported");
-
- FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
- if (failed(tiled))
- return failure();
-
- tiledLinalgOps.push_back(tiled->op);
- if (tiled->loops.size() != numLoops)
- // Not enough loops were generated. This usually means that the input size
- // was smaller than the tiling size.
- // TODO: LinalgTilingPattern should return failure().
- return failure();
- for (unsigned int i = 0; i < numLoops; ++i)
- loopOps[i].push_back(tiled->loops[i]);
- }
-
- transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
- for (unsigned int i = 0; i < numLoops; ++i)
- transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
- return success();
-}
-
LogicalResult transform::TileOp::apply(TransformResults &transformResults,
TransformState &state) {
LinalgTilingOptions tilingOptions;
@@ -337,27 +418,8 @@ LogicalResult transform::TileOp::apply(TransformResults &transformResults,
ParseResult transform::TileOp::parse(OpAsmParser &parser,
OperationState &result) {
- StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
- OpAsmParser::UnresolvedOperand targetOperand;
- SMLoc opLoc = parser.getCurrentLocation();
- if (parser.parseOperand(targetOperand) ||
- parser.parseOptionalAttrDict(result.attributes))
- return failure();
- Attribute sizesAttr = result.attributes.get(sizesAttrName);
- if (!sizesAttr)
- return parser.emitError(opLoc)
- << "expected '" << sizesAttrName << "' attribute";
- auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
- if (!sizesArrayAttr)
- return parser.emitError(opLoc)
- << "'" << sizesAttrName << "' attribute must be an array";
- Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
- size_t numExpectedLoops =
- sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
- result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
- if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
- return failure();
- return success();
+ return parseTileLikeOp(parser, result,
+ TileOp::getSizesAttrName(result.name).getValue());
}
void TileOp::print(OpAsmPrinter &p) {
@@ -366,26 +428,6 @@ void TileOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs());
}
-void TileOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- // `target` arg is consumed and can no longer be used.
- effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
- TransformMappingResource::get());
- effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
- TransformMappingResource::get());
-
- for (Value r : getResults()) {
- effects.emplace_back(MemoryEffects::Write::get(), r,
- TransformMappingResource::get());
- effects.emplace_back(MemoryEffects::Allocate::get(), r,
- TransformMappingResource::get());
- }
-
- effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
-}
-
//===----------------------------------------------------------------------===//
// VectorizeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
new file mode 100644
index 000000000000..af6da5d7eaeb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @fuse_unary
+func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: linalg.elemwise_unary
+ // CHECK: linalg.elemwise_binary
+ %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_unary
+func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: linalg.elemwise_unary
+ // CHECK: linalg.elemwise_binary
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: linalg.elemwise_unary
+ // CHECK: linalg.elemwise_binary
+ %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+ transform.loop.peel %loops#0
+ }
+}
More information about the Mlir-commits
mailing list