[Mlir-commits] [mlir] [mlir] MemrefToLLVM: Support `llvm.address_space` as memory space and do not generate noop `addrspacecast`s (PR #173387)

Ivan Butygin llvmlistbot at llvm.org
Tue Dec 23 07:12:47 PST 2025


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/173387

Having `llvm.address_space` in memref is useful for progressive lowering and `addrspacecast` with same address spaces are forbidden by LLVM spec.

>From 2fa36a72a5e275f2fa9a455b96ced3a0ff22f0e1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 23 Dec 2025 15:55:36 +0100
Subject: [PATCH] [mlir] MemrefToLLVM: Support `llvm.address_space` as memory
 space and do not generate noop `addrspacecast`s

Signed-off-by: Ivan Butygin <ivan.butygin at gmail.com>
---
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 15 +++++++++++-
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  | 24 ++++++++++++-------
 .../MemRefToLLVM/memref-to-llvm.mlir          | 20 ++++++++++++++++
 3 files changed, 49 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index cb9dea108cc48..07661550d436e 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -269,6 +269,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // Integer memory spaces map to themselves.
   addTypeAttributeConversion(
       [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
+
+  // LLVM address spaces map to themselves.
+  addTypeAttributeConversion(
+      [](BaseMemRefType memref, LLVM::AddressSpaceAttr addrspace) {
+        return addrspace;
+      });
 }
 
 /// Returns the MLIR context.
@@ -575,17 +581,24 @@ FailureOr<unsigned>
 LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
   if (!type.getMemorySpace()) // Default memory space -> 0.
     return 0;
+
   std::optional<Attribute> converted =
       convertTypeAttribute(type, type.getMemorySpace());
   if (!converted)
     return failure();
+
   if (!(*converted)) // Conversion to default is 0.
     return 0;
-  if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
+
+  if (auto explicitSpace = dyn_cast<IntegerAttr>(*converted)) {
     if (explicitSpace.getType().isIndex() ||
         explicitSpace.getType().isSignlessInteger())
       return explicitSpace.getInt();
   }
+
+  if (auto explicitSpace = dyn_cast<LLVM::AddressSpaceAttr>(*converted))
+    return explicitSpace.getAddressSpace();
+
   return failure();
 }
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..d37895d1fb1ad 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -114,6 +114,14 @@ static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
   return layout->getTypeSize(elementType);
 }
 
+static Value createAddrSpaceCast(ConversionPatternRewriter &rewriter,
+                                 Location loc, Type type, Value value) {
+  if (value.getType() == type)
+    return value;
+
+  return LLVM::AddrSpaceCastOp::create(rewriter, loc, type, value);
+}
+
 static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
                                  Location loc, Value allocatedPtr,
                                  MemRefType memRefType, Type elementPtrType,
@@ -124,7 +132,7 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
   assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
   unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
   if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
-    allocatedPtr = LLVM::AddrSpaceCastOp::create(
+    allocatedPtr = createAddrSpaceCast(
         rewriter, loc,
         LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
         allocatedPtr);
@@ -1262,10 +1270,8 @@ struct MemorySpaceCastOpLowering
       SmallVector<Value> descVals;
       MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
                                descVals);
-      descVals[0] =
-          LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
-      descVals[1] =
-          LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
+      descVals[0] = createAddrSpaceCast(rewriter, loc, newPtrType, descVals[0]);
+      descVals[1] = createAddrSpaceCast(rewriter, loc, newPtrType, descVals[1]);
       Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
                                             resultTypeR, descVals);
       rewriter.replaceOp(op, result);
@@ -1314,10 +1320,10 @@ struct MemorySpaceCastOpLowering
       Value alignedPtr =
           sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
                                 sourceUnderlyingDesc, sourceElemPtrType);
-      allocatedPtr = LLVM::AddrSpaceCastOp::create(
-          rewriter, loc, resultElemPtrType, allocatedPtr);
-      alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
-                                                 resultElemPtrType, alignedPtr);
+      allocatedPtr =
+          createAddrSpaceCast(rewriter, loc, resultElemPtrType, allocatedPtr);
+      alignedPtr =
+          createAddrSpaceCast(rewriter, loc, resultElemPtrType, alignedPtr);
 
       result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
                              resultElemPtrType, allocatedPtr);
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 0cbe064572911..ff20ccba123af 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -289,6 +289,26 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
 
 // -----
 
+// ALL-LABEL: func @llvm_address_space_cast
+//  ALL-SAME:   (%[[ARG:.*]]: memref<f32, #llvm.address_space<3>>)
+func.func @llvm_address_space_cast(%arg0 : memref<f32, #llvm.address_space<3>>) -> memref<f32, 3 : i32> {
+  // ALL: %[[UNREALIZED:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32, #llvm.address_space<3>> to !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[ALLOC:.*]] = llvm.extractvalue %[[UNREALIZED]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[ALIGNED:.*]] = llvm.extractvalue %[[UNREALIZED]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[OFFSET:.*]] = llvm.extractvalue %[[UNREALIZED]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[POISON:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[INS0:.*]] = llvm.insertvalue %[[ALLOC]], %[[POISON]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[INS1:.*]] = llvm.insertvalue %[[ALIGNED]], %[[INS0]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[INS2:.*]] = llvm.insertvalue %[[OFFSET]], %[[INS1]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[RECAST:.*]] = builtin.unrealized_conversion_cast %[[INS2]] : !llvm.struct<(ptr<3>, ptr<3>, i64)> to memref<f32, 3 : i32>
+  // ALL: return %[[RECAST]] : memref<f32, 3 : i32>
+
+  %0 = memref.memory_space_cast %arg0 : memref<f32, #llvm.address_space<3>> to memref<f32, 3 : i32>
+  func.return %0 : memref<f32, 3 : i32>
+}
+
+// -----
+
 // CHECK-LABEL: func @transpose
 //       CHECK:   llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>



More information about the Mlir-commits mailing list