[Mlir-commits] [mlir] [MLIR][LLVM] Change addressof builders to use opaque pointers (PR #69215)

Christian Ulmann llvmlistbot at llvm.org
Mon Oct 16 22:35:14 PDT 2023


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/69215

>From f1d4ee14d99fcaba27fc9e42be711487a2116bbf Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 16 Oct 2023 14:29:36 +0000
Subject: [PATCH 1/3] [MLIR][LLVM] Change addressof builders to use opaque
 pointers

This commit changes the builders of the `llvm.mlir.addressof` operations
to no longer produce typed pointers.

As a consequence, a GPU to NVVM pattern had to be updated, that still
relied on typed pointers.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  4 +--
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 19 ++++++-------
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     | 28 +++++++++----------
 3 files changed, 24 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8745d14c8d48318..2a572ab4de706a3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1071,7 +1071,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
-            LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()),
+            LLVM::LLVMPointerType::get($_builder.getContext(), global.getAddrSpace()),
             global.getSymName());
       $_state.addAttributes(attrs);
     }]>,
@@ -1079,7 +1079,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
-            LLVM::LLVMPointerType::get(func.getFunctionType()), func.getName());
+            LLVM::LLVMPointerType::get($_builder.getContext()), func.getName());
       $_state.addAttributes(attrs);
     }]>
   ];
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 96d8fceba706617..59823c6605fe25d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -441,7 +441,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   Location loc = gpuPrintfOp->getLoc();
 
   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
-  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
+  mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
 
   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
   // This ensures that global constants and declarations are placed within
@@ -449,7 +449,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
 
   auto vprintfType =
-      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
+      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
   LLVM::LLVMFuncOp vprintfDecl =
       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
 
@@ -473,7 +473,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
   Value stringStart = rewriter.create<LLVM::GEPOp>(
-      loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+      loc, ptrType, ptrType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
   SmallVector<Type> types;
   SmallVector<Value> args;
   // Promote and pack the arguments into a stack allocation.
@@ -490,18 +490,17 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   }
   Type structType =
       LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
-  Type structPtrType = LLVM::LLVMPointerType::get(structType);
   Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
                                                 rewriter.getIndexAttr(1));
-  Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
-                                                    /*alignment=*/0);
+  Value tempAlloc =
+      rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
+                                      /*alignment=*/0);
   for (auto [index, arg] : llvm::enumerate(args)) {
-    Value ptr = rewriter.create<LLVM::GEPOp>(
-        loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
-        ArrayRef<LLVM::GEPArg>{0, index});
+    Value ptr =
+        rewriter.create<LLVM::GEPOp>(loc, ptrType, arg.getType(), tempAlloc,
+                                     ArrayRef<LLVM::GEPArg>{0, index});
     rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
   }
-  tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
   std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
 
   rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 391ccd74841dca4..a33a0797aa56507 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -542,16 +542,15 @@ gpu.module @test_module_28 {
 gpu.module @test_module_29 {
   // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
   // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
-  // CHECK-DAG: llvm.func @vprintf(!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
+  // CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32
 
   // CHECK-LABEL: func @test_const_printf
   gpu.func @test_const_printf() {
-    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
-    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<14 x i8>>) -> !llvm.ptr<i8>
+    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
+    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr
     // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
-    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr<struct<()>>
-    // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
-    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
+    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
+    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
     gpu.printf "Hello, world\n"
     gpu.return
   }
@@ -559,17 +558,16 @@ gpu.module @test_module_29 {
   // CHECK-LABEL: func @test_printf
   // CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
   gpu.func @test_printf(%arg0: i32, %arg1: f32) {
-    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
-    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
+    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr
+    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr
     // CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64
     // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
-    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr<struct<(i32, f64)>>
-    // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<i32>
-    // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr<i32>
-    // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<f64>
-    // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr<f64>
-    // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<(i32, f64)>> to !llvm.ptr<i8>
-    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
+    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr
+    // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr) -> !llvm.ptr
+    // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : i32, !llvm.ptr
+    // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr
+    // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr
+    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
     gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
     gpu.return
   }

>From 2c1bd7107c25a22b1aa0226298c0e34a6d126ea2 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 16 Oct 2023 19:15:03 +0000
Subject: [PATCH 2/3] fix MLIR examples

---
 mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 11 ++++++-----
 mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 10 +++++-----
 2 files changed, 11 insertions(+), 10 deletions(-)

diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 684ce37b2398ce2..41a5be2177bbf36 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
   ///   * `i32 (i8*, ...)`
   static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
     auto llvmI32Ty = IntegerType::get(context, 32);
-    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
-    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+    auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
+    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
                                                   /*isVarArg=*/true);
     return llvmFnType;
   }
@@ -161,10 +161,11 @@ class PrintOpLowering : public ConversionPattern {
     Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
     Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
                                                   builder.getIndexAttr(0));
+    // LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
     return builder.create<LLVM::GEPOp>(
-        loc,
-        LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
-        globalPtr, ArrayRef<Value>({cst0, cst0}));
+        loc, LLVM::LLVMPointerType::get(builder.getContext()),
+        IntegerType::get(builder.getContext(), 8), globalPtr,
+        ArrayRef<Value>({cst0, cst0}));
   }
 };
 } // namespace
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 684ce37b2398ce2..e8c5414f8f387f7 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
   ///   * `i32 (i8*, ...)`
   static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
     auto llvmI32Ty = IntegerType::get(context, 32);
-    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
-    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+    auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
+    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
                                                   /*isVarArg=*/true);
     return llvmFnType;
   }
@@ -162,9 +162,9 @@ class PrintOpLowering : public ConversionPattern {
     Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
                                                   builder.getIndexAttr(0));
     return builder.create<LLVM::GEPOp>(
-        loc,
-        LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
-        globalPtr, ArrayRef<Value>({cst0, cst0}));
+        loc, LLVM::LLVMPointerType::get(builder.getContext()),
+        IntegerType::get(builder.getContext(), 8), globalPtr,
+        ArrayRef<Value>({cst0, cst0}));
   }
 };
 } // namespace

>From bb241cc1b4c350fe2cb2dfef418a902bf4eb7274 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Tue, 17 Oct 2023 05:35:01 +0000
Subject: [PATCH 3/3] remove leftover comment

---
 mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 41a5be2177bbf36..e8c5414f8f387f7 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -161,7 +161,6 @@ class PrintOpLowering : public ConversionPattern {
     Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
     Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
                                                   builder.getIndexAttr(0));
-    // LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
     return builder.create<LLVM::GEPOp>(
         loc, LLVM::LLVMPointerType::get(builder.getContext()),
         IntegerType::get(builder.getContext(), 8), globalPtr,



More information about the Mlir-commits mailing list