[Mlir-commits] [mlir] [mlir][LLVMIR] Support memory model relaxation annotations (MMRA) (PR #157770)

Krzysztof Drewniak llvmlistbot at llvm.org
Wed Sep 10 09:31:25 PDT 2025


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/157770

>From b6d28d1205981480a2613a0f3217ca006ebd699d Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 10 Sep 2025 00:00:35 +0000
Subject: [PATCH 1/3] [mlir][LLVMIR] Support memory model relaxation
 annotations (MMRA)

This commit adds support for exportind and importing MMRA data in the
LLVM dialect. MMRA is a potentilly-discardable piece of metadata that
can be placed on any operation that touches memory (fences, loads,
stores, atomics, and intrinsics that operate on memory). It includes
one (technically zero) ome more prefix:suffix string pairs which
indicate ways in which the LLVM memory model can be relaxed for these
annotations.

At the MLIR level, each tag is represented with a
`#llvm.mmra_tag<"prefix":"suffex">` attribute, and the MMRA metadata
as a whole is represented as a discardable llvm.mmra attribute. (This
discardability both allows us to transparently enable MMRA for wrapper
dialects like ROCDL and ensures that MLIR passes which don't know
about MMRA combining will, conservatively, discard the annotations,
per the LLVM spec).
---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 41 ++++++++++++++++
 .../mlir/Dialect/LLVMIR/LLVMDialect.td        |  1 +
 .../LLVMIR/LLVMIRToLLVMTranslation.cpp        | 31 +++++++++++-
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        | 47 +++++++++++++++++++
 mlir/test/Dialect/LLVMIR/mmra.mlir            | 29 ++++++++++++
 .../Target/LLVMIR/Import/metadata-mmra.ll     | 22 +++++++++
 mlir/test/Target/LLVMIR/mmra.mlir             | 35 ++++++++++++++
 7 files changed, 205 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/LLVMIR/mmra.mlir
 create mode 100644 mlir/test/Target/LLVMIR/Import/metadata-mmra.ll
 create mode 100644 mlir/test/Target/LLVMIR/mmra.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index ac99b8aba073a..9980e1174d694 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1232,6 +1232,47 @@ def LLVM_TBAATagArrayAttr
   let constBuilderCall = ?;
 }
 
+//===----------------------------------------------------------------------===//
+// Memory Model Relaxation Annations (mmra) Attributes
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// MMRATagAttr - single MMRA tag
+//===----------------------------------------------------------------------===//
+
+def LLVM_MMRATagAttr : LLVM_Attr<"MMRATag", "mmra_tag"> {
+  let parameters = (ins
+    StringRefParameter<>:$prefix,
+    StringRefParameter<>:$suffix
+  );
+
+  let summary = "MLIR wrapper around a prefix:suffix MMRA tag";
+
+  let description = [{
+    Defines a single memory model relaxation annotation (MMRA) entry
+    with prefix `$prefix` and suffix `$suffix`. This corresponds directly
+    to a LLVM `!{prefix, suffix}` metadata tuple, which is often written
+    `prefix:shuffix` as shorthand.
+
+    Example:
+    ```mlir
+    #mmra_tag = #llvm.mmmra_tag<"amdgpu-synchronize-as":"local">
+    #mmra_tag1 = #llvm.mmra_tag<"foo":"bar">
+    ```
+
+    See the following link for more details:
+    https://llvm.org/docs/MemoryModelRelaxationAnnotations.html
+  }];
+
+  let assemblyFormat = "`<` $prefix `` `:` `` $suffix `>`";
+
+  let genMnemonicAlias = 1;
+}
+
+def LLVM_MMRATagArrayAttr : TypedArrayAttrBase<
+    LLVM_MMRATagAttr,
+    LLVM_MMRATagAttr.summary # " array">;
+
 //===----------------------------------------------------------------------===//
 // ConstantRangeAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
