[Mlir-commits] [mlir] [mlir] MemrefToLLVM: Support `llvm.address_space` as memory space and do not generate noop `addrspacecast`s (PR #173387)
Ivan Butygin
llvmlistbot at llvm.org
Tue Dec 23 07:12:47 PST 2025
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/173387
Having `llvm.address_space` in memref is useful for progressive lowering and `addrspacecast` with same address spaces are forbidden by LLVM spec.
>From 2fa36a72a5e275f2fa9a455b96ced3a0ff22f0e1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 23 Dec 2025 15:55:36 +0100
Subject: [PATCH] [mlir] MemrefToLLVM: Support `llvm.address_space` as memory
space and do not generate noop `addrspacecast`s
Signed-off-by: Ivan Butygin <ivan.butygin at gmail.com>
---
.../Conversion/LLVMCommon/TypeConverter.cpp | 15 +++++++++++-
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 24 ++++++++++++-------
.../MemRefToLLVM/memref-to-llvm.mlir | 20 ++++++++++++++++
3 files changed, 49 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index cb9dea108cc48..07661550d436e 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -269,6 +269,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// Integer memory spaces map to themselves.
addTypeAttributeConversion(
[](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
+
+ // LLVM address spaces map to themselves.
+ addTypeAttributeConversion(
+ [](BaseMemRefType memref, LLVM::AddressSpaceAttr addrspace) {
+ return addrspace;
+ });
}
/// Returns the MLIR context.
@@ -575,17 +581,24 @@ FailureOr<unsigned>
LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
if (!type.getMemorySpace()) // Default memory space -> 0.
return 0;
+
std::optional<Attribute> converted =
convertTypeAttribute(type, type.getMemorySpace());
if (!converted)
return failure();
+
if (!(*converted)) // Conversion to default is 0.
return 0;
- if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
+
+ if (auto explicitSpace = dyn_cast<IntegerAttr>(*converted)) {
if (explicitSpace.getType().isIndex() ||
explicitSpace.getType().isSignlessInteger())
return explicitSpace.getInt();
}
+
+ if (auto explicitSpace = dyn_cast<LLVM::AddressSpaceAttr>(*converted))
+ return explicitSpace.getAddressSpace();
+
return failure();
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..d37895d1fb1ad 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -114,6 +114,14 @@ static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
return layout->getTypeSize(elementType);
}
+static Value createAddrSpaceCast(ConversionPatternRewriter &rewriter,
+ Location loc, Type type, Value value) {
+ if (value.getType() == type)
+ return value;
+
+ return LLVM::AddrSpaceCastOp::create(rewriter, loc, type, value);
+}
+
static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
Location loc, Value allocatedPtr,
MemRefType memRefType, Type elementPtrType,
@@ -124,7 +132,7 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
- allocatedPtr = LLVM::AddrSpaceCastOp::create(
+ allocatedPtr = createAddrSpaceCast(
rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
allocatedPtr);
@@ -1262,10 +1270,8 @@ struct MemorySpaceCastOpLowering
SmallVector<Value> descVals;
MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
descVals);
- descVals[0] =
- LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
- descVals[1] =
- LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
+ descVals[0] = createAddrSpaceCast(rewriter, loc, newPtrType, descVals[0]);
+ descVals[1] = createAddrSpaceCast(rewriter, loc, newPtrType, descVals[1]);
Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
resultTypeR, descVals);
rewriter.replaceOp(op, result);
@@ -1314,10 +1320,10 @@ struct MemorySpaceCastOpLowering
Value alignedPtr =
sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
sourceUnderlyingDesc, sourceElemPtrType);
- allocatedPtr = LLVM::AddrSpaceCastOp::create(
- rewriter, loc, resultElemPtrType, allocatedPtr);
- alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
- resultElemPtrType, alignedPtr);
+ allocatedPtr =
+ createAddrSpaceCast(rewriter, loc, resultElemPtrType, allocatedPtr);
+ alignedPtr =
+ createAddrSpaceCast(rewriter, loc, resultElemPtrType, alignedPtr);
result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
resultElemPtrType, allocatedPtr);
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 0cbe064572911..ff20ccba123af 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -289,6 +289,26 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
// -----
+// ALL-LABEL: func @llvm_address_space_cast
+// ALL-SAME: (%[[ARG:.*]]: memref<f32, #llvm.address_space<3>>)
+func.func @llvm_address_space_cast(%arg0 : memref<f32, #llvm.address_space<3>>) -> memref<f32, 3 : i32> {
+ // ALL: %[[UNREALIZED:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32, #llvm.address_space<3>> to !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[ALLOC:.*]] = llvm.extractvalue %[[UNREALIZED]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[ALIGNED:.*]] = llvm.extractvalue %[[UNREALIZED]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[OFFSET:.*]] = llvm.extractvalue %[[UNREALIZED]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[POISON:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[INS0:.*]] = llvm.insertvalue %[[ALLOC]], %[[POISON]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[INS1:.*]] = llvm.insertvalue %[[ALIGNED]], %[[INS0]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[INS2:.*]] = llvm.insertvalue %[[OFFSET]], %[[INS1]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[RECAST:.*]] = builtin.unrealized_conversion_cast %[[INS2]] : !llvm.struct<(ptr<3>, ptr<3>, i64)> to memref<f32, 3 : i32>
+ // ALL: return %[[RECAST]] : memref<f32, 3 : i32>
+
+ %0 = memref.memory_space_cast %arg0 : memref<f32, #llvm.address_space<3>> to memref<f32, 3 : i32>
+ func.return %0 : memref<f32, 3 : i32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose
// CHECK: llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
More information about the Mlir-commits
mailing list