[Mlir-commits] [mlir] [MLIR][AsyncToLLVM] Remove typed pointer support (PR #70731)

Christian Ulmann llvmlistbot at llvm.org
Mon Oct 30 15:12:53 PDT 2023


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/70731

>From 7672173e465d7596cdf12ab89825433db25f62fa Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 30 Oct 2023 21:46:25 +0000
Subject: [PATCH 1/2] [MLIR][AsyncToLLVM] Remove typed pointer support

This commit removes the support for lowering Async to LLVM dialect with
typed pointers. Typed pointers have been deprecated for a while now and
it's planned to soon remove them from the LLVM dialect.

Related PSA: https://discourse.llvm.org/t/psa-removal-of-typed-pointers-from-the-llvm-dialect/74502
---
 mlir/include/mlir/Conversion/Passes.td        |   5 -
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   |   5 +-
 .../Conversion/AsyncToLLVM/AsyncToLLVM.cpp    | 239 ++++++------------
 .../AsyncToLLVM/convert-coro-to-llvm.mlir     |   2 +-
 .../AsyncToLLVM/convert-runtime-to-llvm.mlir  |   2 +-
 .../AsyncToLLVM/convert-to-llvm.mlir          |   2 +-
 .../AsyncToLLVM/typed-pointers.mlir           | 138 ----------
 7 files changed, 81 insertions(+), 312 deletions(-)
 delete mode 100644 mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cf6e545749ffc64..74ac7135083f853 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -191,11 +191,6 @@ def ConvertAsyncToLLVMPass : Pass<"convert-async-to-llvm", "ModuleOp"> {
     "LLVM::LLVMDialect",
     "func::FuncDialect",
   ];
-  let options = [
-    Option<"useOpaquePointers", "use-opaque-pointers", "bool",
-           /*default=*/"true", "Generate LLVM IR using opaque pointers "
-           "instead of typed pointers">,
-  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 9e69717f471bce2..05320c0c7186907 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -52,8 +52,9 @@ LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType,
                                         bool opaquePointers);
 LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, Type indexType,
-                                              bool opaquePointers);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp, bool opaquePointers);
+                                              bool opaquePointers = true);
+LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp,
+                                      bool opaquePointers = true);
 LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp, Type indexType,
                                               bool opaquePointers);
 LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index d9ea60a6749d926..3e61c9c7de50e2f 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -76,20 +76,16 @@ namespace {
 /// lowering all async data types become opaque pointers at runtime.
 struct AsyncAPI {
   // All async types are lowered to opaque LLVM pointers at runtime.
-  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx,
-                                                 bool useLLVMOpaquePointers) {
-    if (useLLVMOpaquePointers)
-      return LLVM::LLVMPointerType::get(ctx);
-    return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
+  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
+    return LLVM::LLVMPointerType::get(ctx);
   }
 
   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
     return LLVM::LLVMTokenType::get(ctx);
   }
 
-  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx,
-                                               bool useLLVMOpaquePointers) {
-    auto ref = opaquePointerType(ctx, useLLVMOpaquePointers);
+  static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
+    auto ref = opaquePointerType(ctx);
     auto count = IntegerType::get(ctx, 64);
     return FunctionType::get(ctx, {ref, count}, {});
   }
@@ -98,10 +94,9 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
   }
 
-  static FunctionType createValueFunctionType(MLIRContext *ctx,
-                                              bool useLLVMOpaquePointers) {
+  static FunctionType createValueFunctionType(MLIRContext *ctx) {
     auto i64 = IntegerType::get(ctx, 64);
-    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
+    auto value = opaquePointerType(ctx);
     return FunctionType::get(ctx, {i64}, {value});
   }
 
@@ -110,10 +105,9 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
   }
 
-  static FunctionType getValueStorageFunctionType(MLIRContext *ctx,
-                                                  bool useLLVMOpaquePointers) {
-    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
-    auto storage = opaquePointerType(ctx, useLLVMOpaquePointers);
+  static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
+    auto storage = opaquePointerType(ctx);
     return FunctionType::get(ctx, {value}, {storage});
   }
 
