[Mlir-commits] [mlir] 4e39335 - [mlir] Add an AccessGroup attribute to load/store LLVM dialect ops and generate the access_group LLVM metadata.
Alex Zinenko
llvmlistbot at llvm.org
Thu Mar 4 09:17:30 PST 2021
Author: Arpith C. Jacob
Date: 2021-03-04T18:17:23+01:00
New Revision: 4e393350c547edb8144592168c3b176646747a98
URL: https://github.com/llvm/llvm-project/commit/4e393350c547edb8144592168c3b176646747a98
DIFF: https://github.com/llvm/llvm-project/commit/4e393350c547edb8144592168c3b176646747a98.diff
LOG: [mlir] Add an AccessGroup attribute to load/store LLVM dialect ops and generate the access_group LLVM metadata.
This also includes LLVM dialect ops created from intrinsics.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D97944
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Target/llvmir.mlir
mlir/test/mlir-tblgen/llvm-intrinsics.td
mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index f0b4c69b6ae6..8c83dbc0c9d1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -35,6 +35,7 @@ def LLVM_Dialect : Dialect {
static StringRef getLoopAttrName() { return "llvm.loop"; }
static StringRef getParallelAccessAttrName() { return "parallel_access"; }
static StringRef getLoopOptionsAttrName() { return "options"; }
+ static StringRef getAccessGroupsAttrName() { return "access_groups"; }
/// Verifies if the given string is a well-formed data layout descriptor.
/// Uses `reportError` to report errors.
@@ -247,7 +248,8 @@ def LLVM_IntrPatterns {
// `llvm::Intrinsic` enum; one usually wants these to be related.
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
- list<OpTrait> traits, int numResults>
+ list<OpTrait> traits, int numResults,
+ bit requiresAccessGroup = 0>
: LLVM_OpBase<dialect, opName, traits>,
Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
string resultPattern = !if(!gt(numResults, 1),
@@ -264,19 +266,21 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
overloadedOperands>.lst), ", ") # [{
});
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
- }] # !if(!gt(numResults, 0), "$res = ", "")
- # [{builder.CreateCall(fn, operands);
- }];
+ }] # [{auto *inst = builder.CreateCall(fn, operands);
+ }] # !if(!gt(requiresAccessGroup, 0),
+ "moduleTranslation.setAccessGroupsMetadata(op, inst);",
+ "(void) inst;")
+ # !if(!gt(numResults, 0), "$res = inst;", "");
}
// Base class for LLVM intrinsic operations, should not be used directly. Places
// the intrinsic into the LLVM dialect and prefixes its name with "intr.".
class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<OpTrait> traits,
- int numResults>
+ int numResults, bit requiresAccessGroup = 0>
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
overloadedResults, overloadedOperands, traits,
- numResults>;
+ numResults, requiresAccessGroup>;
// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 661f83c7c5f4..07583866621e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -287,6 +287,10 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
inst->setMetadata(module->getMDKindID("nontemporal"), metadata);
}
}];
+
+ code setAccessGroupsMetadataCode = [{
+ moduleTranslation.setAccessGroupsMetadata(op, inst);
+ }];
}
// Memory-related operations.
@@ -326,12 +330,13 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]>,
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
let arguments = (ins LLVM_PointerTo<LLVM_LoadableType>:$addr,
+ OptionalAttr<SymbolRefArrayAttr>:$access_groups,
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
UnitAttr:$nontemporal);
let results = (outs LLVM_Type:$res);
string llvmBuilder = [{
auto *inst = builder.CreateLoad($addr, $volatile_);
- }] # setAlignmentCode # setNonTemporalMetadataCode # [{
+ }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode # [{
$res = inst;
}];
let builders = [
@@ -346,16 +351,18 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
CArg<"bool", "false">:$isNonTemporal)>];
let parser = [{ return parseLoadOp(parser, result); }];
let printer = [{ printLoadOp(p, *this); }];
+ let verifier = [{ return ::verify(*this); }];
}
def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
let arguments = (ins LLVM_LoadableType:$value,
LLVM_PointerTo<LLVM_LoadableType>:$addr,
+ OptionalAttr<SymbolRefArrayAttr>:$access_groups,
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
UnitAttr:$nontemporal);
string llvmBuilder = [{
auto *inst = builder.CreateStore($value, $addr, $volatile_);
- }] # setAlignmentCode # setNonTemporalMetadataCode;
+ }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode;
let builders = [
OpBuilder<(ins "Value":$value, "Value":$addr,
CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
@@ -363,6 +370,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
];
let parser = [{ return parseStoreOp(parser, result); }];
let printer = [{ printStoreOp(p, *this); }];
+ let verifier = [{ return ::verify(*this); }];
}
// Casts.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 748268575f86..e046ada3b004 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -128,6 +128,9 @@ class ModuleTranslation {
"attempting to map loop options that was already mapped");
}
+ // Sets LLVM metadata for memory operations that are in a parallel loop.
+ void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst);
+
/// Converts the type from MLIR LLVM dialect to LLVM.
llvm::Type *convertType(Type type);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0538862b56e1..941792dc9c5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -404,6 +404,34 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
// Builder, printer and parser for for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
+static LogicalResult verifyAccessGroups(Operation *op) {
+ if (Attribute attribute =
+ op->getAttr(LLVMDialect::getAccessGroupsAttrName())) {
+ // The attribute is already verified to be a symbol ref array attribute via
+ // a constraint in the operation definition.
+ for (SymbolRefAttr accessGroupRef :
+ attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
+ StringRef metadataName = accessGroupRef.getRootReference();
+ auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+ op->getParentOp(), metadataName);
+ if (!metadataOp)
+ return op->emitOpError() << "expected '" << accessGroupRef
+ << "' to reference a metadata op";
+ StringRef accessGroupName = accessGroupRef.getLeafReference();
+ Operation *accessGroupOp =
+ SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
+ if (!accessGroupOp)
+ return op->emitOpError() << "expected '" << accessGroupRef
+ << "' to reference an access_group op";
+ }
+ }
+ return success();
+}
+
+static LogicalResult verify(LoadOp op) {
+ return verifyAccessGroups(op.getOperation());
+}
+
void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal) {
@@ -462,6 +490,10 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
// Builder, printer and parser for LLVM::StoreOp.
//===----------------------------------------------------------------------===//
+static LogicalResult verify(StoreOp op) {
+ return verifyAccessGroups(op.getOperation());
+}
+
void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3a03b278e264..891f30b95b66 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -656,6 +656,27 @@ LogicalResult ModuleTranslation::createAccessGroupMetadata() {
return success();
}
+void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
+ llvm::Instruction *inst) {
+ auto accessGroups =
+ op->getAttrOfType<ArrayAttr>(LLVMDialect::getAccessGroupsAttrName());
+ if (accessGroups && !accessGroups.empty()) {
+ llvm::Module *module = inst->getModule();
+ SmallVector<llvm::Metadata *> metadatas;
+ for (SymbolRefAttr accessGroupRef :
+ accessGroups.getAsRange<SymbolRefAttr>())
+ metadatas.push_back(getAccessGroup(*op, accessGroupRef));
+
+ llvm::MDNode *unionMD = nullptr;
+ if (metadatas.size() == 1)
+ unionMD = llvm::cast<llvm::MDNode>(metadatas.front());
+ else if (metadatas.size() >= 2)
+ unionMD = llvm::MDNode::get(module->getContext(), metadatas);
+
+ inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD);
+ }
+}
+
llvm::Type *ModuleTranslation::convertType(Type type) {
return typeTranslator.translateType(type);
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 6a45b1f67e71..e83706b88a9e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -796,3 +796,39 @@ module {
llvm.return
}
}
+
+// -----
+
+module {
+ llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
+ // expected-error at below {{attribute 'access_groups' failed to satisfy constraint: symbol ref array attribute}}
+ %0 = llvm.load %arg0 { "access_groups" = "test" } : !llvm.ptr<i32>
+ llvm.return
+ }
+}
+
+// -----
+
+module {
+ llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
+ // expected-error at below {{expected '@func1' to reference a metadata op}}
+ %0 = llvm.load %arg0 { "access_groups" = [@func1] } : !llvm.ptr<i32>
+ llvm.return
+ }
+ llvm.func @func1() {
+ llvm.return
+ }
+}
+
+// -----
+
+module {
+ llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
+ // expected-error at below {{expected '@metadata' to reference an access_group op}}
+ %0 = llvm.load %arg0 { "access_groups" = [@metadata] } : !llvm.ptr<i32>
+ llvm.return
+ }
+ llvm.metadata @metadata {
+ llvm.return
+ }
+}
diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 1109345231f2..85d0d16737b4 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1483,6 +1483,7 @@ module {
llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
^bb4:
%3 = llvm.add %1, %arg2 : i32
+ // CHECK: = load i32, i32* %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE:[0-9]+]]
%5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr<i32>
// CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]]
llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
@@ -1504,3 +1505,4 @@ module {
// CHECK: ![[UNROLL_DISABLE_NODE]] = !{!"llvm.loop.unroll.disable", i1 true}
// CHECK: ![[LICM_DISABLE_NODE]] = !{!"llvm.licm.disable", i1 true}
// CHECK: ![[INTERLEAVE_NODE]] = !{!"llvm.loop.interleave.count", i32 1}
+// CHECK: ![[ACCESS_GROUPS_NODE]] = !{![[GROUP_NODE1]], ![[GROUP_NODE2]]}
diff --git a/mlir/test/mlir-tblgen/llvm-intrinsics.td b/mlir/test/mlir-tblgen/llvm-intrinsics.td
index 511f062e8e4a..a6932b381284 100644
--- a/mlir/test/mlir-tblgen/llvm-intrinsics.td
+++ b/mlir/test/mlir-tblgen/llvm-intrinsics.td
@@ -23,11 +23,33 @@
// It has no side effects.
// CHECK: [NoSideEffect]
// It has a result.
-// CHECK: 1>
+// CHECK: 1,
+// It does not require an access group.
+// CHECK: 0>
// CHECK: Arguments<(ins LLVM_Type, LLVM_Type
//---------------------------------------------------------------------------//
+// This checks that we can define an op that takes in an access group metadata.
+//
+// RUN: cat %S/../../../llvm/include/llvm/IR/Intrinsics.td \
+// RUN: | grep -v "llvm/IR/Intrinsics" \
+// RUN: | mlir-tblgen -gen-llvmir-intrinsics -I %S/../../../llvm/include/ --llvmir-intrinsics-filter=ptrmask --llvmir-intrinsics-access-group-regexp=ptrmask \
+// RUN: | FileCheck --check-prefix=GROUPS %s
+
+// GROUPS-LABEL: def LLVM_ptrmask
+// GROUPS: LLVM_IntrOp<"ptrmask
+// It has no side effects.
+// GROUPS: [NoSideEffect]
+// It has a result.
+// GROUPS: 1,
+// It requires generation of an access group LLVM metadata.
+// GROUPS: 1>
+// It has an access group attribute.
+// GROUPS: OptionalAttr<SymbolRefArrayAttr>:$access_groups
+
+//---------------------------------------------------------------------------//
+
// This checks that the ODS we produce can be consumed by MLIR tablegen. We only
// make sure the entire process does not fail and produces some C++. The shape
// of this C++ code is tested by ODS tests.
diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
index 72554a1f0c24..dc76962b4d28 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
@@ -17,6 +17,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MachineValueType.h"
#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/Regex.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
@@ -37,6 +38,12 @@ static llvm::cl::opt<std::string>
"are planning to emit"),
llvm::cl::init("LLVM_IntrOp"), llvm::cl::cat(IntrinsicGenCat));
+static llvm::cl::opt<std::string> accessGroupRegexp(
+ "llvmir-intrinsics-access-group-regexp",
+ llvm::cl::desc("Mark intrinsics that match the specified "
+ "regexp as taking an access group metadata"),
+ llvm::cl::cat(IntrinsicGenCat));
+
// Used to represent the indices of overloadable operands/results.
using IndicesTy = llvm::SmallBitVector;
@@ -185,6 +192,10 @@ void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
LLVMIntrinsic intr(record);
+ llvm::Regex accessGroupMatcher(accessGroupRegexp);
+ bool requiresAccessGroup =
+ !accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
+
// Prepare strings for traits, if any.
llvm::SmallVector<llvm::StringRef, 2> traits;
if (intr.isCommutative())
@@ -195,6 +206,8 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
// Prepare strings for operands.
llvm::SmallVector<llvm::StringRef, 8> operands(intr.getNumOperands(),
"LLVM_Type");
+ if (requiresAccessGroup)
+ operands.push_back("OptionalAttr<SymbolRefArrayAttr>:$access_groups");
// Emit the definition.
os << "def LLVM_" << intr.getProperRecordName() << " : " << opBaseClass
@@ -204,7 +217,8 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
os << ", ";
printBracketedRange(traits, os);
- os << ", " << intr.getNumResults() << ">, Arguments<(ins"
+ os << ", " << intr.getNumResults() << ", "
+ << (requiresAccessGroup ? "1" : "0") << ">, Arguments<(ins"
<< (operands.empty() ? "" : " ");
llvm::interleaveComma(operands, os);
os << ")>;\n\n";
More information about the Mlir-commits
mailing list