[Mlir-commits] [mlir] 240c6f2 - [mlir][llvm] Improve LoadOp and StoreOp import.

Tobias Gysi llvmlistbot at llvm.org
Sun Feb 12 23:56:43 PST 2023


Author: Tobias Gysi
Date: 2023-02-13T08:42:43+01:00
New Revision: 240c6f2643991b1fe69c03ddcc8be0104e699327

URL: https://github.com/llvm/llvm-project/commit/240c6f2643991b1fe69c03ddcc8be0104e699327
DIFF: https://github.com/llvm/llvm-project/commit/240c6f2643991b1fe69c03ddcc8be0104e699327.diff

LOG: [mlir][llvm] Improve LoadOp and StoreOp import.

The revision supports importing the volatile keyword and nontemporal
metadata for the LoadOp and StoreOp. Additionally, it updates the
builders and uses an assembly format for printing and parsing.

The operation type still requires custom parse and print methods
due to the current handling of typed and opaque pointers.

Reviewed By: Dinistro

Differential Revision: https://reviews.llvm.org/D143714

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Target/LLVMIR/Import/instructions.ll
    mlir/test/Target/LLVMIR/Import/metadata-loop.ll

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8e4b8340a52b..d79e511b210d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -189,11 +189,10 @@ class MemoryOpBase {
   }];
   code setNonTemporalMetadataCode = [{
     if ($nontemporal) {
-      llvm::Module *module = builder.GetInsertBlock()->getModule();
       llvm::MDNode *metadata = llvm::MDNode::get(
           inst->getContext(), llvm::ConstantAsMetadata::get(
               builder.getInt32(1)));
-      inst->setMetadata(module->getMDKindID("nontemporal"), metadata);
+      inst->setMetadata(llvm::LLVMContext::MD_nontemporal, metadata);
     }
   }];
   code setAccessGroupsMetadataCode = [{
@@ -355,6 +354,10 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
                    UnitAttr:$nontemporal);
   let results = (outs LLVM_LoadableType:$res);
   string llvmInstName = "Load";
+  let assemblyFormat = [{
+    (`volatile` $volatile_^)? $addr attr-dict `:`
+    custom<LoadType>(type($addr), type($res))
+  }];
   string llvmBuilder = [{
     auto *inst = builder.CreateLoad($_resultType, $addr, $volatile_);
   }] # setAlignmentCode
@@ -365,9 +368,12 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
      # [{
     $res = inst;
   }];
-  // FIXME: Import attributes.
   string mlirBuilder = [{
-    $res = $_builder.create<LLVM::LoadOp>($_location, $_resultType, $addr);
+    auto *loadInst = cast<llvm::LoadInst>(inst);
+    unsigned alignment = loadInst->getAlign().value();
+    $res = $_builder.create<LLVM::LoadOp>($_location, $_resultType, $addr,
+        alignment, loadInst->isVolatile(),
+        loadInst->hasMetadata(llvm::LLVMContext::MD_nontemporal));
   }];
   let builders = [
     OpBuilder<(ins "Value":$addr, CArg<"unsigned", "0">:$alignment,
@@ -378,9 +384,10 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
                      "when the pointer type is opaque");
       build($_builder, $_state, type, addr, alignment, isVolatile, isNonTemporal);
     }]>,
-    OpBuilder<(ins "Type":$t, "Value":$addr,
+    OpBuilder<(ins "Type":$type, "Value":$addr,
       CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
-      CArg<"bool", "false">:$isNonTemporal)>,];
+      CArg<"bool", "false">:$isNonTemporal)>
+  ];
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
@@ -395,6 +402,10 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpBase {
                    OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
                    UnitAttr:$nontemporal);
   string llvmInstName = "Store";
+  let assemblyFormat = [{
+    (`volatile` $volatile_^)? $value `,` $addr attr-dict `:`
+    custom<StoreType>(type($value), type($addr))
+  }];
   string llvmBuilder = [{
     auto *inst = builder.CreateStore($value, $addr, $volatile_);
   }] # setAlignmentCode
