[Mlir-commits] [mlir] ad89eb5 - Revert "Revert "[mlir][linalg] Add nicer builders for `map` and `reduce`.""
Oleg Shyshkov
llvmlistbot at llvm.org
Fri Oct 28 02:35:21 PDT 2022
Author: Oleg Shyshkov
Date: 2022-10-28T11:31:58+02:00
New Revision: ad89eb5b1fccf002eb59dfbab0fdb515ea3e65b7
URL: https://github.com/llvm/llvm-project/commit/ad89eb5b1fccf002eb59dfbab0fdb515ea3e65b7
DIFF: https://github.com/llvm/llvm-project/commit/ad89eb5b1fccf002eb59dfbab0fdb515ea3e65b7.diff
LOG: Revert "Revert "[mlir][linalg] Add nicer builders for `map` and `reduce`.""
This reverts commit 7eef3ea5f4fe4f4cc461b191bac031e3962d0347.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 1692a0f9c5492..510f8831f019a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -267,6 +267,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [
let results = (outs Variadic<AnyTensor>:$result);
let regions = (region SizedRegion<1>:$mapper);
+ let builders = [
+ OpBuilder<(ins "ValueRange":$inputs, "Value":$init,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>",
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+ ];
+
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Implement functions necessary for LinalgStructuredInterface.
SmallVector<StringRef> getIteratorTypesArray();
@@ -341,6 +347,13 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
let results = (outs Variadic<AnyTensor>);
let regions = (region SizedRegion<1>:$combiner);
+ let builders = [
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
+ "ArrayRef<int64_t>":$dimensions,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>",
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+ ];
+
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<StringRef> getIteratorTypesArray();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 896fcf44e3934..e9f26306d58b7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -661,6 +661,26 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
// GenericOp
//===----------------------------------------------------------------------===//
+static void buildGenericRegion(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+ SmallVector<Type, 4> blockArgTypes;
+ SmallVector<Location, 4> blockArgLocs;
+ for (ValueRange container : {inputs, outputs}) {
+ for (Value v : container) {
+ blockArgTypes.push_back(getElementTypeOrSelf(v));
+ blockArgLocs.push_back(v.getLoc());
+ }
+ }
+
+ OpBuilder::InsertionGuard guard(builder);
+ auto ®ion = *result.regions.front();
+ Block *bodyBlock =
+ builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
+ bodyBuild(builder, result.location, bodyBlock->getArguments());
+}
+
void GenericOp::getAsmBlockArgumentNames(Region ®ion,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
@@ -678,23 +698,8 @@ void GenericOp::build(
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
iteratorTypes, doc, libraryCall);
result.addAttributes(attributes);
- if (!bodyBuild)
- return;
-
- SmallVector<Type, 4> blockArgTypes;
- SmallVector<Location, 4> blockArgLocs;
- for (ValueRange container : {inputs, outputs}) {
- for (Value v : container) {
- blockArgTypes.push_back(getElementTypeOrSelf(v));
- blockArgLocs.push_back(v.getLoc());
- }
- }
-
- OpBuilder::InsertionGuard guard(builder);
- auto ®ion = *result.regions.front();
- Block *bodyBlock =
- builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
- bodyBuild(builder, result.location, bodyBlock->getArguments());
+ if (bodyBuild)
+ buildGenericRegion(builder, result, inputs, outputs, bodyBuild);
}
void GenericOp::build(
@@ -1329,6 +1334,22 @@ void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResults().front(), "mapped");
}
+void MapOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+ ArrayRef<NamedAttribute> attributes) {
+ build(builder, result, TypeRange{}, inputs, init);
+ result.addAttributes(attributes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ Type initType = init.getType();
+ if (initType.isa<RankedTensorType>())
+ result.addTypes(initType);
+
+ if (bodyBuild)
+ buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild);
+}
+
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseDstStyleOp(parser, result))
return failure();
@@ -1436,6 +1457,25 @@ void ReduceOp::getAsmResultNames(
setNameFn(getResults().front(), "reduced");
}
+void ReduceOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange inits, ArrayRef<int64_t> dimensions,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+ ArrayRef<NamedAttribute> attributes) {
+ build(builder, result, TypeRange{}, inputs, inits, dimensions);
+ result.addAttributes(attributes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ for (Value init : inits) {
+ Type initType = init.getType();
+ if (initType.isa<RankedTensorType>())
+ result.addTypes(initType);
+ }
+
+ if (bodyBuild)
+ buildGenericRegion(builder, result, inputs, inits, bodyBuild);
+}
+
SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
SmallVector<StringRef> iteratorTypes(inputRank,
@@ -1618,45 +1658,32 @@ TransposeOp::getRegionBuilder() {
};
}
-void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
- ::mlir::OperationState &odsState) {
- Region *region = odsState.addRegion();
-
- SmallVector<Type> argTypes;
- SmallVector<Location> argLocs;
- for (auto t : odsState.operands) {
- argTypes.push_back(getElementTypeOrSelf(t));
- argLocs.push_back(opBuilder.getUnknownLoc());
- }
-
- // RAII.
- OpBuilder::InsertionGuard guard(opBuilder);
- Block *body =
- opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
-
- ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
- getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
-}
-
-void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState, Value input,
- Value init, DenseI64ArrayAttr permutation,
+void TransposeOp::build(::mlir::OpBuilder &builder,
+ ::mlir::OperationState &result, Value input, Value init,
+ DenseI64ArrayAttr permutation,
ArrayRef<NamedAttribute> attributes) {
- odsState.addOperands(input);
- odsState.addOperands(init);
- odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
- odsState.addAttributes(attributes);
- odsState.addTypes(init.getType());
+ result.addOperands(input);
+ result.addOperands(init);
+ result.addAttribute(getPermutationAttrName(result.name), permutation);
+ result.addAttributes(attributes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ Type initType = init.getType();
+ if (initType.isa<RankedTensorType>())
+ result.addTypes(initType);
- createRegion(odsBuilder, odsState);
+ buildGenericRegion(builder, result, input, init,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ b.create<linalg::YieldOp>(loc, args[0]);
+ });
}
-void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState, Value input,
- Value init, ArrayRef<int64_t> permutation,
+void TransposeOp::build(::mlir::OpBuilder &builder,
+ ::mlir::OperationState &result, Value input, Value init,
+ ArrayRef<int64_t> permutation,
ArrayRef<NamedAttribute> attributes) {
- build(odsBuilder, odsState, input, init,
- odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
+ build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
+ attributes);
}
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1666,8 +1693,13 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
})))
return failure();
- OpBuilder opBuilder(parser.getContext());
- createRegion(opBuilder, result);
+ (void)result.addRegion();
+ OpBuilder builder(parser.getContext());
+ buildGenericRegion(builder, result, /*inputs=*/result.operands,
+ /*outputs=*/{},
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ b.create<linalg::YieldOp>(loc, args[0]);
+ });
return success();
}
More information about the Mlir-commits
mailing list