[Mlir-commits] [mlir] [mlir][gpu] Fix gpu.host_register lowering and runtime support (PR #170085)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 1 02:32:51 PST 2025


https://github.com/Men-cotton updated https://github.com/llvm/llvm-project/pull/170085

>From 83403cadd5b271dc6bdefd4384cabfdabadd3c5e Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Mon, 1 Dec 2025 01:51:50 +0900
Subject: [PATCH 1/3] [mlir][GPU] Fix gpu.host_register lowering with bare
 pointer calling convention

---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  4 +-
 .../GPUCommon/GPUToLLVMConversion.cpp         | 74 +++++++++++++++----
 .../Transforms/SparseGPUCodegen.cpp           | 16 ++--
 3 files changed, 68 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index a6c6038e1e224..81e8801aae1f0 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1566,7 +1566,7 @@ def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins
 }
 
 def GPU_HostRegisterOp : GPU_Op<"host_register">,
-    Arguments<(ins AnyUnrankedMemRef:$value)> {
+    Arguments<(ins AnyMemRef:$value)> {
   let summary = "Registers a memref for access from device.";
   let description = [{
     This op maps the provided host buffer into the device address space.
@@ -1583,7 +1583,7 @@ def GPU_HostRegisterOp : GPU_Op<"host_register">,
 }
 
 def GPU_HostUnregisterOp : GPU_Op<"host_unregister">,
-    Arguments<(ins AnyUnrankedMemRef:$value)> {
+    Arguments<(ins AnyMemRef:$value)> {
   let summary = "Unregisters a memref for access from device.";
   let description = [{
       This op unmaps the provided host buffer from the device address space.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 5994b64f3d9a5..c34005893ed8c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -709,6 +709,50 @@ isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
   return success();
 }
 
+static LogicalResult prepareHostRegisterUnregisterArguments(
+    Operation *op, Value value, Value adaptorValue,
+    const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter,
+    SmallVectorImpl<Value> &arguments, SmallVectorImpl<Type> &elementTypes) {
+  Location loc = op->getLoc();
+  auto valueType = value.getType();
+
+  if (auto memRefType = dyn_cast<MemRefType>(valueType)) {
+    Type elementType = memRefType.getElementType();
+    elementTypes.push_back(elementType);
+    Type llvmIntPtrType = IntegerType::get(
+        rewriter.getContext(), typeConverter->getPointerBitwidth(0));
+    Value rank = rewriter.create<LLVM::ConstantOp>(
+        loc, llvmIntPtrType,
+        rewriter.getIntegerAttr(llvmIntPtrType, memRefType.getRank()));
+    Value descriptor = adaptorValue;
+    Value descriptorPtr;
+    bool useBarePtrCallConv = typeConverter->getOptions().useBarePtrCallConv;
+
+    if (useBarePtrCallConv) {
+      if (!LLVMTypeConverter::canConvertToBarePtr(memRefType)) {
+        return op->emitError(
+            "cannot lower memref with bare pointer calling convention");
+      }
+
+      if (isa<LLVM::LLVMPointerType>(descriptor.getType()))
+        descriptor = MemRefDescriptor::fromStaticShape(
+            rewriter, loc, *typeConverter, memRefType, descriptor);
+
+      descriptorPtr =
+          typeConverter->promoteOneMemRefDescriptor(loc, descriptor, rewriter);
+    } else {
+      descriptorPtr =
+          typeConverter->promoteOneMemRefDescriptor(loc, descriptor, rewriter);
+    }
+    arguments.push_back(rank);
+    arguments.push_back(descriptorPtr);
+  } else {
+    return rewriter.notifyMatchFailure(op, "expected memref operand");
+  }
+
+  return success();
+}
+
 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -716,14 +760,15 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
     return failure();
 
-  Location loc = op->getLoc();
-
-  auto memRefType = hostRegisterOp.getValue().getType();
-  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
-  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
+  SmallVector<Value> arguments;
+  SmallVector<Type> elementTypes;
+  if (failed(prepareHostRegisterUnregisterArguments(
+          op, hostRegisterOp.getValue(), adaptor.getValue(), getTypeConverter(),
+          rewriter, arguments, elementTypes)))
+    return failure(); // Error already emitted or match failure notified
 
-  auto arguments = getTypeConverter()->promoteOperands(
-      loc, op->getOperands(), adaptor.getOperands(), rewriter);
+  Location loc = op->getLoc();
+  auto elementSize = getSizeInBytes(loc, elementTypes.front(), rewriter);
   arguments.push_back(elementSize);
   hostRegisterCallBuilder.create(loc, rewriter, arguments);
 
@@ -738,14 +783,15 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
     return failure();
 
-  Location loc = op->getLoc();
-
-  auto memRefType = hostUnregisterOp.getValue().getType();
-  auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
-  auto elementSize = getSizeInBytes(loc, elementType, rewriter);
+  SmallVector<Value> arguments;
+  SmallVector<Type> elementTypes;
+  if (failed(prepareHostRegisterUnregisterArguments(
+          op, hostUnregisterOp.getValue(), adaptor.getValue(),
+          getTypeConverter(), rewriter, arguments, elementTypes)))
+    return failure(); // Error already emitted or match failure notified
 
-  auto arguments = getTypeConverter()->promoteOperands(
-      loc, op->getOperands(), adaptor.getOperands(), rewriter);
+  Location loc = op->getLoc();
+  auto elementSize = getSizeInBytes(loc, elementTypes.front(), rewriter);
   arguments.push_back(elementSize);
   hostUnregisterCallBuilder.create(loc, rewriter, arguments);
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 0bd1d34c3504b..4f6ddab0821fb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -109,21 +109,17 @@ static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
 /// Writes from the host are guaranteed to be visible to device kernels
 /// that are launched afterwards. Writes from the device are guaranteed
 /// to be visible on the host after synchronizing with the device kernel
-/// completion. Needs to cast the buffer to a unranked buffer.
+/// completion.
 static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
                                    Value mem) {
-  MemRefType memTp = cast<MemRefType>(mem.getType());
-  UnrankedMemRefType resTp =
-      UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
-  Value cast = memref::CastOp::create(builder, loc, resTp, mem);
-  gpu::HostRegisterOp::create(builder, loc, cast);
-  return cast;
+  gpu::HostRegisterOp::create(builder, loc, mem);
+  return mem;
 }
 
-/// Unmaps the provided buffer, expecting the casted buffer.
+/// Unmaps the provided buffer.
 static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
-                                    Value cast) {
-  gpu::HostUnregisterOp::create(builder, loc, cast);
+                                    Value mem) {
+  gpu::HostUnregisterOp::create(builder, loc, mem);
 }
 
 /// Generates first wait in an asynchronous chain.

>From a79b498df1c06a7af37568d21cc93ccc3b6f8e22 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Mon, 1 Dec 2025 01:53:51 +0900
Subject: [PATCH 2/3] [mlir][GPU] Fix host_register runtime for multi-rank
 memrefs

---
 .../ExecutionEngine/CudaRuntimeWrappers.cpp   | 32 +++++++----
 .../ExecutionEngine/RocmRuntimeWrappers.cpp   | 42 ++++++++-------
 .../host-register-bare-ptr-func.mlir          |  9 ++++
 .../GPUCommon/host-register-bare-ptr.mlir     | 54 +++++++++++++++++++
 .../GPU/CUDA/host-register-ranked-memref.mlir | 18 +++++++
 5 files changed, 126 insertions(+), 29 deletions(-)
 create mode 100644 mlir/test/Conversion/GPUCommon/host-register-bare-ptr-func.mlir
 create mode 100644 mlir/test/Conversion/GPUCommon/host-register-bare-ptr.mlir
 create mode 100644 mlir/test/Integration/GPU/CUDA/host-register-ranked-memref.mlir

diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index f203363e16ea2..b9aa78e22e109 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -284,31 +284,40 @@ mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
   CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
 }
 
-/// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a
-/// ranked memref descriptor struct of rank `rank`. Helpful until we have
+/// Registers a memref with the CUDA runtime. `descriptor` is a pointer to an
+/// unranked memref descriptor struct of rank `rank`. Helpful until we have
 /// transfer functions implemented.
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
+mgpuMemHostRegisterMemRef(int64_t rank, void *descriptor,
                           int64_t elementSizeBytes) {
+  ::UnrankedMemRefType<char> unranked{rank, descriptor};
+  DynamicMemRefType<char> memRef(unranked);
+
+  // Rank-0 memref: single element.
+  if (rank == 0) {
+    auto *ptr = memRef.data + memRef.offset * elementSizeBytes;
+    mgpuMemHostRegister(ptr, elementSizeBytes);
+    return;
+  }
+
   // Only densely packed tensors are currently supported.
 #ifdef _WIN32
   int64_t *denseStrides = (int64_t *)_alloca(rank * sizeof(int64_t));
 #else
   int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t));
 #endif // _WIN32
-  int64_t *sizes = descriptor->sizes;
+  const int64_t *sizes = memRef.sizes;
   for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) {
     denseStrides[i] = runningStride;
     runningStride *= sizes[i];
   }
   uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes;
-  int64_t *strides = &sizes[rank];
-  (void)strides;
-  for (unsigned i = 0; i < rank; ++i)
+  const int64_t *strides = memRef.strides;
+  for (int64_t i = 0; i < rank; ++i)
     assert(strides[i] == denseStrides[i] &&
            "Mismatch in computed dense strides");
 
-  auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+  auto *ptr = memRef.data + memRef.offset * elementSizeBytes;
   mgpuMemHostRegister(ptr, sizeBytes);
 }
 
