[Mlir-commits] [mlir] dfafba3 - [mlir][linalg] Add callback-based builders for `linalg.(indexed_)generic`.

Alexander Belyaev llvmlistbot at llvm.org
Fri Jun 19 05:00:52 PDT 2020


Author: Alexander Belyaev
Date: 2020-06-19T13:55:20+02:00
New Revision: dfafba3989648a0d16292a36c57865c1e28b9f5a

URL: https://github.com/llvm/llvm-project/commit/dfafba3989648a0d16292a36c57865c1e28b9f5a
DIFF: https://github.com/llvm/llvm-project/commit/dfafba3989648a0d16292a36c57865c1e28b9f5a.diff

LOG: [mlir][linalg] Add callback-based builders for `linalg.(indexed_)generic`.

Differential Revision: https://reviews.llvm.org/D82045

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index cddd4f9b22f8..85fdd1e3f34e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -508,20 +508,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
      }
   }];
 
-  let builders = [
-    OpBuilder<"OpBuilder &builder, OperationState &result, "
-              "ArrayRef<Type> resultTypes, ValueRange args, "
-              "int64_t inputCount, int64_t outputCount, "
-              "ArrayRef<AffineMap> indexingMaps, "
-              "ArrayRef<StringRef> iteratorTypes", [{
-        return build(builder, result, resultTypes, args,
-                     builder.getI64IntegerAttr(inputCount),
-                     builder.getI64IntegerAttr(outputCount),
-                     builder.getAffineMapArrayAttr(indexingMaps),
-                     builder.getStrArrayAttr(iteratorTypes),
-                     /*doc=*/nullptr, /*library_call=*/nullptr);
-  }]>];
-
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parseGenericOp(parser, result); }];
 }
@@ -637,6 +623,14 @@ def GenericOp : GenericOpBase<"generic"> {
     future.
   }];
 
+  let builders = [
+    OpBuilder<
+      "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes, "
+      "ValueRange args, int64_t inputCount, int64_t outputCount, "
+      "ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
+      "function_ref<void(OpBuilder &, Location, ValueRange)> = nullptr">
+  ];
+
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
@@ -763,6 +757,16 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
     future.
   }];
 
+  let builders = [
+    OpBuilder<
+      "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes, "
+      "ValueRange args, int64_t inputCount, int64_t outputCount, "
+      "ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
+      "function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> "
+      "= nullptr">
+  ];
+
+
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c8401977d612..8012a1087ee1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -70,6 +70,58 @@ static LogicalResult foldMemRefCast(Operation *op) {
 // GenericOps
 //===----------------------------------------------------------------------===//
 
+void GenericOp::build(
+    OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes,
+    ValueRange args, int64_t inputCount, int64_t outputCount,
+    ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+  build(builder, result, resultTypes, args,
+        builder.getI64IntegerAttr(inputCount),
+        builder.getI64IntegerAttr(outputCount),
+        builder.getAffineMapArrayAttr(indexingMaps),
+        builder.getStrArrayAttr(iteratorTypes),
+        /*doc=*/nullptr, /*library_call=*/nullptr);
+  if (!bodyBuild)
+    return;
+
+  SmallVector<Type, 4> blockArgTypes;
+  for (Value arg : args)
+    blockArgTypes.push_back(arg.getType().cast<ShapedType>().getElementType());
+
+  OpBuilder::InsertionGuard guard(builder);
+  auto &region = *result.regions.front();
+  Block *bodyBlock = builder.createBlock(&region, region.end(), blockArgTypes);
+  bodyBuild(builder, result.location, bodyBlock->getArguments());
+}
+
+void IndexedGenericOp::build(
+    OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes,
+    ValueRange args, int64_t inputCount, int64_t outputCount,
+    ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+    function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
+        bodyBuild) {
+  build(builder, result, resultTypes, args,
+        builder.getI64IntegerAttr(inputCount),
+        builder.getI64IntegerAttr(outputCount),
+        builder.getAffineMapArrayAttr(indexingMaps),
+        builder.getStrArrayAttr(iteratorTypes),
+        /*doc=*/nullptr, /*library_call=*/nullptr);
+  if (!bodyBuild)
+    return;
+
+  unsigned nLoops = iteratorTypes.size();
+  SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
+  for (Value arg : args)
+    blockArgTypes.push_back(arg.getType().cast<ShapedType>().getElementType());
+
+  OpBuilder::InsertionGuard guard(builder);
+  auto &region = *result.regions.front();
+  Block *bodyBlock = builder.createBlock(&region, region.end(), blockArgTypes);
+  bodyBuild(builder, result.location,
+            bodyBlock->getArguments().take_front(nLoops),
+            bodyBlock->getArguments().drop_front(nLoops));
+}
+
 template <typename GenericOpType>
 static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
   auto attrNames = op.linalgTraitAttrNames();


        


More information about the Mlir-commits mailing list