[llvm-branch-commits] [mlir] [mlir][ptr] Add translations to LLVMIR for ptr ops. (PR #156355)
Fabian Mora via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Sep 1 09:49:33 PDT 2025
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/156355
>From d4befc04b5565fd13cc53694031bf8296fd22312 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Mon, 1 Sep 2025 16:32:01 +0000
Subject: [PATCH 1/2] add translations for ptr ops
---
.../include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 3 +-
.../Dialect/Ptr/PtrToLLVMIRTranslation.cpp | 210 +++++++++++++++++-
mlir/test/Target/LLVMIR/ptr.mlir | 75 +++++++
3 files changed, 278 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index d6aa9580870a8..bd59319c79ad3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -95,8 +95,7 @@ def LLVM_NonLoadableTargetExtType : Type<
// type that has size (not void, function, opaque struct type or target
// extension type which does not support memory operations).
def LLVM_LoadableType : Type<
- Or<[And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>,
- Neg<LLVM_NonLoadableTargetExtType.predicate>]>,
+ Or<[CPred<"mlir::LLVM::isLoadableType($_self)">,
LLVM_PointerElementTypeInterface.predicate]>,
"LLVM type with size">;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index 7b89ec8fcbffb..e3ccf728f25db 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -16,11 +16,193 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
using namespace mlir;
using namespace mlir::ptr;
namespace {
+
+/// Converts ptr::AtomicOrdering to llvm::AtomicOrdering
+static llvm::AtomicOrdering
+convertAtomicOrdering(ptr::AtomicOrdering ordering) {
+ switch (ordering) {
+ case ptr::AtomicOrdering::not_atomic:
+ return llvm::AtomicOrdering::NotAtomic;
+ case ptr::AtomicOrdering::unordered:
+ return llvm::AtomicOrdering::Unordered;
+ case ptr::AtomicOrdering::monotonic:
+ return llvm::AtomicOrdering::Monotonic;
+ case ptr::AtomicOrdering::acquire:
+ return llvm::AtomicOrdering::Acquire;
+ case ptr::AtomicOrdering::release:
+ return llvm::AtomicOrdering::Release;
+ case ptr::AtomicOrdering::acq_rel:
+ return llvm::AtomicOrdering::AcquireRelease;
+ case ptr::AtomicOrdering::seq_cst:
+ return llvm::AtomicOrdering::SequentiallyConsistent;
+ }
+ llvm_unreachable("Unknown atomic ordering");
+}
+
+/// Convert ptr.ptr_add operation
+static LogicalResult
+convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *basePtr = moduleTranslation.lookupValue(ptrAddOp.getBase());
+ llvm::Value *offset = moduleTranslation.lookupValue(ptrAddOp.getOffset());
+
+ if (!basePtr || !offset)
+ return ptrAddOp.emitError("Failed to lookup operands");
+
+ // Create GEP instruction for pointer arithmetic
+ llvm::GetElementPtrInst *gep = llvm::GetElementPtrInst::Create(
+ builder.getInt8Ty(), basePtr, {offset}, "", builder.GetInsertBlock());
+
+ // Set the appropriate flags
+ switch (ptrAddOp.getFlags()) {
+ case ptr::PtrAddFlags::none:
+ break;
+ case ptr::PtrAddFlags::nusw:
+ gep->setNoWrapFlags(llvm::GEPNoWrapFlags::noUnsignedSignedWrap());
+ break;
+ case ptr::PtrAddFlags::nuw:
+ gep->setNoWrapFlags(llvm::GEPNoWrapFlags::noUnsignedWrap());
+ break;
+ case ptr::PtrAddFlags::inbounds:
+ gep->setNoWrapFlags(llvm::GEPNoWrapFlags::inBounds());
+ break;
+ }
+
+ moduleTranslation.mapValue(ptrAddOp.getResult(), gep);
+ return success();
+}
+
+/// Convert ptr.load operation
+static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *ptr = moduleTranslation.lookupValue(loadOp.getPtr());
+ if (!ptr)
+ return loadOp.emitError("Failed to lookup pointer operand");
+
+ // Convert result type to LLVM type
+ llvm::Type *resultType =
+ moduleTranslation.convertType(loadOp.getValue().getType());
+ if (!resultType)
+ return loadOp.emitError("Failed to convert result type");
+
+ // Create the load instruction.
+ llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
+ resultType, ptr, llvm::MaybeAlign(loadOp.getAlignment().value_or(0)),
+ loadOp.getVolatile_());
+
+ // Set op flags and metadata.
+ loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering()));
+ // Set sync scope if specified
+ if (loadOp.getSyncscope().has_value()) {
+ llvm::LLVMContext &ctx = builder.getContext();
+ llvm::SyncScope::ID syncScope =
+ ctx.getOrInsertSyncScopeID(loadOp.getSyncscope().value());
+ loadInst->setSyncScopeID(syncScope);
+ }
+
+ // Set metadata for nontemporal, invariant, and invariant_group
+ if (loadOp.getNontemporal()) {
+ llvm::MDNode *nontemporalMD =
+ llvm::MDNode::get(builder.getContext(),
+ llvm::ConstantAsMetadata::get(builder.getInt32(1)));
+ loadInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
+ }
+
+ if (loadOp.getInvariant()) {
+ llvm::MDNode *invariantMD = llvm::MDNode::get(builder.getContext(), {});
+ loadInst->setMetadata(llvm::LLVMContext::MD_invariant_load, invariantMD);
+ }
+
+ if (loadOp.getInvariantGroup()) {
+ llvm::MDNode *invariantGroupMD =
+ llvm::MDNode::get(builder.getContext(), {});
+ loadInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
+ invariantGroupMD);
+ }
+
+ moduleTranslation.mapValue(loadOp.getResult(), loadInst);
+ return success();
+}
+
+/// Convert ptr.store operation
+static LogicalResult
+convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Value *value = moduleTranslation.lookupValue(storeOp.getValue());
+ llvm::Value *ptr = moduleTranslation.lookupValue(storeOp.getPtr());
+
+ if (!value || !ptr)
+ return storeOp.emitError("Failed to lookup operands");
+
+ // Create the store instruction.
+ llvm::StoreInst *storeInst = builder.CreateAlignedStore(
+ value, ptr, llvm::MaybeAlign(storeOp.getAlignment().value_or(0)),
+ storeOp.getVolatile_());
+
+ // Set op flags and metadata.
+ storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering()));
+ // Set sync scope if specified
+ if (storeOp.getSyncscope().has_value()) {
+ llvm::LLVMContext &ctx = builder.getContext();
+ llvm::SyncScope::ID syncScope =
+ ctx.getOrInsertSyncScopeID(storeOp.getSyncscope().value());
+ storeInst->setSyncScopeID(syncScope);
+ }
+
+ // Set metadata for nontemporal and invariant_group
+ if (storeOp.getNontemporal()) {
+ llvm::MDNode *nontemporalMD =
+ llvm::MDNode::get(builder.getContext(),
+ llvm::ConstantAsMetadata::get(builder.getInt32(1)));
+ storeInst->setMetadata(llvm::LLVMContext::MD_nontemporal, nontemporalMD);
+ }
+
+ if (storeOp.getInvariantGroup()) {
+ llvm::MDNode *invariantGroupMD =
+ llvm::MDNode::get(builder.getContext(), {});
+ storeInst->setMetadata(llvm::LLVMContext::MD_invariant_group,
+ invariantGroupMD);
+ }
+
+ return success();
+}
+
+/// Convert ptr.type_offset operation
+static LogicalResult
+convertTypeOffsetOp(TypeOffsetOp typeOffsetOp, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ // Convert the element type to LLVM type
+ llvm::Type *elementType =
+ moduleTranslation.convertType(typeOffsetOp.getElementType());
+ if (!elementType)
+ return typeOffsetOp.emitError("Failed to convert the element type");
+
+ // Convert result type
+ llvm::Type *resultType =
+ moduleTranslation.convertType(typeOffsetOp.getResult().getType());
+ if (!resultType)
+ return typeOffsetOp.emitError("Failed to convert the result type");
+
+ // Use GEP with null pointer to compute type size/offset.
+ llvm::Value *nullPtr = llvm::Constant::getNullValue(builder.getPtrTy(0));
+ llvm::Value *offsetPtr =
+ builder.CreateGEP(elementType, nullPtr, {builder.getInt32(1)});
+ llvm::Value *offset = builder.CreatePtrToInt(offsetPtr, resultType);
+
+ moduleTranslation.mapValue(typeOffsetOp.getResult(), offset);
+ return success();
+}
+
/// Implementation of the dialect interface that converts operations belonging
/// to the `ptr` dialect to LLVM IR.
class PtrDialectLLVMIRTranslationInterface
@@ -33,10 +215,24 @@ class PtrDialectLLVMIRTranslationInterface
LogicalResult
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const final {
- // Translation for ptr dialect operations to LLVM IR is currently
- // unimplemented.
- return op->emitError("Translation for ptr dialect operations to LLVM IR is "
- "not implemented.");
+
+ return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<PtrAddOp>([&](PtrAddOp ptrAddOp) {
+ return convertPtrAddOp(ptrAddOp, builder, moduleTranslation);
+ })
+ .Case<LoadOp>([&](LoadOp loadOp) {
+ return convertLoadOp(loadOp, builder, moduleTranslation);
+ })
+ .Case<StoreOp>([&](StoreOp storeOp) {
+ return convertStoreOp(storeOp, builder, moduleTranslation);
+ })
+ .Case<TypeOffsetOp>([&](TypeOffsetOp typeOffsetOp) {
+ return convertTypeOffsetOp(typeOffsetOp, builder, moduleTranslation);
+ })
+ .Default([&](Operation *op) {
+ return op->emitError("Translation for operation '")
+ << op->getName() << "' is not implemented.";
+ });
}
/// Attaches module-level metadata for functions marked as kernels.
@@ -44,10 +240,8 @@ class PtrDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
- // Translation for ptr dialect operations to LLVM IR is currently
- // unimplemented.
- return op->emitError("Translation for ptr dialect operations to LLVM IR is "
- "not implemented.");
+ // No special amendments needed for ptr dialect operations
+ return success();
}
};
} // namespace
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index c1620cb9ed313..6e3b365b862e2 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -14,3 +14,78 @@ llvm.func @llvm_ops_with_ptr_values(%arg0: !llvm.ptr) {
llvm.store %1, %arg0 : !ptr.ptr<#llvm.address_space<1>>, !llvm.ptr
llvm.return
}
+
+// CHECK-LABEL: define ptr @ptr_add
+// CHECK-SAME: (ptr %[[PTR:.*]], i32 %[[OFF:.*]]) {
+// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], i32 %[[OFF]]
+// CHECK-NEXT: %[[RES0:.*]] = getelementptr i8, ptr %[[PTR]], i32 %[[OFF]]
+// CHECK-NEXT: %[[RES1:.*]] = getelementptr nusw i8, ptr %[[PTR]], i32 %[[OFF]]
+// CHECK-NEXT: %[[RES2:.*]] = getelementptr nuw i8, ptr %[[PTR]], i32 %[[OFF]]
+// CHECK-NEXT: %[[RES3:.*]] = getelementptr inbounds i8, ptr %[[PTR]], i32 %[[OFF]]
+// CHECK-NEXT: ret ptr %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add(%ptr: !ptr.ptr<#llvm.address_space<0>>, %off: i32) -> !ptr.ptr<#llvm.address_space<0>> {
+ %res = ptr.ptr_add %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
+ %res0 = ptr.ptr_add none %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
+ %res1 = ptr.ptr_add nusw %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
+ %res2 = ptr.ptr_add nuw %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
+ %res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#llvm.address_space<0>>, i32
+ llvm.return %res : !ptr.ptr<#llvm.address_space<0>>
+}
+
+// CHECK-LABEL: define { i32, i32, i32, i32 } @type_offset
+// CHECK-NEXT: ret { i32, i32, i32, i32 } { i32 8, i32 1, i32 2, i32 4 }
+llvm.func @type_offset(%arg0: !ptr.ptr<#llvm.address_space<0>>) -> !llvm.struct<(i32, i32, i32, i32)> {
+ %0 = ptr.type_offset f64 : i32
+ %1 = ptr.type_offset i8 : i32
+ %2 = ptr.type_offset i16 : i32
+ %3 = ptr.type_offset i32 : i32
+ %4 = llvm.mlir.poison : !llvm.struct<(i32, i32, i32, i32)>
+ %5 = llvm.insertvalue %0, %4[0] : !llvm.struct<(i32, i32, i32, i32)>
+ %6 = llvm.insertvalue %1, %5[1] : !llvm.struct<(i32, i32, i32, i32)>
+ %7 = llvm.insertvalue %2, %6[2] : !llvm.struct<(i32, i32, i32, i32)>
+ %8 = llvm.insertvalue %3, %7[3] : !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return %8 : !llvm.struct<(i32, i32, i32, i32)>
+}
+
+// CHECK-LABEL: define void @load_ops
+// CHECK-SAME: (ptr %[[PTR:.*]]) {
+// CHECK-NEXT: %[[V0:.*]] = load float, ptr %[[PTR]], align 4
+// CHECK-NEXT: %[[V1:.*]] = load volatile float, ptr %[[PTR]], align 4
+// CHECK-NEXT: %[[V2:.*]] = load float, ptr %[[PTR]], align 4, !nontemporal !{{.*}}
+// CHECK-NEXT: %[[V3:.*]] = load float, ptr %[[PTR]], align 4, !invariant.load !{{.*}}
+// CHECK-NEXT: %[[V4:.*]] = load float, ptr %[[PTR]], align 4, !invariant.group !{{.*}}
+// CHECK-NEXT: %[[V5:.*]] = load atomic i64, ptr %[[PTR]] monotonic, align 8
+// CHECK-NEXT: %[[V6:.*]] = load atomic volatile i32, ptr %[[PTR]] syncscope("workgroup") acquire, align 4, !nontemporal !{{.*}}
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+llvm.func @load_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>) {
+ %0 = ptr.load %arg0 : !ptr.ptr<#llvm.address_space<0>> -> f32
+ %1 = ptr.load volatile %arg0 : !ptr.ptr<#llvm.address_space<0>> -> f32
+ %2 = ptr.load %arg0 nontemporal : !ptr.ptr<#llvm.address_space<0>> -> f32
+ %3 = ptr.load %arg0 invariant : !ptr.ptr<#llvm.address_space<0>> -> f32
+ %4 = ptr.load %arg0 invariant_group : !ptr.ptr<#llvm.address_space<0>> -> f32
+ %5 = ptr.load %arg0 atomic monotonic alignment = 8 : !ptr.ptr<#llvm.address_space<0>> -> i64
+ %6 = ptr.load volatile %arg0 atomic syncscope("workgroup") acquire nontemporal alignment = 4 : !ptr.ptr<#llvm.address_space<0>> -> i32
+ llvm.return
+}
+
+// CHECK-LABEL: define void @store_ops
+// CHECK-SAME: (ptr %[[PTR:.*]], float %[[ARG1:.*]], i64 %[[ARG2:.*]], i32 %[[ARG3:.*]]) {
+// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4
+// CHECK-NEXT: store volatile float %[[ARG1]], ptr %[[PTR]], align 4
+// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4, !nontemporal !{{.*}}
+// CHECK-NEXT: store float %[[ARG1]], ptr %[[PTR]], align 4, !invariant.group !{{.*}}
+// CHECK-NEXT: store atomic i64 %[[ARG2]], ptr %[[PTR]] monotonic, align 8
+// CHECK-NEXT: store atomic volatile i32 %[[ARG3]], ptr %[[PTR]] syncscope("workgroup") release, align 4, !nontemporal !{{.*}}
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
+llvm.func @store_ops(%arg0: !ptr.ptr<#llvm.address_space<0>>, %arg1: f32, %arg2: i64, %arg3: i32) {
+ ptr.store %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<0>>
+ ptr.store volatile %arg1, %arg0 : f32, !ptr.ptr<#llvm.address_space<0>>
+ ptr.store %arg1, %arg0 nontemporal : f32, !ptr.ptr<#llvm.address_space<0>>
+ ptr.store %arg1, %arg0 invariant_group : f32, !ptr.ptr<#llvm.address_space<0>>
+ ptr.store %arg2, %arg0 atomic monotonic alignment = 8 : i64, !ptr.ptr<#llvm.address_space<0>>
+ ptr.store volatile %arg3, %arg0 atomic syncscope("workgroup") release nontemporal alignment = 4 : i32, !ptr.ptr<#llvm.address_space<0>>
+ llvm.return
+}
>From 17a511328cb6751a5b8b85feb7c5b5b0e79eb5a3 Mon Sep 17 00:00:00 2001
From: Fabian Mora <6982088+fabianmcg at users.noreply.github.com>
Date: Mon, 1 Sep 2025 16:48:25 +0000
Subject: [PATCH 2/2] address comments
---
.../Dialect/Ptr/PtrToLLVMIRTranslation.cpp | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
index e3ccf728f25db..906e19901617b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp
@@ -60,8 +60,8 @@ convertPtrAddOp(PtrAddOp ptrAddOp, llvm::IRBuilderBase &builder,
return ptrAddOp.emitError("Failed to lookup operands");
// Create GEP instruction for pointer arithmetic
- llvm::GetElementPtrInst *gep = llvm::GetElementPtrInst::Create(
- builder.getInt8Ty(), basePtr, {offset}, "", builder.GetInsertBlock());
+ auto *gep = cast<llvm::GetElementPtrInst>(
+ builder.CreateGEP(builder.getInt8Ty(), basePtr, {offset}));
// Set the appropriate flags
switch (ptrAddOp.getFlags()) {
@@ -96,9 +96,11 @@ static LogicalResult convertLoadOp(LoadOp loadOp, llvm::IRBuilderBase &builder,
return loadOp.emitError("Failed to convert result type");
// Create the load instruction.
+ llvm::MaybeAlign alignment = loadOp.getAlignment()
+ ? llvm::MaybeAlign(*loadOp.getAlignment())
+ : llvm::MaybeAlign();
llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
- resultType, ptr, llvm::MaybeAlign(loadOp.getAlignment().value_or(0)),
- loadOp.getVolatile_());
+ resultType, ptr, alignment, loadOp.getVolatile_());
// Set op flags and metadata.
loadInst->setAtomic(convertAtomicOrdering(loadOp.getOrdering()));
@@ -145,9 +147,11 @@ convertStoreOp(StoreOp storeOp, llvm::IRBuilderBase &builder,
return storeOp.emitError("Failed to lookup operands");
// Create the store instruction.
- llvm::StoreInst *storeInst = builder.CreateAlignedStore(
- value, ptr, llvm::MaybeAlign(storeOp.getAlignment().value_or(0)),
- storeOp.getVolatile_());
+ llvm::MaybeAlign alignment = storeOp.getAlignment()
+ ? llvm::MaybeAlign(*storeOp.getAlignment())
+ : llvm::MaybeAlign();
+ llvm::StoreInst *storeInst =
+ builder.CreateAlignedStore(value, ptr, alignment, storeOp.getVolatile_());
// Set op flags and metadata.
storeInst->setAtomic(convertAtomicOrdering(storeOp.getOrdering()));
More information about the llvm-branch-commits
mailing list