[Mlir-commits] [mlir] dfafba3 - [mlir][linalg] Add callback-based builders for `linalg.(indexed_)generic`.
Alexander Belyaev
llvmlistbot at llvm.org
Fri Jun 19 05:00:52 PDT 2020
Author: Alexander Belyaev
Date: 2020-06-19T13:55:20+02:00
New Revision: dfafba3989648a0d16292a36c57865c1e28b9f5a
URL: https://github.com/llvm/llvm-project/commit/dfafba3989648a0d16292a36c57865c1e28b9f5a
DIFF: https://github.com/llvm/llvm-project/commit/dfafba3989648a0d16292a36c57865c1e28b9f5a.diff
LOG: [mlir][linalg] Add callback-based builders for `linalg.(indexed_)generic`.
Differential Revision: https://reviews.llvm.org/D82045
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index cddd4f9b22f8..85fdd1e3f34e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -508,20 +508,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
}
}];
- let builders = [
- OpBuilder<"OpBuilder &builder, OperationState &result, "
- "ArrayRef<Type> resultTypes, ValueRange args, "
- "int64_t inputCount, int64_t outputCount, "
- "ArrayRef<AffineMap> indexingMaps, "
- "ArrayRef<StringRef> iteratorTypes", [{
- return build(builder, result, resultTypes, args,
- builder.getI64IntegerAttr(inputCount),
- builder.getI64IntegerAttr(outputCount),
- builder.getAffineMapArrayAttr(indexingMaps),
- builder.getStrArrayAttr(iteratorTypes),
- /*doc=*/nullptr, /*library_call=*/nullptr);
- }]>];
-
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseGenericOp(parser, result); }];
}
@@ -637,6 +623,14 @@ def GenericOp : GenericOpBase<"generic"> {
future.
}];
+ let builders = [
+ OpBuilder<
+ "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes, "
+ "ValueRange args, int64_t inputCount, int64_t outputCount, "
+ "ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
+ "function_ref<void(OpBuilder &, Location, ValueRange)> = nullptr">
+ ];
+
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
@@ -763,6 +757,16 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
future.
}];
+ let builders = [
+ OpBuilder<
+ "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes, "
+ "ValueRange args, int64_t inputCount, int64_t outputCount, "
+ "ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
+ "function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> "
+ "= nullptr">
+ ];
+
+
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c8401977d612..8012a1087ee1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -70,6 +70,58 @@ static LogicalResult foldMemRefCast(Operation *op) {
// GenericOps
//===----------------------------------------------------------------------===//
+void GenericOp::build(
+ OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes,
+ ValueRange args, int64_t inputCount, int64_t outputCount,
+ ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+ build(builder, result, resultTypes, args,
+ builder.getI64IntegerAttr(inputCount),
+ builder.getI64IntegerAttr(outputCount),
+ builder.getAffineMapArrayAttr(indexingMaps),
+ builder.getStrArrayAttr(iteratorTypes),
+ /*doc=*/nullptr, /*library_call=*/nullptr);
+ if (!bodyBuild)
+ return;
+
+ SmallVector<Type, 4> blockArgTypes;
+ for (Value arg : args)
+ blockArgTypes.push_back(arg.getType().cast<ShapedType>().getElementType());
+
+ OpBuilder::InsertionGuard guard(builder);
+ auto ®ion = *result.regions.front();
+ Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes);
+ bodyBuild(builder, result.location, bodyBlock->getArguments());
+}
+
+void IndexedGenericOp::build(
+ OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes,
+ ValueRange args, int64_t inputCount, int64_t outputCount,
+ ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+ function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
+ bodyBuild) {
+ build(builder, result, resultTypes, args,
+ builder.getI64IntegerAttr(inputCount),
+ builder.getI64IntegerAttr(outputCount),
+ builder.getAffineMapArrayAttr(indexingMaps),
+ builder.getStrArrayAttr(iteratorTypes),
+ /*doc=*/nullptr, /*library_call=*/nullptr);
+ if (!bodyBuild)
+ return;
+
+ unsigned nLoops = iteratorTypes.size();
+ SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
+ for (Value arg : args)
+ blockArgTypes.push_back(arg.getType().cast<ShapedType>().getElementType());
+
+ OpBuilder::InsertionGuard guard(builder);
+ auto ®ion = *result.regions.front();
+ Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes);
+ bodyBuild(builder, result.location,
+ bodyBlock->getArguments().take_front(nLoops),
+ bodyBlock->getArguments().drop_front(nLoops));
+}
+
template <typename GenericOpType>
static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
auto attrNames = op.linalgTraitAttrNames();
More information about the Mlir-commits
mailing list