[Mlir-commits] [mlir] 95ddc83 - [mlir][Linalg] Allow all build methods of Structured ops to specify additional attributes.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 19 11:14:51 PDT 2021


Author: MaheshRavishankar
Date: 2021-08-19T11:14:35-07:00
New Revision: 95ddc8341ae2c27229ad3dcf1d55abebcec15d02

URL: https://github.com/llvm/llvm-project/commit/95ddc8341ae2c27229ad3dcf1d55abebcec15d02
DIFF: https://github.com/llvm/llvm-project/commit/95ddc8341ae2c27229ad3dcf1d55abebcec15d02.diff

LOG: [mlir][Linalg] Allow all build methods of Structured ops to specify additional attributes.

Differential Revision: https://reviews.llvm.org/D108338

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 33f4992e41f9d..1d4e6d546067e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -620,18 +620,22 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
       "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
       "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
       "StringRef":$libraryCall,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
+      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
       "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
       "StringRef":$doc, "StringRef":$libraryCall,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
+      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
       "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
       "ArrayRef<StringRef>":$iteratorTypes,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
+      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
       "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>
+      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
   ];
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6b65a9ecd9e51..f4750ca390a88 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -502,13 +502,15 @@ void GenericOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
   build(builder, result, resultTensorTypes, inputs, outputs,
         builder.getAffineMapArrayAttr(indexingMaps),
         builder.getStrArrayAttr(iteratorTypes),
         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
         libraryCall.empty() ? StringAttr()
                             : builder.getStringAttr(libraryCall));
+  result.addAttributes(attributes);
   if (!bodyBuild)
     return;
 
@@ -527,30 +529,33 @@ void GenericOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
   build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
-        iteratorTypes, doc, libraryCall, bodyBuild);
+        iteratorTypes, doc, libraryCall, bodyBuild, attributes);
 }
 
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
     ArrayRef<StringRef> iteratorTypes,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
   build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
         /*doc=*/"",
-        /*libraryCall=*/"", bodyBuild);
+        /*libraryCall=*/"", bodyBuild, attributes);
 }
 
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
     ArrayRef<StringRef> iteratorTypes,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
         iteratorTypes,
         /*doc=*/"",
-        /*libraryCall=*/"", bodyBuild);
+        /*libraryCall=*/"", bodyBuild, attributes);
 }
 
 static void print(OpAsmPrinter &p, GenericOp op) {

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 471961f837bf3..743bdbdb12d6a 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -169,7 +169,8 @@ It has one output.
 // ODS-LABEL: def Test7Op
 // ODS:         OpBuilder<
 // ODS:           (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-// ODS:            "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b)
+// ODS:            "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b,
+// ODS:           CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)
 // ODS:           $_state.addAttribute("attr_a", attr_a);
 // ODS:           $_state.addAttribute("attr_b", attr_b);
 //

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 6613ab2a006f1..f8393a232e43e 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -75,6 +75,7 @@ structured_op: !LinalgStructuredOpConfig
 #  ODS-NEXT:      $_builder.getI32VectorAttr({
 #  ODS-NEXT:        static_cast<int32_t>(inputs.size()),
 #  ODS-NEXT:        static_cast<int32_t>(outputs.size())}));
+#  ODS-NEXT:    $_state.addAttributes(attributes);
 #  ODS-NEXT:    createAndFillStructuredOpRegion<Test1Op>(
 #  ODS-NEXT:      $_builder,
 #  ODS-NEXT:      $_state,

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 1bdb5b8806d0d..590f17fdedfa2 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1910,7 +1910,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
       let skipDefaultBuilders = 1;
       let builders = [
         OpBuilder<
-        (ins "ValueRange":$inputs, "ValueRange":$outputs),
+        (ins "ValueRange":$inputs, "ValueRange":$outputs,
+             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
         [{{
           $_state.addOperands(inputs);
           $_state.addOperands(outputs);
@@ -1919,6 +1920,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder.getI32VectorAttr({{
               static_cast<int32_t>(inputs.size()),
               static_cast<int32_t>(outputs.size())}));
+          $_state.addAttributes(attributes);
           createAndFillStructuredOpRegion<{0}>(
             $_builder,
             $_state,
@@ -1927,7 +1929,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         }]>,
         OpBuilder<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-             "ValueRange":$outputs),
+             "ValueRange":$outputs,
+             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
         [{{
           $_state.addOperands(inputs);
           $_state.addOperands(outputs);
@@ -1937,6 +1940,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder.getI32VectorAttr({{
               static_cast<int32_t>(inputs.size()),
               static_cast<int32_t>(outputs.size())}));
+          $_state.addAttributes(attributes);
           createAndFillStructuredOpRegion<{0}>(
             $_builder,
             $_state,
@@ -2020,7 +2024,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
     const char *builderFmt = R"FMT(
       , OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-           "ValueRange":$outputs, {1}),
+           "ValueRange":$outputs, {1},
+           CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
       [{{
         $_state.addOperands(inputs);
         $_state.addOperands(outputs);
@@ -2030,6 +2035,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
           $_builder.getI32VectorAttr({{
             static_cast<int32_t>(inputs.size()),
             static_cast<int32_t>(outputs.size())}));
+        $_state.addAttributes(attributes);
         createAndFillStructuredOpRegion<{0}>(
           $_builder,
           $_state,

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index a0eb1dea88603..98e90b69d631d 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -457,7 +457,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
     let skipDefaultBuilders = 1;
     let builders = [
       OpBuilder<
-      (ins "ValueRange":$inputs, "ValueRange":$outputs),
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
       [{{
         $_state.addOperands(inputs);
         $_state.addOperands(outputs);
@@ -471,6 +472,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
           $_builder.getI32VectorAttr({{
             static_cast<int32_t>(inputs.size()),
             static_cast<int32_t>(outputs.size())}));
+        $_state.addAttributes(attributes);
         createAndFillStructuredOpRegion<{0}>(
           $_builder,
           $_state,
@@ -539,7 +541,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
 static const char structuredOpBuilderFormat[] = R"FMT(
   , OpBuilder<
   (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-       "ValueRange":$outputs, {1}),
+       "ValueRange":$outputs, {1},
+       CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
   [{{
     $_state.addOperands(inputs);
     $_state.addOperands(outputs);
@@ -555,6 +558,7 @@ static const char structuredOpBuilderFormat[] = R"FMT(
       TypeRange(inputs),
       TypeRange(outputs));
     {2}
+    $_state.addAttributes(attributes);
   }]>
 )FMT";
 


        


More information about the Mlir-commits mailing list