@@ -321,10 +330,11 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregister(void *ptr) {
 /// Unregisters a memref with the CUDA runtime. `descriptor` is a pointer to a
 /// ranked memref descriptor struct of rank `rank`
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuMemHostUnregisterMemRef(int64_t rank,
-                            StridedMemRefType<char, 1> *descriptor,
+mgpuMemHostUnregisterMemRef(int64_t rank, void *descriptor,
                             int64_t elementSizeBytes) {
-  auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+  ::UnrankedMemRefType<char> unranked{rank, descriptor};
+  DynamicMemRefType<char> memRef(unranked);
+  auto *ptr = memRef.data + memRef.offset * elementSizeBytes;
   mgpuMemHostUnregister(ptr);
 }
 
diff --git a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
index b984149ca6dea..e4267273b20c3 100644
--- a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
@@ -13,7 +13,6 @@
 //===----------------------------------------------------------------------===//
 
 #include <cassert>
-#include <numeric>
 
 #include "mlir/ExecutionEngine/CRunnerUtils.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -143,25 +142,32 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
 
 // Allows to register a MemRef with the ROCm runtime. Helpful until we have
 // transfer functions implemented.
-extern "C" void
-mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
-                          int64_t elementSizeBytes) {
+extern "C" void mgpuMemHostRegisterMemRef(int64_t rank, void *descriptor,
+                                          int64_t elementSizeBytes) {
+  ::UnrankedMemRefType<char> unranked{rank, descriptor};
+  DynamicMemRefType<char> memRef(unranked);
+
+  // Rank-0 memref: single element.
+  if (rank == 0) {
+    auto ptr = memRef.data + memRef.offset * elementSizeBytes;
+    mgpuMemHostRegister(ptr, elementSizeBytes);
+    return;
+  }
 
   llvm::SmallVector<int64_t, 4> denseStrides(rank);
-  llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
-  llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
+  llvm::ArrayRef<int64_t> sizes(memRef.sizes, rank);
+  llvm::ArrayRef<int64_t> strides(memRef.strides, rank);
 
-  std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
-                   std::multiplies<int64_t>());
-  auto sizeBytes = denseStrides.front() * elementSizeBytes;
+  for (int64_t i = rank - 1, runningStride = 1; i >= 0; --i) {
+    denseStrides[i] = runningStride;
+    runningStride *= sizes[i];
+  }
+  auto sizeBytes = sizes.front() * denseStrides.front() * elementSizeBytes;
 
   // Only densely packed tensors are currently supported.
-  std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
-              denseStrides.end());
-  denseStrides.back() = 1;
   assert(strides == llvm::ArrayRef(denseStrides));
 
-  auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+  auto ptr = memRef.data + memRef.offset * elementSizeBytes;
   mgpuMemHostRegister(ptr, sizeBytes);
 }
 
