[Mlir-commits] [mlir] b994d38 - [mlir][SCF] Add a ParallelCombiningOpInterface to decouple scf::PerformConcurrently from its contained operations

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jul 1 00:18:57 PDT 2022


Author: Nicolas Vasilache
Date: 2022-07-01T00:16:02-07:00
New Revision: b994d388aeb26aa54611e53884447b21dd3b440b

URL: https://github.com/llvm/llvm-project/commit/b994d388aeb26aa54611e53884447b21dd3b440b
DIFF: https://github.com/llvm/llvm-project/commit/b994d388aeb26aa54611e53884447b21dd3b440b.diff

LOG: [mlir][SCF] Add a ParallelCombiningOpInterface to decouple scf::PerformConcurrently from its contained operations

This allows purging references of scf.ForeachThreadOp and scf.PerformConcurrentlyOp from
ParallelInsertSliceOp.
This will allowmoving the op closer to tensor::InsertSliceOp with which it should share much more
code.

In the future, the decoupling will also allow extending the type of ops that can be used in the
parallel combinator as well as semantics related to multiple concurrent inserts to the same
result.

Differential Revision: https://reviews.llvm.org/D128857

Added: 
    mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
    mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
    mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCF.h
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Interfaces/CMakeLists.txt
    mlir/lib/Dialect/SCF/IR/CMakeLists.txt
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Interfaces/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 2c0dad6382009..12675c8b3cbc4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/RegionKindInterface.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 5de21cc350ac1..f45f76cc6f817 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -16,6 +16,7 @@
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/IR/RegionKindInterface.td"
+include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 
@@ -468,6 +469,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
 def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
        NoSideEffect,
        Terminator,
+       DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
        HasParent<"ForeachThreadOp">,
       ] # GraphRegionNoTerminator.traits> {
   let summary = "terminates a `foreach_thread` block";
@@ -495,8 +497,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
   // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
   // appear inside perform_concurrently.
   let extraClassDeclaration = [{
-    SmallVector<Type> yieldedTypes();
-    ::llvm::iterator_range<Block::iterator> yieldingOps();
+    ::llvm::SmallVector<::mlir::Type> getYieldedTypes();
+    ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
+    ::mlir::OpResult getParentResult(int64_t idx);
   }];
 }
 
@@ -508,7 +511,9 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
 def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
        AttrSizedOperandSegments,
        OffsetSizeAndStrideOpInterface,
