[Mlir-commits] [mlir] fef08da - [mlir][llvm] Store memory op metadata using op attributes.

Tobias Gysi llvmlistbot at llvm.org
Fri Feb 10 06:29:52 PST 2023


Author: Tobias Gysi
Date: 2023-02-10T15:27:25+01:00
New Revision: fef08da4b75fc751c6117df2a0213a0b075d05f5

URL: https://github.com/llvm/llvm-project/commit/fef08da4b75fc751c6117df2a0213a0b075d05f5
DIFF: https://github.com/llvm/llvm-project/commit/fef08da4b75fc751c6117df2a0213a0b075d05f5.diff

LOG: [mlir][llvm] Store memory op metadata using op attributes.

The revision introduces operation attributes to store tbaa metadata on
load and store operations rather than relying using dialect attributes.
At the same time, the change also ensures the provided getters and
setters instead are used instead of a string based lookup. The latter
is done for the tbaa, access groups, and alias scope attributes.

The goal of this change is to ensure the metadata attributes are only
placed on operations that have the corresponding operation attributes.
This is imported since only these operations later on translate these
attributes to LLVM IR. Dialect attributes placed on other operations
are lost during the translation.

Reviewed By: vzakhari, Dinistro

Differential Revision: https://reviews.llvm.org/D143654

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/TBAABuilder.cpp
    flang/test/Fir/tbaa.fir
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
    mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir
    mlir/test/Target/LLVMIR/Import/import-failure.ll
    mlir/test/Target/LLVMIR/tbaa.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp b/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp
index 2d206ed2dcbf5..c42081869c7ad 100644
--- a/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp
+++ b/flang/lib/Optimizer/CodeGen/TBAABuilder.cpp
@@ -12,6 +12,7 @@
 
 #include "TBAABuilder.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 
@@ -159,9 +160,13 @@ void TBAABuilder::attachTBAATag(Operation *op, Type baseFIRType,
   else
     tbaaTagSym = getDataAccessTag(baseFIRType, accessFIRType, gep);
 
-  if (tbaaTagSym)
-    op->setAttr(LLVMDialect::getTBAAAttrName(),
-                ArrayAttr::get(op->getContext(), tbaaTagSym));
+  if (!tbaaTagSym)
+    return;
+
+  auto tbaaAttr = ArrayAttr::get(op->getContext(), tbaaTagSym);
+  llvm::TypeSwitch<Operation *>(op)
+      .Case<LoadOp, StoreOp>([&](auto memOp) { memOp.setTbaaAttr(tbaaAttr); })
+      .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
 }
 
 } // namespace fir

