[Mlir-commits] [mlir] 9744606 - [MLIR] Change default builders generated by TableGen to use TypeRange for result types

Rahul Joshi llvmlistbot at llvm.org
Wed Sep 23 09:06:33 PDT 2020


Author: Rahul Joshi
Date: 2020-09-23T09:06:07-07:00
New Revision: 9744606614df4ba85a4d546c94b3b5ef9d3a3a96

URL: https://github.com/llvm/llvm-project/commit/9744606614df4ba85a4d546c94b3b5ef9d3a3a96
DIFF: https://github.com/llvm/llvm-project/commit/9744606614df4ba85a4d546c94b3b5ef9d3a3a96.diff

LOG: [MLIR] Change default builders generated by TableGen to use TypeRange for result types

- Change the default builders to use TypeRange instead of ArrayRef<Type>
- Custom builders defined in LinalgStructuredOps now conflict with the default
  separate param ones, but the default collective params one is still needed. Resolve
  this by replicating the collective param builder as a custom builder and skipping
  the generation of default builders for these ops.

Differential Revision: https://reviews.llvm.org/D87926

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/test/mlir-tblgen/op-decl.td
    mlir/test/mlir-tblgen/op-result.td
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 330cdca6c0ab..f1befa4d980b 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2006,8 +2006,8 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
   //
   // ```c++
   // static void build(OpBuilder &, OperationState &odsState,
-  //                   ArrayRef<Type> resultTypes,
-  //                   ArrayRef<Value> operands,
+  //                   TypeRange resultTypes,
+  //                   ValueRange operands,
   //                   ArrayRef<NamedAttribute> attributes);
   // ```
   list<OpBuilder> builders = ?;

diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 0392264440f1..4ff77dc2c3f7 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -81,9 +81,9 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   ::mlir::FloatAttr attr2Attr()
 // CHECK:   ::llvm::Optional< ::llvm::APFloat > attr2();
 // CHECK:   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value val);
-// CHECK:   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount)
-// CHECK:   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount)
-// CHECK:   static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions)
+// CHECK:   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount)
+// CHECK:   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount)
+// CHECK:   static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions)
 // CHECK:   static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
 // CHECK:   void print(::mlir::OpAsmPrinter &p);
 // CHECK:   ::mlir::LogicalResult verify();
@@ -180,8 +180,8 @@ def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> {
 
 // CHECK_LABEL: class NS_HCollectiveParamsOp :
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a);
-// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a);
-// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
 
 // Check suppression of "separate arg, separate result" build method for an op
 // with single variadic arg and single variadic result (since it will be
@@ -192,8 +192,8 @@ def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> {
 }
 
 // CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op :
-// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
-// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
 // Check suppression of "separate arg, collective result" build method for an op
 // with single variadic arg and non variadic result (since it will be
@@ -204,8 +204,8 @@ def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> {
 }
 
 // CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op :
-// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
-// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
 // Check suppression of "separate arg, collective result" build method for an op
 // with single variadic arg and > 1 variadic result (since it will be
@@ -217,9 +217,9 @@ def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVari
   let results = (outs Variadic<I32>:$b, Variadic<F32>:$c);
 }
 // CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op :
-// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::llvm::ArrayRef<::mlir::Type> c, ::mlir::ValueRange a);
-// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a);
-// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a);
+// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
 // Check default value of `attributes` for the `genUseOperandAsResultTypeCollectiveParamBuilder` builder
 def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperandsAndResultType]> {
@@ -228,8 +228,8 @@ def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperands
 }
 // CHECK_LABEL: class NS_IOp :
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
-// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b);
-// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
@@ -241,8 +241,8 @@ def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterface
 // CHECK_LABEL: class NS_JOp :
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
-// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b);
-// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
+// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
+// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
 // Check that default builders can be suppressed.

diff  --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 68492202b4a6..ac8f1cbd49f6 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -14,7 +14,7 @@ def OpA : NS_Op<"one_normal_result_op", []> {
 }
 
 // CHECK-LABEL: void OpA::build
