[Mlir-commits] [mlir] 7eef3ea - Revert "[mlir][linalg] Add nicer builders for `map` and `reduce`."
Oleg Shyshkov
llvmlistbot at llvm.org
Fri Oct 28 00:57:50 PDT 2022
Author: Oleg Shyshkov
Date: 2022-10-28T09:56:59+02:00
New Revision: 7eef3ea5f4fe4f4cc461b191bac031e3962d0347
URL: https://github.com/llvm/llvm-project/commit/7eef3ea5f4fe4f4cc461b191bac031e3962d0347
DIFF: https://github.com/llvm/llvm-project/commit/7eef3ea5f4fe4f4cc461b191bac031e3962d0347.diff
LOG: Revert "[mlir][linalg] Add nicer builders for `map` and `reduce`."
This reverts commit aebde280476943e58f5bcd9993fdd7e36cdbe47e.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 510f8831f019a..1692a0f9c5492 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -267,12 +267,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
let results = (outs Variadic<AnyTensor>:$result);
let regions = (region SizedRegion<1>:$mapper);
- let builders = [
- OpBuilder<(ins "ValueRange":$inputs, "Value":$init,
- "function_ref<void(OpBuilder &, Location, ValueRange)>",
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
- ];
-
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Implement functions necessary for LinalgStructuredInterface.
SmallVector<StringRef> getIteratorTypesArray();
@@ -347,13 +341,6 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
let results = (outs Variadic<AnyTensor>);
let regions = (region SizedRegion<1>:$combiner);
- let builders = [
- OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
- "ArrayRef<int64_t>":$dimensions,
- "function_ref<void(OpBuilder &, Location, ValueRange)>",
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
- ];
-
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<StringRef> getIteratorTypesArray();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 443bc5cc7407e..03a959acfb2c7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -661,26 +661,6 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
// GenericOp
//===----------------------------------------------------------------------===//
-static void buildGenericRegion(
- OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange outputs,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
- SmallVector<Type, 4> blockArgTypes;
- SmallVector<Location, 4> blockArgLocs;
- for (ValueRange container : {inputs, outputs}) {
- for (Value v : container) {
- blockArgTypes.push_back(getElementTypeOrSelf(v));
- blockArgLocs.push_back(v.getLoc());
- }
- }
-
- OpBuilder::InsertionGuard guard(builder);
- auto ®ion = *result.regions.front();
- Block *bodyBlock =
- builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
- bodyBuild(builder, result.location, bodyBlock->getArguments());
-}
-
void GenericOp::getAsmBlockArgumentNames(Region ®ion,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
@@ -698,8 +678,23 @@ void GenericOp::build(
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
iteratorTypes, doc, libraryCall);
result.addAttributes(attributes);
- if (bodyBuild)
- buildGenericRegion(builder, result, inputs, outputs, bodyBuild);
+ if (!bodyBuild)
+ return;
+
+ SmallVector<Type, 4> blockArgTypes;
+ SmallVector<Location, 4> blockArgLocs;
+ for (ValueRange container : {inputs, outputs}) {
+ for (Value v : container) {
+ blockArgTypes.push_back(getElementTypeOrSelf(v));
+ blockArgLocs.push_back(v.getLoc());
+ }
+ }
+
+ OpBuilder::InsertionGuard guard(builder);
+ auto ®ion = *result.regions.front();
+ Block *bodyBlock =
+ builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
+ bodyBuild(builder, result.location, bodyBlock->getArguments());
}
void GenericOp::build(
@@ -1334,22 +1329,6 @@ void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResults().front(), "mapped");
}
-void MapOp::build(
- OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
- ArrayRef<NamedAttribute> attributes) {
- build(builder, result, TypeRange{}, inputs, init);
- result.addAttributes(attributes);
-
- // Add output types for `RankedTensorType` output arguments.
- Type initType = init.getType();
- if (initType.isa<RankedTensorType>())
- result.addTypes(initType);
-
- if (bodyBuild)
- buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild);
-}
-
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseDstStyleOp(parser, result))
return failure();
@@ -1457,25 +1436,6 @@ void ReduceOp::getAsmResultNames(
setNameFn(getResults().front(), "reduced");
}
-void ReduceOp::build(
- OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange inits, ArrayRef<int64_t> dimensions,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
- ArrayRef<NamedAttribute> attributes) {
- build(builder, result, TypeRange{}, inputs, inits, dimensions);
- result.addAttributes(attributes);
-
- // Add output types for `RankedTensorType` output arguments.
- for (Value init : inits) {
- Type initType = init.getType();
- if (initType.isa<RankedTensorType>())
- result.addTypes(initType);
- }
-
- if (bodyBuild)
- buildGenericRegion(builder, result, inputs, inits, bodyBuild);
-}
-
SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
SmallVector<StringRef> iteratorTypes(inputRank,
@@ -1658,32 +1618,45 @@ TransposeOp::getRegionBuilder() {
};
}
-void TransposeOp::build(::mlir::OpBuilder &builder,
- ::mlir::OperationState &result, Value input, Value init,
- DenseI64ArrayAttr permutation,
- ArrayRef<NamedAttribute> attributes) {
- result.addOperands(input);
- result.addOperands(init);
- result.addAttribute(getPermutationAttrName(result.name), permutation);
- result.addAttributes(attributes);
+void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
+ ::mlir::OperationState &odsState) {
+ Region *region = odsState.addRegion();
- // Add output types for `RankedTensorType` output arguments.
- Type initType = init.getType();
- if (initType.isa<RankedTensorType>())
- result.addTypes(initType);
+ SmallVector<Type> argTypes;
+ SmallVector<Location> argLocs;
+ for (auto t : odsState.operands) {
+ argTypes.push_back(getElementTypeOrSelf(t));
+ argLocs.push_back(opBuilder.getUnknownLoc());
+ }
- buildGenericRegion(builder, result, input, init,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- b.create<linalg::YieldOp>(loc, args[0]);
- });
+ // RAII.
+ OpBuilder::InsertionGuard guard(opBuilder);
+ Block *body =
+ opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
+
+ ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
+ getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
}
-void TransposeOp::build(::mlir::OpBuilder &builder,
- ::mlir::OperationState &result, Value input, Value init,
- ArrayRef<int64_t> permutation,
+void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, Value input,
+ Value init, DenseI64ArrayAttr permutation,
ArrayRef<NamedAttribute> attributes) {
- build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
- attributes);
+ odsState.addOperands(input);
+ odsState.addOperands(init);
+ odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
+ odsState.addAttributes(attributes);
+ odsState.addTypes(init.getType());
+
+ createRegion(odsBuilder, odsState);
+}
+
+void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, Value input,
+ Value init, ArrayRef<int64_t> permutation,
+ ArrayRef<NamedAttribute> attributes) {
+ build(odsBuilder, odsState, input, init,
+ odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
}
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1693,13 +1666,8 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
})))
return failure();
- (void)result.addRegion();
- OpBuilder builder(parser.getContext());
- buildGenericRegion(builder, result, /*inputs=*/result.operands,
- /*outputs=*/{},
- [&](OpBuilder &b, Location loc, ValueRange args) {
- b.create<linalg::YieldOp>(loc, args[0]);
- });
+ OpBuilder opBuilder(parser.getContext());
+ createRegion(opBuilder, result);
return success();
}
More information about the Mlir-commits
mailing list