@@ -402,16 +413,18 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpBase {
      # setAccessGroupsMetadataCode
      # setAliasScopeMetadataCode
      # setTBAAMetadataCode;
-  // FIXME: Import attributes.
   string mlirBuilder = [{
-    $_op = $_builder.create<LLVM::StoreOp>($_location, $value, $addr);
+    auto *storeInst = cast<llvm::StoreInst>(inst);
+    unsigned alignment = storeInst->getAlign().value();
+    $_op = $_builder.create<LLVM::StoreOp>($_location, $value, $addr,
+        alignment, storeInst->isVolatile(),
+        storeInst->hasMetadata(llvm::LLVMContext::MD_nontemporal));
   }];
   let builders = [
     OpBuilder<(ins "Value":$value, "Value":$addr,
       CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
       CArg<"bool", "false">:$isNonTemporal)>
-    ];
-  let hasCustomAssemblyFormat = 1;
+  ];
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 1dc501f0c402..c19d268d745f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -44,8 +44,6 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage;
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
 
-static constexpr const char kVolatileAttrName[] = "volatile_";
-static constexpr const char kNonTemporalAttrName[] = "nontemporal";
 static constexpr const char kElemTypeAttrName[] = "elem_type";
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
@@ -665,7 +663,7 @@ Type LLVM::GEPOp::getSourceElementType() {
 }
 
 //===----------------------------------------------------------------------===//
-// Builder, printer and parser for for LLVM::LoadOp.
+// LoadOp
 //===----------------------------------------------------------------------===//
 
 /// Verifies the given array attribute contains symbol references and checks the
@@ -759,29 +757,13 @@ LogicalResult verifyMemOpMetadata(OpTy memOp) {
 
 LogicalResult LoadOp::verify() { return verifyMemOpMetadata(*this); }
 
-void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
+void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
                    Value addr, unsigned alignment, bool isVolatile,
                    bool isNonTemporal) {
-  result.addOperands(addr);
-  result.addTypes(t);
-  if (isVolatile)
-    result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
-  if (isNonTemporal)
-    result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
-  if (alignment != 0)
-    result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
-}
-
-void LoadOp::print(OpAsmPrinter &p) {
-  p << ' ';
-  if (getVolatile_())
-    p << "volatile ";
-  p << getAddr();
-  p.printOptionalAttrDict((*this)->getAttrs(),
-                          {kVolatileAttrName, kElemTypeAttrName});
-  p << " : " << getAddr().getType();
-  if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
-    p << " -> " << getType();
+  build(builder, state, type, addr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+        alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
+        isNonTemporal);
 }
 
 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
@@ -797,105 +779,85 @@ getLoadStoreElementType(OpAsmParser &parser, Type type, SMLoc trailingTypeLoc) {
   return llvmTy.getElementType();
 }
 