-// CHECK:         ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands
+// CHECK:         ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands
 // CHECK:         assert(resultTypes.size() == 1u && "mismatched number of return types");
 // CHECK-NEXT:    odsState.addTypes(resultTypes);
 
@@ -39,7 +39,7 @@ def OpC : NS_Op<"three_normal_result_op", []> {
 // CHECK-NEXT:   odsState.addTypes(resultType1)
 // CHECK-NEXT:   odsState.addTypes(z)
 
-// CHECK:      void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes) {
+// CHECK:      void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes) {
 // CHECK-NEXT:   assert(resultTypes.size() == 3u && "mismatched number of results");
 // CHECK-NEXT:   odsState.addTypes(resultTypes);
 
@@ -67,7 +67,7 @@ def OpF : NS_Op<"one_variadic_result_op", []> {
 }
 
 // CHECK-LABEL: void OpF::build
-// CHECK-SAME:    ::llvm::ArrayRef<::mlir::Type> x
+// CHECK-SAME:    ::mlir::TypeRange x
 // CHECK-NOT:     assert
 // CHECK:         odsState.addTypes(x);
 
@@ -78,12 +78,12 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
 
 // CHECK-LABEL: OpG definitions
 
-// CHECK:      void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::llvm::ArrayRef<::mlir::Type> y)
+// CHECK:      void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::mlir::TypeRange y)
 // CHECK-NEXT:   odsState.addTypes(x);
 // CHECK-NEXT:   odsState.addTypes(y);
 
 // CHECK:       void OpG::build
-// CHECK:         ::llvm::ArrayRef<::mlir::Type> resultTypes
+// CHECK:         ::mlir::TypeRange resultTypes
 // CHECK:         assert(resultTypes.size() >= 1u && "mismatched number of return types");
 // CHECK-NEXT:    odsState.addTypes(resultTypes);
 

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 4efdaf6d1d86..4fe3cd1ee174 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1455,8 +1455,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
       let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
       let regions = (region AnyRegion:$region);
 
+      let skipDefaultBuilders = 1;
       let builders = [ OpBuilder<
-        "OpBuilder &b, OperationState &result,"
+        "OpBuilder &b, OperationState &result, "
         "ValueRange inputs, ValueRange outputBuffers",
         [{{
           result.addOperands(inputs);
@@ -1493,6 +1494,14 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             TypeRange(outputBuffers),
             TypeRange(initTensors),
             resultTensorTypes);
+        }]>, OpBuilder<
+        "OpBuilder &b, OperationState &result, TypeRange resultTensorTypes,"
+        "ValueRange operands, ArrayRef<NamedAttribute> attributes = {{}",
+        [{{
+          result.addOperands(operands);
+          result.addAttributes(attributes);
+          result.addTypes(resultTensorTypes);
+          (void)result.addRegion();
         }]>
       ];
       let printer = [{{ return ::printNamedStructuredOp(p, *this); }];

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index f89547d90493..18b5d26d880c 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1252,7 +1252,7 @@ void OpEmitter::genCollectiveParamBuilder() {
   SmallVector<OpMethodParameter, 4> paramList;
   paramList.emplace_back("::mlir::OpBuilder &", "");
   paramList.emplace_back("::mlir::OperationState &", builderOpState);
-  paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes");
+  paramList.emplace_back("::mlir::TypeRange", "resultTypes");
   paramList.emplace_back("::mlir::ValueRange", "operands");
   // Provide default value for `attributes` when its the last parameter
   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
@@ -1322,8 +1322,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
       if (resultName.empty())
         resultName = std::string(formatv("resultType{0}", i));
 
-      StringRef type = result.isVariadic() ? "::llvm::ArrayRef<::mlir::Type>"
-                                           : "::mlir::Type";
+      StringRef type =
+          result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
       if (result.isOptional())
         properties = OpMethodParameter::PP_Optional;
@@ -1333,7 +1333,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
     }
   } break;
   case TypeParamKind::Collective: {
-    paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes");
+    paramList.emplace_back("::mlir::TypeRange", "resultTypes");
     resultTypeNames.push_back("resultTypes");
   } break;
   }


        


More information about the Mlir-commits mailing list