[Mlir-commits] [mlir] 3661827 - [MLIR][LLVMDialect] Added volatile and nontemporal attributes to load/store
George Mitenkov
llvmlistbot at llvm.org
Mon Jul 27 00:56:38 PDT 2020
Author: George Mitenkov
Date: 2020-07-27T10:55:56+03:00
New Revision: 36618274f3e2cdea98cd8202204b8ad2913aae46
URL: https://github.com/llvm/llvm-project/commit/36618274f3e2cdea98cd8202204b8ad2913aae46
DIFF: https://github.com/llvm/llvm-project/commit/36618274f3e2cdea98cd8202204b8ad2913aae46.diff
LOG: [MLIR][LLVMDialect] Added volatile and nontemporal attributes to load/store
This patch introduces 2 new optional attributes to `llvm.load`
and `llvm.store` ops: `volatile` and `nontemporal`. These attributes
are translated into proper LLVM as a `volatile` marker and a metadata node
respectively. They are also helpful with SPIR-V to LLVM dialect conversion
since they are the mappings for `Volatile` and `NonTemporal` Memory Operands.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D84396
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Target/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 5322e243427a..29d7fd930030 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -235,6 +235,19 @@ class MemoryOpWithAlignmentBase {
}];
}
+// Code definition that is used for nontemporal metadata creation.
+class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
+ code setNonTemporalMetadataCode = [{
+ if ($nontemporal) {
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ llvm::MDNode *metadata = llvm::MDNode::get(
+ inst->getContext(), llvm::ConstantAsMetadata::get(
+ builder.getInt32(1)));
+ inst->setMetadata(module->getMDKindID("nontemporal"), metadata);
+ }
+ }];
+}
+
// Memory-related operations.
def LLVM_AllocaOp :
MemoryOpWithAlignmentBase,
@@ -266,52 +279,49 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
}];
}
def LLVM_LoadOp :
- MemoryOpWithAlignmentBase,
+ MemoryOpWithAlignmentAndAttributes,
LLVM_OneResultOp<"load">,
- Arguments<(ins LLVM_Type:$addr, OptionalAttr<I64Attr>:$alignment)> {
+ Arguments<(ins LLVM_Type:$addr,
+ OptionalAttr<I64Attr>:$alignment,
+ UnitAttr:$volatile_,
+ UnitAttr:$nontemporal)> {
string llvmBuilder = [{
- auto *inst = builder.CreateLoad($addr);
- }] # setAlignmentCode # [{
+ auto *inst = builder.CreateLoad($addr, $volatile_);
+ }] # setAlignmentCode # setNonTemporalMetadataCode # [{
$res = inst;
}];
let builders = [OpBuilder<
- "OpBuilder &b, OperationState &result, Value addr, unsigned alignment = 0",
+ "OpBuilder &b, OperationState &result, Value addr, "
+ "unsigned alignment = 0, bool isVolatile = false, "
+ "bool isNonTemporal = false",
[{
auto type = addr.getType().cast<LLVM::LLVMType>().getPointerElementTy();
- build(b, result, type, addr, alignment);
+ build(b, result, type, addr, alignment, isVolatile, isNonTemporal);
}]>,
OpBuilder<
"OpBuilder &b, OperationState &result, Type t, Value addr, "
- "unsigned alignment = 0",
- [{
- if (alignment == 0)
- return build(b, result, t, addr, IntegerAttr());
- build(b, result, t, addr, b.getI64IntegerAttr(alignment));
- }]>];
+ "unsigned alignment = 0, bool isVolatile = false, "
+ "bool isNonTemporal = false">];
let parser = [{ return parseLoadOp(parser, result); }];
let printer = [{ printLoadOp(p, *this); }];
let verifier = alignmentVerifierCode;
}
def LLVM_StoreOp :
- MemoryOpWithAlignmentBase,
+ MemoryOpWithAlignmentAndAttributes,
LLVM_ZeroResultOp<"store">,
Arguments<(ins LLVM_Type:$value,
LLVM_Type:$addr,
- OptionalAttr<I64Attr>:$alignment)> {
+ OptionalAttr<I64Attr>:$alignment,
+ UnitAttr:$volatile_,
+ UnitAttr:$nontemporal)> {
string llvmBuilder = [{
- auto *inst = builder.CreateStore($value, $addr);
- }] # setAlignmentCode;
- let builders = [
- OpBuilder<
+ auto *inst = builder.CreateStore($value, $addr, $volatile_);
+ }] # setAlignmentCode # setNonTemporalMetadataCode;
+ let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, Value value, Value addr, "
- "unsigned alignment = 0",
- [{
- if (alignment == 0)
- return build(b, result, ArrayRef<Type>{}, value, addr, IntegerAttr());
- build(b, result, ArrayRef<Type>{}, value, addr,
- b.getI64IntegerAttr(alignment));
- }]
- >];
+ "unsigned alignment = 0, bool isVolatile = false, "
+ "bool isNonTemporal = false">
+ ];
let parser = [{ return parseStoreOp(parser, result); }];
let printer = [{ printStoreOp(p, *this); }];
let verifier = alignmentVerifierCode;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f448427099a4..cf7a5d926528 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -31,6 +31,9 @@
using namespace mlir;
using namespace mlir::LLVM;
+static constexpr const char kVolatileAttrName[] = "volatile_";
+static constexpr const char kNonTemporalAttrName[] = "nontemporal";
+
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
//===----------------------------------------------------------------------===//
@@ -178,12 +181,28 @@ CondBrOp::getMutableSuccessorOperands(unsigned index) {
}
//===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::LoadOp.
+// Builder, printer and parser for for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
+void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
+ Value addr, unsigned alignment, bool isVolatile,
+ bool isNonTemporal) {
+ result.addOperands(addr);
+ result.addTypes(t);
+ if (isVolatile)
+ result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
+ if (isNonTemporal)
+ result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
+ if (alignment != 0)
+ result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
+}
+
static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
- p << op.getOperationName() << ' ' << op.addr();
- p.printOptionalAttrDict(op.getAttrs());
+ p << op.getOperationName() << ' ';
+ if (op.volatile_())
+ p << "volatile ";
+ p << op.addr();
+ p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName});
p << " : " << op.addr().getType();
}
@@ -201,12 +220,15 @@ static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
return llvmTy.getPointerElementTy();
}
-// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
+// <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType addr;
Type type;
llvm::SMLoc trailingTypeLoc;
+ if (succeeded(parser.parseOptionalKeyword("volatile")))
+ result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
+
if (parser.parseOperand(addr) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
@@ -220,21 +242,41 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
}
//===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::StoreOp.
+// Builder, printer and parser for LLVM::StoreOp.
//===----------------------------------------------------------------------===//
+void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
+ Value addr, unsigned alignment, bool isVolatile,
+ bool isNonTemporal) {
+ result.addOperands({value, addr});
+ result.addTypes(ArrayRef<Type>{});
+ if (isVolatile)
+ result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
+ if (isNonTemporal)
+ result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
+ if (alignment != 0)
+ result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
+}
+
static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
- p << op.getOperationName() << ' ' << op.value() << ", " << op.addr();
- p.printOptionalAttrDict(op.getAttrs());
+ p << op.getOperationName() << ' ';
+ if (op.volatile_())
+ p << "volatile ";
+ p << op.value() << ", " << op.addr();
+ p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName});
p << " : " << op.addr().getType();
}
-// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
+// <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
+// attribute-dict? `:` type
static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType addr, value;
Type type;
llvm::SMLoc trailingTypeLoc;
+ if (succeeded(parser.parseOptionalKeyword("volatile")))
+ result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
+
if (parser.parseOperand(value) || parser.parseComma() ||
parser.parseOperand(addr) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 954b5b134541..d6180cbf1849 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1266,3 +1266,32 @@ llvm.func @cond_br_weights(%cond : !llvm.i1, %arg0 : !llvm.i32, %arg1 : !llvm.i
}
// CHECK: ![[NODE]] = !{!"branch_weights", i32 5, i32 10}
+
+// -----
+
+llvm.func @volatile_store_and_load() {
+ %val = llvm.mlir.constant(5 : i32) : !llvm.i32
+ %size = llvm.mlir.constant(1 : i64) : !llvm.i64
+ %0 = llvm.alloca %size x !llvm.i32 : (!llvm.i64) -> (!llvm<"i32*">)
+ // CHECK: store volatile i32 5, i32* %{{.*}}
+ llvm.store volatile %val, %0 : !llvm<"i32*">
+ // CHECK: %{{.*}} = load volatile i32, i32* %{{.*}}
+ %1 = llvm.load volatile %0: !llvm<"i32*">
+ llvm.return
+}
+
+// -----
+
+// Check that nontemporal attribute is exported as metadata node.
+llvm.func @nontemoral_store_and_load() {
+ %val = llvm.mlir.constant(5 : i32) : !llvm.i32
+ %size = llvm.mlir.constant(1 : i64) : !llvm.i64
+ %0 = llvm.alloca %size x !llvm.i32 : (!llvm.i64) -> (!llvm<"i32*">)
+ // CHECK: !nontemporal ![[NODE:[0-9]+]]
+ llvm.store %val, %0 {nontemporal} : !llvm<"i32*">
+ // CHECK: !nontemporal ![[NODE]]
+ %1 = llvm.load %0 {nontemporal} : !llvm<"i32*">
+ llvm.return
+}
+
+// CHECK: ![[NODE]] = !{i32 1}
More information about the Mlir-commits
mailing list