index ab0462f945a33..d2d71318a6118 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
@@ -36,6 +36,7 @@ def LLVM_Dialect : Dialect {
     static StringRef getIdentAttrName() { return "llvm.ident"; }
     static StringRef getModuleFlags() { return "llvm.module.flags"; }
     static StringRef getCommandlineAttrName() { return "llvm.commandline"; }
+    static StringRef getMmraAttrName() { return "llvm.mmra"; }
 
     /// Names of llvm parameter attributes.
     static StringRef getAlignAttrName() { return "llvm.align"; }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 8d5d7f9b649f2..9a548cf77e0f5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
 #include "mlir/Support/LLVM.h"
@@ -21,6 +22,7 @@
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/MemoryModelRelaxationAnnotations.h"
 
 using namespace mlir;
 using namespace mlir::LLVM;
@@ -88,6 +90,7 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
       llvm::LLVMContext::MD_alias_scope,
       llvm::LLVMContext::MD_dereferenceable,
       llvm::LLVMContext::MD_dereferenceable_or_null,
+      llvm::LLVMContext::MD_mmra,
       context.getMDKindID(vecTypeHintMDName),
       context.getMDKindID(workGroupSizeHintMDName),
       context.getMDKindID(reqdWorkGroupSizeMDName),
@@ -212,6 +215,31 @@ static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
   return success();
 }
 
+/// Convert the given MMRA metadata (either an MMRA tag or an array of rhem)
+/// into corresponding MLIR attributes and set them on the given operation as a
+/// discardable `llvm.mmra` attribute.
+static LogicalResult setMmraAttr(llvm::MDNode *node, Operation *op,
+                                 LLVM::ModuleImport &moduleImport) {
+  llvm::MMRAMetadata wrapper(node);
+  if (wrapper.empty()) {
+    return success();
+  }
+  MLIRContext *ctx = op->getContext();
+  Attribute mlirMmra;
+  if (wrapper.size() == 1) {
+    auto [prefix, suffix] = *wrapper.begin();
+    mlirMmra = LLVM::MMRATagAttr::get(ctx, prefix, suffix);
+  } else {
+    SmallVector<Attribute> tags;
+    for (auto [prefix, suffix] : wrapper) {
+      tags.push_back(LLVM::MMRATagAttr::get(ctx, prefix, suffix));
+    }
+    mlirMmra = ArrayAttr::get(ctx, tags);
+  }
+  op->setAttr(LLVMDialect::getMmraAttrName(), mlirMmra);
+  return success();
+}
+
 /// Converts the given loop metadata node to an MLIR loop annotation attribute
 /// and attaches it to the imported operation if the translation succeeds.
 /// Returns failure otherwise.
@@ -432,7 +460,8 @@ class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
       return setDereferenceableAttr(
           node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
           moduleImport);
-
+    if (kind == llvm::LLVMContext::MD_mmra)
+      return setMmraAttr(node, op, moduleImport);
     llvm::LLVMContext &context = node->getContext();
     if (kind == context.getMDKindID(vecTypeHintMDName))
       return setVecTypeHintAttr(builder, node, op, moduleImport);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index fd8463ad1a8e2..ddf9b16b7b552 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -24,6 +24,8 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/MatrixBuilder.h"
+#include "llvm/IR/MemoryModelRelaxationAnnotations.h"
+#include "llvm/Support/LogicalResult.h"
 
 using namespace mlir;
 using namespace mlir::LLVM;
@@ -723,6 +725,43 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
   return failure();
 }
 
+static LogicalResult
+amendOperationImpl(Operation &op, ArrayRef<llvm::Instruction *> instructions,
+                   NamedAttribute attribute,
+                   LLVM::ModuleTranslation &moduleTranslation) {
+  StringRef name = attribute.getName();
+  if (name == LLVMDialect::getMmraAttrName()) {
+    SmallVector<llvm::MMRAMetadata::TagT> tags;
+    if (auto oneTag = dyn_cast<LLVM::MMRATagAttr>(attribute.getValue())) {
+      tags.emplace_back(oneTag.getPrefix(), oneTag.getSuffix());
+    } else if (auto manyTags = dyn_cast<ArrayAttr>(attribute.getValue())) {
+      for (auto a : manyTags) {
+        auto tag = dyn_cast<MMRATagAttr>(a);
+        if (tag) {
+          tags.emplace_back(tag.getPrefix(), tag.getSuffix());
+        } else {
+          return op.emitOpError(
+              "MMRA annotations array contains value that isn't an MMRA tag");
+        }
+      }
+    } else {
+      return op.emitOpError(
+          "llvm.mmra is something other than an MMRA tag or an array of them");
+    }
+    llvm::MDTuple *mmraMd =
+        llvm::MMRAMetadata::getMD(moduleTranslation.getLLVMContext(), tags);
+    if (!mmraMd) {
+      // Empty list, canonicalizes to nothing
+      return success();
+    }
+    for (llvm::Instruction *inst : instructions) {
+      inst->setMetadata(llvm::LLVMContext::MD_mmra, mmraMd);
+    }
+    return success();
+  }
+  return success();
+}
+
 namespace {
 /// Implementation of the dialect interface that converts operations belonging
 /// to the LLVM dialect to LLVM IR.
@@ -738,6 +777,14 @@ class LLVMDialectLLVMIRTranslationInterface
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     return convertOperationImpl(*op, builder, moduleTranslation);
   }
