[Mlir-commits] [mlir] [MLIR] Modify lowering of gpu.alloc op to llvm (PR #69969)

Nishant Patel llvmlistbot at llvm.org
Tue Oct 24 11:13:01 PDT 2023


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/69969

>From 6665a452e895bfe16fd717c9f7b73d3cc6b8348a Mon Sep 17 00:00:00 2001
From: Nishant Patel <nishant.b.patel at intel.com>
Date: Mon, 23 Oct 2023 20:43:43 +0000
Subject: [PATCH 1/2] [MLIR] Modify lowering of gpu.alloc op to llvm

---
 .../GPUCommon/GPUToLLVMConversion.cpp          | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 097caf23edfa5dd..da1c468ed1dfd71 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -836,7 +836,11 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
   // Allocate the underlying buffer and store a pointer to it in the MemRef
   // descriptor.
   Type elementPtrType = this->getElementPtrType(memRefType);
-  auto stream = adaptor.getAsyncDependencies().front();
+
+  Value stream =
+      adaptor.getAsyncDependencies().empty()
+          ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
+          : adaptor.getAsyncDependencies().front();
 
   auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
       loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
@@ -855,7 +859,17 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto memRefDescriptor = this->createMemRefDescriptor(
       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
 
-  rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
+  if (allocOp.getAsyncToken()) {
+    // Async alloc: make dependent ops use the same stream.
+    rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
+  } else {
+    // Synchronize with host and destroy stream. This must be the stream created
+    // above (with no other uses) because we check that the synchronous version
+    // does not have any async dependencies.
+    streamSynchronizeCallBuilder.create(loc, rewriter, {stream});
+    streamDestroyCallBuilder.create(loc, rewriter, {stream});
+    rewriter.replaceOp(allocOp, {memRefDescriptor});
+  }
 
   return success();
 }

>From 18dce073686693224ba23b328820af6e91dcbde4 Mon Sep 17 00:00:00 2001
From: Nishant Patel <nishant.b.patel at intel.com>
Date: Tue, 24 Oct 2023 18:05:25 +0000
Subject: [PATCH 2/2] Pass nullptr for stream for sync execution

---
 .../Conversion/GPUCommon/GPUToLLVMConversion.cpp    | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index da1c468ed1dfd71..12bd02050be036c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -837,10 +837,10 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
   // descriptor.
   Type elementPtrType = this->getElementPtrType(memRefType);
 
-  Value stream =
-      adaptor.getAsyncDependencies().empty()
-          ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
-          : adaptor.getAsyncDependencies().front();
+  auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
+  Value stream = adaptor.getAsyncDependencies().empty()
+                     ? nullPtr
+                     : adaptor.getAsyncDependencies().front();
 
   auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
       loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
@@ -863,11 +863,6 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
     // Async alloc: make dependent ops use the same stream.
     rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
   } else {
-    // Synchronize with host and destroy stream. This must be the stream created
-    // above (with no other uses) because we check that the synchronous version
-    // does not have any async dependencies.
-    streamSynchronizeCallBuilder.create(loc, rewriter, {stream});
-    streamDestroyCallBuilder.create(loc, rewriter, {stream});
     rewriter.replaceOp(allocOp, {memRefDescriptor});
   }
 



More information about the Mlir-commits mailing list