[Mlir-commits] [mlir] b6c58ec - [mlir] add producer fusion to structured transform ops

Alex Zinenko llvmlistbot at llvm.org
Thu Jun 9 05:30:52 PDT 2022


Author: Alex Zinenko
Date: 2022-06-09T14:30:45+02:00
New Revision: b6c58ec486891312b3373a6f580318235421d918

URL: https://github.com/llvm/llvm-project/commit/b6c58ec486891312b3373a6f580318235421d918
DIFF: https://github.com/llvm/llvm-project/commit/b6c58ec486891312b3373a6f580318235421d918.diff

LOG: [mlir] add producer fusion to structured transform ops

This relies on the existing TileAndFuse pattern for tensor-based structured
ops. It complements pure tiling, from which some utilities are generalized.

Depends On D127300

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D127319

Added: 
    mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 205b0987ff98..fca389b438a3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -37,6 +37,25 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
   }];
 }
 
+def FuseOp : Op<Transform_Dialect, "structured.fuse",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Tiles the operations pointed to by the target handle and fuses their
+    producers greedily using the options provided as attributes.
+  }];
+
+  let arguments =
+    (ins PDL_Operation:$target,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
+  let results = (outs PDL_Operation:$transformed,
+                      Variadic<PDL_Operation>:$loops);
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
 def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -136,7 +155,7 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
 
 def TileOp : Op<Transform_Dialect, "structured.tile",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
-        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+        FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
   let description = [{
     Indicates that the given `target` op should be tiled with the options
     provided as attributes. This transform generates a loop nest with a smaller

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 61a01cb65808..1e7cfb53ca8e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -451,16 +451,15 @@ struct PayloadIRResource
   StringRef getName() override { return "transform.payload_ir"; }
 };
 
-/// Trait implementing the MemoryEffectOpInterface for single-operand zero- or
-/// single-result operations that "consume" their operand and produce a new
-/// result.
+/// Trait implementing the MemoryEffectOpInterface for single-operand operations
+/// that "consume" their operand and produce a new result.
 template <typename OpTy>
 class FunctionalStyleTransformOpTrait
     : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
 public:
   /// This op "consumes" the operand by reading and freeing it, "produces" the
-  /// result by allocating and writing it and reads/writes the payload IR in the
-  /// process.
+  /// results by allocating and writing it and reads/writes the payload IR in
+  /// the process.
   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
     effects.emplace_back(MemoryEffects::Read::get(),
                          this->getOperation()->getOperand(0),
@@ -468,12 +467,10 @@ class FunctionalStyleTransformOpTrait
     effects.emplace_back(MemoryEffects::Free::get(),
                          this->getOperation()->getOperand(0),
                          TransformMappingResource::get());
-    if (this->getOperation()->getNumResults() == 1) {
-      effects.emplace_back(MemoryEffects::Allocate::get(),
-                           this->getOperation()->getResult(0),
+    for (Value result : this->getOperation()->getResults()) {
+      effects.emplace_back(MemoryEffects::Allocate::get(), result,
                            TransformMappingResource::get());
-      effects.emplace_back(MemoryEffects::Write::get(),
-                           this->getOperation()->getResult(0),
+      effects.emplace_back(MemoryEffects::Write::get(), result,
                            TransformMappingResource::get());
     }
     effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
@@ -484,9 +481,6 @@ class FunctionalStyleTransformOpTrait
   static LogicalResult verifyTrait(Operation *op) {
     static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
                   "expected single-operand op");
-    static_assert(OpTy::template hasTrait<OpTrait::ZeroResults>() ||
-                      OpTy::template hasTrait<OpTrait::OneResult>(),
-                  "expected zero- or single-result op");
     if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
       op->emitError()
           << "FunctionalStyleTransformOpTrait should only be attached to ops "

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b02f9337995a..b5cfb2ab58dd 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -92,6 +92,130 @@ FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
   return reportUnknownTransformError(target);
 }
 
+//===----------------------------------------------------------------------===//
+// FuseOp
+//===----------------------------------------------------------------------===//
+
+/// Apply a tiling transformation to all payload ops and store both the
+/// tiled operation as well as the created tile loops.
+static LogicalResult
+applyTilingToAll(Operation *transformOp, Value target,
+                 ArrayRef<int64_t> tileSizes,
+                 transform::TransformResults &transformResults,
+                 transform::TransformState &state,
+                 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
+  // Number of loops: Number of tiles sizes that are not zero.
+  size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
+  // All payload ops. These should all be LinalgOps for now.
+  ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
+
+  SmallVector<Operation *> tiledLinalgOps;
+  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
+  for (unsigned int i = 0; i < numLoops; ++i)
+    loopOps[i].reserve(payloadOps.size());
+
+  for (Operation *target : payloadOps) {
+    auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
+    if (!linalgOp)
+      return transformOp->emitError("only LinalgOps are supported");
+
+    FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
+    if (failed(tiled))
+      return failure();
+
+    tiledLinalgOps.push_back(tiled->op);
+    if (tiled->loops.size() != numLoops)
+      // Not enough loops were generated. This usually means that the input size
+      // was smaller than the tiling size.
+      // TODO: LinalgTilingPattern should return failure().
+      return failure();
+    for (unsigned int i = 0; i < numLoops; ++i)
+      loopOps[i].push_back(tiled->loops[i]);
+  }
+
+  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
+  for (unsigned int i = 0; i < numLoops; ++i)
+    transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
+  return success();
+}
+
+/// Parse a tiling-like operation that returns the tiled op as well as the
+/// created tile loops. The function counts the non-zero tile sizes to compute
+/// the number of results.
+static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
+                                   StringRef sizesAttrName) {
+  OpAsmParser::UnresolvedOperand targetOperand;
+  SMLoc opLoc = parser.getCurrentLocation();
+  if (parser.parseOperand(targetOperand) ||
+      parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  Attribute sizesAttr = result.attributes.get(sizesAttrName);
+  if (!sizesAttr)
+    return parser.emitError(opLoc)
+           << "expected '" << sizesAttrName << "' attribute";
+  auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
+  if (!sizesArrayAttr)
+    return parser.emitError(opLoc)
+           << "'" << sizesAttrName << "' attribute must be an array";
+  Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
+  size_t numExpectedLoops =
+      sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
+  result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
+  if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
+    return failure();
+  return success();
+}
+
+LogicalResult
+transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
+                         mlir::transform::TransformState &state) {
+  LinalgTilingAndFusionOptions fusionOptions;
+  fusionOptions.tileSizes = extractI64Array(getTileSizes());
+  fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
+
+  return applyTilingToAll(
+      getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
+      state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
+        LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
+        SimpleRewriter rewriter(getContext());
+        rewriter.setInsertionPoint(linalgOp);
+        FailureOr<TileLoopNest> tileLoopNest =
+            pattern.returningMatchAndRewrite(linalgOp, rewriter);
+        if (failed(tileLoopNest))
+          return failure();
+
+        TiledLinalgOp tiledLinalgOp;
+        tiledLinalgOp.op = tileLoopNest->getRootOp();
+        tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
+                               tileLoopNest->getLoopOps().end()};
+        return tiledLinalgOp;
+      });
+}
+
+ParseResult transform::FuseOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  return parseTileLikeOp(
+      parser, result,
+      transform::FuseOp::getTileSizesAttrName(result.name).getValue());
+}
+
+void transform::FuseOp::print(OpAsmPrinter &p) {
+  p << ' ';
+  p << getTarget();
+  p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+LogicalResult transform::FuseOp::verify() {
+  SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
+  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
+  if (!std::is_permutation(sequence.begin(), sequence.end(),
+                           permutation.begin(), permutation.end())) {
+    return emitOpError() << "expects interchange to be a permutation, found "
+                         << getTileInterchange();
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GeneralizeOp
 //===----------------------------------------------------------------------===//
@@ -274,49 +398,6 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
 // TileOp
 //===----------------------------------------------------------------------===//
 
-/// Apply a tiling transformation to all payload ops and store both the
-/// tiled operation as well as the created tile loops.
-static LogicalResult
-applyTilingToAll(Operation *transformOp, Value target,
-                 ArrayRef<int64_t> tileSizes,
-                 transform::TransformResults &transformResults,
-                 transform::TransformState &state,
-                 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
-  // Number of loops: Number of tiles sizes that are not zero.
-  size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
-  // All payload ops. These should all be LinalgOps for now.
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
-
-  SmallVector<Operation *> tiledLinalgOps;
-  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
-  for (unsigned int i = 0; i < numLoops; ++i)
-    loopOps[i].reserve(payloadOps.size());
-
-  for (Operation *target : payloadOps) {
-    auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
-    if (!linalgOp)
-      return transformOp->emitError("only LinalgOps are supported");
-
-    FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
-    if (failed(tiled))
-      return failure();
-
-    tiledLinalgOps.push_back(tiled->op);
-    if (tiled->loops.size() != numLoops)
-      // Not enough loops were generated. This usually means that the input size
-      // was smaller than the tiling size.
-      // TODO: LinalgTilingPattern should return failure().
-      return failure();
-    for (unsigned int i = 0; i < numLoops; ++i)
-      loopOps[i].push_back(tiled->loops[i]);
-  }
-
-  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
-  for (unsigned int i = 0; i < numLoops; ++i)
-    transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
-  return success();
-}
-
 LogicalResult transform::TileOp::apply(TransformResults &transformResults,
                                        TransformState &state) {
   LinalgTilingOptions tilingOptions;
@@ -337,27 +418,8 @@ LogicalResult transform::TileOp::apply(TransformResults &transformResults,
 
 ParseResult transform::TileOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
-  StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
-  OpAsmParser::UnresolvedOperand targetOperand;
-  SMLoc opLoc = parser.getCurrentLocation();
-  if (parser.parseOperand(targetOperand) ||
-      parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-  Attribute sizesAttr = result.attributes.get(sizesAttrName);
-  if (!sizesAttr)
-    return parser.emitError(opLoc)
-           << "expected '" << sizesAttrName << "' attribute";
-  auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
-  if (!sizesArrayAttr)
-    return parser.emitError(opLoc)
-           << "'" << sizesAttrName << "' attribute must be an array";
-  Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
-  size_t numExpectedLoops =
-      sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
-  result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
-  if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
-    return failure();
-  return success();
+  return parseTileLikeOp(parser, result,
+                         TileOp::getSizesAttrName(result.name).getValue());
 }
 
 void TileOp::print(OpAsmPrinter &p) {
@@ -366,26 +428,6 @@ void TileOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
-void TileOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  // `target` arg is consumed and can no longer be used.
-  effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
-                       TransformMappingResource::get());
-  effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
-                       TransformMappingResource::get());
-
-  for (Value r : getResults()) {
-    effects.emplace_back(MemoryEffects::Write::get(), r,
-                         TransformMappingResource::get());
-    effects.emplace_back(MemoryEffects::Allocate::get(), r,
-                         TransformMappingResource::get());
-  }
-
-  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
-  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
-}
-
 //===----------------------------------------------------------------------===//
 // VectorizeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
new file mode 100644
index 000000000000..af6da5d7eaeb
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @fuse_unary
+func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  //     CHECK:   scf.for
+  //     CHECK:     scf.for
+  //     CHECK:       linalg.elemwise_unary
+  //     CHECK:       linalg.elemwise_binary
+  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_unary
+func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  //     CHECK:   scf.for
+  //     CHECK:     scf.for
+  //     CHECK:       linalg.elemwise_unary
+  //     CHECK:       linalg.elemwise_binary
+  //     CHECK:   scf.for
+  //     CHECK:     scf.for
+  //     CHECK:       linalg.elemwise_unary
+  //     CHECK:       linalg.elemwise_binary
+  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+    transform.loop.peel %loops#0
+  }
+}


        


More information about the Mlir-commits mailing list