diff  --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index ac94ea94cb3cf..66261e6ee002a 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -28,10 +28,10 @@ module {
 // CHECK:           %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_5:.*]] = llvm.mlir.constant(10 : i32) : i32
 // CHECK:           %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<struct<()>>>
-// CHECK:           %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<ptr<struct<()>>>
+// CHECK:           %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<ptr<struct<()>>>
 // CHECK:           %[[VAL_8:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_9:.*]] = llvm.getelementptr %[[VAL_0]][0, 7, 0, 2] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_10:.*]] = llvm.load %[[VAL_9]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_10:.*]] = llvm.load %[[VAL_9]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_11:.*]] = llvm.mul %[[VAL_4]], %[[VAL_10]]  : i64
 // CHECK:           %[[VAL_12:.*]] = llvm.add %[[VAL_11]], %[[VAL_8]]  : i64
 // CHECK:           %[[VAL_13:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
@@ -40,11 +40,11 @@ module {
 // CHECK:           %[[VAL_16:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_17:.*]] = llvm.mlir.constant(-1 : i32) : i32
 // CHECK:           %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_0]][0, 8] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i8>>
-// CHECK:           %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
+// CHECK:           %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
 // CHECK:           %[[VAL_20:.*]] = llvm.getelementptr %[[VAL_0]][0, 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_21:.*]] = llvm.load %[[VAL_20]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_21:.*]] = llvm.load %[[VAL_20]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_22:.*]] = llvm.getelementptr %[[VAL_0]][0, 4] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
 // CHECK:           %[[VAL_24:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_21]], %[[VAL_24]][1] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_26:.*]] = llvm.mlir.constant(20180515 : i32) : i32
@@ -64,15 +64,15 @@ module {
 // CHECK:           %[[VAL_40:.*]] = llvm.insertvalue %[[VAL_39]], %[[VAL_38]][7] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_41:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<struct<()>> to !llvm.ptr<struct<()>>
 // CHECK:           %[[VAL_42:.*]] = llvm.insertvalue %[[VAL_41]], %[[VAL_40]][0] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>
-// CHECK:           llvm.store %[[VAL_42]], %[[VAL_2]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>
+// CHECK:           llvm.store %[[VAL_42]], %[[VAL_2]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>
 // CHECK:           %[[VAL_43:.*]] = llvm.getelementptr %[[VAL_2]][0, 4] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i8>
-// CHECK:           %[[VAL_44:.*]] = llvm.load %[[VAL_43]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i8>
+// CHECK:           %[[VAL_44:.*]] = llvm.load %[[VAL_43]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i8>
 // CHECK:           %[[VAL_45:.*]] = llvm.icmp "eq" %[[VAL_44]], %[[VAL_3]] : i8
 // CHECK:           llvm.cond_br %[[VAL_45]], ^bb1, ^bb2
 // CHECK:         ^bb1:
 // CHECK:           %[[VAL_46:.*]] = llvm.getelementptr %[[VAL_2]][0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i32>>
-// CHECK:           %[[VAL_47:.*]] = llvm.load %[[VAL_46]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
-// CHECK:           llvm.store %[[VAL_5]], %[[VAL_47]] {llvm.tbaa = [@__flang_tbaa::@[[DATAT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_47:.*]] = llvm.load %[[VAL_46]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
+// CHECK:           llvm.store %[[VAL_5]], %[[VAL_47]] {tbaa = [@__flang_tbaa::@[[DATAT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           llvm.br ^bb2
 // CHECK:         ^bb2:
 // CHECK:           llvm.return
@@ -133,24 +133,24 @@ module {
 // CHECK:           %[[VAL_8:.*]] = llvm.mlir.addressof @_QQcl.2E2F64756D6D792E66393000 : !llvm.ptr<array<12 x i8>>
 // CHECK:           %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr<array<12 x i8>> to !llvm.ptr<i8>
 // CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_9]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr<i8>, i32) -> !llvm.ptr<i8>
-// CHECK:           %[[VAL_11:.*]] = llvm.load %[[VAL_7]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
-// CHECK:           llvm.store %[[VAL_11]], %[[VAL_3]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
+// CHECK:           %[[VAL_11:.*]] = llvm.load %[[VAL_7]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
+// CHECK:           llvm.store %[[VAL_11]], %[[VAL_3]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
 // CHECK:           %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>, i64) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_13:.*]] = llvm.load %[[VAL_12]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_13:.*]] = llvm.load %[[VAL_12]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>, i64) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_15:.*]] = llvm.load %[[VAL_14]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_15:.*]] = llvm.load %[[VAL_14]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_16:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 2] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>, i64) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_17:.*]] = llvm.load %[[VAL_16]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_17:.*]] = llvm.load %[[VAL_16]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_18:.*]] = llvm.getelementptr %[[VAL_3]][0, 8] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i8>>
-// CHECK:           %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
+// CHECK:           %[[VAL_19:.*]] = llvm.load %[[VAL_18]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
 // CHECK:           %[[VAL_20:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_21:.*]] = llvm.mlir.constant(-1 : i32) : i32
 // CHECK:           %[[VAL_22:.*]] = llvm.getelementptr %[[VAL_3]][0, 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_23:.*]] = llvm.load %[[VAL_22]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_24:.*]] = llvm.getelementptr %[[VAL_3]][0, 4] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_25:.*]] = llvm.load %[[VAL_24]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_25:.*]] = llvm.load %[[VAL_24]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i32>
 // CHECK:           %[[VAL_26:.*]] = llvm.getelementptr %[[VAL_3]][0, 8] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<i8>>
-// CHECK:           %[[VAL_27:.*]] = llvm.load %[[VAL_26]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
+// CHECK:           %[[VAL_27:.*]] = llvm.load %[[VAL_26]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i8>>
 // CHECK:           %[[VAL_28:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_29:.*]] = llvm.insertvalue %[[VAL_23]], %[[VAL_28]][1] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_30:.*]] = llvm.mlir.constant(20180515 : i32) : i32
@@ -169,13 +169,13 @@ module {
 // CHECK:           %[[VAL_43:.*]] = llvm.bitcast %[[VAL_27]] : !llvm.ptr<i8> to !llvm.ptr<i8>
 // CHECK:           %[[VAL_44:.*]] = llvm.insertvalue %[[VAL_43]], %[[VAL_42]][8] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_45:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_46:.*]] = llvm.load %[[VAL_45]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_46:.*]] = llvm.load %[[VAL_45]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_47:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 1] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_48:.*]] = llvm.load %[[VAL_47]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_48:.*]] = llvm.load %[[VAL_47]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_49:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, 0, 2] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_50:.*]] = llvm.load %[[VAL_49]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_50:.*]] = llvm.load %[[VAL_49]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_51:.*]] = llvm.getelementptr %[[VAL_3]][0, 0] : (!llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>) -> !llvm.ptr<ptr<struct<()>>>
-// CHECK:           %[[VAL_52:.*]] = llvm.load %[[VAL_51]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<struct<()>>>
+// CHECK:           %[[VAL_52:.*]] = llvm.load %[[VAL_51]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<struct<()>>>
 // CHECK:           %[[VAL_53:.*]] = llvm.mlir.constant(0 : i64) : i64
 // CHECK:           %[[VAL_54:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:           %[[VAL_55:.*]] = llvm.icmp "eq" %[[VAL_48]], %[[VAL_53]] : i64
@@ -185,7 +185,7 @@ module {
 // CHECK:           %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
 // CHECK:           %[[VAL_60:.*]] = llvm.bitcast %[[VAL_52]] : !llvm.ptr<struct<()>> to !llvm.ptr<struct<()>>
 // CHECK:           %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_60]], %[[VAL_59]][0] : !llvm.struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>
-// CHECK:           llvm.store %[[VAL_61]], %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
+// CHECK:           llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>>
 // CHECK:           %[[VAL_62:.*]] = llvm.bitcast %[[VAL_1]] : !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr<i8>, array<1 x i64>)>> to !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>
 // CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_62]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr<i8>, !llvm.ptr<struct<(ptr<struct<()>>, i64, i32, i8, i8, i8, i8, ptr<i8>, array<1 x i64>)>>) -> i1
 // CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr<i8>) -> i32
@@ -253,7 +253,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           llvm.return %[[VAL_2]] : i32
 // CHECK:         }
 
@@ -275,7 +275,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> i1 {
 // CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(0 : i32) : i32
 // CHECK:           %[[VAL_4:.*]] = llvm.icmp "ne" %[[VAL_2]], %[[VAL_3]] : i32
 // CHECK:           llvm.return %[[VAL_4]] : i1
@@ -299,7 +299,7 @@ func.func @tbaa(%arg0: !fir.box<f32>) -> i32 {
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                               %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f32>, i64, i32, i8, i8, i8, i8)>>) -> i32 {
 // CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 1] : (!llvm.ptr<struct<(ptr<f32>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           llvm.return %[[VAL_2]] : i32
 // CHECK:         }
 
@@ -321,7 +321,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> i1 {
 // CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr<struct<(ptr<f64>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<i32>
-// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
+// CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i32>
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK:           %[[VAL_4:.*]] = llvm.and %[[VAL_2]], %[[VAL_3]]  : i32
 // CHECK:           %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -353,11 +353,11 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<?xi32>>) {
 // CHECK:           %[[VAL_4:.*]] = llvm.sub %[[VAL_1]], %[[VAL_2]]  : i64
 // CHECK:           %[[VAL_5:.*]] = llvm.mul %[[VAL_4]], %[[VAL_2]]  : i64
 // CHECK:           %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]][0, 7, 0, 2] : (!llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>>) -> !llvm.ptr<i64>
-// CHECK:           %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i64>
+// CHECK:           %[[VAL_7:.*]] = llvm.load %[[VAL_6]] {tbaa = [@__flang_tbaa::@[[BOXT:tag_[0-9]*]]]} : !llvm.ptr<i64>
 // CHECK:           %[[VAL_8:.*]] = llvm.mul %[[VAL_5]], %[[VAL_7]]  : i64
 // CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_8]], %[[VAL_3]]  : i64
 // CHECK:           %[[VAL_10:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>>) -> !llvm.ptr<ptr<i32>>
-// CHECK:           %[[VAL_11:.*]] = llvm.load %[[VAL_10]] {llvm.tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
+// CHECK:           %[[VAL_11:.*]] = llvm.load %[[VAL_10]] {tbaa = [@__flang_tbaa::@[[BOXT]]]} : !llvm.ptr<ptr<i32>>
 // CHECK:           %[[VAL_12:.*]] = llvm.bitcast %[[VAL_11]] : !llvm.ptr<i32> to !llvm.ptr<i8>
 // CHECK:           %[[VAL_13:.*]] = llvm.getelementptr %[[VAL_12]]{{\[}}%[[VAL_9]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
 // CHECK:           %[[VAL_14:.*]] = llvm.bitcast %[[VAL_13]] : !llvm.ptr<i8> to !llvm.ptr<i32>

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 4fd758639ba15..9c9ebfcf255bf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -35,11 +35,7 @@ def LLVM_Dialect : Dialect {
   let extraClassDeclaration = [{
     /// Name of the data layout attributes.
     static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; }
-    static StringRef getNoAliasScopesAttrName() { return "noalias_scopes"; }
-    static StringRef getAliasScopesAttrName() { return "alias_scopes"; }
     static StringRef getLoopAttrName() { return "llvm.loop"; }
-    static StringRef getAccessGroupsAttrName() { return "access_groups"; }
-    static StringRef getTBAAAttrName() { return "llvm.tbaa"; }
 
     /// Names of llvm parameter attributes.
     static StringRef getAlignAttrName() { return "llvm.align"; }

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 39b79f1ed2284..8e4b8340a52ba 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -350,6 +350,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
                    OptionalAttr<SymbolRefArrayAttr>:$access_groups,
                    OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
                    OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
+                   OptionalAttr<SymbolRefArrayAttr>:$tbaa,
                    OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
                    UnitAttr:$nontemporal);
   let results = (outs LLVM_LoadableType:$res);
@@ -390,6 +391,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpBase {
                    OptionalAttr<SymbolRefArrayAttr>:$access_groups,
                    OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
                    OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
+                   OptionalAttr<SymbolRefArrayAttr>:$tbaa,
                    OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
                    UnitAttr:$nontemporal);
   string llvmInstName = "Store";

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 59b29f506609b..2b08d96b680b9 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -120,15 +120,14 @@ class ModuleTranslation {
   /// in these blocks.
   void forgetMapping(Region &region);
 
-  /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
-  /// dialect access group operation.
-  llvm::MDNode *getAccessGroup(Operation &opInst,
+  /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+  /// LLVM dialect access group operation.
+  llvm::MDNode *getAccessGroup(Operation *op,
                                SymbolRefAttr accessGroupRef) const;
 
-  /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
-  /// dialect alias scope operation
-  llvm::MDNode *getAliasScope(Operation &opInst,
-                              SymbolRefAttr aliasScopeRef) const;
+  /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+  /// LLVM dialect alias scope operation
+  llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const;
 
   // Sets LLVM metadata for memory operations that are in a parallel loop.
   void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst);
@@ -287,9 +286,9 @@ class ModuleTranslation {
   /// metadata nodes for them and their domains.
   LogicalResult createAliasScopeMetadata();
 
-  /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
-  /// dialect TBAATagOp operation.
-  llvm::MDNode *getTBAANode(Operation &memOp, SymbolRefAttr tagRef) const;
+  /// Returns the LLVM metadata corresponding to a symbol reference to an mlir
+  /// LLVM dialect TBAATagOp operation.
+  llvm::MDNode *getTBAANode(Operation *op, SymbolRefAttr tagRef) const;
 
   /// Process tbaa LLVM Metadata operations and create LLVM
   /// metadata nodes for them.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 1ef24c63902f4..1dc501f0c4021 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -668,53 +668,60 @@ Type LLVM::GEPOp::getSourceElementType() {
 // Builder, printer and parser for for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//
 
-LogicalResult verifySymbolAttribute(
-    Operation *op, StringRef attributeName,
+/// Verifies the given array attribute contains symbol references and checks the
+/// referenced symbol types using the provided verification function.
+LogicalResult verifyMemOpSymbolRefs(
+    Operation *op, StringRef name, ArrayAttr symbolRefs,
     llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)>
         verifySymbolType) {
-  if (Attribute attribute = op->getAttr(attributeName)) {
-    // Verify that the attribute is a symbol ref array attribute,
-    // because this constraint is not verified for all attribute
-    // names processed here (e.g. 'tbaa'). This verification
-    // is redundant in some cases.
-    if (!(attribute.isa<ArrayAttr>() &&
-          llvm::all_of(attribute.cast<ArrayAttr>(), [&](Attribute attr) {
-            return attr && attr.isa<SymbolRefAttr>();
-          })))
-      return op->emitOpError("attribute '")
-             << attributeName
-             << "' failed to satisfy constraint: symbol ref array attribute";
-
-    for (SymbolRefAttr symbolRef :
-         attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
-      StringAttr metadataName = symbolRef.getRootReference();
-      StringAttr symbolName = symbolRef.getLeafReference();
-      // We want @metadata::@symbol, not just @symbol
-      if (metadataName == symbolName) {
-        return op->emitOpError() << "expected '" << symbolRef
-                                 << "' to specify a fully qualified reference";
-      }
-      auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
-          op->getParentOp(), metadataName);
-      if (!metadataOp)
-        return op->emitOpError()
-               << "expected '" << symbolRef << "' to reference a metadata op";
-      Operation *symbolOp =
-          SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
-      if (!symbolOp)
-        return op->emitOpError()
-               << "expected '" << symbolRef << "' to be a valid reference";
-      if (failed(verifySymbolType(symbolOp, symbolRef))) {
-        return failure();
-      }
+  assert(symbolRefs && "expected a non-null attribute");
+
+  // Verify that the attribute is a symbol ref array attribute,
+  // because this constraint is not verified for all attribute
+  // names processed here (e.g. 'tbaa'). This verification
+  // is redundant in some cases.
+  if (!llvm::all_of(symbolRefs, [](Attribute attr) {
+        return attr && attr.isa<SymbolRefAttr>();
+      }))
+    return op->emitOpError("attribute '")
+           << name
+           << "' failed to satisfy constraint: symbol ref array attribute";
+
+  for (SymbolRefAttr symbolRef : symbolRefs.getAsRange<SymbolRefAttr>()) {
+    StringAttr metadataName = symbolRef.getRootReference();
+    StringAttr symbolName = symbolRef.getLeafReference();
+    // We want @metadata::@symbol, not just @symbol
+    if (metadataName == symbolName) {
+      return op->emitOpError() << "expected '" << symbolRef
+                               << "' to specify a fully qualified reference";
+    }
+    auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+        op->getParentOp(), metadataName);
+    if (!metadataOp)
+      return op->emitOpError()
+             << "expected '" << symbolRef << "' to reference a metadata op";
+    Operation *symbolOp =
+        SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
+    if (!symbolOp)
+      return op->emitOpError()
+             << "expected '" << symbolRef << "' to be a valid reference";
+    if (failed(verifySymbolType(symbolOp, symbolRef))) {
+      return failure();
     }
   }
+
   return success();
 }
 
-// Verifies that metadata ops are wired up properly.
+/// Verifies the given array attribute contains symbol references that point to
+/// metadata operations of the given type.
 template <typename OpTy>
-static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
+static LogicalResult
+verifyMemOpSymbolRefsPointTo(Operation *op, StringRef name,
+                             std::optional<ArrayAttr> symbolRefs) {
+  if (!symbolRefs)
+    return success();
+
   auto verifySymbolType = [op](Operation *symbolOp,
                                SymbolRefAttr symbolRef) -> LogicalResult {
     if (!isa<OpTy>(symbolOp)) {
@@ -724,35 +731,33 @@ static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
     }
     return success();
   };
-
-  return verifySymbolAttribute(op, attributeName, verifySymbolType);
+  return verifyMemOpSymbolRefs(op, name, *symbolRefs, verifySymbolType);
 }
 
-static LogicalResult verifyMemoryOpMetadata(Operation *op) {
-  // access_groups
-  if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>(
-          op, LLVMDialect::getAccessGroupsAttrName())))
+/// Verifies the types of the metadata operations referenced by aliasing and
+/// access group metadata.
+template <typename OpTy>
+LogicalResult verifyMemOpMetadata(OpTy memOp) {
+  if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AccessGroupMetadataOp>(
+          memOp, memOp.getAccessGroupsAttrName(), memOp.getAccessGroups())))
     return failure();
 
-  // alias_scopes
-  if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
-          op, LLVMDialect::getAliasScopesAttrName())))
+  if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AliasScopeMetadataOp>(
+          memOp, memOp.getAliasScopesAttrName(), memOp.getAliasScopes())))
     return failure();
 
-  // noalias_scopes
-  if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
-          op, LLVMDialect::getNoAliasScopesAttrName())))
+  if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AliasScopeMetadataOp>(
+          memOp, memOp.getNoaliasScopesAttrName(), memOp.getNoaliasScopes())))
     return failure();
 