@@ -121,9 +115,8 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
 
-  static FunctionType emplaceValueFunctionType(MLIRContext *ctx,
-                                               bool useLLVMOpaquePointers) {
-    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
+  static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
     return FunctionType::get(ctx, {value}, {});
   }
 
@@ -131,9 +124,8 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
 
-  static FunctionType setValueErrorFunctionType(MLIRContext *ctx,
-                                                bool useLLVMOpaquePointers) {
-    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
+  static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
     return FunctionType::get(ctx, {value}, {});
   }
 
@@ -142,9 +134,8 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
   }
 
-  static FunctionType isValueErrorFunctionType(MLIRContext *ctx,
-                                               bool useLLVMOpaquePointers) {
-    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
+  static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
     auto i1 = IntegerType::get(ctx, 1);
     return FunctionType::get(ctx, {value}, {i1});
   }
@@ -158,9 +149,8 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
 
-  static FunctionType awaitValueFunctionType(MLIRContext *ctx,
-                                             bool useLLVMOpaquePointers) {
-    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
+  static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
     return FunctionType::get(ctx, {value}, {});
   }
 
@@ -168,15 +158,9 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
   }
 
-  static FunctionType executeFunctionType(MLIRContext *ctx,
-                                          bool useLLVMOpaquePointers) {
-    auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
-    Type resume;
-    if (useLLVMOpaquePointers)
-      resume = LLVM::LLVMPointerType::get(ctx);
-    else
-      resume = LLVM::LLVMPointerType::get(
-          resumeFunctionType(ctx, useLLVMOpaquePointers));
+  static FunctionType executeFunctionType(MLIRContext *ctx) {
+    auto hdl = opaquePointerType(ctx);
+    Type resume = AsyncAPI::opaquePointerType(ctx);
     return FunctionType::get(ctx, {hdl, resume}, {});
   }
 
@@ -186,42 +170,22 @@ struct AsyncAPI {
                              {i64});
   }
 
-  static FunctionType
-  awaitTokenAndExecuteFunctionType(MLIRContext *ctx,
-                                   bool useLLVMOpaquePointers) {
-    auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
-    Type resume;
-    if (useLLVMOpaquePointers)
-      resume = LLVM::LLVMPointerType::get(ctx);
-    else
-      resume = LLVM::LLVMPointerType::get(
-          resumeFunctionType(ctx, useLLVMOpaquePointers));
+  static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
+    auto hdl = opaquePointerType(ctx);
+    Type resume = AsyncAPI::opaquePointerType(ctx);
     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
   }
 
-  static FunctionType
-  awaitValueAndExecuteFunctionType(MLIRContext *ctx,
-                                   bool useLLVMOpaquePointers) {
-    auto value = opaquePointerType(ctx, useLLVMOpaquePointers);
-    auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
-    Type resume;
-    if (useLLVMOpaquePointers)
-      resume = LLVM::LLVMPointerType::get(ctx);
-    else
-      resume = LLVM::LLVMPointerType::get(
-          resumeFunctionType(ctx, useLLVMOpaquePointers));
+  static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
+    auto hdl = opaquePointerType(ctx);
+    Type resume = AsyncAPI::opaquePointerType(ctx);
     return FunctionType::get(ctx, {value, hdl, resume}, {});
   }
 
-  static FunctionType
-  awaitAllAndExecuteFunctionType(MLIRContext *ctx, bool useLLVMOpaquePointers) {
-    auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers);
-    Type resume;
-    if (useLLVMOpaquePointers)
-      resume = LLVM::LLVMPointerType::get(ctx);
-    else
-      resume = LLVM::LLVMPointerType::get(
-          resumeFunctionType(ctx, useLLVMOpaquePointers));
+  static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
+    auto hdl = opaquePointerType(ctx);
+    Type resume = AsyncAPI::opaquePointerType(ctx);
     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
   }
 
