[Mlir-commits] [mlir] f3b1361 - [mlir] Generate constructor in generic adaptors allowing construction from an op instance
Markus Böck
llvmlistbot at llvm.org
Thu Aug 10 03:52:23 PDT 2023
Author: Markus Böck
Date: 2023-08-10T12:38:54+02:00
New Revision: f3b1361f35fcbd2b551f224bebc4068d7de8c3d2
URL: https://github.com/llvm/llvm-project/commit/f3b1361f35fcbd2b551f224bebc4068d7de8c3d2
DIFF: https://github.com/llvm/llvm-project/commit/f3b1361f35fcbd2b551f224bebc4068d7de8c3d2.diff
LOG: [mlir] Generate constructor in generic adaptors allowing construction from an op instance
In essentially all occurrences of adaptor constructions in the codebase, an instance of the op is available and only a different value range is being used. Nevertheless, one had to perform the ritual of calling and pass `getAttrDictionary()`, `getProperties` and `getRegions` manually.
This patch changes that by teaching TableGen to generate a new constructor in the adaptor that is constructable using `GenericAdaptor(valueRange, op)`. The (discardable) attr dictionary, properties and the regions are then taken directly from the passed op, with only the value range being taken from the first parameter.
This simplifies a lot of code and also guarantees that all the various getters of the adaptor work in all scenarios.
Differential Revision: https://reviews.llvm.org/D157516
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/TableGen/Class.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/include/mlir/Transforms/OneToNTypeConversion.h
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/TableGen/Class.cpp
mlir/test/mlir-tblgen/op-decl-and-defs.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 0aee13818df4d5..075d753ea6ed82 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -149,13 +149,8 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- if constexpr (SourceOp::hasProperties())
- return rewrite(cast<SourceOp>(op),
- OpAdaptor(operands, op->getDiscardableAttrDictionary(),
- cast<SourceOp>(op).getProperties()),
- rewriter);
- rewrite(cast<SourceOp>(op),
- OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter);
+ rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
+ rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
@@ -163,15 +158,8 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- if constexpr (SourceOp::hasProperties())
- return matchAndRewrite(cast<SourceOp>(op),
- OpAdaptor(operands,
- op->getDiscardableAttrDictionary(),
- cast<SourceOp>(op).getProperties()),
- rewriter);
- return matchAndRewrite(
- cast<SourceOp>(op),
- OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter);
+ return matchAndRewrite(cast<SourceOp>(op),
+ OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index afbd0395b466a3..cba9bb30006184 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1915,15 +1915,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
SmallVectorImpl<OpFoldResult> &results) {
OpFoldResult result;
if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>) {
- if constexpr (hasProperties()) {
- result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
- operands, op->getDiscardableAttrDictionary(),
- cast<ConcreteOpT>(op).getProperties(), op->getRegions()));
- } else {
- result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
- operands, op->getDiscardableAttrDictionary(), {},
- op->getRegions()));
- }
+ result = cast<ConcreteOpT>(op).fold(
+ typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)));
} else {
result = cast<ConcreteOpT>(op).fold(operands);
}
@@ -1946,19 +1939,9 @@ class Op : public OpState, public Traits<ConcreteType>... {
SmallVectorImpl<OpFoldResult> &results) {
auto result = LogicalResult::failure();
if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
- if constexpr (hasProperties()) {
- result = cast<ConcreteOpT>(op).fold(
- typename ConcreteOpT::FoldAdaptor(
- operands, op->getDiscardableAttrDictionary(),
- cast<ConcreteOpT>(op).getProperties(), op->getRegions()),
- results);
- } else {
- result = cast<ConcreteOpT>(op).fold(
- typename ConcreteOpT::FoldAdaptor(
- operands, op->getDiscardableAttrDictionary(), {},
- op->getRegions()),
- results);
- }
+ result = cast<ConcreteOpT>(op).fold(
+ typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)),
+ results);
} else {
result = cast<ConcreteOpT>(op).fold(operands, results);
}
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 27b5d5c57e9135..81cdf7dbef5f70 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -166,6 +166,21 @@ class MethodSignature {
/// method definition).
void writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const;
+ /// Write the template parameters of the signature.
+ void writeTemplateParamsTo(raw_indented_ostream &os) const;
+
+ /// Add a template parameter.
+ template <typename ParamT>
+ void addTemplateParam(ParamT param) {
+ templateParams.push_back(stringify(param));
+ }
+
+ /// Add a list of template parameters.
+ template <typename ContainerT>
+ void addTemplateParams(ContainerT &&container) {
+ templateParams.insert(std::begin(container), std::end(container));
+ }
+
private:
/// The method's C++ return type.
std::string returnType;
@@ -173,6 +188,8 @@ class MethodSignature {
std::string methodName;
/// The method's parameter list.
MethodParameters parameters;
+ /// An optional list of template parameters.
+ SmallVector<std::string, 0> templateParams;
};
/// This class contains the body of a C++ method.
@@ -367,6 +384,14 @@ class Method : public ClassDeclarationBase<ClassDeclaration::Method> {
void writeDefTo(raw_indented_ostream &os,
StringRef namePrefix) const override;
+ /// Add a template parameter.
+ template <typename ParamT>
+ void addTemplateParam(ParamT param);
+
+ /// Add a list of template parameters.
+ template <typename ContainerT>
+ void addTemplateParams(ContainerT &&container);
+
protected:
/// A collection of method properties.
Properties properties;
@@ -459,6 +484,20 @@ operator|=(mlir::tblgen::Method::Properties &lhs,
namespace mlir {
namespace tblgen {
+template <typename ParamT>
+void Method::addTemplateParam(ParamT param) {
+ // Templates imply inline.
+ properties |= Method::Inline;
+ methodSignature.addTemplateParam(param);
+}
+
+template <typename ContainerT>
+void Method::addTemplateParams(ContainerT &&container) {
+ // Templates imply inline.
+ properties |= Method::Inline;
+ methodSignature.addTemplateParam(std::forward<ContainerT>(container));
+}
+
/// This class describes a C++ parent class declaration.
class ParentClass {
public:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 559c0c00013d45..e8a0e6ec6991b0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -523,24 +523,13 @@ class OpConversionPattern : public ConversionPattern {
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
- rewrite(sourceOp,
- OpAdaptor(operands, op->getDiscardableAttrDictionary(),
- sourceOp.getProperties()),
- rewriter);
+ rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
- if constexpr (SourceOp::hasProperties())
- return matchAndRewrite(sourceOp,
- OpAdaptor(operands,
- op->getDiscardableAttrDictionary(),
- sourceOp.getProperties()),
- rewriter);
- return matchAndRewrite(
- sourceOp, OpAdaptor(operands, op->getDiscardableAttrDictionary()),
- rewriter);
+ return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index b86a6b5f76a3bc..933961814cbe40 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -231,12 +231,9 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
OpAdaptor(const OneToNTypeMapping *operandMapping,
const OneToNTypeMapping *resultMapping,
- const ValueRange *convertedOperands, RangeT values,
- DictionaryAttr attrs = nullptr, Properties &properties = {},
- RegionRange regions = {})
- : BaseT(values, attrs, properties, regions),
- operandMapping(operandMapping), resultMapping(resultMapping),
- convertedOperands(convertedOperands) {}
+ const ValueRange *convertedOperands, RangeT values, SourceOp op)
+ : BaseT(values, op), operandMapping(operandMapping),
+ resultMapping(resultMapping), convertedOperands(convertedOperands) {}
/// Get the type mapping of the original operands to the converted operands.
const OneToNTypeMapping &getOperandMapping() const {
@@ -276,8 +273,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
valueRanges.push_back(values);
}
OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
- valueRanges, op->getAttrDictionary(),
- cast<SourceOp>(op).getProperties(), op->getRegions());
+ valueRanges, cast<SourceOp>(op));
// Call overload implemented by the derived class.
return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 80ae59e16e5a95..4b27dcb6cda281 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -136,11 +136,7 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto reallocOp = cast<memref::ReallocOp>(op);
- return matchAndRewrite(reallocOp,
- OpAdaptor(operands,
- op->getDiscardableAttrDictionary(),
- reallocOp.getProperties()),
- rewriter);
+ return matchAndRewrite(reallocOp, OpAdaptor(operands, reallocOp), rewriter);
}
// A `realloc` is converted as follows:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 750fe9c021673f..b4dae244825364 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -292,7 +292,7 @@ ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
void ConditionOp::getSuccessorRegions(
ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) {
- FoldAdaptor adaptor(operands);
+ FoldAdaptor adaptor(operands, *this);
WhileOp whileOp = getParentOp();
@@ -2031,7 +2031,7 @@ void IfOp::getSuccessorRegions(std::optional<unsigned> index,
void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> ®ions) {
- FoldAdaptor adaptor(operands);
+ FoldAdaptor adaptor(operands, *this);
auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
if (!boolAttr || boolAttr.getValue())
regions.emplace_back(&getThenRegion());
@@ -4039,7 +4039,7 @@ void IndexSwitchOp::getSuccessorRegions(
void IndexSwitchOp::getEntrySuccessorRegions(
ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &successors) {
- FoldAdaptor adaptor(operands);
+ FoldAdaptor adaptor(operands, *this);
// If a constant was not provided, all regions are possible successors.
auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp
index fd2ba65e568e7d..36038230568501 100644
--- a/mlir/lib/TableGen/Class.cpp
+++ b/mlir/lib/TableGen/Class.cpp
@@ -93,6 +93,17 @@ void MethodSignature::writeDefTo(raw_indented_ostream &os,
os << ")";
}
+void MethodSignature::writeTemplateParamsTo(
+ mlir::raw_indented_ostream &os) const {
+ if (templateParams.empty())
+ return;
+
+ os << "template <";
+ llvm::interleaveComma(templateParams, os,
+ [&](StringRef param) { os << "typename " << param; });
+ os << ">\n";
+}
+
//===----------------------------------------------------------------------===//
// MethodBody definitions
//===----------------------------------------------------------------------===//
@@ -114,6 +125,7 @@ void MethodBody::writeTo(raw_indented_ostream &os) const {
//===----------------------------------------------------------------------===//
void Method::writeDeclTo(raw_indented_ostream &os) const {
+ methodSignature.writeTemplateParamsTo(os);
if (deprecationMessage) {
os << "[[deprecated(\"";
os.write_escaped(*deprecationMessage);
@@ -153,6 +165,7 @@ void Method::writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const {
//===----------------------------------------------------------------------===//
void Constructor::writeDeclTo(raw_indented_ostream &os) const {
+ methodSignature.writeTemplateParamsTo(os);
if (properties & ConstexprValue)
os << "constexpr ";
methodSignature.writeDeclTo(os);
diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 077aa750352e04..8107194b584b8a 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -57,6 +57,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: namespace detail {
// CHECK: class AOpGenericAdaptorBase {
// CHECK: public:
+// CHECK: AOpGenericAdaptorBase(AOp{{[[:space:]]}}
// CHECK: ::mlir::IntegerAttr getAttr1Attr();
// CHECK: uint32_t getAttr1();
// CHECK: ::mlir::FloatAttr getSomeAttr2Attr();
@@ -127,6 +128,14 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// DEFS-LABEL: NS::AOp definitions
// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions)
+
+// Check that `getAttrDictionary()` is used when not using properties.
+
+// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(AOp op)
+// DEFS-SAME: op->getAttrDictionary()
+// DEFS-SAME: p.getProperties()
+// DEFS-SAME: op->getRegions()
+
// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions()
// DEFS-NEXT: return odsRegions.drop_front(1);
// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions()
@@ -330,6 +339,17 @@ def NS_MOp : NS_Op<"op_with_single_result_and_fold_adaptor_fold", []> {
// CHECK-LABEL: class MOp :
// CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
+def NS_NOp : NS_Op<"op_with_properties", []> {
+ let arguments = (ins Property<"unsigned">:$value);
+}
+
+// Check that `getDiscardableAttrDictionary()` is used with properties.
+
+// DEFS: NOpGenericAdaptorBase::NOpGenericAdaptorBase(NOp op) : NOpGenericAdaptorBase(
+// DEFS-SAME: op->getDiscardableAttrDictionary()
+// DEFS-SAME: op.getProperties()
+// DEFS-SAME: op->getRegions()
+
// Test that type defs have the proper namespaces when used as a constraint.
// ---
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index fec1fd859fe710..795a2e3ce934e3 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3961,6 +3961,35 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
}
}
+ // Create constructors constructing the adaptor from an instance of the op.
+ // This takes the attributes, properties and regions from the op instance
+ // and the value range from the parameter.
+ {
+ // Base class is in the cpp file and can simply access the members of the op
+ // class to initialize the template independent fields.
+ auto *constructor = genericAdaptorBase.addConstructor(
+ MethodParameter(op.getCppClassName(), "op"));
+ constructor->addMemberInitializer(
+ genericAdaptorBase.getClassName(),
+ llvm::Twine(!useProperties ? "op->getAttrDictionary()"
+ : "op->getDiscardableAttrDictionary()") +
+ ", op.getProperties(), op->getRegions()");
+
+ // Generic adaptor is templated and therefore defined inline in the header.
+ // We cannot use the Op class here as it is an incomplete type (we have a
+ // circular reference between the two).
+ // Use a template trick to make the constructor be instantiated at call site
+ // when the op class is complete.
+ constructor = genericAdaptor.addConstructor(
+ MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op"));
+ constructor->addTemplateParam("LateInst = " + op.getCppClassName());
+ constructor->addTemplateParam(
+ "= std::enable_if_t<std::is_same_v<LateInst, " + op.getCppClassName() +
+ ">>");
+ constructor->addMemberInitializer("Base", "op");
+ constructor->addMemberInitializer("odsOperands", "values");
+ }
+
std::string sizeAttrInit;
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
if (op.getDialect().usePropertiesForAttributes())
@@ -4074,9 +4103,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
// Constructor taking the Op as single parameter.
auto *constructor =
adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
- constructor->addMemberInitializer(
- adaptor.getClassName(), "op->getOperands(), op->getAttrDictionary(), "
- "op.getProperties(), op->getRegions()");
+ constructor->addMemberInitializer(genericAdaptorClassName,
+ "op->getOperands(), op");
}
// Add verification function.
More information about the Mlir-commits
mailing list