-  // tbaa
-  if (failed(verifyOpMetadata<LLVM::TBAATagOp>(op,
-                                               LLVMDialect::getTBAAAttrName())))
+  if (failed(verifyMemOpSymbolRefsPointTo<LLVM::TBAATagOp>(
+          memOp, memOp.getTbaaAttrName(), memOp.getTbaa())))
     return failure();
 
   return success();
 }
 
-LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
+LogicalResult LoadOp::verify() { return verifyMemOpMetadata(*this); }
 
 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
                    Value addr, unsigned alignment, bool isVolatile,
@@ -828,7 +833,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
 // Builder, printer and parser for LLVM::StoreOp.
 //===----------------------------------------------------------------------===//
 
-LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
+LogicalResult StoreOp::verify() { return verifyMemOpMetadata(*this); }
 
 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
                     Value addr, unsigned alignment, bool isVolatile,

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 207ef8b428099..7d36d40c70491 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -76,6 +76,30 @@ static ArrayRef<unsigned> getSupportedMetadataImpl() {
   return convertibleMetadata;
 }
 
+namespace {
+/// Helper class to attach metadata attributes to specific operation types. It
+/// specializes TypeSwitch to take an Operation and return a LogicalResult.
+template <typename... OpTys>
+struct AttributeSetter {
+  AttributeSetter(Operation *op) : op(op) {}
+
+  /// Calls `attachFn` on the provided Operation if it has one of
+  /// the given operation types. Returns failure otherwise.
+  template <typename CallableT>
+  LogicalResult apply(CallableT &&attachFn) {
+    return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+        .Case<OpTys...>([&attachFn](auto concreteOp) {
+          attachFn(concreteOp);
+          return success();
+        })
+        .Default([&](auto) { return failure(); });
+  }
+
+private:
+  Operation *op;
+};
+} // namespace
+
 /// Converts the given profiling metadata `node` to an MLIR profiling attribute
 /// and attaches it to the imported operation if the translation succeeds.
 /// Returns failure otherwise.
