[Mlir-commits] [mlir] 038f2a3 - Move LLVM::FMFAttr definition to TableGen (NFC)
Mehdi Amini
llvmlistbot at llvm.org
Mon Mar 8 21:30:19 PST 2021
Author: Mehdi Amini
Date: 2021-03-09T05:29:54Z
New Revision: 038f2a337d09e114469ddcfba5b613cdb8c0fe1d
URL: https://github.com/llvm/llvm-project/commit/038f2a337d09e114469ddcfba5b613cdb8c0fe1d
DIFF: https://github.com/llvm/llvm-project/commit/038f2a337d09e114469ddcfba5b613cdb8c0fe1d.diff
LOG: Move LLVM::FMFAttr definition to TableGen (NFC)
This is using the new Attribute storage generation support in
TableGen to define the LLVM FastMathFlags.
Differential Revision: https://reviews.llvm.org/D98007
Added:
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Modified:
mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index ae0fc152ddc4..20b989616d7a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -1,5 +1,10 @@
add_subdirectory(Transforms)
+set(LLVM_TARGET_DEFINITIONS LLVMAttrDefs.td)
+mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRLLVMAttrsIncGen)
+
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
new file mode 100644
index 000000000000..8cdf36565a3e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -0,0 +1,29 @@
+//===-- LLVMAttrDefs.td - LLVM Attributes definition file --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVMIR_ATTRDEFS
+#define LLVMIR_ATTRDEFS
+
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+
+// All of the attributes will extend this class.
+class LLVM_Attr<string name> : AttrDef<LLVM_Dialect, name>;
+
+// The "FastMath" flags associated with floating point LLVM instructions.
+def FastmathFlagsAttr : LLVM_Attr<"FMF"> {
+ let mnemonic = "fastmath";
+
+ // List of type parameters.
+ let parameters = (
+ ins
+ "FastmathFlags":$flags
+ );
+}
+
+#endif // LLVMIR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index ac5b5907bf82..bd0cc61c1e3d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -30,6 +30,8 @@
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc"
#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc"
namespace llvm {
class Type;
@@ -47,24 +49,9 @@ class LLVMDialect;
namespace detail {
struct LLVMTypeStorage;
struct LLVMDialectImpl;
-struct BitmaskEnumStorage;
struct LoopOptionAttrStorage;
} // namespace detail
-/// An attribute that specifies LLVM instruction fastmath flags.
-class FMFAttr : public Attribute::AttrBase<FMFAttr, Attribute,
- detail::BitmaskEnumStorage> {
-public:
- using Base::Base;
-
- static FMFAttr get(FastmathFlags flags, MLIRContext *context);
-
- FastmathFlags getFlags() const;
-
- void print(DialectAsmPrinter &p) const;
- static Attribute parse(DialectAsmParser &parser);
-};
-
/// An attribute that specifies LLVM loop codegen options.
class LoopOptionAttr
: public Attribute::AttrBase<LoopOptionAttr, Attribute,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 07583866621e..4b2e99d52bd6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -44,7 +44,7 @@ def LLVM_FMFAttr : DialectAttr<
let returnType = "::mlir::LLVM::FastmathFlags";
let convertFromStorage = "$_self.getFlags()";
let constBuilderCall =
- "::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())";
+ "::mlir::LLVM::FMFAttr::get($_builder.getContext(), $0)";
}
def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 1>;
@@ -249,7 +249,7 @@ def LLVM_FCmpOp : LLVM_Op<"fcmp", [
[{
build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
$_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs,
- ::mlir::LLVM::FMFAttr::get(fmf, $_builder.getContext()));
+ ::mlir::LLVM::FMFAttr::get($_builder.getContext(), fmf));
}]>];
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
let printer = [{ printFCmpOp(p, *this); }];
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 4421a9fb4808..00ab63790fd4 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -30,7 +30,7 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
Value real = complexStruct.real(rewriter, op.getLoc());
Value imag = complexStruct.imaginary(rewriter, op.getLoc());
- auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+ auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value sqNorm = rewriter.create<LLVM::FAddOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
@@ -133,7 +133,7 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
- auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+ auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value real =
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
@@ -161,7 +161,7 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
- auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+ auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value rhsRe = arg.rhs.real();
Value rhsIm = arg.rhs.imag();
Value lhsRe = arg.lhs.real();
@@ -206,7 +206,7 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
- auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+ auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value rhsRe = arg.rhs.real();
Value rhsIm = arg.rhs.imag();
Value lhsRe = arg.lhs.real();
@@ -243,7 +243,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to substract complex numbers.
- auto fmf = LLVM::FMFAttr::get({}, op.getContext());
+ auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
Value real =
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 1ca1650e2750..871f54b3aa09 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -829,7 +829,7 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
operation, dstType,
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
operation.operand1(), operation.operand2(),
- LLVM::FMFAttr::get({}, operation.getContext()));
+ LLVM::FMFAttr::get(operation.getContext(), {}));
return success();
}
};
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 663a922c5039..3601af43c73d 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2981,7 +2981,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
ConversionPatternRewriter &rewriter) const override {
CmpFOpAdaptor transformed(operands);
- auto fmf = LLVM::FMFAttr::get({}, cmpfOp.getContext());
+ auto fmf = LLVM::FMFAttr::get(cmpfOp.getContext(), {});
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 941792dc9c5c..99e12de73442 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
@@ -37,25 +38,12 @@ static constexpr const char kNonTemporalAttrName[] = "nontemporal";
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
namespace mlir {
namespace LLVM {
namespace detail {
-struct BitmaskEnumStorage : public AttributeStorage {
- using KeyTy = uint64_t;
-
- BitmaskEnumStorage(KeyTy val) : value(val) {}
-
- bool operator==(const KeyTy &key) const { return value == key; }
-
- static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
- const KeyTy &key) {
- return new (allocator.allocate<BitmaskEnumStorage>())
- BitmaskEnumStorage(key);
- }
-
- KeyTy value = 0;
-};
struct LoopOptionAttrStorage : public AttributeStorage {
using KeyTy = std::pair<uint64_t, int32_t>;
@@ -84,7 +72,7 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute, 8> filteredAttrs(
llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
if (attr.first == "fastmathFlags") {
- auto defAttr = FMFAttr::get({}, attr.second.getContext());
+ auto defAttr = FMFAttr::get(attr.second.getContext(), {});
return defAttr != attr.second;
}
return true;
@@ -2387,14 +2375,6 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
-FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) {
- return Base::get(context, static_cast<uint64_t>(flags));
-}
-
-FastmathFlags FMFAttr::getFlags() const {
- return static_cast<FastmathFlags>(getImpl()->value);
-}
-
static constexpr const FastmathFlags FastmathFlagsList[] = {
// clang-format off
FastmathFlags::nnan,
@@ -2418,7 +2398,8 @@ void FMFAttr::print(DialectAsmPrinter &printer) const {
printer << ">";
}
-Attribute FMFAttr::parse(DialectAsmParser &parser) {
+Attribute FMFAttr::parse(MLIRContext *context, DialectAsmParser &parser,
+ Type type) {
if (failed(parser.parseLess()))
return {};
@@ -2443,7 +2424,7 @@ Attribute FMFAttr::parse(DialectAsmParser &parser) {
return {};
}
- return FMFAttr::get(flags, parser.getBuilder().getContext());
+ return FMFAttr::get(parser.getBuilder().getContext(), flags);
}
LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context,
@@ -2558,9 +2539,9 @@ Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
StringRef attrKind;
if (parser.parseKeyword(&attrKind))
return {};
-
- if (attrKind == "fastmath")
- return FMFAttr::parse(parser);
+ if (auto attr =
+ generatedAttributeParser(getContext(), parser, attrKind, type))
+ return attr;
if (attrKind == "loopopt")
return LoopOptionAttr::parse(parser);
@@ -2570,9 +2551,9 @@ Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
}
void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
- if (auto fmf = attr.dyn_cast<FMFAttr>())
- fmf.print(os);
- else if (auto lopt = attr.dyn_cast<LoopOptionAttr>())
+ if (succeeded(generatedAttributePrinter(attr, os)))
+ return;
+ if (auto lopt = attr.dyn_cast<LoopOptionAttr>())
lopt.print(os);
else
llvm_unreachable("Unknown attribute type");
More information about the Mlir-commits
mailing list