-       HasParent<"PerformConcurrentlyOp">]> {
+       // TODO: Cannot use an interface here atm, verify this manually for now.
+       // HasParent<"ParallelCombiningOpInterface">
+  ]> {
   let summary = [{
     Specify the tensor slice update of a single thread within the terminator of
     an `scf.foreach_thread`.
@@ -568,6 +573,11 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
       return getSource().getType().cast<RankedTensorType>();
     }
 
+    ParallelCombiningOpInterface getParallelCombiningParent() {
+      return dyn_cast<ParallelCombiningOpInterface>(
+        getOperation()->getParentOp());
+    }
+
     /// Return the expected rank of each of the `static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
     std::array<unsigned, 3> getArrayAttrMaxRanks() {
@@ -599,6 +609,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
   
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 918f3ea398e47..230df17b55a10 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_interface(DerivedAttributeOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
+add_mlir_interface(ParallelCombiningOpInterface)
 add_mlir_interface(SideEffectInterfaces)
 add_mlir_interface(TilingInterface)
 add_mlir_interface(VectorInterfaces)

diff  --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
new file mode 100644
index 0000000000000..72db06163df37
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -0,0 +1,29 @@
+//===- ParallelCombiningOpInterface.h - Parallel combining op interface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the operation interface for ops that parallel combining
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
+#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace detail {
+// TODO: Single region single block interface on interfaces ?
+LogicalResult verifyParallelCombiningOpInterface(Operation *op);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/ParallelCombiningOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_

diff  --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
new file mode 100644
index 0000000000000..45497fbe038db
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -0,0 +1,75 @@
+//===- ParallelCombiningOpInterface.td - Parallel iface ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for ops that perform parallel combining operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
+#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
+  let description = [{
+    A parallel combining op is an op with a region, that is not isolated from
+    above and yields values to its parent op without itself returning an SSA
+    value. The yielded values are determined by subvalues produced by the ops 
+    contained in the region (the `yieldingOps`) and combined in any unspecified
+    order to produce the values yielded to the parent op.
+
+    This is useful as a terminator to parallel operations that iterate over 
+    some set and return tensors while avoiding tight coupling between the 
+    iterating op, the combining op and the individual subtensor producing ops.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return `idx`^th result of the parent operation.
+      }],
+      /*retTy=*/"::mlir::OpResult",
+      /*methodName=*/"getParentResult",
+      /*args=*/(ins "int64_t":$idx),
+      /*methodBody=*/[{
+        return $_op.getParentResult(idx);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the contained ops that yield subvalues that this op combines to
+        yield to its parent.
+      }],
+      /*retTy=*/"::llvm::iterator_range<Block::iterator>",
+      /*methodName=*/"getYieldingOps",
+      /*args=*/(ins),
+      /*methodBody=*/[{
+        return $_op.getYieldingOps();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the contained ops that yield subvalues that this op combines to
+        yield to its parent.
+      }],
+      /*retTy=*/"::llvm::SmallVector<::mlir::Type>",
+      /*methodName=*/"getYieldedTypes",
+      /*args=*/(ins),
+      /*methodBody=*/[{
+        return $_op.getYieldedTypes();
+      }]
+    >,
+  ];
+  // TODO: Single region single block interface on interfaces ?
+  let verify = [{
+    return verifyParallelCombiningOpInterface($_op);
+  }];
+}
+
+#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE

diff  --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 9dad732618de3..1b450acdf017a 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFDialect
   MLIRControlFlowDialect
   MLIRIR
   MLIRLoopLikeInterface
+  MLIRParallelCombiningOpInterface
   MLIRSideEffectInterfaces
   )
 

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 557a9edc2f18e..d7536f72e100c 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1061,7 +1061,7 @@ LogicalResult ForeachThreadOp::verify() {
     return emitOpError("region expects ") << getRank() << " arguments";
 
   // Verify consistency between the result types and the terminator.
-  auto terminatorTypes = getTerminator().yieldedTypes();
+  auto terminatorTypes = getTerminator().getYieldedTypes();
   auto opResults = getResults();
   if (opResults.size() != terminatorTypes.size())
     return emitOpError("produces ")
@@ -1182,7 +1182,7 @@ void ForeachThreadOp::build(
       llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
   assert(terminator &&
          "expected bodyBuilder to create PerformConcurrentlyOp terminator");
-  result.addTypes(terminator.yieldedTypes());
+  result.addTypes(terminator.getYieldedTypes());
 }
 
 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
@@ -1216,15 +1216,15 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
 //===----------------------------------------------------------------------===//
 
 OpResult ParallelInsertSliceOp::getTiedOpResult() {
-  auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>();
-  assert(foreachThreadOp && "unlinked ParallelInsertSliceOp");
-  PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator();
-  for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) {
+  ParallelCombiningOpInterface parallelCombiningParent =
+      getParallelCombiningParent();
+  for (const auto &it :
+       llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
     Operation &nextOp = it.value();
     if (&nextOp == getOperation())
-      return foreachThreadOp->getResult(it.index());
+      return parallelCombiningParent.getParentResult(it.index());
   }
-  llvm_unreachable("ParallelInsertSliceOp not found");
+  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
 }
 
 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
@@ -1262,6 +1262,13 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
 }
 
+LogicalResult ParallelInsertSliceOp::verify() {
+  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
+    return this->emitError("expected ParallelCombiningOpInterface parent, got:")
+           << *(getOperation()->getParentOp());
+  return success();
+}
+
 namespace {
 /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
 class ParallelInsertSliceOpConstantArgumentFolder final
@@ -1382,15 +1389,19 @@ ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
   return success();
 }
 
-SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
+OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
+  return getOperation()->getParentOp()->getResult(idx);
+}
+
+SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
   return llvm::to_vector<4>(
-      llvm::map_range(this->yieldingOps(), [](Operation &op) {
+      llvm::map_range(getYieldingOps(), [](Operation &op) {
         auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
         return insertSliceOp ? insertSliceOp.yieldedType() : Type();
       }));
 }
 
-llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::yieldingOps() {
+llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::getYieldingOps() {
   return getRegion().front().getOperations();
 }
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 09bd37cd85d38..431b9ef47eda1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1043,8 +1043,7 @@ struct ParallelInsertSliceOpInterface
     if (&opOperand != &op->getOpOperand(1) /*dest*/)
       return {};
 
-    // ParallelInsertSliceOp itself has no results. Tensors are returned via
-    // the parent op.
+    // ParallelInsertSliceOp itself has no results, query its tied op results.
     auto insertOp = cast<ParallelInsertSliceOp>(op);
     return {insertOp.getTiedOpResult()};
   }
@@ -1090,8 +1089,10 @@ struct ParallelInsertSliceOpInterface
     // }
 
     OpBuilder::InsertionGuard g(rewriter);
-    auto insertOp = cast<ParallelInsertSliceOp>(op);
-    auto foreachThreadOp = insertOp->getParentOfType<ForeachThreadOp>();
+    auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
+    ParallelCombiningOpInterface parallelCombiningParent =
+        parallelInsertSliceOp.getParallelCombiningParent();
+    Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
 
     // Nothing to do if the destination tensor is inplace.
     assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
@@ -1100,20 +1101,21 @@ struct ParallelInsertSliceOpInterface
       return success();
 
     // Find corresponding OpResult.
-    OpResult opResult = insertOp.getTiedOpResult();
+    OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
 
     // Insert tensor allocation right before the ForeachThreadOp.
-    rewriter.setInsertionPoint(foreachThreadOp);
+    rewriter.setInsertionPoint(parallelIteratingOp);
     bool isYielded = state.isTensorYielded(opResult);
-    FailureOr<Value> alloc =
-        allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(),
-                                     /*escape=*/isYielded, state.getOptions());
+    FailureOr<Value> alloc = allocateTensorForShapedValue(
+        rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
+        /*escape=*/isYielded, state.getOptions());
     if (failed(alloc))
       return failure();
 
     // Update destination operand.
