[Mlir-commits] [mlir] [WIP][mlir][llvm] support new-struct-path-tbaa (PR #119698)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 12 05:01:17 PST 2024


https://github.com/PikachuHyA created https://github.com/llvm/llvm-project/pull/119698

None

>From 0db13b83e7a2829dae6088b3e568808a27e812c2 Mon Sep 17 00:00:00 2001
From: PikachuHy <pikachuhy at linux.alibaba.com>
Date: Thu, 12 Dec 2024 20:57:20 +0800
Subject: [PATCH] [mlir][llvm] support new-struct-path-tbaa

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 78 ++++++++++++++++++-
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |  3 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp | 10 ++-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  | 38 ++++++++-
 5 files changed, 125 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index e8eeafd09a9cba..198e1f8982ef14 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1080,8 +1080,84 @@ def LLVM_TBAATagAttr : LLVM_Attr<"TBAATag", "tbaa_tag"> {
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+def LLVM_TBAAStructFieldAttr : LLVM_Attr<"TBAAStructField", "tbaa_struct_field"> {
+  let parameters = (ins
+    "TBAANodeAttr":$typeDesc,
+    "int64_t":$offset,
+    "int64_t":$size
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
+def LLVM_TBAAStructFieldAttrArray : ArrayRefParameter<"TBAAStructFieldAttr"> {
+  let printer = [{
+    $_printer << '{';
+    llvm::interleaveComma($_self, $_printer, [&](TBAAStructFieldAttr attr) {
+        $_printer.printStrippedAttrOrType(attr);
+    });
+    $_printer << '}';
+  }];
+
+  let parser = [{
+    [&]() -> FailureOr<SmallVector<TBAAStructFieldAttr>> {
+        using Result = SmallVector<TBAAStructFieldAttr>;
+        if ($_parser.parseLBrace())
+            return failure();
+        FailureOr<Result> result = FieldParser<Result>::parse($_parser);
+        if (failed(result))
+            return failure();
+        if ($_parser.parseRBrace())
+            return failure();
+        return result;
+    }()
+  }];
+}
+
+def LLVM_TBAATypeNodeAttr : LLVM_Attr<"TBAATypeNode", "tbaa_type_node", [], "TBAANodeAttr"> {
+  let parameters = (ins
+    "TBAANodeAttr":$parent,
+    "int64_t":$size,
+    StringRefParameter<>:$id,
+    LLVM_TBAAStructFieldAttrArray:$fields
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def LLVM_TBAAAccessTagAttr : LLVM_Attr<"TBAAAccessTag", "tbaa_access_tag"> {
+  let parameters = (ins
+    "TBAATypeNodeAttr":$base_type,
+    "TBAATypeNodeAttr":$access_type,
+    "int64_t":$offset,
+    "int64_t":$size
+  );
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "TBAATypeNodeAttr":$baseType,
+                                        "TBAATypeNodeAttr":$accessType,
+                                        "int64_t":$offset,
+                                        "int64_t":$size), [{
+      return $_get(baseType.getContext(), baseType, accessType, offset, size);
+    }]>
+  ];
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def LLVM_TBAAAccessTagArrayAttr
+    : TypedArrayAttrBase<LLVM_TBAAAccessTagAttr,
+                         LLVM_TBAAAccessTagAttr.summary # " array"> {
+  let constBuilderCall = ?;
+}
+
+// def LLVM_TBAATagAttr2 : AnyAttrOf<[
+//   LLVM_TBAATagAttr,
+//   LLVM_TBAAAccessTagAttr
+// ]>;
+
 def LLVM_TBAATagArrayAttr
-    : TypedArrayAttrBase<LLVM_TBAATagAttr,
+    : TypedArrayAttrBase<AnyAttrOf<[
+  LLVM_TBAATagAttr,
+  LLVM_TBAAAccessTagAttr
+]>,
                          LLVM_TBAATagAttr.summary # " array"> {
   let constBuilderCall = ?;
 }
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..c7a79aa330d3da 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -323,7 +323,7 @@ class ModuleTranslation {
 
   /// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
   /// TBAATagAttr.
-  llvm::MDNode *getTBAANode(TBAATagAttr tbaaAttr) const;
+  llvm::MDNode *getTBAANode(Attribute tbaaAttr) const;
 
   /// Process tbaa LLVM Metadata operations and create LLVM
   /// metadata nodes for them.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 6b2d8943bf4885..b2b0b9b331e0b4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3401,7 +3401,8 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
               LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr,
               LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr,
               LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr,
-              TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) {
+              TBAATagAttr, TBAATypeDescriptorAttr, TBAAAccessTagAttr,
+              TBAATypeNodeAttr>([&](auto attr) {
           os << decltype(attr)::getMnemonic();
           return AliasResult::OverridableAlias;
         })
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
index cff16afc73af3f..6a9395b1f4a26e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
@@ -58,7 +58,15 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
   ArrayAttr tags = iface.getTBAATagsOrNull();
   if (!tags)
     return success();
-
+  if (tags.size() > 0) {
+    if (mlir::isa<TBAATagAttr>(tags[0])) {
+      return isArrayOf<TBAATagAttr>(op, tags);
+    }
+
+    if (mlir::isa<TBAAAccessTagAttr>(tags[0])) {
+      return isArrayOf<TBAAAccessTagAttr>(op, tags);
+    }
+  }
   return isArrayOf<TBAATagAttr>(op, tags);
 }
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..6a6c29869ba805 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1766,7 +1766,8 @@ void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
                         llvm::LLVMContext::MD_noalias);
 }
 
-llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
+// llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
+llvm::MDNode *ModuleTranslation::getTBAANode(Attribute tbaaAttr) const {
   return tbaaMetadataMapping.lookup(tbaaAttr);
 }
 
@@ -1786,7 +1787,8 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
     return;
   }
 
