[Mlir-commits] [mlir] [mlir][gpu] Add field to mark asynchronous side effects (PR #72013)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 11 01:02:01 PST 2023


https://github.com/spaceotter updated https://github.com/llvm/llvm-project/pull/72013

>From fa023139829225ddd48a3d96539f16484d9ffc2e Mon Sep 17 00:00:00 2001
From: Eric Eaton <eric at nod-labs.com>
Date: Thu, 9 Nov 2023 15:47:58 -0800
Subject: [PATCH] [mlir][gpu] Add field to mark asynchronous side effects

This adds an extra field to EffectInstance, to indicate that the side
effects could occur after the associated op exits. This would be the
case for launching DMA operations or other calculations performed in
parallel by an accelerator attached to the processor. Importantly, in
such cases gpu.barrier cannot synchronize the effects, because it only
affects threads and not the DMA/accelerator. Codegen using such
asynchronous ops must separately ensure they are waited on correctly.
---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   |  6 ++---
 .../Interfaces/SideEffectInterfaceBase.td     |  6 ++++-
 .../mlir/Interfaces/SideEffectInterfaces.h    | 18 +++++++++-----
 .../mlir/Interfaces/SideEffectInterfaces.td   | 24 +++++++++----------
 mlir/include/mlir/TableGen/SideEffects.h      |  3 +++
 .../GPU/Transforms/EliminateBarriers.cpp      |  5 ++++
 mlir/lib/TableGen/SideEffects.cpp             |  4 ++++
 .../test/Dialect/GPU/barrier-elimination.mlir | 20 ++++++++++++++++
 mlir/test/lib/Dialect/Test/TestInterfaces.td  |  2 +-
 mlir/test/mlir-tblgen/op-side-effects.td      | 14 +++++------
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   | 14 ++++++-----
 11 files changed, 80 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 440f7d0380eb17e..7c1e94d28fbce9a 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -423,9 +423,9 @@ def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
     ```
   }];
   let results = (outs NVGPU_DeviceAsyncToken:$asyncToken);
-  let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$dst,
+  let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect, 1>]>:$dst,
                        Variadic<Index>:$dstIndices,
-                       Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
+                       Arg<AnyMemRef, "", [MemReadAt<0, FullEffect, 1>]>:$src,
                        Variadic<Index>:$srcIndices,
                        IndexAttr:$dstElements,
                        Optional<Index>:$srcElements,
@@ -642,7 +642,7 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
 
     The Op uses `$barrier` mbarrier based completion mechanism. 
   }];  
-  let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$dst,
+  let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect, 1>]>:$dst,
                        NVGPU_MBarrierGroup:$barriers,
                        NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
                        Variadic<Index>:$coordinates, 
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td b/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td
index 45a9ffa94363ef3..bd9a28a225c152a 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td
@@ -160,7 +160,8 @@ def PartialEffect : EffectRange<0>;
 // This class is the general base side effect class. This is used by derived
 // effect interfaces to define their effects.
 class SideEffect<EffectOpInterfaceBase interface, string effectName,
-                 Resource resourceReference, int effectStage, EffectRange range>
+                 Resource resourceReference, int effectStage, EffectRange range,
+                 bits<1> isAsync>
     : OpVariableDecorator {
   /// The name of the base effects class.
   string baseEffectName = interface.baseEffectName;
@@ -183,6 +184,9 @@ class SideEffect<EffectOpInterfaceBase interface, string effectName,
 
   // Does this side effect act on every single value of resource.
   bit effectOnFullRegion = range.Value;
+
+  // Does this side effect potentially occur after op exit
+  bit asynchronous = isAsync;
 }
 
 // This class is the base used for specifying effects applied to an operation.
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index ec4e36263bbe6d1..9bd23fb4f18a977 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -141,28 +141,28 @@ class EffectInstance {
   EffectInstance(EffectT *effect, Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), stage(0),
         effectOnFullRegion(false) {}
-  EffectInstance(EffectT *effect, int stage, bool effectOnFullRegion,
+  EffectInstance(EffectT *effect, int stage, bool effectOnFullRegion, bool async,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), stage(stage),
-        effectOnFullRegion(effectOnFullRegion) {}
+        effectOnFullRegion(effectOnFullRegion), asynchronous(async) {}
   EffectInstance(EffectT *effect, Value value,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(value), stage(0),
         effectOnFullRegion(false) {}
   EffectInstance(EffectT *effect, Value value, int stage,
-                 bool effectOnFullRegion,
+                 bool effectOnFullRegion, bool async,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(value), stage(stage),
-        effectOnFullRegion(effectOnFullRegion) {}
+        effectOnFullRegion(effectOnFullRegion), asynchronous(async) {}
   EffectInstance(EffectT *effect, SymbolRefAttr symbol,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(symbol), stage(0),
         effectOnFullRegion(false) {}
   EffectInstance(EffectT *effect, SymbolRefAttr symbol, int stage,
-                 bool effectOnFullRegion,
+                 bool effectOnFullRegion, bool async,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), value(symbol), stage(stage),
-        effectOnFullRegion(effectOnFullRegion) {}
+        effectOnFullRegion(effectOnFullRegion), asynchronous(async) {}
   EffectInstance(EffectT *effect, Attribute parameters,
                  Resource *resource = DefaultResource::get())
       : effect(effect), resource(resource), parameters(parameters), stage(0),
@@ -221,6 +221,9 @@ class EffectInstance {
   /// Return if this side effect act on every single value of resource.
   bool getEffectOnFullRegion() const { return effectOnFullRegion; }
 
+  /// Return if the side effect may occur after the op exits.
+  bool getAsynchronous() const { return asynchronous; }
+
 private:
   /// The specific effect being applied.
   EffectT *effect;
@@ -242,6 +245,9 @@ class EffectInstance {
 
   // Does this side effect act on every single value of resource.
   bool effectOnFullRegion;
+
+  /// Does this side effect potentially occur after op exit.
+  bool asynchronous;
 };
 } // namespace SideEffects
 
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td
index b2ab4fee9d29c03..ca9bf4f8fcd39ee 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td
@@ -35,8 +35,8 @@ def MemoryEffectsOpInterface
 
 // The base class for defining specific memory effects.
 class MemoryEffect<string effectName, Resource resource, int stage,
-                   EffectRange range>
-  : SideEffect<MemoryEffectsOpInterface, effectName, resource, stage, range>;
+                   EffectRange range, bits<1> async>
+  : SideEffect<MemoryEffectsOpInterface, effectName, resource, stage, range, async>;
 
 // This class represents the trait for memory effects that may be placed on
 // operations.
@@ -51,7 +51,7 @@ class MemoryEffects<list<MemoryEffect> effects = []>
 // not any visible mutation or dereference.
 class MemAlloc<Resource resource, int stage = 0,
                EffectRange range = PartialEffect>
-  : MemoryEffect<"::mlir::MemoryEffects::Allocate", resource, stage, range>;
+  : MemoryEffect<"::mlir::MemoryEffects::Allocate", resource, stage, range, 0>;
 def MemAlloc : MemAlloc<DefaultResource, 0, PartialEffect>;
 class MemAllocAt<int stage, EffectRange range = PartialEffect>
   : MemAlloc<DefaultResource, stage, range>;
@@ -61,7 +61,7 @@ class MemAllocAt<int stage, EffectRange range = PartialEffect>
 // resource, and not any visible allocation, mutation or dereference.
 class MemFree<Resource resource, int stage = 0,
               EffectRange range = PartialEffect>
-  : MemoryEffect<"::mlir::MemoryEffects::Free", resource, stage, range>;
+  : MemoryEffect<"::mlir::MemoryEffects::Free", resource, stage, range, 0>;
 def MemFree : MemFree<DefaultResource, 0, PartialEffect>;
 class MemFreeAt<int stage, EffectRange range = PartialEffect>
   : MemFree<DefaultResource, stage, range>;
@@ -70,21 +70,21 @@ class MemFreeAt<int stage, EffectRange range = PartialEffect>
 // resource. A 'read' effect implies only dereferencing of the resource, and
 // not any visible mutation.
 class MemRead<Resource resource, int stage = 0,
-              EffectRange range = PartialEffect>
-  : MemoryEffect<"::mlir::MemoryEffects::Read", resource, stage, range>;
+              EffectRange range = PartialEffect, bits<1> async = 0>
+  : MemoryEffect<"::mlir::MemoryEffects::Read", resource, stage, range, async>;
 def MemRead : MemRead<DefaultResource, 0, PartialEffect>;
-class MemReadAt<int stage, EffectRange range = PartialEffect>
-  : MemRead<DefaultResource, stage, range>;
+class MemReadAt<int stage, EffectRange range = PartialEffect, bits<1> async = 0>
+  : MemRead<DefaultResource, stage, range, async>;
 
 // The following effect indicates that the operation writes to some
 // resource. A 'write' effect implies only mutating a resource, and not any
 // visible dereference or read.
 class MemWrite<Resource resource, int stage = 0,
-               EffectRange range = PartialEffect>
-  : MemoryEffect<"::mlir::MemoryEffects::Write", resource, stage, range>;
+               EffectRange range = PartialEffect, bits<1> async = 0>
+  : MemoryEffect<"::mlir::MemoryEffects::Write", resource, stage, range, async>;
 def MemWrite : MemWrite<DefaultResource, 0, PartialEffect>;
-class MemWriteAt<int stage, EffectRange range = PartialEffect>
-  : MemWrite<DefaultResource, stage, range>;
+class MemWriteAt<int stage, EffectRange range = PartialEffect, bits<1> async = 0>
+  : MemWrite<DefaultResource, stage, range, async>;
 
 //===----------------------------------------------------------------------===//
 // Effect Traits
diff --git a/mlir/include/mlir/TableGen/SideEffects.h b/mlir/include/mlir/TableGen/SideEffects.h
index 5a9a34d4e427ccf..cceb03eca61fe5d 100644
--- a/mlir/include/mlir/TableGen/SideEffects.h
+++ b/mlir/include/mlir/TableGen/SideEffects.h
@@ -41,6 +41,9 @@ class SideEffect : public Operator::VariableDecorator {
   // Return if this side effect act on every single value of resource.
   bool getEffectOnfullRegion() const;
 
+  // Return if the side effect occurs after op exit.
+  bool getAsynchronous() const;
+
   static bool classof(const Operator::VariableDecorator *var);
 };
 
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index 1adc381092bf3ae..b1d2f0d6768743a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -519,6 +519,11 @@ static bool
 haveConflictingEffects(ArrayRef<MemoryEffects::EffectInstance> beforeEffects,
                        ArrayRef<MemoryEffects::EffectInstance> afterEffects) {
   for (const MemoryEffects::EffectInstance &before : beforeEffects) {
+    // Before may conflict with after, but since it is async, a BarrierOp cannot
+    // synchronize the effects. If the async field is set, it is presumed that
+    // some architecture-specific mechanism is needed to synchronize the effect.
+    if (before.getAsynchronous()) continue;
+
     for (const MemoryEffects::EffectInstance &after : afterEffects) {
       // If cannot alias, definitely no conflict.
       if (!mayAlias(before, after))
diff --git a/mlir/lib/TableGen/SideEffects.cpp b/mlir/lib/TableGen/SideEffects.cpp
index 55ad59d3d0d01a8..145f3adbdddd92d 100644
--- a/mlir/lib/TableGen/SideEffects.cpp
+++ b/mlir/lib/TableGen/SideEffects.cpp
@@ -42,6 +42,10 @@ bool SideEffect::getEffectOnfullRegion() const {
   return def->getValueAsBit("effectOnFullRegion");
 }
 
+bool SideEffect::getAsynchronous() const {
+  return def->getValueAsBit("asynchronous");
+}
+
 bool SideEffect::classof(const Operator::VariableDecorator *var) {
   return var->getDef().isSubClassOf("SideEffect");
 }
diff --git a/mlir/test/Dialect/GPU/barrier-elimination.mlir b/mlir/test/Dialect/GPU/barrier-elimination.mlir
index 844dc7dd6ac00da..8330a7853b118d2 100644
--- a/mlir/test/Dialect/GPU/barrier-elimination.mlir
+++ b/mlir/test/Dialect/GPU/barrier-elimination.mlir
@@ -182,3 +182,23 @@ attributes {__parallel_region_boundary_for_test} {
   %4 = memref.load %C[] : memref<f32>
   return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
 }
+
+// CHECK-LABEL: @async_copy
+func.func @async_copy() -> ()
+attributes {__parallel_region_boundary_for_test} {
+  // CHECK: %[[A:.+]] = memref.alloc
+  // CHECK: %[[B:.+]] = memref.alloc
+  %A = memref.alloc() : memref<f32>
+  %B = memref.alloc() : memref<f32, #gpu.address_space<workgroup>>
+  gpu.barrier
+  // CHECK: %[[T:.+]] = nvgpu.device_async_copy %[[A]][], %[[B]][], 1
+  %token = nvgpu.device_async_copy %A[], %B[], 1 : memref<f32> to memref<f32, #gpu.address_space<workgroup>>
+  // This needs to be erased because it can't synchronize the effects on %B.
+  gpu.barrier
+  // This does synchronize the effects on %B.
+  // CHECK-NEXT: nvgpu.device_async_wait %[[T]]
+  nvgpu.device_async_wait %token
+  // CHECK-NEXT: linalg.abs ins(%[[B]] : memref<f32, #gpu.address_space<workgroup>>) outs(%[[A]] : memref<f32>)
+  linalg.abs ins(%B: memref<f32, #gpu.address_space<workgroup>>) outs(%A: memref<f32>)
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index dea26b8dda62a0b..2fbefb59e85f9e8 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -128,7 +128,7 @@ def TestEffectOpInterface
 
 class TestEffect<string effectName>
     : SideEffect<TestEffectOpInterface, effectName, DefaultResource, 0,
-                 PartialEffect>;
+                 PartialEffect, 0>;
 
 class TestEffects<list<TestEffect> effects = []>
    : SideEffectsTraitBase<TestEffectOpInterface, effects>;
diff --git a/mlir/test/mlir-tblgen/op-side-effects.td b/mlir/test/mlir-tblgen/op-side-effects.td
index 09612db905899fd..cca5b1487cc9c9f 100644
--- a/mlir/test/mlir-tblgen/op-side-effects.td
+++ b/mlir/test/mlir-tblgen/op-side-effects.td
@@ -26,15 +26,15 @@ def SideEffectOpB : TEST_Op<"side_effect_op_b",
 
 // CHECK: void SideEffectOpA::getEffects
 // CHECK:   for (::mlir::Value value : getODSOperands(0))
-// CHECK:     effects.emplace_back(::mlir::MemoryEffects::Read::get(), value, 0, false, ::mlir::SideEffects::DefaultResource::get());
+// CHECK:     effects.emplace_back(::mlir::MemoryEffects::Read::get(), value, 0, false, false, ::mlir::SideEffects::DefaultResource::get());
 // CHECK:   for (::mlir::Value value : getODSOperands(1))
-// CHECK:     effects.emplace_back(::mlir::MemoryEffects::Write::get(), value, 1, true, ::mlir::SideEffects::DefaultResource::get());
-// CHECK:   effects.emplace_back(::mlir::MemoryEffects::Read::get(), getSymbolAttr(), 0, false, ::mlir::SideEffects::DefaultResource::get());
-// CHECK:   effects.emplace_back(::mlir::MemoryEffects::Write::get(), getFlatSymbolAttr(), 0, false, ::mlir::SideEffects::DefaultResource::get());
+// CHECK:     effects.emplace_back(::mlir::MemoryEffects::Write::get(), value, 1, true, false, ::mlir::SideEffects::DefaultResource::get());
+// CHECK:   effects.emplace_back(::mlir::MemoryEffects::Read::get(), getSymbolAttr(), 0, false, false, ::mlir::SideEffects::DefaultResource::get());
+// CHECK:   effects.emplace_back(::mlir::MemoryEffects::Write::get(), getFlatSymbolAttr(), 0, false, false, ::mlir::SideEffects::DefaultResource::get());
 // CHECK:   if (auto symbolRef = getOptionalSymbolAttr())
-// CHECK:     effects.emplace_back(::mlir::MemoryEffects::Read::get(), symbolRef, 0, false, ::mlir::SideEffects::DefaultResource::get());
+// CHECK:     effects.emplace_back(::mlir::MemoryEffects::Read::get(), symbolRef, 0, false, false, ::mlir::SideEffects::DefaultResource::get());
 // CHECK:   for (::mlir::Value value : getODSResults(0))
-// CHECK:     effects.emplace_back(::mlir::MemoryEffects::Allocate::get(), value, 0, false, CustomResource::get());
+// CHECK:     effects.emplace_back(::mlir::MemoryEffects::Allocate::get(), value, 0, false, false, CustomResource::get());
 
 // CHECK: void SideEffectOpB::getEffects
-// CHECK:   effects.emplace_back(::mlir::MemoryEffects::Write::get(), 0, false, CustomResource::get());
+// CHECK:   effects.emplace_back(::mlir::MemoryEffects::Write::get(), 0, false, false, CustomResource::get());
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 842964b853d084d..71f06c292fe5cf7 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3304,9 +3304,10 @@ void OpEmitter::genSideEffectInterfaceMethods() {
   // {1}: Optional value or symbol reference.
   // {2}: The side effect stage.
   // {3}: Does this side effect act on every single value of resource.
-  // {4}: The resource class.
+  // {4}: Is asynchronous
+  // {5}: The resource class.
   const char *addEffectCode =
-      "  effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n";
+      "  effects.emplace_back({0}::get(), {1}{2}, {3}, {4}, {5}::get());\n";
 
   for (auto &it : interfaceEffects) {
     // Generate the 'getEffects' method.
@@ -3325,10 +3326,11 @@ void OpEmitter::genSideEffectInterfaceMethods() {
       StringRef resource = location.effect.getResource();
       int stage = (int)location.effect.getStage();
       bool effectOnFullRegion = (int)location.effect.getEffectOnfullRegion();
+      bool async = (int)location.effect.getAsynchronous();
       if (location.kind == EffectKind::Static) {
         // A static instance has no attached value.
         body << llvm::formatv(addEffectCode, effect, "", stage,
-                              effectOnFullRegion, resource)
+                              effectOnFullRegion, async, resource)
                     .str();
       } else if (location.kind == EffectKind::Symbol) {
         // A symbol reference requires adding the proper attribute.
@@ -3337,11 +3339,11 @@ void OpEmitter::genSideEffectInterfaceMethods() {
         if (attr->attr.isOptional()) {
           body << "  if (auto symbolRef = " << argName << "Attr())\n  "
                << llvm::formatv(addEffectCode, effect, "symbolRef, ", stage,
-                                effectOnFullRegion, resource)
+                                effectOnFullRegion, async, resource)
                       .str();
         } else {
           body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ",
-                                stage, effectOnFullRegion, resource)
+                                stage, effectOnFullRegion, async, resource)
                       .str();
         }
       } else {
@@ -3350,7 +3352,7 @@ void OpEmitter::genSideEffectInterfaceMethods() {
              << (location.kind == EffectKind::Operand ? "Operands" : "Results")
              << "(" << location.index << "))\n  "
              << llvm::formatv(addEffectCode, effect, "value, ", stage,
-                              effectOnFullRegion, resource)
+                              effectOnFullRegion, async, resource)
                     .str();
       }
     }



More information about the Mlir-commits mailing list