[Mlir-commits] [mlir] [WIP][mlir][llvm] support new-struct-path-tbaa (PR #119698)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 12 05:01:17 PST 2024
https://github.com/PikachuHyA created https://github.com/llvm/llvm-project/pull/119698
None
>From 0db13b83e7a2829dae6088b3e568808a27e812c2 Mon Sep 17 00:00:00 2001
From: PikachuHy <pikachuhy at linux.alibaba.com>
Date: Thu, 12 Dec 2024 20:57:20 +0800
Subject: [PATCH] [mlir][llvm] support new-struct-path-tbaa
---
.../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 78 ++++++++++++++++++-
.../mlir/Target/LLVMIR/ModuleTranslation.h | 2 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp | 10 ++-
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 38 ++++++++-
5 files changed, 125 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index e8eeafd09a9cba..198e1f8982ef14 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1080,8 +1080,84 @@ def LLVM_TBAATagAttr : LLVM_Attr<"TBAATag", "tbaa_tag"> {
let assemblyFormat = "`<` struct(params) `>`";
}
+def LLVM_TBAAStructFieldAttr : LLVM_Attr<"TBAAStructField", "tbaa_struct_field"> {
+ let parameters = (ins
+ "TBAANodeAttr":$typeDesc,
+ "int64_t":$offset,
+ "int64_t":$size
+ );
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
+def LLVM_TBAAStructFieldAttrArray : ArrayRefParameter<"TBAAStructFieldAttr"> {
+ let printer = [{
+ $_printer << '{';
+ llvm::interleaveComma($_self, $_printer, [&](TBAAStructFieldAttr attr) {
+ $_printer.printStrippedAttrOrType(attr);
+ });
+ $_printer << '}';
+ }];
+
+ let parser = [{
+ [&]() -> FailureOr<SmallVector<TBAAStructFieldAttr>> {
+ using Result = SmallVector<TBAAStructFieldAttr>;
+ if ($_parser.parseLBrace())
+ return failure();
+ FailureOr<Result> result = FieldParser<Result>::parse($_parser);
+ if (failed(result))
+ return failure();
+ if ($_parser.parseRBrace())
+ return failure();
+ return result;
+ }()
+ }];
+}
+
+def LLVM_TBAATypeNodeAttr : LLVM_Attr<"TBAATypeNode", "tbaa_type_node", [], "TBAANodeAttr"> {
+ let parameters = (ins
+ "TBAANodeAttr":$parent,
+ "int64_t":$size,
+ StringRefParameter<>:$id,
+ LLVM_TBAAStructFieldAttrArray:$fields
+ );
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def LLVM_TBAAAccessTagAttr : LLVM_Attr<"TBAAAccessTag", "tbaa_access_tag"> {
+ let parameters = (ins
+ "TBAATypeNodeAttr":$base_type,
+ "TBAATypeNodeAttr":$access_type,
+ "int64_t":$offset,
+ "int64_t":$size
+ );
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "TBAATypeNodeAttr":$baseType,
+ "TBAATypeNodeAttr":$accessType,
+ "int64_t":$offset,
+ "int64_t":$size), [{
+ return $_get(baseType.getContext(), baseType, accessType, offset, size);
+ }]>
+ ];
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def LLVM_TBAAAccessTagArrayAttr
+ : TypedArrayAttrBase<LLVM_TBAAAccessTagAttr,
+ LLVM_TBAAAccessTagAttr.summary # " array"> {
+ let constBuilderCall = ?;
+}
+
+// def LLVM_TBAATagAttr2 : AnyAttrOf<[
+// LLVM_TBAATagAttr,
+// LLVM_TBAAAccessTagAttr
+// ]>;
+
def LLVM_TBAATagArrayAttr
- : TypedArrayAttrBase<LLVM_TBAATagAttr,
+ : TypedArrayAttrBase<AnyAttrOf<[
+ LLVM_TBAATagAttr,
+ LLVM_TBAAAccessTagAttr
+]>,
LLVM_TBAATagAttr.summary # " array"> {
let constBuilderCall = ?;
}
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..c7a79aa330d3da 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -323,7 +323,7 @@ class ModuleTranslation {
/// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
/// TBAATagAttr.
- llvm::MDNode *getTBAANode(TBAATagAttr tbaaAttr) const;
+ llvm::MDNode *getTBAANode(Attribute tbaaAttr) const;
/// Process tbaa LLVM Metadata operations and create LLVM
/// metadata nodes for them.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 6b2d8943bf4885..b2b0b9b331e0b4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3401,7 +3401,8 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr,
LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr,
LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr,
- TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) {
+ TBAATagAttr, TBAATypeDescriptorAttr, TBAAAccessTagAttr,
+ TBAATypeNodeAttr>([&](auto attr) {
os << decltype(attr)::getMnemonic();
return AliasResult::OverridableAlias;
})
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
index cff16afc73af3f..6a9395b1f4a26e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
@@ -58,7 +58,15 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
ArrayAttr tags = iface.getTBAATagsOrNull();
if (!tags)
return success();
-
+ if (tags.size() > 0) {
+ if (mlir::isa<TBAATagAttr>(tags[0])) {
+ return isArrayOf<TBAATagAttr>(op, tags);
+ }
+
+ if (mlir::isa<TBAAAccessTagAttr>(tags[0])) {
+ return isArrayOf<TBAAAccessTagAttr>(op, tags);
+ }
+ }
return isArrayOf<TBAATagAttr>(op, tags);
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..6a6c29869ba805 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1766,7 +1766,8 @@ void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
llvm::LLVMContext::MD_noalias);
}
-llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
+// llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
+llvm::MDNode *ModuleTranslation::getTBAANode(Attribute tbaaAttr) const {
return tbaaMetadataMapping.lookup(tbaaAttr);
}
@@ -1786,7 +1787,8 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
return;
}
- llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
+ // llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
+ llvm::MDNode *node = getTBAANode(tagRefs[0]);
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
}
@@ -1806,6 +1808,7 @@ void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
LogicalResult ModuleTranslation::createTBAAMetadata() {
llvm::LLVMContext &ctx = llvmModule->getContext();
llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
+ llvm::IntegerType *sizeTy = llvm::IntegerType::get(ctx, 64);
// Walk the entire module and create all metadata nodes for the TBAA
// attributes. The code below relies on two invariants of the
@@ -1833,6 +1836,23 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
});
+ walker.addWalk([&](TBAATypeNodeAttr descriptor) {
+ SmallVector<llvm::Metadata *> operands;
+ operands.push_back(tbaaMetadataMapping.lookup(descriptor.getParent()));
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(sizeTy, descriptor.getSize())));
+ operands.push_back(llvm::MDString::get(ctx, descriptor.getId()));
+ for (auto field : descriptor.getFields()) {
+ operands.push_back(tbaaMetadataMapping.lookup(field.getTypeDesc()));
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(offsetTy, field.getOffset())));
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(sizeTy, field.getSize())));
+ }
+
+ tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
+ });
+
walker.addWalk([&](TBAATagAttr tag) {
SmallVector<llvm::Metadata *> operands;
@@ -1848,6 +1868,20 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
});
+ walker.addWalk([&](TBAAAccessTagAttr tag) {
+ SmallVector<llvm::Metadata *> operands;
+
+ operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
+ operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
+
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(offsetTy, tag.getOffset())));
+ operands.push_back(llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(sizeTy, tag.getSize())));
+
+ tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
+ });
+
mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) {
if (auto attr = analysisOpInterface.getTBAATagsOrNull())
walker.walk(attr);
More information about the Mlir-commits
mailing list