-  llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
+  // llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
+  llvm::MDNode *node = getTBAANode(tagRefs[0]);
   inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
 }
 
@@ -1806,6 +1808,7 @@ void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
 LogicalResult ModuleTranslation::createTBAAMetadata() {
   llvm::LLVMContext &ctx = llvmModule->getContext();
   llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
+  llvm::IntegerType *sizeTy = llvm::IntegerType::get(ctx, 64);
 
   // Walk the entire module and create all metadata nodes for the TBAA
   // attributes. The code below relies on two invariants of the
@@ -1833,6 +1836,23 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
     tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
   });
 
+  walker.addWalk([&](TBAATypeNodeAttr descriptor) {
+    SmallVector<llvm::Metadata *> operands;
+    operands.push_back(tbaaMetadataMapping.lookup(descriptor.getParent()));
+    operands.push_back(llvm::ConstantAsMetadata::get(
+        llvm::ConstantInt::get(sizeTy, descriptor.getSize())));
+    operands.push_back(llvm::MDString::get(ctx, descriptor.getId()));
+    for (auto field : descriptor.getFields()) {
+      operands.push_back(tbaaMetadataMapping.lookup(field.getTypeDesc()));
+      operands.push_back(llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(offsetTy, field.getOffset())));
+      operands.push_back(llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(sizeTy, field.getSize())));
+    }
+
+    tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
+  });
+
   walker.addWalk([&](TBAATagAttr tag) {
     SmallVector<llvm::Metadata *> operands;
 
@@ -1848,6 +1868,20 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
     tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
   });
 
+  walker.addWalk([&](TBAAAccessTagAttr tag) {
+    SmallVector<llvm::Metadata *> operands;
+
+    operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
+    operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
+
+    operands.push_back(llvm::ConstantAsMetadata::get(
+        llvm::ConstantInt::get(offsetTy, tag.getOffset())));
+    operands.push_back(llvm::ConstantAsMetadata::get(
+        llvm::ConstantInt::get(sizeTy, tag.getSize())));
+
+    tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
+  });
+
   mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) {
     if (auto attr = analysisOpInterface.getTBAATagsOrNull())
       walker.walk(attr);



More information about the Mlir-commits mailing list