@@ -173,11 +179,11 @@ extern "C" void mgpuMemHostUnregister(void *ptr) {
 
 // Allows to unregister a MemRef with the ROCm runtime. Helpful until we have
 // transfer functions implemented.
-extern "C" void
-mgpuMemHostUnregisterMemRef(int64_t rank,
-                            StridedMemRefType<char, 1> *descriptor,
-                            int64_t elementSizeBytes) {
-  auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
+extern "C" void mgpuMemHostUnregisterMemRef(int64_t rank, void *descriptor,
+                                            int64_t elementSizeBytes) {
+  ::UnrankedMemRefType<char> unranked{rank, descriptor};
+  DynamicMemRefType<char> memRef(unranked);
+  auto ptr = memRef.data + memRef.offset * elementSizeBytes;
   mgpuMemHostUnregister(ptr);
 }
 
diff --git a/mlir/test/Conversion/GPUCommon/host-register-bare-ptr-func.mlir b/mlir/test/Conversion/GPUCommon/host-register-bare-ptr-func.mlir
new file mode 100644
index 0000000000000..8b5c4c4f55d1e
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/host-register-bare-ptr-func.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-host=1" -split-input-file -verify-diagnostics
+
+module attributes {gpu.container_module} {
+  func.func @dynamic(%buf : memref<?xf32>) {
+    // expected-error @+1 {{cannot lower memref with bare pointer calling convention}}
+    gpu.host_register %buf : memref<?xf32>
+    return
+  }
+}
diff --git a/mlir/test/Conversion/GPUCommon/host-register-bare-ptr.mlir b/mlir/test/Conversion/GPUCommon/host-register-bare-ptr.mlir
new file mode 100644
index 0000000000000..e8013f86766d8
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/host-register-bare-ptr.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-host=1" -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=BARE
+
+module attributes {gpu.container_module} {
+  func.func @host_register(%arg0: memref<4x6xf16>) {
+    gpu.host_register %arg0 : memref<4x6xf16>
+    gpu.host_unregister %arg0 : memref<4x6xf16>
+    return
+  }
+}
+
+// BARE-LABEL: llvm.func @host_register
+// BARE-SAME: ({{.*}}: !llvm.ptr) {
+// BARE: %[[DESC0:.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// BARE: %[[DESC1:.+]] = llvm.insertvalue %arg0, %[[DESC0]][0]
+// BARE: %[[DESC2:.+]] = llvm.insertvalue %arg0, %[[DESC1]][1]
+// BARE: %[[OFF:.+]] = llvm.mlir.constant(0 : {{.*}}) : i64
+// BARE: %[[DESC3:.+]] = llvm.insertvalue %[[OFF]], %[[DESC2]][2]
+// BARE: %[[SIZE0:.+]] = llvm.mlir.constant(4 : {{.*}}) : i64
+// BARE: %[[DESC4:.+]] = llvm.insertvalue %[[SIZE0]], %[[DESC3]][3, 0]
+// BARE: %[[STRIDE0:.+]] = llvm.mlir.constant(6 : {{.*}}) : i64
+// BARE: %[[DESC5:.+]] = llvm.insertvalue %[[STRIDE0]], %[[DESC4]][4, 0]
+// BARE: %[[SIZE1:.+]] = llvm.mlir.constant(6 : {{.*}}) : i64
+// BARE: %[[DESC6:.+]] = llvm.insertvalue %[[SIZE1]], %[[DESC5]][3, 1]
+// BARE: %[[STRIDE1:.+]] = llvm.mlir.constant(1 : {{.*}}) : i64
+// BARE: %[[DESC7:.+]] = llvm.insertvalue %[[STRIDE1]], %[[DESC6]][4, 1]
+// BARE: %[[RANK:.+]] = llvm.mlir.constant(2 : {{.*}}) : i64
+// BARE: %[[ALLOCA:.+]] = llvm.alloca %{{.*}} x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// BARE: llvm.store %[[DESC7]], %[[ALLOCA]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, !llvm.ptr
+// BARE: %[[NULL:.+]] = llvm.mlir.zero : !llvm.ptr
+// BARE: %[[GEP:.+]] = llvm.getelementptr %[[NULL]][1] : (!llvm.ptr) -> !llvm.ptr, f16
+// BARE: %[[ELTSZ:.+]] = llvm.ptrtoint %[[GEP]] : !llvm.ptr to i64
+// BARE: llvm.call @mgpuMemHostRegisterMemRef(%[[RANK]], %[[ALLOCA]], %[[ELTSZ]])
+// BARE: llvm.call @mgpuMemHostUnregisterMemRef(%{{.*}}, %{{.*}}, %{{.*}})
+
+// -----
+
+module attributes {gpu.container_module} {
+  func.func @dynamic(%n: index) {
+    %buf = memref.alloc(%n) : memref<?xf32>
+    // expected-error @+1 {{cannot lower memref with bare pointer calling convention}}
+    gpu.host_register %buf : memref<?xf32>
+    return
+  }
+}
+
+// -----
+
+module attributes {gpu.container_module} {
+  func.func @unranked(%arg0: memref<*xf32>) {
+    // expected-error @+1 {{custom op 'gpu.host_register' invalid kind of type specified: expected builtin.memref, but found 'memref<*xf32>'}}
+    gpu.host_register %arg0 : memref<*xf32>
+    return
+  }
+}
diff --git a/mlir/test/Integration/GPU/CUDA/host-register-ranked-memref.mlir b/mlir/test/Integration/GPU/CUDA/host-register-ranked-memref.mlir
new file mode 100644
index 0000000000000..3195ffb8b6f2d
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/host-register-ranked-memref.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --shared-libs=%mlir_c_runner_utils \
+// RUN:   --entry-point-result=void
+
+module attributes {gpu.container_module} {
+  func.func @main() {
+    %0 = memref.alloc() : memref<64x64xf32>
+
+    // Call host_register with a rank-2 memref.
+    gpu.host_register %0 : memref<64x64xf32>
+
+    memref.dealloc %0 : memref<64x64xf32>
+    return
+  }
+}

>From 087f1cf67741648e55de96ddbac5ac2eadcca4d9 Mon Sep 17 00:00:00 2001
From: mencotton <mencotton0410 at gmail.com>
Date: Mon, 1 Dec 2025 19:32:35 +0900
Subject: [PATCH 3/3] Fix: avoid `rewriter.create` usage

---
 mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index c34005893ed8c..8d2fa0324e598 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -721,8 +721,8 @@ static LogicalResult prepareHostRegisterUnregisterArguments(
     elementTypes.push_back(elementType);
     Type llvmIntPtrType = IntegerType::get(
         rewriter.getContext(), typeConverter->getPointerBitwidth(0));
-    Value rank = rewriter.create<LLVM::ConstantOp>(
-        loc, llvmIntPtrType,
+    Value rank = LLVM::ConstantOp::create(
+        rewriter, loc, llvmIntPtrType,
         rewriter.getIntegerAttr(llvmIntPtrType, memRefType.getRank()));
     Value descriptor = adaptorValue;
     Value descriptorPtr;



More information about the Mlir-commits mailing list