[Mlir-commits] [mlir] f3b1361 - [mlir] Generate constructor in generic adaptors allowing construction from an op instance

Markus Böck llvmlistbot at llvm.org
Thu Aug 10 03:52:23 PDT 2023


Author: Markus Böck
Date: 2023-08-10T12:38:54+02:00
New Revision: f3b1361f35fcbd2b551f224bebc4068d7de8c3d2

URL: https://github.com/llvm/llvm-project/commit/f3b1361f35fcbd2b551f224bebc4068d7de8c3d2
DIFF: https://github.com/llvm/llvm-project/commit/f3b1361f35fcbd2b551f224bebc4068d7de8c3d2.diff

LOG: [mlir] Generate constructor in generic adaptors allowing construction from an op instance

In essentially all occurrences of adaptor constructions in the codebase, an instance of the op is available and only a different value range is being used. Nevertheless, one had to perform the ritual of calling and pass `getAttrDictionary()`, `getProperties` and `getRegions` manually.

This patch changes that by teaching TableGen to generate a new constructor in the adaptor that is constructable using `GenericAdaptor(valueRange, op)`. The (discardable) attr dictionary, properties and the regions are then taken directly from the passed op, with only the value range being taken from the first parameter.

This simplifies a lot of code and also guarantees that all the various getters of the adaptor work in all scenarios.

Differential Revision: https://reviews.llvm.org/D157516

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/TableGen/Class.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/include/mlir/Transforms/OneToNTypeConversion.h
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/TableGen/Class.cpp
    mlir/test/mlir-tblgen/op-decl-and-defs.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 0aee13818df4d5..075d753ea6ed82 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -149,13 +149,8 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   /// Wrappers around the RewritePattern methods that pass the derived op type.
   void rewrite(Operation *op, ArrayRef<Value> operands,
                ConversionPatternRewriter &rewriter) const final {
-    if constexpr (SourceOp::hasProperties())
-      return rewrite(cast<SourceOp>(op),
-                     OpAdaptor(operands, op->getDiscardableAttrDictionary(),
-                               cast<SourceOp>(op).getProperties()),
-                     rewriter);
-    rewrite(cast<SourceOp>(op),
-            OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter);
+    rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
+            rewriter);
   }
   LogicalResult match(Operation *op) const final {
     return match(cast<SourceOp>(op));
@@ -163,15 +158,8 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    if constexpr (SourceOp::hasProperties())
-      return matchAndRewrite(cast<SourceOp>(op),
-                             OpAdaptor(operands,
-                                       op->getDiscardableAttrDictionary(),
-                                       cast<SourceOp>(op).getProperties()),
-                             rewriter);
-    return matchAndRewrite(
-        cast<SourceOp>(op),
-        OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter);
+    return matchAndRewrite(cast<SourceOp>(op),
+                           OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
   }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index afbd0395b466a3..cba9bb30006184 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1915,15 +1915,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
                        SmallVectorImpl<OpFoldResult> &results) {
     OpFoldResult result;
     if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>) {
-      if constexpr (hasProperties()) {
-        result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
-            operands, op->getDiscardableAttrDictionary(),
-            cast<ConcreteOpT>(op).getProperties(), op->getRegions()));
-      } else {
-        result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
-            operands, op->getDiscardableAttrDictionary(), {},
-            op->getRegions()));
-      }
+      result = cast<ConcreteOpT>(op).fold(
+          typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)));
     } else {
       result = cast<ConcreteOpT>(op).fold(operands);
     }