@@ -129,16 +153,10 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
     branchWeights.push_back(branchWeight->getZExtValue());
   }
 
-  // Attach the branch weights to the operations that support it.
-  return llvm::TypeSwitch<Operation *, LogicalResult>(op)
-      .Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
+  return AttributeSetter<CondBrOp, SwitchOp, CallOp, InvokeOp>(op).apply(
+      [&](auto branchWeightOp) {
         branchWeightOp.setBranchWeightsAttr(
             builder.getI32VectorAttr(branchWeights));
-        return success();
-      })
-      .Default([op](auto) {
-        return op->emitWarning()
-               << op->getName() << " does not support branch weights";
       });
 }
 
@@ -151,9 +169,9 @@ static LogicalResult setTBAAAttr(const llvm::MDNode *node, Operation *op,
   if (!tbaaTagSym)
     return failure();
 
-  op->setAttr(LLVMDialect::getTBAAAttrName(),
-              ArrayAttr::get(op->getContext(), tbaaTagSym));
-  return success();
+  return AttributeSetter<LoadOp, StoreOp>(op).apply([&](auto memOp) {
+    memOp.setTbaaAttr(ArrayAttr::get(memOp.getContext(), tbaaTagSym));
+  });
 }
 
 /// Looks up all the symbol references pointing to the access group operations