-// <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
-//                 (`->` type)?
-ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand addr;
-  Type type;
+/// Parses the LoadOp type either using the typed or opaque pointer format.
+// TODO: Drop once the typed pointer assembly format is not needed anymore.
+static ParseResult parseLoadType(OpAsmParser &parser, Type &type,
+                                 Type &elementType) {
   SMLoc trailingTypeLoc;
-
-  if (succeeded(parser.parseOptionalKeyword("volatile")))
-    result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
-
-  if (parser.parseOperand(addr) ||
-      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
-      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
-      parser.resolveOperand(addr, type, result.operands))
+  if (parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
     return failure();
 
-  std::optional<Type> elemTy =
+  std::optional<Type> pointerElementType =
       getLoadStoreElementType(parser, type, trailingTypeLoc);
-  if (!elemTy)
+  if (!pointerElementType)
     return failure();
-  if (*elemTy) {
-    result.addTypes(*elemTy);
+  if (*pointerElementType) {
+    elementType = *pointerElementType;
     return success();
   }
 
-  Type trailingType;
-  if (parser.parseArrow() || parser.parseType(trailingType))
+  if (parser.parseArrow() || parser.parseType(elementType))
     return failure();
-  result.addTypes(trailingType);
   return success();
 }
 
+/// Prints the LoadOp type either using the typed or opaque pointer format.
+// TODO: Drop once the typed pointer assembly format is not needed anymore.
+static void printLoadType(OpAsmPrinter &printer, Operation *op, Type type,
+                          Type elementType) {
+  printer << type;
+  auto pointerType = cast<LLVMPointerType>(type);
+  if (pointerType.isOpaque())
+    printer << " -> " << elementType;
+}
+
 //===----------------------------------------------------------------------===//
-// Builder, printer and parser for LLVM::StoreOp.
+// StoreOp
 //===----------------------------------------------------------------------===//
 
 LogicalResult StoreOp::verify() { return verifyMemOpMetadata(*this); }
 
-void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
+void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
                     Value addr, unsigned alignment, bool isVolatile,
                     bool isNonTemporal) {
-  result.addOperands({value, addr});
-  result.addTypes({});
-  if (isVolatile)
-    result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
-  if (isNonTemporal)
-    result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
-  if (alignment != 0)
-    result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
+  build(builder, state, value, addr, /*access_groups=*/nullptr,
+        /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+        alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
+        isNonTemporal);
 }
 
-void StoreOp::print(OpAsmPrinter &p) {
-  p << ' ';
-  if (getVolatile_())
-    p << "volatile ";
-  p << getValue() << ", " << getAddr();
-  p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName});
-  p << " : ";
-  if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
-    p << getValue().getType() << ", ";
-  p << getAddr().getType();
-}
-
-// <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
-//                 attribute-dict? `:` type (`,` type)?
-ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand addr, value;
-  Type type;
+/// Parses the StoreOp type either using the typed or opaque pointer format.
+// TODO: Drop once the typed pointer assembly format is not needed anymore.
+static ParseResult parseStoreType(OpAsmParser &parser, Type &elementType,
+                                  Type &type) {
   SMLoc trailingTypeLoc;
-
-  if (succeeded(parser.parseOptionalKeyword("volatile")))
-    result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
-
-  if (parser.parseOperand(value) || parser.parseComma() ||
-      parser.parseOperand(addr) ||
-      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
-      parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
+  if (parser.getCurrentLocation(&trailingTypeLoc) ||
+      parser.parseType(elementType))
     return failure();
 
-  Type operandType;
-  if (succeeded(parser.parseOptionalComma())) {
-    operandType = type;
-    if (parser.parseType(type))
-      return failure();
-  } else {
-    std::optional<Type> maybeOperandType =
-        getLoadStoreElementType(parser, type, trailingTypeLoc);
-    if (!maybeOperandType)
-      return failure();
-    operandType = *maybeOperandType;
-  }
+  if (succeeded(parser.parseOptionalComma()))
+    return parser.parseType(type);
 
-  if (parser.resolveOperand(value, operandType, result.operands) ||
-      parser.resolveOperand(addr, type, result.operands))
+  // Extract the element type from the pointer type.
+  type = elementType;
+  std::optional<Type> pointerElementType =
+      getLoadStoreElementType(parser, type, trailingTypeLoc);
+  if (!pointerElementType)
     return failure();
-
+  elementType = *pointerElementType;
   return success();
 }
 
+/// Prints the StoreOp type either using the typed or opaque pointer format.
+// TODO: Drop once the typed pointer assembly format is not needed anymore.
+static void printStoreType(OpAsmPrinter &printer, Operation *op,
+                           Type elementType, Type type) {
+  auto pointerType = cast<LLVMPointerType>(type);
+  if (pointerType.isOpaque())
+    printer << elementType << ", ";
+  printer << type;
+}
+
 //===----------------------------------------------------------------------===//
 // CallOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 14dbf07371bf..cd54411de4ef 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -251,7 +251,7 @@ define void @integer_arith(i32 %arg1, i32 %arg2, i64 %arg3, i64 %arg4) {
 ; CHECK-SAME:  %[[VEC:[a-zA-Z0-9]+]]
 ; CHECK-SAME:  %[[IDX:[a-zA-Z0-9]+]]
 define half @extract_element(ptr %vec, i32 %idx) {
-  ; CHECK:  %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr -> vector<4xf16>
+  ; CHECK:  %[[V1:.+]] = llvm.load %[[VEC]] {{.*}} : !llvm.ptr -> vector<4xf16>
   ; CHECK:  %[[V2:.+]] = llvm.extractelement %[[V1]][%[[IDX]] : i32] : vector<4xf16>
   ; CHECK:  llvm.return %[[V2]]
   %1 = load <4 x half>, ptr %vec
@@ -266,7 +266,7 @@ define half @extract_element(ptr %vec, i32 %idx) {
 ; CHECK-SAME:  %[[VAL:[a-zA-Z0-9]+]]
 ; CHECK-SAME:  %[[IDX:[a-zA-Z0-9]+]]
 define <4 x half> @insert_element(ptr %vec, half %val, i32 %idx) {
-  ; CHECK:  %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr -> vector<4xf16>
+  ; CHECK:  %[[V1:.+]] = llvm.load %[[VEC]] {{.*}} : !llvm.ptr -> vector<4xf16>
   ; CHECK:  %[[V2:.+]] = llvm.insertelement %[[VAL]], %[[V1]][%[[IDX]] : i32] : vector<4xf16>
   ; CHECK:  llvm.return %[[V2]]
   %1 = load <4 x half>, ptr %vec
@@ -352,13 +352,20 @@ define ptr @alloca(i64 %size) {
 ; CHECK-LABEL: @load_store
 ; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
 define void @load_store(ptr %ptr) {
-  ; CHECK:  %[[V1:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64
-  ; CHECK:  llvm.store %[[V1]], %[[PTR]] : f64, !llvm.ptr
+  ; CHECK:  %[[V1:[0-9]+]] = llvm.load %[[PTR]] {alignment = 8 : i64} : !llvm.ptr -> f64
+  ; CHECK:  %[[V2:[0-9]+]] = llvm.load volatile %[[PTR]] {alignment = 16 : i64, nontemporal} : !llvm.ptr -> f64
   %1 = load double, ptr %ptr
+  %2 = load volatile double, ptr %ptr, align 16, !nontemporal !0
+
+  ; CHECK:  llvm.store %[[V1]], %[[PTR]] {alignment = 8 : i64} : f64, !llvm.ptr
+  ; CHECK:  llvm.store volatile %[[V2]], %[[PTR]] {alignment = 16 : i64, nontemporal} : f64, !llvm.ptr
   store double %1, ptr %ptr
+  store volatile double %2, ptr %ptr, align 16, !nontemporal !0
   ret void
 }
 
+!0 = !{i32 1}
+
 ; // -----
 
 ; CHECK-LABEL: @atomic_rmw

diff  --git a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll
index 1ddd5e2b8d20..9aecb13c0095 100644
--- a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll
+++ b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll
@@ -8,13 +8,12 @@
 ; CHECK: }
 
 ; CHECK-LABEL: llvm.func @access_group
-; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
 define void @access_group(ptr %arg1) {
-  ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]]}
+  ; CHECK:  access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]]
   %1 = load i32, ptr %arg1, !llvm.access.group !0
-  ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]]}
+  ; CHECK:  access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]]
   %2 = load i32, ptr %arg1, !llvm.access.group !1
-  ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP3]]]}
+  ; CHECK:  access_groups = [@__llvm_global_metadata::@[[$GROUP3]]]
   %3 = load i32, ptr %arg1, !llvm.access.group !2
   ret void
 }


        


More information about the Mlir-commits mailing list