@@ -1946,19 +1939,9 @@ class Op : public OpState, public Traits<ConcreteType>... {
                                 SmallVectorImpl<OpFoldResult> &results) {
     auto result = LogicalResult::failure();
     if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
-      if constexpr (hasProperties()) {
-        result = cast<ConcreteOpT>(op).fold(
-            typename ConcreteOpT::FoldAdaptor(
-                operands, op->getDiscardableAttrDictionary(),
-                cast<ConcreteOpT>(op).getProperties(), op->getRegions()),
-            results);
-      } else {
-        result = cast<ConcreteOpT>(op).fold(
-            typename ConcreteOpT::FoldAdaptor(
-                operands, op->getDiscardableAttrDictionary(), {},
-                op->getRegions()),
-            results);
-      }
+      result = cast<ConcreteOpT>(op).fold(
+          typename ConcreteOpT::FoldAdaptor(operands, cast<ConcreteOpT>(op)),
+          results);
     } else {
       result = cast<ConcreteOpT>(op).fold(operands, results);
     }

diff  --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 27b5d5c57e9135..81cdf7dbef5f70 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -166,6 +166,21 @@ class MethodSignature {
   /// method definition).
   void writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const;
 
+  /// Write the template parameters of the signature.
+  void writeTemplateParamsTo(raw_indented_ostream &os) const;
+
+  /// Add a template parameter.
+  template <typename ParamT>
+  void addTemplateParam(ParamT param) {
+    templateParams.push_back(stringify(param));
+  }
+
+  /// Add a list of template parameters.
+  template <typename ContainerT>
+  void addTemplateParams(ContainerT &&container) {
+    templateParams.insert(std::begin(container), std::end(container));
+  }
+
 private:
   /// The method's C++ return type.
   std::string returnType;
@@ -173,6 +188,8 @@ class MethodSignature {
   std::string methodName;
   /// The method's parameter list.
   MethodParameters parameters;
+  /// An optional list of template parameters.
+  SmallVector<std::string, 0> templateParams;
 };
 
 /// This class contains the body of a C++ method.
@@ -367,6 +384,14 @@ class Method : public ClassDeclarationBase<ClassDeclaration::Method> {
   void writeDefTo(raw_indented_ostream &os,
                   StringRef namePrefix) const override;
 
+  /// Add a template parameter.
+  template <typename ParamT>
+  void addTemplateParam(ParamT param);
+
+  /// Add a list of template parameters.
+  template <typename ContainerT>
+  void addTemplateParams(ContainerT &&container);
+
 protected:
   /// A collection of method properties.
   Properties properties;
@@ -459,6 +484,20 @@ operator|=(mlir::tblgen::Method::Properties &lhs,
 namespace mlir {
 namespace tblgen {
 
+template <typename ParamT>
+void Method::addTemplateParam(ParamT param) {
+  // Templates imply inline.
+  properties |= Method::Inline;
+  methodSignature.addTemplateParam(param);
+}
+
+template <typename ContainerT>
+void Method::addTemplateParams(ContainerT &&container) {
+  // Templates imply inline.
+  properties |= Method::Inline;
+  methodSignature.addTemplateParam(std::forward<ContainerT>(container));
+}
+
 /// This class describes a C++ parent class declaration.
 class ParentClass {
 public:

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 559c0c00013d45..e8a0e6ec6991b0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -523,24 +523,13 @@ class OpConversionPattern : public ConversionPattern {
   void rewrite(Operation *op, ArrayRef<Value> operands,
                ConversionPatternRewriter &rewriter) const final {
     auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp,
-            OpAdaptor(operands, op->getDiscardableAttrDictionary(),
-                      sourceOp.getProperties()),
-            rewriter);
+    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     auto sourceOp = cast<SourceOp>(op);
-    if constexpr (SourceOp::hasProperties())
-      return matchAndRewrite(sourceOp,
-                             OpAdaptor(operands,
-                                       op->getDiscardableAttrDictionary(),
-                                       sourceOp.getProperties()),
-                             rewriter);
-    return matchAndRewrite(
-        sourceOp, OpAdaptor(operands, op->getDiscardableAttrDictionary()),
-        rewriter);
+    return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be

diff  --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index b86a6b5f76a3bc..933961814cbe40 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -231,12 +231,9 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
 
     OpAdaptor(const OneToNTypeMapping *operandMapping,
               const OneToNTypeMapping *resultMapping,
-              const ValueRange *convertedOperands, RangeT values,
-              DictionaryAttr attrs = nullptr, Properties &properties = {},
-              RegionRange regions = {})
-        : BaseT(values, attrs, properties, regions),
-          operandMapping(operandMapping), resultMapping(resultMapping),
-          convertedOperands(convertedOperands) {}
+              const ValueRange *convertedOperands, RangeT values, SourceOp op)
+        : BaseT(values, op), operandMapping(operandMapping),
+          resultMapping(resultMapping), convertedOperands(convertedOperands) {}
 
     /// Get the type mapping of the original operands to the converted operands.
     const OneToNTypeMapping &getOperandMapping() const {
@@ -276,8 +273,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
       valueRanges.push_back(values);
     }
     OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
-                      valueRanges, op->getAttrDictionary(),
-                      cast<SourceOp>(op).getProperties(), op->getRegions());
+                      valueRanges, cast<SourceOp>(op));
 
     // Call overload implemented by the derived class.
     return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 80ae59e16e5a95..4b27dcb6cda281 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -136,11 +136,7 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     auto reallocOp = cast<memref::ReallocOp>(op);
-    return matchAndRewrite(reallocOp,
-                           OpAdaptor(operands,
-                                     op->getDiscardableAttrDictionary(),
-                                     reallocOp.getProperties()),
-                           rewriter);
+    return matchAndRewrite(reallocOp, OpAdaptor(operands, reallocOp), rewriter);
   }
 
   // A `realloc` is converted as follows:

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 750fe9c021673f..b4dae244825364 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -292,7 +292,7 @@ ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
 
 void ConditionOp::getSuccessorRegions(
     ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
-  FoldAdaptor adaptor(operands);
+  FoldAdaptor adaptor(operands, *this);
 
   WhileOp whileOp = getParentOp();
 
@@ -2031,7 +2031,7 @@ void IfOp::getSuccessorRegions(std::optional<unsigned> index,
 
 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
                                     SmallVectorImpl<RegionSuccessor> &regions) {
-  FoldAdaptor adaptor(operands);
+  FoldAdaptor adaptor(operands, *this);
   auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
   if (!boolAttr || boolAttr.getValue())
     regions.emplace_back(&getThenRegion());
@@ -4039,7 +4039,7 @@ void IndexSwitchOp::getSuccessorRegions(
 void IndexSwitchOp::getEntrySuccessorRegions(
     ArrayRef<Attribute> operands,
     SmallVectorImpl<RegionSuccessor> &successors) {
-  FoldAdaptor adaptor(operands);
+  FoldAdaptor adaptor(operands, *this);
 
   // If a constant was not provided, all regions are possible successors.
   auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());

diff  --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp
index fd2ba65e568e7d..36038230568501 100644
--- a/mlir/lib/TableGen/Class.cpp
+++ b/mlir/lib/TableGen/Class.cpp
@@ -93,6 +93,17 @@ void MethodSignature::writeDefTo(raw_indented_ostream &os,
   os << ")";
 }
 
+void MethodSignature::writeTemplateParamsTo(
+    mlir::raw_indented_ostream &os) const {
+  if (templateParams.empty())
+    return;
+
+  os << "template <";
+  llvm::interleaveComma(templateParams, os,
+                        [&](StringRef param) { os << "typename " << param; });
+  os << ">\n";
+}
+
 //===----------------------------------------------------------------------===//
 // MethodBody definitions
 //===----------------------------------------------------------------------===//
@@ -114,6 +125,7 @@ void MethodBody::writeTo(raw_indented_ostream &os) const {
 //===----------------------------------------------------------------------===//
 
 void Method::writeDeclTo(raw_indented_ostream &os) const {
+  methodSignature.writeTemplateParamsTo(os);
   if (deprecationMessage) {
     os << "[[deprecated(\"";
     os.write_escaped(*deprecationMessage);
@@ -153,6 +165,7 @@ void Method::writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const {
 //===----------------------------------------------------------------------===//
 
 void Constructor::writeDeclTo(raw_indented_ostream &os) const {
+  methodSignature.writeTemplateParamsTo(os);
   if (properties & ConstexprValue)
     os << "constexpr ";
   methodSignature.writeDeclTo(os);

diff  --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 077aa750352e04..8107194b584b8a 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -57,6 +57,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK: namespace detail {
 // CHECK: class AOpGenericAdaptorBase {
 // CHECK: public:
+// CHECK:   AOpGenericAdaptorBase(AOp{{[[:space:]]}}
 // CHECK:   ::mlir::IntegerAttr getAttr1Attr();
 // CHECK:   uint32_t getAttr1();
 // CHECK:   ::mlir::FloatAttr getSomeAttr2Attr();
@@ -127,6 +128,14 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // DEFS-LABEL: NS::AOp definitions
 
 // DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions)
+
+// Check that `getAttrDictionary()` is used when not using properties.
+
+// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(AOp op)
+// DEFS-SAME: op->getAttrDictionary()
+// DEFS-SAME: p.getProperties()
+// DEFS-SAME: op->getRegions()
+
 // DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions()
 // DEFS-NEXT: return odsRegions.drop_front(1);
 // DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions()
@@ -330,6 +339,17 @@ def NS_MOp : NS_Op<"op_with_single_result_and_fold_adaptor_fold", []> {
 // CHECK-LABEL: class MOp :
 // CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
 
+def NS_NOp : NS_Op<"op_with_properties", []> {
+  let arguments = (ins Property<"unsigned">:$value);
+}
+
+// Check that `getDiscardableAttrDictionary()` is used with properties.
+
+// DEFS: NOpGenericAdaptorBase::NOpGenericAdaptorBase(NOp op) : NOpGenericAdaptorBase(
+// DEFS-SAME: op->getDiscardableAttrDictionary()
+// DEFS-SAME: op.getProperties()
+// DEFS-SAME: op->getRegions()
+
 // Test that type defs have the proper namespaces when used as a constraint.
 // ---
 

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index fec1fd859fe710..795a2e3ce934e3 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3961,6 +3961,35 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
     }
   }
 
+  // Create constructors constructing the adaptor from an instance of the op.
+  // This takes the attributes, properties and regions from the op instance
+  // and the value range from the parameter.
+  {
+    // Base class is in the cpp file and can simply access the members of the op
+    // class to initialize the template independent fields.
+    auto *constructor = genericAdaptorBase.addConstructor(
+        MethodParameter(op.getCppClassName(), "op"));
+    constructor->addMemberInitializer(
+        genericAdaptorBase.getClassName(),
+        llvm::Twine(!useProperties ? "op->getAttrDictionary()"
+                                   : "op->getDiscardableAttrDictionary()") +
+            ", op.getProperties(), op->getRegions()");
+
+    // Generic adaptor is templated and therefore defined inline in the header.
+    // We cannot use the Op class here as it is an incomplete type (we have a
+    // circular reference between the two).
+    // Use a template trick to make the constructor be instantiated at call site
+    // when the op class is complete.
+    constructor = genericAdaptor.addConstructor(
+        MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op"));
+    constructor->addTemplateParam("LateInst = " + op.getCppClassName());
+    constructor->addTemplateParam(
+        "= std::enable_if_t<std::is_same_v<LateInst, " + op.getCppClassName() +
+        ">>");
+    constructor->addMemberInitializer("Base", "op");
+    constructor->addMemberInitializer("odsOperands", "values");
+  }
+
   std::string sizeAttrInit;
   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
     if (op.getDialect().usePropertiesForAttributes())
@@ -4074,9 +4103,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
     // Constructor taking the Op as single parameter.
     auto *constructor =
         adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
-    constructor->addMemberInitializer(
-        adaptor.getClassName(), "op->getOperands(), op->getAttrDictionary(), "
-                                "op.getProperties(), op->getRegions()");
+    constructor->addMemberInitializer(genericAdaptorClassName,
+                                      "op->getOperands(), op");
   }
 
   // Add verification function.


        


More information about the Mlir-commits mailing list