+
+  /// Handle some metadata that is represented as a discardable attribute.
+  LogicalResult
+  amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+                 NamedAttribute attribute,
+                 LLVM::ModuleTranslation &moduleTranslation) const final {
+    return amendOperationImpl(*op, instructions, attribute, moduleTranslation);
+  }
 };
 } // namespace
 
diff --git a/mlir/test/Dialect/LLVMIR/mmra.mlir b/mlir/test/Dialect/LLVMIR/mmra.mlir
new file mode 100644
index 0000000000000..95da9666d053c
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/mmra.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -split-input-file --verify-roundtrip --mlir-print-local-scope | FileCheck %s
+
+// CHECK-LABEL: llvm.func @native
+// CHECK: llvm.load
+// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar">
+// CHECK: llvm.fence
+// CHECK-SAME: llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #llvm.mmra_tag<"foo":"bar">]
+// CHECK: llvm.store
+// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar">
+
+#mmra_tag = #llvm.mmra_tag<"foo":"bar">
+
+llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) {
+  %0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32
+  llvm.fence syncscope("workgroup-one-as") release
+    {llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]}
+  llvm.store %0, %y {llvm.mmra = #llvm.mmra_tag<"foo":"bar">} : i32, !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @foreign_op
+// CHECK: rocdl.load.to.lds
+// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"fake":"example">
+llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) {
+  rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll b/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll
new file mode 100644
index 0000000000000..180d438eca70e
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll
@@ -0,0 +1,22 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-DAG: #[[$MMRA0:.+]] = #llvm.mmra_tag<"foo":"bar">
+; CHECK-DAG: #[[$MMRA1:.+]] = #llvm.mmra_tag<"amdgpu-synchronize-as":"local">
+
+; CHECK-LABEL: llvm.func @native
+define void @native(ptr %x, ptr %y) {
+  ; CHECK: llvm.load
+  ; CHECK-SAME: llvm.mmra = #[[$MMRA0]]
+  %v = load i32, ptr %x, align 4, !mmra !0
+  ; CHECK: llvm.fence
+  ; CHECK-SAME: llvm.mmra = [#[[$MMRA0]], #[[$MMRA1]]]
+  fence syncscope("workgroup-one-as") release, !mmra !2
+  ; CHECK: llvm.store {{.*}}, !llvm.ptr{{$}}
+  store i32 %v, ptr %y, align 4, !mmra !3
+  ret void
+}
+
+!0 = !{!"foo", !"bar"}
+!1 = !{!"amdgpu-synchronize-as", !"local"}
+!2 = !{!0, !1}
+!3 = !{}
diff --git a/mlir/test/Target/LLVMIR/mmra.mlir b/mlir/test/Target/LLVMIR/mmra.mlir
new file mode 100644
index 0000000000000..5864e0e0759e6
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/mmra.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: define void @native
+// CHECK: load
+// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]]
+// CHECK: fence
+// CHECK-SAME: !mmra ![[MMRA1:[0-9]+]]
+// CHECK: store {{.*}}, align 4{{$}}
+
+#mmra_tag = #llvm.mmra_tag<"foo":"bar">
+
+llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) {
+  %0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32
+  llvm.fence syncscope("workgroup-one-as") release
+    {llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]}
+  llvm.store %0, %y {llvm.mmra = []} : i32, !llvm.ptr
+  llvm.return
+}
+
+// Actual MMRA metadata
+// CHECK-DAG: ![[MMRA0]] = !{!"foo", !"bar"}
+// CHECK-DAG: ![[MMRA_PART0:[0-9]+]] = !{!"amdgpu-synchronize-as", !"local"}
+// CHECK-DAG: ![[MMRA1]] = !{![[MMRA_PART0]], ![[MMRA0]]}
+
+// -----
+
+// CHECK-LABEL: define void @foreign_op
+// CHECK: call void @llvm.amdgcn.load.to.lds
+// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]]
+llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) {
+  rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1>
+  llvm.return
+}
+
+// CHECK: ![[MMRA0]] = !{!"fake", !"example"}