@@ -169,9 +187,10 @@ static LogicalResult setAccessGroupAttr(const llvm::MDNode *node, Operation *op,
 
   SmallVector<Attribute> accessGroupAttrs(accessGroups->begin(),
                                           accessGroups->end());
-  op->setAttr(LLVMDialect::getAccessGroupsAttrName(),
-              ArrayAttr::get(op->getContext(), accessGroupAttrs));
-  return success();
+  return AttributeSetter<LoadOp, StoreOp>(op).apply([&](auto memOp) {
+    memOp.setAccessGroupsAttr(
+        ArrayAttr::get(memOp.getContext(), accessGroupAttrs));
+  });
 }
 
 /// Converts the given loop metadata node to an MLIR loop annotation attribute

diff  --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
index 5864b48e0d733..b02433b33edfe 100644
--- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
@@ -210,7 +210,7 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
         llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
     for (SymbolRefAttr accessGroupRef : parallelAccessGroups)
       parallelAccess.push_back(
-          moduleTranslation.getAccessGroup(*op, accessGroupRef));
+          moduleTranslation.getAccessGroup(op, accessGroupRef));
     metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
   }
 

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index aa5627e7a15f7..3834bf02dcbcf 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -986,12 +986,12 @@ LogicalResult ModuleTranslation::convertFunctions() {
 }
 
 llvm::MDNode *