-    rewriter.updateRootInPlace(
-        insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); });
+    rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
+      parallelInsertSliceOp.getDestMutable().assign(*alloc);
+    });
 
     return success();
   }
@@ -1121,39 +1123,41 @@ struct ParallelInsertSliceOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     OpBuilder::InsertionGuard g(rewriter);
-    auto insertOp = cast<ParallelInsertSliceOp>(op);
-    auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp());
-    auto foreachThreadOp =
-        cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp());
+    auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
+    ParallelCombiningOpInterface parallelCombiningParent =
+        parallelInsertSliceOp.getParallelCombiningParent();
+    Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
 
     // Get destination buffer.
     FailureOr<Value> destBuffer =
-        getBuffer(rewriter, insertOp.getDest(), options);
+        getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
     if (failed(destBuffer))
       return failure();
 
-    // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
-    rewriter.setInsertionPoint(performConcurrentlyOp);
+    // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
+    rewriter.setInsertionPoint(parallelCombiningParent);
     FailureOr<Value> srcBuffer =
-        getBuffer(rewriter, insertOp.getSource(), options);
+        getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
     if (failed(srcBuffer))
       return failure();
     Value subview = rewriter.create<memref::SubViewOp>(
-        insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
-        insertOp.getMixedSizes(), insertOp.getMixedStrides());
+        parallelInsertSliceOp.getLoc(), *destBuffer,
+        parallelInsertSliceOp.getMixedOffsets(),
+        parallelInsertSliceOp.getMixedSizes(),
+        parallelInsertSliceOp.getMixedStrides());
     // This memcpy will fold away if everything bufferizes in-place.
-    if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
-                                    subview)))
+    if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
+                                    *srcBuffer, subview)))
       return failure();
 
-    // Replace all uses of ForeachThreadOp (just the corresponding result).
-    rewriter.setInsertionPointAfter(foreachThreadOp);
+    // Replace all uses of parallelIteratingOp (just the corresponding result).
+    rewriter.setInsertionPointAfter(parallelIteratingOp);
     Value toTensorOp =
-        rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
+        rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
     // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
-    SmallVector<OpOperand *> resultUses =
-        llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(),
-                                        [](OpOperand &use) { return &use; }));
+    SmallVector<OpOperand *> resultUses = llvm::to_vector(
+        llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
+                        [](OpOperand &use) { return &use; }));
     for (OpOperand *use : resultUses) {
       rewriter.updateRootInPlace(use->getOwner(),
                                  [&]() { use->set(toTensorOp); });

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 2082ad41a7f27..783f3860c81cd 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
   InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
+  ParallelCombiningOpInterface.cpp
   SideEffectInterfaces.cpp
   TilingInterface.cpp
   VectorInterfaces.cpp
@@ -38,6 +39,7 @@ add_mlir_interface_library(DataLayoutInterfaces)
 add_mlir_interface_library(DerivedAttributeOpInterface)
 add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
+add_mlir_interface_library(ParallelCombiningOpInterface)
 add_mlir_interface_library(SideEffectInterfaces)
 add_mlir_interface_library(TilingInterface)
 add_mlir_interface_library(VectorInterfaces)

diff  --git a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
new file mode 100644
index 0000000000000..2b6703543bbd3
--- /dev/null
+++ b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
@@ -0,0 +1,27 @@
+//===- ParallelCombiningOpInterface.cpp - Parallel combining op interface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ParallelCombiningOpInterface
+//===----------------------------------------------------------------------===//
+
+// TODO: Single region single block interface on interfaces ?
+LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return op->emitError("expected single region op");
+  if (!op->getRegion(0).hasOneBlock())
+    return op->emitError("expected single block op region");
+  return success();
+}
+
+/// Include the definitions of the interface.
+#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"


        


More information about the Mlir-commits mailing list