[Mlir-commits] [mlir] 348bfc8 - [mlir][linalg] Add attributes to region builder (NFC).
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 14 05:17:18 PST 2022
Author: gysit
Date: 2022-02-14T13:14:14Z
New Revision: 348bfc8e50ea69e17e057636dd2823f60b88a034
URL: https://github.com/llvm/llvm-project/commit/348bfc8e50ea69e17e057636dd2823f60b88a034
DIFF: https://github.com/llvm/llvm-project/commit/348bfc8e50ea69e17e057636dd2823f60b88a034.diff
LOG: [mlir][linalg] Add attributes to region builder (NFC).
Adapt the region builder signature to hand in the attributes of the created ops. The revision is a preparation step the support named ops that need access to the operation attributes during op creation.
Depends On D119692
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D119693
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/CAPI/Dialect/Linalg.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 995aab50f9ee8..84f5690a8649f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -47,8 +47,8 @@ def Linalg_Dialect : Dialect {
constexpr const static ::llvm::StringLiteral
kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";
- using RegionBuilderFunType =
- llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &)>;
+ using RegionBuilderFunType = llvm::function_ref<
+ void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>)>;
RegionBuilderFunType getRegionBuilder(StringRef name) {
return namedStructuredOpRegionBuilders.lookup(name);
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index f40f82ed7cd94..00de7ba0265bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1025,7 +1025,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Returns a null function if this named op does not define a region
builder.
}],
- /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &)>",
+ /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
/*methodName=*/"getRegionBuilder",
(ins),
[{ return ConcreteOp::getRegionBuilder(); }]
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 8bdec0971ee18..33e2422060f2f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -83,8 +83,10 @@ def FillOp : LinalgStructured_Op<"fill", []> {
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
}
- static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
- static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
+ static void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs);
+ static std::function<void(ImplicitLocOpBuilder&,
+ Block&, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return ®ionBuilder;
}
@@ -254,7 +256,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> {
library_call()->str() : "op_has_no_registered_library_name";
}
- static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return nullptr;
}
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 8862b6b154ea5..bfb3313d1a21d 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
Region ®ion = op->getRegion(0);
Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
b.setInsertionPointToStart(body);
- fun(b, *body);
+ fun(b, *body, op->getAttrs());
}
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4868fdb99341a..87278dcba0896 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -49,7 +49,7 @@ using namespace mlir::linalg;
template <typename NamedStructuredOpType>
static void fillStructuredOpRegion(
OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
- TypeRange outputTypes,
+ TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
llvm::function_ref<void(unsigned, unsigned)> errorHandler = nullptr);
/// Generic entry point to create both the region and the block of a LinalgOp.
@@ -72,7 +72,8 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p,
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
- TypeRange inputTypes, TypeRange outputTypes);
+ TypeRange inputTypes, TypeRange outputTypes,
+ ArrayRef<NamedAttribute> attrs);
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
@@ -375,7 +376,8 @@ class RegionBuilderHelper {
//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
-void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
+void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
b.create<linalg::YieldOp>(block.getArgument(0));
}
@@ -384,16 +386,16 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value value,
Value output) {
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
output);
- fillStructuredOpRegion<FillOp>(builder, *result.regions.front(),
- TypeRange{value.getType()},
- TypeRange{output.getType()}, {});
+ fillStructuredOpRegion<FillOp>(
+ builder, *result.regions.front(), TypeRange{value.getType()},
+ TypeRange{output.getType()}, result.attributes.getAttrs(), {});
}
ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType,
Type outputType) {
OpBuilder opBuilder(parser.getContext());
fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType},
- TypeRange{outputType});
+ TypeRange{outputType}, {});
return success();
}
@@ -1820,7 +1822,7 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
template <typename NamedStructuredOpType>
static void fillStructuredOpRegion(
OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
- TypeRange outputTypes,
+ TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
llvm::function_ref<void(unsigned, unsigned)> errorHandler) {
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
@@ -1851,7 +1853,7 @@ static void fillStructuredOpRegion(
opBuilder.setInsertionPointToStart(body);
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
- NamedStructuredOpType::regionBuilder(b, *body);
+ NamedStructuredOpType::regionBuilder(b, *body, attrs);
// indexing_maps is an auto-generated method.
@@ -1866,7 +1868,7 @@ void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
TypeRange outputTypes) {
Region ®ion = *result.addRegion();
fillStructuredOpRegion<NamedStructuredOpType>(
- opBuilder, region, inputTypes, outputTypes,
+ opBuilder, region, inputTypes, outputTypes, result.attributes.getAttrs(),
[&](unsigned expected, unsigned actual) {
assert(expected != actual && "incorrect number of arguments");
});
@@ -1929,14 +1931,15 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p,
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
- TypeRange inputTypes, TypeRange outputTypes) {
+ TypeRange inputTypes, TypeRange outputTypes,
+ ArrayRef<NamedAttribute> attrs) {
ParseResult res = success();
OpBuilder opBuilder(parser.getContext());
// Resolve `captures` into `capturedValues` at parse time so we can build the
// region with captures.
SmallVector<Value> capturedValues;
fillStructuredOpRegion<NamedStructuredOpType>(
- opBuilder, region, inputTypes, outputTypes,
+ opBuilder, region, inputTypes, outputTypes, attrs,
[&](unsigned expected, unsigned actual) {
res = parser.emitError(
parser.getCurrentLocation(),
@@ -1973,7 +1976,8 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
std::unique_ptr<Region> region = std::make_unique<Region>();
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
- parser, *region, inputTypes, outputTypes))
+ parser, *region, inputTypes, outputTypes,
+ result.attributes.getAttrs()))
return failure();
result.addRegion(std::move(region));
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5b7d973429269..f5834efe9cb5a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2555,11 +2555,13 @@ def TestLinalgConvOp :
let extraClassDeclaration = [{
bool hasIndexSemantics() { return false; }
- static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block) {
+ static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+ mlir::ArrayRef<mlir::NamedAttribute> attrs) {
b.create<mlir::linalg::YieldOp>(block.getArguments().back());
}
- static std::function<void(mlir::ImplicitLocOpBuilder &b, mlir::Block &block)>
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
return ®ionBuilder;
}
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 347923825cee3..ba44e1eeb6262 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -83,8 +83,8 @@ structured_op: !LinalgStructuredOpConfig
# ODS-NEXT: TypeRange(inputs),
# ODS-NEXT: TypeRange(outputs)
-# IMPL-LABEL: void Test1Op::regionBuilder(
-# IMPL: ImplicitLocOpBuilder &b, Block &block)
+# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
+# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]);
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
@@ -174,7 +174,8 @@ structured_op: !LinalgStructuredOpConfig
# IMPL: auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
# IMPL: "incorrect element type for index attribute 'strides'"
# IMPL: "incorrect shape for index attribute 'strides'"
-# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block)
+# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b,
+# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
# IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 &&
# IMPL: yields.push_back(block.getArgument(0));
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index f1fac9f578612..eb1af791f1cf9 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -523,8 +523,10 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
// Auto-generated.
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
- static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
- static std::function<void(ImplicitLocOpBuilder &b, Block &)>
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {{
return regionBuilder;
}
@@ -952,7 +954,8 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
// {1}: Number of args
// {2}: Statements
static const char structuredOpRegionBuilderFormat[] = R"FMT(
-void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
+void {0}::regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs) {{
assert({1} > 0 && block.getNumArguments() == {1} &&
"{0} regionBuilder expects {1} (>=0) args");
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
More information about the Mlir-commits
mailing list