-ModuleTranslation::getAccessGroup(Operation &opInst,
+ModuleTranslation::getAccessGroup(Operation *op,
                                   SymbolRefAttr accessGroupRef) const {
   auto metadataName = accessGroupRef.getRootReference();
   auto accessGroupName = accessGroupRef.getLeafReference();
   auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
-      opInst.getParentOp(), metadataName);
+      op->getParentOp(), metadataName);
   auto *accessGroupOp =
       SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
   return accessGroupMetadataMapping.lookup(accessGroupOp);
@@ -1010,23 +1010,28 @@ LogicalResult ModuleTranslation::createAccessGroupMetadata() {
 
 void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
                                                 llvm::Instruction *inst) {
-  auto accessGroups =
-      op->getAttrOfType<ArrayAttr>(LLVMDialect::getAccessGroupsAttrName());
-  if (accessGroups && !accessGroups.empty()) {
+  auto populateGroupsMetadata = [&](std::optional<ArrayAttr> groupRefs) {
+    if (!groupRefs || groupRefs->empty())
+      return;
+
     llvm::Module *module = inst->getModule();
-    SmallVector<llvm::Metadata *> metadatas;
-    for (SymbolRefAttr accessGroupRef :
-         accessGroups.getAsRange<SymbolRefAttr>())
-      metadatas.push_back(getAccessGroup(*op, accessGroupRef));
-
-    llvm::MDNode *unionMD = nullptr;
-    if (metadatas.size() == 1)
-      unionMD = llvm::cast<llvm::MDNode>(metadatas.front());
-    else if (metadatas.size() >= 2)
-      unionMD = llvm::MDNode::get(module->getContext(), metadatas);
-
-    inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD);
-  }
+    SmallVector<llvm::Metadata *> groupMDs;
+    for (SymbolRefAttr groupRef : groupRefs->getAsRange<SymbolRefAttr>())
+      groupMDs.push_back(getAccessGroup(op, groupRef));
+
+    llvm::MDNode *node = nullptr;
+    if (groupMDs.size() == 1)
+      node = llvm::cast<llvm::MDNode>(groupMDs.front());
+    else if (groupMDs.size() >= 2)
+      node = llvm::MDNode::get(module->getContext(), groupMDs);
+
+    inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
+  };
+
+  llvm::TypeSwitch<Operation *>(op)
+      .Case<LoadOp, StoreOp>(
+          [&](auto memOp) { populateGroupsMetadata(memOp.getAccessGroups()); })
+      .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
 }
 
 LogicalResult ModuleTranslation::createAliasScopeMetadata() {
@@ -1067,12 +1072,12 @@ LogicalResult ModuleTranslation::createAliasScopeMetadata() {
 }
 
 llvm::MDNode *
-ModuleTranslation::getAliasScope(Operation &opInst,
+ModuleTranslation::getAliasScope(Operation *op,
                                  SymbolRefAttr aliasScopeRef) const {
   StringAttr metadataName = aliasScopeRef.getRootReference();
   StringAttr scopeName = aliasScopeRef.getLeafReference();
   auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
-      opInst.getParentOp(), metadataName);
+      op->getParentOp(), metadataName);
   Operation *aliasScopeOp =
       SymbolTable::lookupNearestSymbolFrom(metadataOp, scopeName);
   return aliasScopeMetadataMapping.lookup(aliasScopeOp);
@@ -1080,50 +1085,63 @@ ModuleTranslation::getAliasScope(Operation &opInst,
 
 void ModuleTranslation::setAliasScopeMetadata(Operation *op,
                                               llvm::Instruction *inst) {
-  auto populateScopeMetadata = [this, op, inst](StringRef attrName,
-                                                StringRef llvmMetadataName) {
-    auto scopes = op->getAttrOfType<ArrayAttr>(attrName);
-    if (!scopes || scopes.empty())
+  auto populateScopeMetadata = [&](std::optional<ArrayAttr> scopeRefs,
+                                   unsigned kind) {
+    if (!scopeRefs || scopeRefs->empty())
       return;
     llvm::Module *module = inst->getModule();
     SmallVector<llvm::Metadata *> scopeMDs;
-    for (SymbolRefAttr scopeRef : scopes.getAsRange<SymbolRefAttr>())
-      scopeMDs.push_back(getAliasScope(*op, scopeRef));
-    llvm::MDNode *unionMD = llvm::MDNode::get(module->getContext(), scopeMDs);
-    inst->setMetadata(module->getMDKindID(llvmMetadataName), unionMD);
+    for (SymbolRefAttr scopeRef : scopeRefs->getAsRange<SymbolRefAttr>())
+      scopeMDs.push_back(getAliasScope(op, scopeRef));
+    llvm::MDNode *node = llvm::MDNode::get(module->getContext(), scopeMDs);
+    inst->setMetadata(kind, node);
   };
 
-  populateScopeMetadata(LLVMDialect::getAliasScopesAttrName(), "alias.scope");
-  populateScopeMetadata(LLVMDialect::getNoAliasScopesAttrName(), "noalias");
+  llvm::TypeSwitch<Operation *>(op)
+      .Case<LoadOp, StoreOp>([&](auto memOp) {
+        populateScopeMetadata(memOp.getAliasScopes(),
+                              llvm::LLVMContext::MD_alias_scope);
+        populateScopeMetadata(memOp.getNoaliasScopes(),
+                              llvm::LLVMContext::MD_noalias);
+      })
+      .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
 }
 
-llvm::MDNode *ModuleTranslation::getTBAANode(Operation &memOp,
+llvm::MDNode *ModuleTranslation::getTBAANode(Operation *op,
                                              SymbolRefAttr tagRef) const {
   StringAttr metadataName = tagRef.getRootReference();
   StringAttr tagName = tagRef.getLeafReference();
   auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
-      memOp.getParentOp(), metadataName);
+      op->getParentOp(), metadataName);
   Operation *tagOp = SymbolTable::lookupNearestSymbolFrom(metadataOp, tagName);
   return tbaaMetadataMapping.lookup(tagOp);
 }
 
 void ModuleTranslation::setTBAAMetadata(Operation *op,
                                         llvm::Instruction *inst) {
-  auto tbaa = op->getAttrOfType<ArrayAttr>(LLVMDialect::getTBAAAttrName());
-  if (!tbaa || tbaa.empty())
-    return;
-  // LLVM IR currently does not support attaching more than one
-  // TBAA access tag to a memory accessing instruction.
-  // It may be useful to support this in future, but for the time being
-  // just ignore the metadata if MLIR operation has multiple access tags.
-  if (tbaa.size() > 1) {
-    op->emitWarning() << "TBAA access tags were not translated, because LLVM "
-                         "IR only supports a single tag per instruction";
-    return;
-  }
-  SymbolRefAttr tagRef = tbaa[0].cast<SymbolRefAttr>();
-  llvm::MDNode *tagNode = getTBAANode(*op, tagRef);
-  inst->setMetadata(llvm::LLVMContext::MD_tbaa, tagNode);
+  auto populateTBAAMetadata = [&](std::optional<ArrayAttr> tagRefs) {
+    if (!tagRefs || tagRefs->empty())
+      return;
+
+    // LLVM IR currently does not support attaching more than one
+    // TBAA access tag to a memory accessing instruction.
+    // It may be useful to support this in future, but for the time being
+    // just ignore the metadata if MLIR operation has multiple access tags.
+    if (tagRefs->size() > 1) {
+      op->emitWarning() << "TBAA access tags were not translated, because LLVM "
+                           "IR only supports a single tag per instruction";
+      return;
+    }
+
+    SymbolRefAttr tagRef = (*tagRefs)[0].cast<SymbolRefAttr>();
+    llvm::MDNode *node = getTBAANode(op, tagRef);
+    inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
+  };
+
+  llvm::TypeSwitch<Operation *>(op)
+      .Case<LoadOp, StoreOp>(
+          [&](auto memOp) { populateTBAAMetadata(memOp.getTbaa()); })
+      .Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
 }
 
 LogicalResult ModuleTranslation::createTBAAMetadata() {

diff  --git a/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir b/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir
index 0513dc77ff869..a747d596bd968 100644
--- a/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/tbaa-invalid.mlir
@@ -8,7 +8,7 @@ module {
   llvm.func @tbaa(%arg0: !llvm.ptr) {
     %0 = llvm.mlir.constant(1 : i8) : i8
     // expected-error at below {{expected '@tbaa_tag_1' to specify a fully qualified reference}}
-    llvm.store %0, %arg0 {llvm.tbaa = [@tbaa_tag_1]} : i8, !llvm.ptr
+    llvm.store %0, %arg0 {tbaa = [@tbaa_tag_1]} : i8, !llvm.ptr
     llvm.return
   }
 }
@@ -17,8 +17,8 @@ module {
 
 llvm.func @tbaa(%arg0: !llvm.ptr) {
   %0 = llvm.mlir.constant(1 : i8) : i8
-  // expected-error at below {{attribute 'llvm.tbaa' failed to satisfy constraint: symbol ref array attribute}}
-  llvm.store %0, %arg0 {llvm.tbaa = ["sym"]} : i8, !llvm.ptr
+  // expected-error at below {{attribute 'tbaa' failed to satisfy constraint: symbol ref array attribute}}
+  llvm.store %0, %arg0 {tbaa = ["sym"]} : i8, !llvm.ptr
   llvm.return
 }
 
@@ -28,7 +28,7 @@ module {
   llvm.func @tbaa(%arg0: !llvm.ptr) {
     %0 = llvm.mlir.constant(1 : i8) : i8
     // expected-error at below {{expected '@metadata::@group1' to resolve to a llvm.tbaa_tag}}
-    llvm.store %0, %arg0 {llvm.tbaa = [@metadata::@group1]} : i8, !llvm.ptr
+    llvm.store %0, %arg0 {tbaa = [@metadata::@group1]} : i8, !llvm.ptr
     llvm.return
   }
   llvm.metadata @metadata {
@@ -42,7 +42,7 @@ module {
   llvm.func @tbaa(%arg0: !llvm.ptr) {
     %0 = llvm.mlir.constant(1 : i8) : i8
     // expected-error at below {{expected '@metadata::@sym' to be a valid reference}}
-    llvm.store %0, %arg0 {llvm.tbaa = [@metadata::@sym]} : i8, !llvm.ptr
+    llvm.store %0, %arg0 {tbaa = [@metadata::@sym]} : i8, !llvm.ptr
     llvm.return
   }
   llvm.metadata @metadata {
@@ -54,7 +54,7 @@ module {
 llvm.func @tbaa(%arg0: !llvm.ptr) {
   %0 = llvm.mlir.constant(1 : i8) : i8
   // expected-error at below {{expected '@tbaa::@sym' to reference a metadata op}}
-  llvm.store %0, %arg0 {llvm.tbaa = [@tbaa::@sym]} : i8, !llvm.ptr
+  llvm.store %0, %arg0 {tbaa = [@tbaa::@sym]} : i8, !llvm.ptr
   llvm.return
 }
 

diff  --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index 12a605a72d344..3286d3b904001 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -566,8 +566,6 @@ bb2:
 
 ; // -----
 
-; CHECK:      import-failure.ll
-; CHECK-SAME: warning: llvm.func does not support branch weights
 ; CHECK:      import-failure.ll:{{.*}} warning: unhandled function metadata: !0 = !{!"branch_weights", i32 64}
 define void @cond_br(i1 %arg) !prof !0 {
   ret void

diff  --git a/mlir/test/Target/LLVMIR/tbaa.mlir b/mlir/test/Target/LLVMIR/tbaa.mlir
index 26a96e47aee69..84f27be923268 100644
--- a/mlir/test/Target/LLVMIR/tbaa.mlir
+++ b/mlir/test/Target/LLVMIR/tbaa.mlir
@@ -16,11 +16,11 @@ module {
     %1 = llvm.mlir.constant(1 : i32) : i32
     %2 = llvm.getelementptr inbounds %arg1[%0, 1] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg2_t", (i64, i64)>
     // CHECK: load i64, ptr %{{.*}},{{.*}}!tbaa ![[LTAG:[0-9]*]]
-    %3 = llvm.load %2 {llvm.tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> i64
+    %3 = llvm.load %2 {tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> i64
     %4 = llvm.trunc %3 : i64 to i32
     %5 = llvm.getelementptr inbounds %arg0[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg1_t", (i32, i32)>
     // CHECK: store i32 %{{.*}}, ptr %{{.*}},{{.*}}!tbaa ![[STAG:[0-9]*]]
-    llvm.store %4, %5 {llvm.tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
+    llvm.store %4, %5 {tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
     llvm.return
   }
 }
@@ -60,11 +60,11 @@ module {
     %1 = llvm.mlir.constant(1 : i32) : i32
     %2 = llvm.getelementptr inbounds %arg1[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg2_t", (f32, f32)>
     // CHECK: load float, ptr %{{.*}},{{.*}}!tbaa ![[LTAG:[0-9]*]]
-    %3 = llvm.load %2 {llvm.tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> f32
+    %3 = llvm.load %2 {tbaa = [@__tbaa::@tbaa_tag_4]} : !llvm.ptr -> f32
     %4 = llvm.fptosi %3 : f32 to i32
     %5 = llvm.getelementptr inbounds %arg0[%0, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<"struct.agg1_t", (i32, i32)>
     // CHECK: store i32 %{{.*}}, ptr %{{.*}},{{.*}}!tbaa ![[STAG:[0-9]*]]
-    llvm.store %4, %5 {llvm.tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
+    llvm.store %4, %5 {tbaa = [@__tbaa::@tbaa_tag_7]} : i32, !llvm.ptr
     llvm.return
   }
 }


        


More information about the Mlir-commits mailing list