[Mlir-commits] [mlir] [MLIR] Modify lowering of gpu.alloc op to llvm (PR #69969)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 23 13:52:22 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
If gpu.alloc has no asyn deependency ( this is true if gpu.alloc has hostShared allocation), create a new stream. This PR is follow up to #<!-- -->66401
---
Full diff: https://github.com/llvm/llvm-project/pull/69969.diff
1 Files Affected:
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+16-2)
``````````diff
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 097caf23edfa5dd..da1c468ed1dfd71 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -836,7 +836,11 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
Type elementPtrType = this->getElementPtrType(memRefType);
- auto stream = adaptor.getAsyncDependencies().front();
+
+ Value stream =
+ adaptor.getAsyncDependencies().empty()
+ ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
+ : adaptor.getAsyncDependencies().front();
auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
@@ -855,7 +859,17 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
auto memRefDescriptor = this->createMemRefDescriptor(
loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
- rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
+ if (allocOp.getAsyncToken()) {
+ // Async alloc: make dependent ops use the same stream.
+ rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
+ } else {
+ // Synchronize with host and destroy stream. This must be the stream created
+ // above (with no other uses) because we check that the synchronous version
+ // does not have any async dependencies.
+ streamSynchronizeCallBuilder.create(loc, rewriter, {stream});
+ streamDestroyCallBuilder.create(loc, rewriter, {stream});
+ rewriter.replaceOp(allocOp, {memRefDescriptor});
+ }
return success();
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/69969
More information about the Mlir-commits
mailing list