[Mlir-commits] [mlir] 348bfc8 - [mlir][linalg] Add attributes to region builder (NFC).

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 14 05:17:18 PST 2022


Author: gysit
Date: 2022-02-14T13:14:14Z
New Revision: 348bfc8e50ea69e17e057636dd2823f60b88a034

URL: https://github.com/llvm/llvm-project/commit/348bfc8e50ea69e17e057636dd2823f60b88a034
DIFF: https://github.com/llvm/llvm-project/commit/348bfc8e50ea69e17e057636dd2823f60b88a034.diff

LOG: [mlir][linalg] Add attributes to region builder (NFC).

Adapt the region builder signature to hand in the attributes of the created ops. The revision is a preparation step the support named ops that need access to the operation attributes during op creation.

Depends On D119692

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D119693

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/CAPI/Dialect/Linalg.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 995aab50f9ee8..84f5690a8649f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -47,8 +47,8 @@ def Linalg_Dialect : Dialect {
     constexpr const static ::llvm::StringLiteral
         kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";
 
-    using RegionBuilderFunType =
-      llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &)>;
+    using RegionBuilderFunType = llvm::function_ref<
+      void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>)>;
     RegionBuilderFunType getRegionBuilder(StringRef name) {
       return namedStructuredOpRegionBuilders.lookup(name);
     }

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index f40f82ed7cd94..00de7ba0265bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1025,7 +1025,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         Returns a null function if this named op does not define a region
         builder.
       }],
-      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &)>",
+      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
       /*methodName=*/"getRegionBuilder",
       (ins),
       [{ return ConcreteOp::getRegionBuilder(); }]

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 8bdec0971ee18..33e2422060f2f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -83,8 +83,10 @@ def FillOp : LinalgStructured_Op<"fill", []> {
           extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
     }
 
-    static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
-    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
+    static void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                              ArrayRef<NamedAttribute> attrs);
+    static std::function<void(ImplicitLocOpBuilder&,
+                              Block&, ArrayRef<NamedAttribute>)>
     getRegionBuilder() {
       return ®ionBuilder;
     }
@@ -254,7 +256,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> {
         library_call()->str() : "op_has_no_registered_library_name";
     }
 
-    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
+    static std::function<void(ImplicitLocOpBuilder &,
+                              Block &, ArrayRef<NamedAttribute>)>
     getRegionBuilder() {
       return nullptr;
     }

diff  --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 8862b6b154ea5..bfb3313d1a21d 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
   Region &region = op->getRegion(0);
   Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
   b.setInsertionPointToStart(body);
-  fun(b, *body);
+  fun(b, *body, op->getAttrs());
 }
 
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4868fdb99341a..87278dcba0896 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -49,7 +49,7 @@ using namespace mlir::linalg;
 template <typename NamedStructuredOpType>
 static void fillStructuredOpRegion(
     OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
-    TypeRange outputTypes,
+    TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
     llvm::function_ref<void(unsigned, unsigned)> errorHandler = nullptr);
 
 /// Generic entry point to create both the region and the block of a LinalgOp.
@@ -72,7 +72,8 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p,
 template <typename NamedStructuredOpType>
 static ParseResult
 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
-                             TypeRange inputTypes, TypeRange outputTypes);
+                             TypeRange inputTypes, TypeRange outputTypes,
+                             ArrayRef<NamedAttribute> attrs);
 
 static ParseResult
 parseNamedStructuredOpResults(OpAsmParser &parser,
@@ -375,7 +376,8 @@ class RegionBuilderHelper {
 //===----------------------------------------------------------------------===//
 // FillOp
 //===----------------------------------------------------------------------===//
-void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
+void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                           ArrayRef<NamedAttribute> attrs) {
   assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
   b.create<linalg::YieldOp>(block.getArgument(0));
 }
@@ -384,16 +386,16 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value value,
                    Value output) {
   build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
         output);
-  fillStructuredOpRegion<FillOp>(builder, *result.regions.front(),
-                                 TypeRange{value.getType()},
-                                 TypeRange{output.getType()}, {});
+  fillStructuredOpRegion<FillOp>(
+      builder, *result.regions.front(), TypeRange{value.getType()},
+      TypeRange{output.getType()}, result.attributes.getAttrs(), {});
 }
 
 ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType,
                               Type outputType) {
   OpBuilder opBuilder(parser.getContext());
   fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType},
-                                 TypeRange{outputType});
+                                 TypeRange{outputType}, {});
   return success();
 }
 