>From ac028712e3f91002edabb0f4aa8dc4dcd424e079 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 10 Sep 2025 11:05:16 -0500
Subject: [PATCH 2/3] Apply style nits from review

Co-authored-by: Tobias Gysi <tobias.gysi at nextsilicon.com>
---
 .../LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp    | 9 ++++-----
 .../LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp    | 5 ++---
 2 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 9a548cf77e0f5..736b1fe5b850c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -215,15 +215,15 @@ static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
   return success();
 }
 
-/// Convert the given MMRA metadata (either an MMRA tag or an array of rhem)
+/// Convert the given MMRA metadata (either an MMRA tag or an array of them)
 /// into corresponding MLIR attributes and set them on the given operation as a
 /// discardable `llvm.mmra` attribute.
 static LogicalResult setMmraAttr(llvm::MDNode *node, Operation *op,
                                  LLVM::ModuleImport &moduleImport) {
   llvm::MMRAMetadata wrapper(node);
-  if (wrapper.empty()) {
+  if (wrapper.empty()) 
     return success();
-  }
+ 
   MLIRContext *ctx = op->getContext();
   Attribute mlirMmra;
   if (wrapper.size() == 1) {
@@ -231,9 +231,8 @@ static LogicalResult setMmraAttr(llvm::MDNode *node, Operation *op,
     mlirMmra = LLVM::MMRATagAttr::get(ctx, prefix, suffix);
   } else {
     SmallVector<Attribute> tags;
-    for (auto [prefix, suffix] : wrapper) {
+    for (auto [prefix, suffix] : wrapper) 
       tags.push_back(LLVM::MMRATagAttr::get(ctx, prefix, suffix));
-    }
     mlirMmra = ArrayAttr::get(ctx, tags);
   }
   op->setAttr(LLVMDialect::getMmraAttrName(), mlirMmra);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index ddf9b16b7b552..5093b2a47d10c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -735,7 +735,7 @@ amendOperationImpl(Operation &op, ArrayRef<llvm::Instruction *> instructions,
     if (auto oneTag = dyn_cast<LLVM::MMRATagAttr>(attribute.getValue())) {
       tags.emplace_back(oneTag.getPrefix(), oneTag.getSuffix());
     } else if (auto manyTags = dyn_cast<ArrayAttr>(attribute.getValue())) {
-      for (auto a : manyTags) {
+      for (Attribute attr : manyTags) {
         auto tag = dyn_cast<MMRATagAttr>(a);
         if (tag) {
           tags.emplace_back(tag.getPrefix(), tag.getSuffix());
@@ -754,9 +754,8 @@ amendOperationImpl(Operation &op, ArrayRef<llvm::Instruction *> instructions,
       // Empty list, canonicalizes to nothing
       return success();
     }
-    for (llvm::Instruction *inst : instructions) {
+    for (llvm::Instruction *inst : instructions) 
       inst->setMetadata(llvm::LLVMContext::MD_mmra, mmraMd);
-    }
     return success();
   }
   return success();

>From 0c749fcace4219918dca7e14a2c8c658a283dfd0 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 10 Sep 2025 16:31:11 +0000
Subject: [PATCH 3/3] Fix nondeterminism, non-trivial nits

---
 .../LLVMIR/LLVMIRToLLVMTranslation.cpp        | 25 +++++++++++++------
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        | 10 +++-----
 .../Target/LLVMIR/Import/metadata-mmra.ll     |  4 +--
 3 files changed, 23 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 736b1fe5b850c..44732d5466f6d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -220,19 +220,28 @@ static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
 /// discardable `llvm.mmra` attribute.
 static LogicalResult setMmraAttr(llvm::MDNode *node, Operation *op,
                                  LLVM::ModuleImport &moduleImport) {
-  llvm::MMRAMetadata wrapper(node);
-  if (wrapper.empty()) 
+  if (!node)
     return success();
- 
+
+  // We don't use the LLVM wrappers here becasue we care about the order
+  // of the metadata for deterministic roundtripping.
   MLIRContext *ctx = op->getContext();
+  auto toAttribute = [&](llvm::MDNode *tag) -> Attribute {
+    return LLVM::MMRATagAttr::get(
+        ctx, cast<llvm::MDString>(tag->getOperand(0))->getString(),
+        cast<llvm::MDString>(tag->getOperand(1))->getString());
+  };
   Attribute mlirMmra;
-  if (wrapper.size() == 1) {
-    auto [prefix, suffix] = *wrapper.begin();
-    mlirMmra = LLVM::MMRATagAttr::get(ctx, prefix, suffix);
+  if (llvm::MMRAMetadata::isTagMD(node)) {
+    mlirMmra = toAttribute(node);
   } else {
     SmallVector<Attribute> tags;
-    for (auto [prefix, suffix] : wrapper) 
-      tags.push_back(LLVM::MMRATagAttr::get(ctx, prefix, suffix));
+    for (const llvm::MDOperand &operand : node->operands()) {
+      auto *tagNode = dyn_cast<llvm::MDNode>(operand.get());
+      if (!tagNode || !llvm::MMRAMetadata::isTagMD(tagNode))
+        return failure();
+      tags.push_back(toAttribute(tagNode));
+    }
     mlirMmra = ArrayAttr::get(ctx, tags);
   }
   op->setAttr(LLVMDialect::getMmraAttrName(), mlirMmra);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 5093b2a47d10c..eaf1d20da63c7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -736,13 +736,11 @@ amendOperationImpl(Operation &op, ArrayRef<llvm::Instruction *> instructions,
       tags.emplace_back(oneTag.getPrefix(), oneTag.getSuffix());
     } else if (auto manyTags = dyn_cast<ArrayAttr>(attribute.getValue())) {
       for (Attribute attr : manyTags) {
-        auto tag = dyn_cast<MMRATagAttr>(a);
-        if (tag) {
-          tags.emplace_back(tag.getPrefix(), tag.getSuffix());
-        } else {
+        auto tag = dyn_cast<MMRATagAttr>(attr);
+        if (!tag)
           return op.emitOpError(
               "MMRA annotations array contains value that isn't an MMRA tag");
-        }
+        tags.emplace_back(tag.getPrefix(), tag.getSuffix());
       }
     } else {
       return op.emitOpError(
@@ -754,7 +752,7 @@ amendOperationImpl(Operation &op, ArrayRef<llvm::Instruction *> instructions,
       // Empty list, canonicalizes to nothing
       return success();
     }
-    for (llvm::Instruction *inst : instructions) 
+    for (llvm::Instruction *inst : instructions)
       inst->setMetadata(llvm::LLVMContext::MD_mmra, mmraMd);
     return success();
   }
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll b/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll
index 180d438eca70e..5e1ed37d559ef 100644
--- a/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll
+++ b/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll
@@ -9,7 +9,7 @@ define void @native(ptr %x, ptr %y) {
   ; CHECK-SAME: llvm.mmra = #[[$MMRA0]]
   %v = load i32, ptr %x, align 4, !mmra !0
   ; CHECK: llvm.fence
-  ; CHECK-SAME: llvm.mmra = [#[[$MMRA0]], #[[$MMRA1]]]
+  ; CHECK-SAME: llvm.mmra = [#[[$MMRA1]], #[[$MMRA0]]]
   fence syncscope("workgroup-one-as") release, !mmra !2
   ; CHECK: llvm.store {{.*}}, !llvm.ptr{{$}}
   store i32 %v, ptr %y, align 4, !mmra !3
@@ -18,5 +18,5 @@ define void @native(ptr %x, ptr %y) {
 
 !0 = !{!"foo", !"bar"}
 !1 = !{!"amdgpu-synchronize-as", !"local"}
-!2 = !{!0, !1}
+!2 = !{!1, !0}
 !3 = !{}



More information about the Mlir-commits mailing list