@@ -230,17 +194,16 @@ struct AsyncAPI {
   }
 
   // Auxiliary coroutine resume intrinsic wrapper.
-  static Type resumeFunctionType(MLIRContext *ctx, bool useLLVMOpaquePointers) {
+  static Type resumeFunctionType(MLIRContext *ctx) {
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
-    auto ptrType = opaquePointerType(ctx, useLLVMOpaquePointers);
+    auto ptrType = opaquePointerType(ctx);
     return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
   }
 };
 } // namespace
 
 /// Adds Async Runtime C API declarations to the module.
-static void addAsyncRuntimeApiDeclarations(ModuleOp module,
-                                           bool useLLVMOpaquePointers) {
+static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
   auto builder =
       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
 
@@ -251,39 +214,30 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module,
   };
 
   MLIRContext *ctx = module.getContext();
-  addFuncDecl(kAddRef,
-              AsyncAPI::addOrDropRefFunctionType(ctx, useLLVMOpaquePointers));
-  addFuncDecl(kDropRef,
-              AsyncAPI::addOrDropRefFunctionType(ctx, useLLVMOpaquePointers));
+  addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
+  addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
-  addFuncDecl(kCreateValue,
-              AsyncAPI::createValueFunctionType(ctx, useLLVMOpaquePointers));
+  addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
-  addFuncDecl(kEmplaceValue,
-              AsyncAPI::emplaceValueFunctionType(ctx, useLLVMOpaquePointers));
+  addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
-  addFuncDecl(kSetValueError,
-              AsyncAPI::setValueErrorFunctionType(ctx, useLLVMOpaquePointers));
+  addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
-  addFuncDecl(kIsValueError,
-              AsyncAPI::isValueErrorFunctionType(ctx, useLLVMOpaquePointers));
+  addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
-  addFuncDecl(kAwaitValue,
-              AsyncAPI::awaitValueFunctionType(ctx, useLLVMOpaquePointers));
+  addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
-  addFuncDecl(kExecute,
-              AsyncAPI::executeFunctionType(ctx, useLLVMOpaquePointers));
-  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(
-                                    ctx, useLLVMOpaquePointers));
+  addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
+  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
-  addFuncDecl(kAwaitTokenAndExecute, AsyncAPI::awaitTokenAndExecuteFunctionType(
-                                         ctx, useLLVMOpaquePointers));
-  addFuncDecl(kAwaitValueAndExecute, AsyncAPI::awaitValueAndExecuteFunctionType(
-                                         ctx, useLLVMOpaquePointers));
-  addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(
-                                       ctx, useLLVMOpaquePointers));
+  addFuncDecl(kAwaitTokenAndExecute,
+              AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
+  addFuncDecl(kAwaitValueAndExecute,
+              AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
+  addFuncDecl(kAwaitAllAndExecute,
+              AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
   addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
 }
 
@@ -296,7 +250,7 @@ static constexpr const char *kResume = "__resume";
 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
 /// intrinsics. We need this function to be able to pass it to the async
 /// runtime execute API.
-static void addResumeFunction(ModuleOp module, bool useOpaquePointers) {
+static void addResumeFunction(ModuleOp module) {
   if (module.lookupSymbol(kResume))
     return;
 
@@ -305,11 +259,7 @@ static void addResumeFunction(ModuleOp module, bool useOpaquePointers) {
   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
 
   auto voidTy = LLVM::LLVMVoidType::get(ctx);
-  Type ptrType;
-  if (useOpaquePointers)
-    ptrType = LLVM::LLVMPointerType::get(ctx);
-  else
-    ptrType = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
+  Type ptrType = AsyncAPI::opaquePointerType(ctx);
 
   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
       kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
@@ -330,15 +280,10 @@ namespace {
 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
 /// their runtime type (opaque pointers) and does not convert any other types.
 class AsyncRuntimeTypeConverter : public TypeConverter {
-  bool llvmOpaquePointers = false;
-
 public:
-  AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options)
-      : llvmOpaquePointers(options.useOpaquePointers) {
+  AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
     addConversion([](Type type) { return type; });
-    addConversion([this](Type type) {
-      return convertAsyncTypes(type, llvmOpaquePointers);
-    });
+    addConversion([](Type type) { return convertAsyncTypes(type); });
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
@@ -352,28 +297,14 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
     addTargetMaterialization(addUnrealizedCast);
   }
 
-  /// Returns whether LLVM opaque pointers should be used instead of typed
-  /// pointers.
-  bool useOpaquePointers() const { return llvmOpaquePointers; }
-
-  /// Creates an LLVM pointer type which may either be a typed pointer or an
-  /// opaque pointer, depending on what options the converter was constructed
-  /// with.
-  LLVM::LLVMPointerType getPointerType(Type elementType) const {
-    if (llvmOpaquePointers)
-      return LLVM::LLVMPointerType::get(elementType.getContext());
-    return LLVM::LLVMPointerType::get(elementType);
-  }
-
-  static std::optional<Type> convertAsyncTypes(Type type,
-                                               bool useOpaquePointers) {
+  static std::optional<Type> convertAsyncTypes(Type type) {
     if (isa<TokenType, GroupType, ValueType>(type))
-      return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers);
+      return AsyncAPI::opaquePointerType(type.getContext());
 
     if (isa<CoroIdType, CoroStateType>(type))
       return AsyncAPI::tokenType(type.getContext());
     if (isa<CoroHandleType>(type))
-      return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers);
+      return AsyncAPI::opaquePointerType(type.getContext());
 
     return std::nullopt;
   }
@@ -414,8 +345,7 @@ class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto token = AsyncAPI::tokenType(op->getContext());
-    auto ptrType = AsyncAPI::opaquePointerType(
-        op->getContext(), getTypeConverter()->useOpaquePointers());
+    auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
     auto loc = op->getLoc();
 
     // Constants for initializing coroutine frame.
@@ -444,8 +374,7 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
   LogicalResult
   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto ptrType = AsyncAPI::opaquePointerType(
-        op->getContext(), getTypeConverter()->useOpaquePointers());
+    auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
     auto loc = op->getLoc();
 
     // Get coroutine frame size: @llvm.coro.size.i64.
@@ -472,8 +401,7 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
 
     // Allocate memory for the coroutine frame.
     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
-        op->getParentOfType<ModuleOp>(), rewriter.getI64Type(),
-        getTypeConverter()->useOpaquePointers());
+        op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
     auto coroAlloc = rewriter.create<LLVM::CallOp>(
         loc, allocFuncOp, ValueRange{coroAlign, coroSize});
 
@@ -499,8 +427,7 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
   LogicalResult
   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto ptrType = AsyncAPI::opaquePointerType(
-        op->getContext(), getTypeConverter()->useOpaquePointers());
+    auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
     auto loc = op->getLoc();
 
     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
@@ -509,8 +436,7 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
 
     // Free the memory.
     auto freeFuncOp =
-        LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>(),
-                                   getTypeConverter()->useOpaquePointers());
+        LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
                                               ValueRange(coroMem.getResult()));
 
@@ -538,8 +464,9 @@ class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
 
     // Mark the end of a coroutine: @llvm.coro.end.
     auto coroHdl = adaptor.getHandle();
-    rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
-                                     ValueRange({coroHdl, constFalse, noneToken}));
+    rewriter.create<LLVM::CoroEndOp>(
+        op->getLoc(), rewriter.getI1Type(),
+        ValueRange({coroHdl, constFalse, noneToken}));
     rewriter.eraseOp(op);
 
     return success();
@@ -673,7 +600,8 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
         auto i64 = rewriter.getI64Type();
 
         auto storedType = converter->convertType(valueType.getValueType());
-        auto storagePtrType = getTypeConverter()->getPointerType(storedType);
+        auto storagePtrType =
+            AsyncAPI::opaquePointerType(rewriter.getContext());
 
         // %Size = getelementptr %T* null, int 1
         // %SizeI = ptrtoint %T* %Size to i64
@@ -846,12 +774,10 @@ class RuntimeAwaitAndResumeOpLowering
     Value handle = adaptor.getHandle();
 
     // A pointer to coroutine resume intrinsic wrapper.
-    addResumeFunction(op->getParentOfType<ModuleOp>(),
-                      getTypeConverter()->useOpaquePointers());
-    auto resumeFnTy = AsyncAPI::resumeFunctionType(
-        op->getContext(), getTypeConverter()->useOpaquePointers());
+    addResumeFunction(op->getParentOfType<ModuleOp>());
     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
-        op->getLoc(), getTypeConverter()->getPointerType(resumeFnTy), kResume);
+        op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
+        kResume);
 
     rewriter.create<func::CallOp>(
         op->getLoc(), apiFuncName, TypeRange(),
@@ -877,12 +803,10 @@ class RuntimeResumeOpLowering
   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // A pointer to coroutine resume intrinsic wrapper.
-    addResumeFunction(op->getParentOfType<ModuleOp>(),
-                      getTypeConverter()->useOpaquePointers());
-    auto resumeFnTy = AsyncAPI::resumeFunctionType(
-        op->getContext(), getTypeConverter()->useOpaquePointers());
+    addResumeFunction(op->getParentOfType<ModuleOp>());
     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
-        op->getLoc(), getTypeConverter()->getPointerType(resumeFnTy), kResume);
+        op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
+        kResume);
 
     // Call async runtime API to execute a coroutine in the managed thread.
     auto coroHdl = adaptor.getHandle();