@@ -1820,7 +1822,7 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
 template <typename NamedStructuredOpType>
 static void fillStructuredOpRegion(
     OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
-    TypeRange outputTypes,
+    TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
     llvm::function_ref<void(unsigned, unsigned)> errorHandler) {
   assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
 
@@ -1851,7 +1853,7 @@ static void fillStructuredOpRegion(
 
   opBuilder.setInsertionPointToStart(body);
   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
-  NamedStructuredOpType::regionBuilder(b, *body);
+  NamedStructuredOpType::regionBuilder(b, *body, attrs);
 
   // indexing_maps is an auto-generated method.
 
@@ -1866,7 +1868,7 @@ void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
                                      TypeRange outputTypes) {
   Region &region = *result.addRegion();
   fillStructuredOpRegion<NamedStructuredOpType>(
-      opBuilder, region, inputTypes, outputTypes,
+      opBuilder, region, inputTypes, outputTypes, result.attributes.getAttrs(),
       [&](unsigned expected, unsigned actual) {
         assert(expected != actual && "incorrect number of arguments");
       });
@@ -1929,14 +1931,15 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p,
 template <typename NamedStructuredOpType>
 static ParseResult
 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
-                             TypeRange inputTypes, TypeRange outputTypes) {
+                             TypeRange inputTypes, TypeRange outputTypes,
+                             ArrayRef<NamedAttribute> attrs) {
   ParseResult res = success();
   OpBuilder opBuilder(parser.getContext());
   // Resolve `captures` into `capturedValues` at parse time so we can build the
   // region with captures.
   SmallVector<Value> capturedValues;
   fillStructuredOpRegion<NamedStructuredOpType>(
-      opBuilder, region, inputTypes, outputTypes,
+      opBuilder, region, inputTypes, outputTypes, attrs,
       [&](unsigned expected, unsigned actual) {
         res = parser.emitError(
             parser.getCurrentLocation(),
@@ -1973,7 +1976,8 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
 
   std::unique_ptr<Region> region = std::make_unique<Region>();
   if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
-          parser, *region, inputTypes, outputTypes))
+          parser, *region, inputTypes, outputTypes,
+          result.attributes.getAttrs()))
     return failure();
   result.addRegion(std::move(region));
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5b7d973429269..f5834efe9cb5a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2555,11 +2555,13 @@ def TestLinalgConvOp :
   let extraClassDeclaration = [{
     bool hasIndexSemantics() { return false; }
 
-    static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block) {
+    static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+                              mlir::ArrayRef<mlir::NamedAttribute> attrs) {
       b.create<mlir::linalg::YieldOp>(block.getArguments().back());
     }
 
-    static std::function<void(mlir::ImplicitLocOpBuilder &b, mlir::Block &block)>
+    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                              mlir::ArrayRef<mlir::NamedAttribute>)>
     getRegionBuilder() {
       return ®ionBuilder;
     }

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 347923825cee3..ba44e1eeb6262 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -83,8 +83,8 @@ structured_op: !LinalgStructuredOpConfig
 #  ODS-NEXT:      TypeRange(inputs),
 #  ODS-NEXT:      TypeRange(outputs)
 
-# IMPL-LABEL:  void Test1Op::regionBuilder(
-#       IMPL:    ImplicitLocOpBuilder &b, Block &block)
+# IMPL-LABEL:  void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
+#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
 #       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
 #   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]);
 #   IMPL-DAG:  Value [[VAL2:[a-z0-9]+]] = helper.index(1);
@@ -174,7 +174,8 @@ structured_op: !LinalgStructuredOpConfig
 #       IMPL:  auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
 #       IMPL:  "incorrect element type for index attribute 'strides'"
 #       IMPL:  "incorrect shape for index attribute 'strides'"
-#       IMPL:  void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block)
+#       IMPL:  void Test2Op::regionBuilder(ImplicitLocOpBuilder &b,
+#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
 #  IMPL-NEXT:    assert(2 > 0 && block.getNumArguments() == 2 &&
 
 #       IMPL:   yields.push_back(block.getArgument(0));

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index f1fac9f578612..eb1af791f1cf9 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -523,8 +523,10 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
       // Auto-generated.
       ArrayAttr iterator_types();
       ArrayAttr indexing_maps();
-      static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
-      static std::function<void(ImplicitLocOpBuilder &b, Block &)>
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
       getRegionBuilder() {{
         return regionBuilder;
       }
@@ -952,7 +954,8 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
     // {1}: Number of args
     // {2}: Statements
     static const char structuredOpRegionBuilderFormat[] = R"FMT(
-void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
+void {0}::regionBuilder(ImplicitLocOpBuilder &b,
+                        Block &block, ArrayRef<NamedAttribute> attrs) {{
   assert({1} > 0 && block.getNumArguments() == {1} &&
          "{0} regionBuilder expects {1} (>=0) args");
   RegionBuilderHelper helper(block.getArgument(0).getContext(), block);


        


More information about the Mlir-commits mailing list