[Mlir-commits] [mlir] [mlir][llvm] support -new-struct-path-tbaa (PR #119698)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 25 01:12:04 PST 2025
https://github.com/PikachuHyA updated https://github.com/llvm/llvm-project/pull/119698
>From df5635f39d889462913f11ccbf1927f283bc5cdc Mon Sep 17 00:00:00 2001
From: PikachuHy <pikachuhy at linux.alibaba.com>
Date: Tue, 25 Feb 2025 17:10:20 +0800
Subject: [PATCH] [mlir][llvm] support new-struct-path-tbaa
---
.../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 73 ++++++++++++++++++-
.../mlir/Target/LLVMIR/ModuleTranslation.h | 2 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp | 10 ++-
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 36 ++++++++-
5 files changed, 118 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 267389774bd5a..64a22e6e70ff7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1084,8 +1084,79 @@ def LLVM_TBAATagAttr : LLVM_Attr<"TBAATag", "tbaa_tag"> {
let assemblyFormat = "`<` struct(params) `>`";
}
+def LLVM_TBAAStructFieldAttr : LLVM_Attr<"TBAAStructField", "tbaa_struct_field"> {
+ let parameters = (ins
+ "TBAANodeAttr":$typeDesc,
+ "int64_t":$offset,
+ "int64_t":$size
+ );
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
+def LLVM_TBAAStructFieldAttrArray : ArrayRefParameter<"TBAAStructFieldAttr"> {
+ let printer = [{
+ $_printer << '{';
+ llvm::interleaveComma($_self, $_printer, [&](TBAAStructFieldAttr attr) {
+ $_printer.printStrippedAttrOrType(attr);
+ });
+ $_printer << '}';
+ }];
+
+ let parser = [{
+ [&]() -> FailureOr<SmallVector<TBAAStructFieldAttr>> {
+ using Result = SmallVector<TBAAStructFieldAttr>;
+ if ($_parser.parseLBrace())
+ return failure();
+ FailureOr<Result> result = FieldParser<Result>::parse($_parser);
+ if (failed(result))
+ return failure();
+ if ($_parser.parseRBrace())
+ return failure();
+ return result;
+ }()
+ }];
+}
+
+def LLVM_TBAATypeNodeAttr : LLVM_Attr<"TBAATypeNode", "tbaa_type_node", [], "TBAANodeAttr"> {
+ let parameters = (ins
+ "TBAANodeAttr":$parent,
+ "int64_t":$size,
+ StringRefParameter<>:$id,
+ LLVM_TBAAStructFieldAttrArray:$fields
+ );
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def LLVM_TBAAAccessTagAttr : LLVM_Attr<"TBAAAccessTag", "tbaa_access_tag"> {
+ let parameters = (ins
+ "TBAATypeNodeAttr":$base_type,
+ "TBAATypeNodeAttr":$access_type,
+ "int64_t":$offset,
+ "int64_t":$size
+ );
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "TBAATypeNodeAttr":$baseType,
+ "TBAATypeNodeAttr":$accessType,
+ "int64_t":$offset,
+ "int64_t":$size), [{
+ return $_get(baseType.getContext(), baseType, accessType, offset, size);
+ }]>
+ ];
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def LLVM_TBAAAccessTagArrayAttr
+ : TypedArrayAttrBase<LLVM_TBAAAccessTagAttr,
+ LLVM_TBAAAccessTagAttr.summary # " array"> {
+ let constBuilderCall = ?;
+}
+
def LLVM_TBAATagArrayAttr
- : TypedArrayAttrBase<LLVM_TBAATagAttr,
+ : TypedArrayAttrBase<AnyAttrOf<[
+ LLVM_TBAATagAttr,
+ LLVM_TBAAAccessTagAttr
+]>,
LLVM_TBAATagAttr.summary # " array"> {
let constBuilderCall = ?;
}
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index eb59ef8c62266..badde4f2af72f 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -346,7 +346,7 @@ class ModuleTranslation {
/// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
/// TBAATagAttr.
- llvm::MDNode *getTBAANode(TBAATagAttr tbaaAttr) const;
+ llvm::MDNode *getTBAANode(Attribute tbaaAttr) const;
/// Process tbaa LLVM Metadata operations and create LLVM
/// metadata nodes for them.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ccf8f72b2b230..dff81e9b82534 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3594,7 +3594,8 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr,
LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr,
LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr,
- TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) {
+ TBAATagAttr, TBAATypeDescriptorAttr, TBAAAccessTagAttr,
+ TBAATypeNodeAttr>([&](auto attr) {
os << decltype(attr)::getMnemonic();
return AliasResult::OverridableAlias;
})
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
index ca1277c09323b..3ece05f104195 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
@@ -58,7 +58,15 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
ArrayAttr tags = iface.getTBAATagsOrNull();
if (!tags)
return success();
-
+ if (tags.size() > 0) {
+ if (mlir::isa<TBAATagAttr>(tags[0])) {
+ return isArrayOf<TBAATagAttr>(op, tags);
+ }
+
+ if (mlir::isa<TBAAAccessTagAttr>(tags[0])) {
+ return isArrayOf<TBAAAccessTagAttr>(op, tags);
+ }
+ }
return isArrayOf<TBAATagAttr>(op, tags);
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 5cd841ee2df91..ffeb44837da2e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1882,7 +1882,7 @@ void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
llvm::LLVMContext::MD_noalias);
}
-llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
+llvm::MDNode *ModuleTranslation::getTBAANode(Attribute tbaaAttr) const {
return tbaaMetadataMapping.lookup(tbaaAttr);
}
@@ -1902,7 +1902,7 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
return;
}
- llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
+ llvm::MDNode *node = getTBAANode(tagRefs[0]);
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
}
@@ -1922,6 +1922,7 @@ void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
LogicalResult ModuleTranslation::createTBAAMetadata() {
llvm::LLVMContext &ctx = llvmModule->getContext();
llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
+ llvm::IntegerType *sizeTy = llvm::IntegerType::get(ctx, 64);
// Walk the entire module and create all metadata nodes for the TBAA
// attributes. The code below relies on two invariants of the
@@ -1949,6 +1950,23 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
});
+ walker.addWalk([&](TBAATypeNodeAttr descriptor) {
+ SmallVector<llvm::Metadata *> operands;
+ operands.push_back(tbaaMetadataMapping.lookup(descriptor.getParent()));
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(sizeTy, descriptor.getSize())));
+ operands.push_back(llvm::MDString::get(ctx, descriptor.getId()));
+ for (auto field : descriptor.getFields()) {
+ operands.push_back(tbaaMetadataMapping.lookup(field.getTypeDesc()));
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(offsetTy, field.getOffset())));
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(sizeTy, field.getSize())));
+ }
+
+ tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
+ });
+
walker.addWalk([&](TBAATagAttr tag) {
SmallVector<llvm::Metadata *> operands;
@@ -1964,6 +1982,20 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
});
+ walker.addWalk([&](TBAAAccessTagAttr tag) {
+ SmallVector<llvm::Metadata *> operands;
+
+ operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
+ operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
+
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(offsetTy, tag.getOffset())));
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(sizeTy, tag.getSize())));
+
+ tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
+ });
+
mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) {
if (auto attr = analysisOpInterface.getTBAATagsOrNull())
walker.walk(attr);
More information about the Mlir-commits
mailing list