@@ -909,8 +833,7 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
     Location loc = op->getLoc();
 
     // Get a pointer to the async value storage from the runtime.
-    auto ptrType = AsyncAPI::opaquePointerType(
-        rewriter.getContext(), getTypeConverter()->useOpaquePointers());
+    auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
     auto storage = adaptor.getStorage();
     auto storagePtr = rewriter.create<func::CallOp>(
         loc, kGetValueStorage, TypeRange(ptrType), storage);
@@ -923,11 +846,6 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
           op, "failed to convert stored value type to LLVM type");
 
     Value castedStoragePtr = storagePtr.getResult(0);
-    if (!getTypeConverter()->useOpaquePointers())
-      castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
-          loc, getTypeConverter()->getPointerType(llvmValueType),
-          castedStoragePtr);
-
     // Store the yielded value into the async value storage.
     auto value = adaptor.getValue();
     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
@@ -955,8 +873,7 @@ class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
     Location loc = op->getLoc();
 
     // Get a pointer to the async value storage from the runtime.
-    auto ptrType = AsyncAPI::opaquePointerType(
-        rewriter.getContext(), getTypeConverter()->useOpaquePointers());
+    auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
     auto storage = adaptor.getStorage();
     auto storagePtr = rewriter.create<func::CallOp>(
         loc, kGetValueStorage, TypeRange(ptrType), storage);
