[Mlir-commits] [mlir] [mlir] [bufferization] Default implementation of BufferizableOpInterface::isParallelRegion() based on new trait OpTrait::HasParallelRegion (PR #91184)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 6 03:52:32 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Rafael Ubal (rafaelubalmw)

<details>
<summary>Changes</summary>

@<!-- -->matthias-springer: This is a follow up on your comment https://github.com/llvm/llvm-project/pull/90735#issuecomment-2092922097.

I wasn't able to fully remove `isParallelRegion()`. The reason is that `scf.forall` has a specialized implementation that marks its region as non-parallel when it can be statically determined that the loop only runs 1 iteration (https://github.com/llvm/llvm-project/blob/8e7618aa21652132f930b6576b92291c5f1d46b6/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp#L1330)

The most I could do was modify the default implementation of `BufferizableOpInterface::isParallelRegion()` to query trait `HasParallelRegion`, as opposed to merely returning `false`. Since at the moment only `scf.forall` and `scf.parallel` have trait `HasParallelRegion`, the new default implementation of `isParallelRegion()` only affects `scf.parallel`, as it does not override the method. If I'm not mistaken, this exposes a bug: `scf.parallel` was previously not being identified as containing a parallel region for bufferization purposes!

I'd like to come up with a unit test that exposes invalid bufferization for `scf.parallel` with the previous implementation, and that shows how hopefully this change addresses the issue. I'm not too familiar with bufferization at the moment, so if @<!-- -->matthias-springer or anyone else, could advice on what such an MLIR code sample would look like, that'd be great. I can investigate further otherwise.

FWIW, such MLIR code would have to return an invalid result in the call to `hasReadAfterWriteInterference()` -> `hasParallelRegion()` -> `isParallelRegion()` for `scf.parallel` during bufferization analysis (https://github.com/llvm/llvm-project/blob/47214903b1c6d0590780c7e69a2e3e612f43e4a2/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp#L632), and that result would have to manifest in a tangible invalid bufferization decision.

@<!-- -->sabauma 

---
Full diff: https://github.com/llvm/llvm-project/pull/91184.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+6) 
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+7-2) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+7) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 2d8add82383bef..8dbbe1141aea10 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -701,6 +701,12 @@ bool defaultResultBufferizesToMemoryWrite(OpResult opResult,
 bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
                                unsigned index);
 
+/// This is the default implementation of
+/// BufferizableOpInterface::isParallelRegion. Should not be called from other
+/// places.
+bool defaultIsParallelRegion(BufferizableOpInterface bufferizableOp,
+                             unsigned index);
+
 /// This is the default implementation of getAliasingOpOperands in case the
 /// defining op does not implement the BufferizableOpInterface.
 AliasingOpOperandList unknownGetAliasingOpOperands(Value value);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 007c05adc30b5f..3ab16c70b14b66 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -565,14 +565,19 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           The RaW conflict detection of One-Shot Analysis is more strict inside
           parallel regions: Buffer may have to be privatized.
 
-          By default, regions are assumed to be sequential.
+          By default, an op region is considered parallel if the containing op
+          has trait `HasParallelRegion`. While this default implementation is
+          generally sufficient, a specific op may relax this condition by
+          marking a region as non-parallel when it is detected to execute
+          exactly once, and in spite of its parallel semantics.
         }],
         /*retType=*/"bool",
         /*methodName=*/"isParallelRegion",
         /*args=*/(ins "unsigned":$index),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          return false;
+          return ::mlir::bufferization::detail::defaultIsParallelRegion(
+              ::llvm::cast<BufferizableOpInterface>($_op.getOperation()), index);
         }]
       >,
       InterfaceMethod<
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0c..cb36ad8f0f964c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
 
@@ -956,6 +957,12 @@ bool bufferization::detail::defaultIsRepetitiveRegion(
   return regionInterface.isRepetitiveRegion(index);
 }
 
+bool bufferization::detail::defaultIsParallelRegion(
+    BufferizableOpInterface bufferizableOp, unsigned index) {
+  assert(index < bufferizableOp->getNumRegions() && "invalid region index");
+  return bufferizableOp->hasTrait<OpTrait::HasParallelRegion>();
+}
+
 AliasingOpOperandList
 bufferization::detail::unknownGetAliasingOpOperands(Value value) {
   // TODO: Take into account successor blocks.

``````````

</details>


https://github.com/llvm/llvm-project/pull/91184


More information about the Mlir-commits mailing list