[Mlir-commits] [mlir] 4a2930f - [mlir] Add loop codegen options to some LLVM dialect ops.
Alex Zinenko
llvmlistbot at llvm.org
Thu Mar 4 00:02:07 PST 2021
Author: Arpith C. Jacob
Date: 2021-03-04T09:01:57+01:00
New Revision: 4a2930f4950dbeacaf4da6fe9445215934296cce
URL: https://github.com/llvm/llvm-project/commit/4a2930f4950dbeacaf4da6fe9445215934296cce
DIFF: https://github.com/llvm/llvm-project/commit/4a2930f4950dbeacaf4da6fe9445215934296cce.diff
LOG: [mlir] Add loop codegen options to some LLVM dialect ops.
Add a Loop Option attribute and generate llvm metadata attached to
branch instructions to control code generation.
Reviewed By: ftynse, mehdi_amini
Differential Revision: https://reviews.llvm.org/D96820
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Target/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 22ff1517f77b..ac5b5907bf82 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -48,6 +48,7 @@ namespace detail {
struct LLVMTypeStorage;
struct LLVMDialectImpl;
struct BitmaskEnumStorage;
+struct LoopOptionAttrStorage;
} // namespace detail
/// An attribute that specifies LLVM instruction fastmath flags.
@@ -64,6 +65,38 @@ class FMFAttr : public Attribute::AttrBase<FMFAttr, Attribute,
static Attribute parse(DialectAsmParser &parser);
};
+/// An attribute that specifies LLVM loop codegen options.
+class LoopOptionAttr
+ : public Attribute::AttrBase<LoopOptionAttr, Attribute,
+ detail::LoopOptionAttrStorage> {
+public:
+ using Base::Base;
+
+ /// Specifies the llvm.loop.unroll.disable metadata.
+ static LoopOptionAttr getDisableUnroll(MLIRContext *context,
+ bool disable = true);
+
+ /// Specifies the llvm.licm.disable metadata.
+ static LoopOptionAttr getDisableLICM(MLIRContext *context,
+ bool disable = true);
+
+ /// Specifies the llvm.loop.interleave.count metadata.
+ static LoopOptionAttr getInterleaveCount(MLIRContext *context, int32_t count);
+
+ /// Returns the loop option, e.g. parallel_access.
+ LoopOptionCase getCase() const;
+
+ /// Returns if the loop option is activated. Only valid for boolean options.
+ bool getBool() const;
+
+ /// Returns the integer value associated with a loop option. Only valid for
+ /// integer options.
+ int32_t getInt() const;
+
+ void print(DialectAsmPrinter &p) const;
+ static Attribute parse(DialectAsmParser &parser);
+};
+
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 541f7ebfadfa..f0b4c69b6ae6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -32,6 +32,9 @@ def LLVM_Dialect : Dialect {
static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; }
static StringRef getAlignAttrName() { return "llvm.align"; }
static StringRef getNoAliasAttrName() { return "llvm.noalias"; }
+ static StringRef getLoopAttrName() { return "llvm.loop"; }
+ static StringRef getParallelAccessAttrName() { return "parallel_access"; }
+ static StringRef getLoopOptionsAttrName() { return "options"; }
/// Verifies if the given string is a well-formed data layout descriptor.
/// Uses `reportError` to report errors.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 076b8ed96e4e..956c5551ab78 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -47,6 +47,18 @@ def LLVM_FMFAttr : DialectAttr<
"::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())";
}
+def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 1>;
+def LOptDisableLICM : I32EnumAttrCase<"disable_licm", 2>;
+def LOptInterleaveCount : I32EnumAttrCase<"interleave_count", 3>;
+
+def LoopOptionCase : I32EnumAttr<
+ "LoopOptionCase",
+ "LLVM loop option",
+ [LOptDisableUnroll, LOptDisableLICM, LOptInterleaveCount
+ ]> {
+ let cppNamespace = "::mlir::LLVM";
+}
+
class LLVM_Builder<string builder> {
string llvmBuilder = builder;
}
@@ -827,6 +839,48 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof"> {
let verifier = "return ::verify(*this);";
}
+def LLVM_MetadataOp : LLVM_Op<"metadata", [
+ NoRegionArguments, SymbolTable, Symbol
+]> {
+ let arguments = (ins
+ SymbolNameAttr:$sym_name
+ );
+ let summary = "LLVM dialect metadata.";
+ let description = [{
+ llvm.metadata op defines one or more metadata nodes. Currently the
+ llvm.access_group metadata op is supported.
+
+ Example:
+ llvm.metadata @metadata {
+ llvm.access_group @group1
+ llvm.access_group @group2
+ llvm.return
+ }
+ }];
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat = "$sym_name attr-dict-with-keyword $body";
+}
+
+def LLVM_AccessGroupMetadataOp : LLVM_Op<"access_group", [
+ HasParent<"MetadataOp">, Symbol
+]> {
+ let arguments = (ins
+ SymbolNameAttr:$sym_name
+ );
+ let summary = "LLVM dialect access group metadata.";
+ let description = [{
+ Defines an access group metadata that can be attached to any instruction
+ that potentially accesses memory. The access group may be attached to a
+ memory accessing instruction via the `llvm.access.group` metadata and
+ a branch instruction in the loop latch block via the
+ `llvm.loop.parallel_accesses` metadata.
+
+ See the following link for more details:
+ https://llvm.org/docs/LangRef.html#llvm-access-group-metadata
+ }];
+ let assemblyFormat = "$sym_name attr-dict";
+}
+
def LLVM_GlobalOp : LLVM_Op<"mlir.global",
[IsolatedFromAbove, SingleBlockImplicitTerminator<"ReturnOp">, Symbol]> {
let arguments = (ins
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 698641deb99d..748268575f86 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -110,6 +110,24 @@ class ModuleTranslation {
return branchMapping.lookup(op);
}
+ /// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
+ /// dialect access group operation.
+ llvm::MDNode *getAccessGroup(Operation &opInst,
+ SymbolRefAttr accessGroupRef) const;
+
+ /// Returns the LLVM metadata corresponding to a llvm loop's codegen
+ /// options attribute.
+ llvm::MDNode *lookupLoopOptionsMetadata(Attribute options) const {
+ return loopOptionsMetadataMapping.lookup(options);
+ }
+
+ void mapLoopOptionsMetadata(Attribute options, llvm::MDNode *metadata) {
+ auto result = loopOptionsMetadataMapping.try_emplace(options, metadata);
+ (void)result;
+ assert(result.second &&
+ "attempting to map loop options that was already mapped");
+ }
+
/// Converts the type from MLIR LLVM dialect to LLVM.
llvm::Type *convertType(Type type);
@@ -167,6 +185,10 @@ class ModuleTranslation {
LogicalResult convertGlobals();
LogicalResult convertOneFunction(LLVMFuncOp func);
+ /// Process access_group LLVM Metadata operations and create LLVM
+ /// metadata nodes.
+ LogicalResult createAccessGroupMetadata();
+
/// Translates dialect attributes attached to the given operation.
LogicalResult convertDialectAttributes(Operation *op);
@@ -198,6 +220,16 @@ class ModuleTranslation {
/// they are converted to. This allows for connecting PHI nodes to the source
/// values after all operations are converted.
DenseMap<Operation *, llvm::Instruction *> branchMapping;
+
+ /// Mapping from an access group metadata optation to its LLVM metadata.
+ /// This map is populated on module entry and is used to annotate loops (as
+ /// identified via their branches) and contained memory accesses.
+ DenseMap<Operation *, llvm::MDNode *> accessGroupMetadataMapping;
+
+ /// Mapping from an attribute describing loop codegen options to its LLVM
+ /// metadata. The metadata is attached to Latch block branches with this
+ /// attribute.
+ DenseMap<Attribute, llvm::MDNode *> loopOptionsMetadataMapping;
};
namespace detail {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d63e7753f93e..0538862b56e1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -56,6 +56,26 @@ struct BitmaskEnumStorage : public AttributeStorage {
KeyTy value = 0;
};
+
+struct LoopOptionAttrStorage : public AttributeStorage {
+ using KeyTy = std::pair<uint64_t, int32_t>;
+
+ explicit LoopOptionAttrStorage(uint64_t option, int32_t value)
+ : option(option), value(value) {}
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(option, value);
+ }
+
+ static LoopOptionAttrStorage *
+ construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<LoopOptionAttrStorage>())
+ LoopOptionAttrStorage(key.first, key.second);
+ }
+
+ uint64_t option;
+ int32_t value;
+};
} // namespace detail
} // namespace LLVM
} // namespace mlir
@@ -2158,7 +2178,7 @@ static LogicalResult verify(FenceOp &op) {
//===----------------------------------------------------------------------===//
void LLVMDialect::initialize() {
- addAttributes<FMFAttr>();
+ addAttributes<FMFAttr, LoopOptionAttr>();
// clang-format off
addTypes<LLVMVoidType,
@@ -2213,6 +2233,57 @@ LogicalResult LLVMDialect::verifyDataLayoutString(
/// Verify LLVM dialect attributes.
LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
+ // If the `llvm.loop` attribute is present, enforce the following structure,
+ // which the module translation can assume.
+ if (attr.first.strref() == LLVMDialect::getLoopAttrName()) {
+ auto loopAttr = attr.second.dyn_cast<DictionaryAttr>();
+ if (!loopAttr)
+ return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
+ << "' to be a dictionary attribute";
+ Optional<NamedAttribute> parallelAccessGroup =
+ loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
+ if (parallelAccessGroup.hasValue()) {
+ auto accessGroups = parallelAccessGroup->second.dyn_cast<ArrayAttr>();
+ if (!accessGroups)
+ return op->emitOpError()
+ << "expected '" << LLVMDialect::getParallelAccessAttrName()
+ << "' to be an array attribute";
+ for (Attribute attr : accessGroups) {
+ auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>();
+ if (!accessGroupRef)
+ return op->emitOpError()
+ << "expected '" << attr << "' to be a symbol reference";
+ StringRef metadataName = accessGroupRef.getRootReference();
+ auto metadataOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+ op->getParentOp(), metadataName);
+ if (!metadataOp)
+ return op->emitOpError()
+ << "expected '" << attr << "' to reference a metadata op";
+ StringRef accessGroupName = accessGroupRef.getLeafReference();
+ Operation *accessGroupOp =
+ SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
+ if (!accessGroupOp)
+ return op->emitOpError()
+ << "expected '" << attr << "' to reference an access_group op";
+ }
+ }
+
+ Optional<NamedAttribute> loopOptions =
+ loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
+ if (loopOptions.hasValue()) {
+ auto options = loopOptions->second.dyn_cast<ArrayAttr>();
+ if (!options)
+ return op->emitOpError()
+ << "expected '" << LLVMDialect::getLoopOptionsAttrName()
+ << "' to be an array attribute";
+ if (!llvm::all_of(options, [](Attribute option) {
+ return option.isa<LoopOptionAttr>();
+ }))
+ return op->emitOpError() << "invalid loop options list " << options;
+ }
+ }
+
// If the data layout attribute is present, it must use the LLVM data layout
// syntax. Try parsing it and report errors in case of failure. Users of this
// attribute may assume it is well-formed and can pass it to the (asserting)
@@ -2343,6 +2414,109 @@ Attribute FMFAttr::parse(DialectAsmParser &parser) {
return FMFAttr::get(flags, parser.getBuilder().getContext());
}
+LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context,
+ bool disable) {
+ auto option = LoopOptionCase::disable_unroll;
+ return Base::get(context, static_cast<uint64_t>(option),
+ static_cast<int32_t>(disable));
+}
+
+LoopOptionAttr LoopOptionAttr::getDisableLICM(MLIRContext *context,
+ bool disable) {
+ auto option = LoopOptionCase::disable_licm;
+ return Base::get(context, static_cast<uint64_t>(option),
+ static_cast<int32_t>(disable));
+}
+
+LoopOptionAttr LoopOptionAttr::getInterleaveCount(MLIRContext *context,
+ int32_t count) {
+ auto option = LoopOptionCase::interleave_count;
+ return Base::get(context, static_cast<uint64_t>(option),
+ static_cast<int32_t>(count));
+}
+
+LoopOptionCase LoopOptionAttr::getCase() const {
+ return static_cast<LoopOptionCase>(getImpl()->option);
+}
+
+bool LoopOptionAttr::getBool() const {
+ LoopOptionCase option = getCase();
+ (void)option;
+ assert(option == LoopOptionCase::disable_licm ||
+ option == LoopOptionCase::disable_unroll &&
+ "expected a boolean loop option");
+ return static_cast<bool>(getImpl()->value);
+}
+
+int32_t LoopOptionAttr::getInt() const {
+ LoopOptionCase option = getCase();
+ (void)option;
+ assert(option == LoopOptionCase::interleave_count &&
+ "expected an integer loop option");
+ return getImpl()->value;
+}
+
+void LoopOptionAttr::print(DialectAsmPrinter &printer) const {
+ printer << "loopopt<" << stringifyEnum(getCase()) << " = ";
+ switch (getCase()) {
+ case LoopOptionCase::disable_licm:
+ case LoopOptionCase::disable_unroll:
+ printer << (getBool() ? "true" : "false");
+ break;
+ case LoopOptionCase::interleave_count:
+ printer << getInt();
+ break;
+ }
+ printer << ">";
+}
+
+Attribute LoopOptionAttr::parse(DialectAsmParser &parser) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ StringRef optionName;
+ if (failed(parser.parseKeyword(&optionName)))
+ return {};
+
+ auto option = symbolizeLoopOptionCase(optionName);
+ if (!option) {
+ parser.emitError(parser.getNameLoc(), "unknown loop option: ")
+ << optionName;
+ return {};
+ }
+
+ if (failed(parser.parseEqual()))
+ return {};
+
+ int32_t value;
+ switch (*option) {
+ case LoopOptionCase::disable_licm:
+ case LoopOptionCase::disable_unroll:
+ if (succeeded(parser.parseOptionalKeyword("true")))
+ value = 1;
+ else if (succeeded(parser.parseOptionalKeyword("false")))
+ value = 0;
+ else {
+ parser.emitError(parser.getNameLoc(),
+ "expected boolean value 'true' or 'false'");
+ return {};
+ }
+ break;
+ case LoopOptionCase::interleave_count:
+ if (failed(parser.parseInteger(value))) {
+ parser.emitError(parser.getNameLoc(), "expected integer value");
+ return {};
+ }
+ break;
+ }
+
+ if (failed(parser.parseGreater()))
+ return {};
+
+ return Base::get(parser.getBuilder().getContext(),
+ static_cast<uint64_t>(*option), value);
+}
+
Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
if (type) {
@@ -2356,14 +2530,18 @@ Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
if (attrKind == "fastmath")
return FMFAttr::parse(parser);
- parser.emitError(parser.getNameLoc(), "Unknown attrribute type: ")
- << attrKind;
+ if (attrKind == "loopopt")
+ return LoopOptionAttr::parse(parser);
+
+ parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
return {};
}
void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
if (auto fmf = attr.dyn_cast<FMFAttr>())
fmf.print(os);
+ else if (auto lopt = attr.dyn_cast<LoopOptionAttr>())
+ lopt.print(os);
else
llvm_unreachable("Unknown attribute type");
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index baf7107db80f..1f03092628a5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -168,6 +168,82 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
return ret;
}
+/// Returns an LLVM metadata node corresponding to a loop option. This metadata
+/// is attached to an llvm.loop node.
+static llvm::MDNode *getLoopOptionMetadata(llvm::LLVMContext &ctx,
+ LoopOptionAttr option) {
+ StringRef name;
+ llvm::Constant *value = nullptr;
+ switch (option.getCase()) {
+ case LoopOptionCase::disable_licm:
+ name = "llvm.licm.disable";
+ value = llvm::ConstantInt::getBool(ctx, option.getBool());
+ break;
+ case LoopOptionCase::disable_unroll:
+ name = "llvm.loop.unroll.disable";
+ value = llvm::ConstantInt::getBool(ctx, option.getBool());
+ break;
+ case LoopOptionCase::interleave_count:
+ name = "llvm.loop.interleave.count";
+ value = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, /*NumBits=*/32),
+ option.getInt());
+ break;
+ }
+ return llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
+ llvm::ConstantAsMetadata::get(value)});
+}
+
+static void setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ if (Attribute attr = opInst.getAttr(LLVMDialect::getLoopAttrName())) {
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ llvm::MDNode *loopMD = moduleTranslation.lookupLoopOptionsMetadata(attr);
+ if (!loopMD) {
+ llvm::LLVMContext &ctx = module->getContext();
+
+ SmallVector<llvm::Metadata *> loopOptions;
+ // Reserve operand 0 for loop id self reference.
+ auto dummy = llvm::MDNode::getTemporary(ctx, llvm::None);
+ loopOptions.push_back(dummy.get());
+
+ auto loopAttr = attr.cast<DictionaryAttr>();
+ auto parallelAccessGroup =
+ loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
+ if (parallelAccessGroup.hasValue()) {
+ SmallVector<llvm::Metadata *> parallelAccess;
+ parallelAccess.push_back(
+ llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
+ for (SymbolRefAttr accessGroupRef :
+ parallelAccessGroup->second.cast<ArrayAttr>()
+ .getAsRange<SymbolRefAttr>())
+ parallelAccess.push_back(
+ moduleTranslation.getAccessGroup(opInst, accessGroupRef));
+ loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess));
+ }
+
+ auto loopOptionsAttr =
+ loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
+ if (loopOptionsAttr.hasValue()) {
+ for (LoopOptionAttr loopOption :
+ loopOptionsAttr->second.cast<ArrayAttr>()
+ .getAsRange<LoopOptionAttr>())
+ loopOptions.push_back(getLoopOptionMetadata(ctx, loopOption));
+ }
+
+ // Create loop options and set the first operand to itself.
+ loopMD = llvm::MDNode::get(ctx, loopOptions);
+ loopMD->replaceOperandWith(0, loopMD);
+
+ // Store a map from this Attribute to the LLVM metadata in case we
+ // encounter it again.
+ moduleTranslation.mapLoopOptionsMetadata(attr, loopMD);
+ }
+
+ llvmInst.setMetadata(module->getMDKindID("llvm.loop"), loopMD);
+ }
+}
+
static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
@@ -295,6 +371,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::BranchInst *branch =
builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
moduleTranslation.mapBranch(&opInst, branch);
+ setLoopMetadata(opInst, *branch, builder, moduleTranslation);
return success();
}
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
@@ -316,6 +393,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
moduleTranslation.mapBranch(&opInst, branch);
+ setLoopMetadata(opInst, *branch, builder, moduleTranslation);
return success();
}
if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index d52cc78a48fc..3a03b278e264 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -593,7 +593,7 @@ LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
/// Check whether the module contains only supported ops directly in its body.
static LogicalResult checkSupportedModuleOps(Operation *m) {
for (Operation &o : getModuleBody(m).getOperations())
- if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp>(&o) &&
+ if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::MetadataOp>(&o) &&
!o.hasTrait<OpTrait::IsTerminator>())
return o.emitOpError("unsupported module-level operation");
return success();
@@ -633,6 +633,29 @@ LogicalResult ModuleTranslation::convertFunctions() {
return success();
}
+llvm::MDNode *
+ModuleTranslation::getAccessGroup(Operation &opInst,
+ SymbolRefAttr accessGroupRef) const {
+ auto metadataName = accessGroupRef.getRootReference();
+ auto accessGroupName = accessGroupRef.getLeafReference();
+ auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+ opInst.getParentOp(), metadataName);
+ auto *accessGroupOp =
+ SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
+ return accessGroupMetadataMapping.lookup(accessGroupOp);
+}
+
+LogicalResult ModuleTranslation::createAccessGroupMetadata() {
+ mlirModule->walk([&](LLVM::MetadataOp metadatas) {
+ metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
+ llvm::LLVMContext &ctx = llvmModule->getContext();
+ llvm::MDNode *accessGroup = llvm::MDNode::getDistinct(ctx, {});
+ accessGroupMetadataMapping.insert({op, accessGroup});
+ });
+ });
+ return success();
+}
+
llvm::Type *ModuleTranslation::convertType(Type type) {
return typeTranslator.translateType(type);
}
@@ -697,6 +720,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
return nullptr;
if (failed(translator.convertGlobals()))
return nullptr;
+ if (failed(translator.createAccessGroupMetadata()))
+ return nullptr;
if (failed(translator.convertFunctions()))
return nullptr;
if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index b7984805b374..6a45b1f67e71 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -713,3 +713,86 @@ llvm.mlir.global common @non_zero_compound_global_common_linkage(dense<[0, 0, 0,
// expected-error at below {{expected array type for 'appending' linkage}}
llvm.mlir.global appending @non_array_type_global_appending_linkage() : i32
+
+// -----
+
+module {
+ llvm.func @loopOptions() {
+ // expected-error at below {{expected 'llvm.loop' to be a dictionary attribute}}
+ llvm.br ^bb4 {llvm.loop = "test"}
+ ^bb4:
+ llvm.return
+ }
+}
+
+// -----
+
+module {
+ llvm.func @loopOptions() {
+ // expected-error at below {{expected 'parallel_access' to be an array attribute}}
+ llvm.br ^bb4 {llvm.loop = {parallel_access = "loop"}}
+ ^bb4:
+ llvm.return
+ }
+}
+
+// -----
+
+module {
+ llvm.func @loopOptions() {
+ // expected-error at below {{expected '"loop"' to be a symbol reference}}
+ llvm.br ^bb4 {llvm.loop = {parallel_access = ["loop"]}}
+ ^bb4:
+ llvm.return
+ }
+}
+
+// -----
+
+module {
+ llvm.func @loopOptions() {
+ // expected-error at below {{expected '@func1' to reference a metadata op}}
+ llvm.br ^bb4 {llvm.loop = {parallel_access = [@func1]}}
+ ^bb4:
+ llvm.return
+ }
+ llvm.func @func1() {
+ llvm.return
+ }
+}
+
+// -----
+
+module {
+ llvm.func @loopOptions() {
+ // expected-error at below {{expected '@metadata' to reference an access_group op}}
+ llvm.br ^bb4 {llvm.loop = {parallel_access = [@metadata]}}
+ ^bb4:
+ llvm.return
+ }
+ llvm.metadata @metadata {
+ llvm.return
+ }
+}
+
+// -----
+
+module {
+ llvm.func @loopOptions() {
+ // expected-error at below {{expected 'options' to be an array attribute}}
+ llvm.br ^bb4 {llvm.loop = {options = "name"}}
+ ^bb4:
+ llvm.return
+ }
+}
+
+// -----
+
+module {
+ llvm.func @loopOptions() {
+ // expected-error at below {{invalid loop options list}}
+ llvm.br ^bb4 {llvm.loop = {options = ["name"]}}
+ ^bb4:
+ llvm.return
+ }
+}
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 27c5783a4d8d..b202432ccbae 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -419,3 +419,12 @@ func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32) {
%10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32
return
}
+
+module {
+ // CHECK: llvm.metadata @metadata attributes {test_attribute} {
+ llvm.metadata @metadata attributes {test_attribute} {
+ // CHECK: llvm.access_group @group1
+ llvm.access_group @group1
+ llvm.return
+ }
+}
diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index abe2deb685bb..1109345231f2 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1469,3 +1469,38 @@ llvm.func @switch_weights(%arg0: i32) -> i32 {
}
// CHECK: ![[SWITCH_WEIGHT_NODE]] = !{!"branch_weights", i32 13, i32 17, i32 19}
+
+// -----
+
+module {
+ llvm.func @loopOptions(%arg1 : i32, %arg2 : i32) {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ %4 = llvm.alloca %arg1 x i32 : (i32) -> (!llvm.ptr<i32>)
+ llvm.br ^bb3(%0 : i32)
+ ^bb3(%1: i32):
+ %2 = llvm.icmp "slt" %1, %arg1 : i32
+ // CHECK: br i1 {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]]
+ llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
+ ^bb4:
+ %3 = llvm.add %1, %arg2 : i32
+ %5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr<i32>
+ // CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]]
+ llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
+ ^bb5:
+ llvm.return
+ }
+
+ llvm.metadata @metadata {
+ llvm.access_group @group1
+ llvm.access_group @group2
+ llvm.return
+ }
+}
+
+// CHECK: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], ![[PA_NODE:[0-9]+]], ![[UNROLL_DISABLE_NODE:[0-9]+]], ![[LICM_DISABLE_NODE:[0-9]+]], ![[INTERLEAVE_NODE:[0-9]+]]}
+// CHECK: ![[PA_NODE]] = !{!"llvm.loop.parallel_accesses", ![[GROUP_NODE1:[0-9]+]], ![[GROUP_NODE2:[0-9]+]]}
+// CHECK: ![[GROUP_NODE1]] = distinct !{}
+// CHECK: ![[GROUP_NODE2]] = distinct !{}
+// CHECK: ![[UNROLL_DISABLE_NODE]] = !{!"llvm.loop.unroll.disable", i1 true}
+// CHECK: ![[LICM_DISABLE_NODE]] = !{!"llvm.licm.disable", i1 true}
+// CHECK: ![[INTERLEAVE_NODE]] = !{!"llvm.loop.interleave.count", i32 1}
More information about the Mlir-commits
mailing list