@@ -969,10 +886,6 @@ class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
           op, "failed to convert loaded value type to LLVM type");
 
     Value castedStoragePtr = storagePtr.getResult(0);
-    if (!getTypeConverter()->useOpaquePointers())
-      castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
-          loc, getTypeConverter()->getPointerType(llvmValueType),
-          castedStoragePtr);
 
     // Load from the casted pointer.
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
@@ -1115,12 +1028,11 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   MLIRContext *ctx = module->getContext();
 
   LowerToLLVMOptions options(ctx);
-  options.useOpaquePointers = useOpaquePointers;
 
   // Add declarations for most functions required by the coroutines lowering.
   // We delay adding the resume function until it's needed because it currently
   // fails to compile unless '-O0' is specified.
-  addAsyncRuntimeApiDeclarations(module, useOpaquePointers);
+  addAsyncRuntimeApiDeclarations(module);
 
   // Lower async.runtime and async.coro operations to Async Runtime API and
   // LLVM coroutine intrinsics.
@@ -1133,8 +1045,7 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   // operations.
   LLVMTypeConverter llvmConverter(ctx, options);
   llvmConverter.addConversion([&](Type type) {
-    return AsyncRuntimeTypeConverter::convertAsyncTypes(
-        type, llvmConverter.useOpaquePointers());
+    return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
   });
 
   // Convert async types in function signatures and function calls.
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
index 8a611cf96f5b5f8..a398bc5710a865c 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s
+// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s
 
 // CHECK-LABEL: @coro_id
 func.func @coro_id() {
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
index 3672be91bbc07ad..4077edc7420dca1 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s --dump-input=always
+// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always
 
 // CHECK-LABEL: @create_token
 func.func @create_token() {
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index fd419dc95e7a1aa..dd54bdb79872441 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm | FileCheck %s
 
 // CHECK-LABEL: reference_counting
 func.func @reference_counting(%arg0: !async.token) {
diff --git a/mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir b/mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir
deleted file mode 100644
index 07cd2add3b15122..000000000000000
--- a/mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir
+++ /dev/null
@@ -1,138 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm='use-opaque-pointers=0' | FileCheck %s
-
-
-
-// CHECK-LABEL: @store
-func.func @store() {
-  // CHECK: %[[CST:.*]] = arith.constant 1.0
-  %0 = arith.constant 1.0 : f32
-  // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
-  %1 = async.runtime.create : !async.value<f32>
-  // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
-  // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr<i8> to !llvm.ptr<f32>
-  // CHECK: llvm.store %[[CST]], %[[P1]]
-  async.runtime.store %0, %1 : !async.value<f32>
-  return
-}
-
-// CHECK-LABEL: @load
-func.func @load() -> f32 {
-  // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
-  %0 = async.runtime.create : !async.value<f32>
-  // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
-  // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr<i8> to !llvm.ptr<f32>
-  // CHECK: %[[VALUE:.*]] = llvm.load %[[P1]]
-  %1 = async.runtime.load %0 : !async.value<f32>
-  // CHECK: return %[[VALUE]] : f32
-  return %1 : f32
-}
-
-// -----
-
-// CHECK-LABEL: execute_no_async_args
-func.func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
-  // CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1)
-  %token = async.execute {
-    %c0 = arith.constant 0 : index
-    memref.store %arg0, %arg1[%c0] : memref<1xf32>
-    async.yield
-  }
-  // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]])
-  // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
-  // CHECK: %[[TRUE:.*]] = arith.constant true
-  // CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1
-  // CHECK: cf.assert %[[NOT_ERROR]]
-  // CHECK-NEXT: return
-  async.await %token : !async.token
-  return
-}
-
-// Function outlined from the async.execute operation.
-// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
-// CHECK-SAME: -> !llvm.ptr<i8>
-
-// Create token for return op, and mark a function as a coroutine.
-// CHECK: %[[RET:.*]] = call @mlirAsyncRuntimeCreateToken()
-// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
-
-// Pass a suspended coroutine to the async runtime.
-// CHECK: %[[STATE:.*]] = llvm.intr.coro.save
-// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
-// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], %[[RESUME]])
-// CHECK: %[[SUSPENDED:.*]] = llvm.intr.coro.suspend %[[STATE]]
-
-// Decide the next block based on the code returned from suspend.
-// CHECK: %[[SEXT:.*]] = llvm.sext %[[SUSPENDED]] : i8 to i32
-// CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
-// CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
-// CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
-
-// Resume coroutine after suspension.
-// CHECK: ^[[RESUME]]:
-// CHECK: memref.store %arg0, %arg1[%c0] : memref<1xf32>
-// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET]])
-
-// Delete coroutine.
-// CHECK: ^[[CLEANUP]]:
-// CHECK: %[[MEM:.*]] = llvm.intr.coro.free
-// CHECK: llvm.call @free(%[[MEM]])
-
-// Suspend coroutine, and also a return statement for ramp function.
-// CHECK: ^[[SUSPEND]]:
-// CHECK: llvm.intr.coro.end
-// CHECK: return %[[RET]]
-
-// -----
-
-// CHECK-LABEL: execute_and_return_f32
-func.func @execute_and_return_f32() -> f32 {
- // CHECK: %[[RET:.*]]:2 = call @async_execute_fn
-  %token, %result = async.execute -> !async.value<f32> {
-    %c0 = arith.constant 123.0 : f32
-    async.yield %c0 : f32
-  }
-
-  // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[RET]]#1)
-  // CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
-  // CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] :  !llvm.ptr<f32>
-  %0 = async.await %result : !async.value<f32>
-
-  return %0 : f32
-}
-
-// Function outlined from the async.execute operation.
-// CHECK-LABEL: func private @async_execute_fn()
-// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken()
-// CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
-// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
-
-// Suspend coroutine in the beginning.
-// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]],
-// CHECK: llvm.intr.coro.suspend
-
-// Emplace result value.
-// CHECK: %[[CST:.*]] = arith.constant 1.230000e+02 : f32
-// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
-// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
-// CHECK: llvm.store %[[CST]], %[[ST_F32]] : !llvm.ptr<f32>
-// CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]])
-
-// Emplace result token.
-// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]])
-
-// -----
-
-// CHECK-LABEL: @await_and_resume_group
-func.func @await_and_resume_group() {
-  %c = arith.constant 1 : index
-  %0 = async.coro.id
-  // CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
-  %1 = async.coro.begin %0
-  // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup
-  %2 = async.runtime.create_group %c : !async.group
-  // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
-  // CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute
-  // CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]])
-  async.runtime.await_and_resume %2, %1 : !async.group
-  return
-}

>From 72b032925975c32a5872d71c6af0de55f64a82e4 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 30 Oct 2023 22:12:39 +0000
Subject: [PATCH 2/2] address review comments

---
 .../Conversion/AsyncToLLVM/AsyncToLLVM.cpp    | 26 +++++++------------
 1 file changed, 10 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 3e61c9c7de50e2f..0ab53ce7e3327e4 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -106,9 +106,8 @@ struct AsyncAPI {
   }
 
   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
-    auto value = opaquePointerType(ctx);
-    auto storage = opaquePointerType(ctx);
-    return FunctionType::get(ctx, {value}, {storage});
+    auto ptrType = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {ptrType}, {ptrType});
   }
 
   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
@@ -159,9 +158,8 @@ struct AsyncAPI {
   }
 
   static FunctionType executeFunctionType(MLIRContext *ctx) {
-    auto hdl = opaquePointerType(ctx);
-    Type resume = AsyncAPI::opaquePointerType(ctx);
-    return FunctionType::get(ctx, {hdl, resume}, {});
+    auto ptrType = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {ptrType, ptrType}, {});
   }
 
   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
@@ -171,22 +169,18 @@ struct AsyncAPI {
   }
 
   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
-    auto hdl = opaquePointerType(ctx);
-    Type resume = AsyncAPI::opaquePointerType(ctx);
-    return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
+    auto ptrType = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
   }
 
   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
-    auto value = opaquePointerType(ctx);
-    auto hdl = opaquePointerType(ctx);
-    Type resume = AsyncAPI::opaquePointerType(ctx);
-    return FunctionType::get(ctx, {value, hdl, resume}, {});
+    auto ptrType = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
   }
 
   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
-    auto hdl = opaquePointerType(ctx);
-    Type resume = AsyncAPI::opaquePointerType(ctx);
-    return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
+    auto ptrType = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
   }
 
   static FunctionType getNumWorkerThreads(MLIRContext *ctx) {



More information about the Mlir-commits mailing list