[Mlir-commits] [mlir] d629645 - [mlir][OpDSL] Add support for adding canonicalization patterns.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 8 08:00:06 PST 2022
Author: gysit
Date: 2022-03-08T15:56:59Z
New Revision: d629645fcdf30576b1d4dc9ea2639321c4b33eae
URL: https://github.com/llvm/llvm-project/commit/d629645fcdf30576b1d4dc9ea2639321c4b33eae
DIFF: https://github.com/llvm/llvm-project/commit/d629645fcdf30576b1d4dc9ea2639321c4b33eae.diff
LOG: [mlir][OpDSL] Add support for adding canonicalization patterns.
Extend OpDSL with a `defines` method that can set the `hasCanonicalizer` flag for an OpDSL operation. If the flag is set via `defines(Canonicalizer)` the operation needs to implement the `getCanonicalizationPatterns` method. The revision specifies the flag for linalg.fill_tensor and adds an empty `FillTensorOp::getCanonicalizationPatterns` implementation.
This revision is a preparation step to replace linalg.fill by its OpDSL counterpart linalg.fill_tensor. The two are only functionally equivalent if both specify the same canonicalization patterns. The revision is thus a prerequisite for the linalg.fill replacement.
Depends On D120725
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D120726
Added:
mlir/test/python/dialects/linalg/opdsl/metadata.py
Modified:
mlir/docs/Dialects/Linalg/OpDSL.md
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
mlir/test/python/dialects/linalg/opdsl/interfaces.py
################################################################################
diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index d7526bf9f3bab..99136f1472f18 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -55,6 +55,7 @@ def matmul(A=TensorDef(T1, S.M, S.K),
them to the same data type as the accumulator/output.
"""
domain(D.m, D.n, D.k)
+ defines(Canonicalizer)
implements(ContractionOpInterface)
C[D.m, D.n] += TypeFn.cast_signed(
U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
@@ -78,6 +79,9 @@ An explicit iteration domain dimension order can be declared for the op via
Special identifying op interfaces can be declared for the op via
`implements(interface1[, interface2...])`.
+Extra method definitions can be declared for the op via
+`defines(definition1[, definition2...])`.
+
## Parameters
Structured operations take two types of runtime parameters namely scalars and
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 7511e268ae850..21f28cbd84c3b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2877,6 +2877,8 @@ metadata: !LinalgOpMetadata
the value operand, promoting it to the same data type as the output.
implements:
- LinalgFillOpInterface
+ defines:
+ - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d4e46f7619d72..53ff45a531049 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -509,6 +509,10 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
FoldInsertPadIntoFill>(context);
}
+// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp.
+void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {}
+
//===----------------------------------------------------------------------===//
// GenericOps
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 1de5449e27e31..47083de625def 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -689,6 +689,16 @@ def __init__(self, cpp_name: str):
FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
+class OpDefinitionDef:
+ """A method that an op implements."""
+
+ def __init__(self, def_name: str):
+ self.def_name = def_name
+
+
+Canonicalizer = OpDefinitionDef("hasCanonicalizer")
+
+
class OpMetadataDef(YAMLObject):
"""Metadata about the op (generally not behavior impacting)."""
yaml_tag = "!LinalgOpMetadata"
@@ -699,6 +709,7 @@ def __init__(self, name: str, cpp_class_name: Optional[str],
self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
self.doc = doc
self.implements = [] # type: List[OpInterfaceDef]
+ self.defines = [] # type: List[OpDefinitionsDef]
def to_yaml_custom_dict(self):
d = dict(
@@ -708,6 +719,8 @@ def to_yaml_custom_dict(self):
)
if self.implements:
d["implements"] = [intr.cpp_name for intr in self.implements]
+ if self.defines:
+ d["defines"] = [defi.def_name for defi in self.defines]
return d
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index bd9042ac0aacb..45b8d5ccd13d6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -149,13 +149,21 @@ def linalg_structured_op(dsl_func=None,
return DefinedOpCallable(op_name, op_def)
+def domain(*dimensions: DimDef):
+ if any(not isinstance(d, DimDef) for d in dimensions):
+ raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
+ current_op_def().domain.extend(dimensions)
+
+
def implements(*interfaces: OpInterfaceDef):
+ if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
+ raise ValueError(
+ f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
current_op_def().metadata.implements.extend(interfaces)
-def domain(*dimensions: DimDef):
- if current_op_def().domain:
- raise ValueError(f"Expected only one set of domain dimensions per operator")
- if any(not isinstance(dim, DimDef) for dim in dimensions):
- raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
- current_op_def().domain.extend(dimensions)
+def defines(*definitions: OpDefinitionDef):
+ if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
+ raise ValueError(
+ f"Expected definitions of type OpDefinitionDef but got {definitions}")
+ current_op_def().metadata.defines.extend(definitions)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 7798d7f9498e3..39934131cb225 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -672,6 +672,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
the value operand, promoting it to the same data type as the output.
"""
implements(FillOpInterface)
+ defines(Canonicalizer)
O[None] = TypeFn.cast_signed(U, value)
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 3f6c763470146..a31984764ebbf 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -333,3 +333,58 @@ structured_op: !LinalgStructuredOpConfig
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
# IMPL-NEXT: yields.push_back([[VAL1]])
+
+# @linalg_structured_op
+# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)):
+# """Title.
+
+# Detailed description.
+# """
+# implements(FillOpInterface)
+# defines(Canonicalizer)
+# O[None] = TypeFn.cast(U, value)
+
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: test5
+ cpp_class_name: Test5Op
+ doc: |-
+ Title.
+
+ Detailed description.
+ implements:
+ - LinalgFillOpInterface
+ defines:
+ - hasCanonicalizer
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: value
+ kind: scalar
+ type_var: T1
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<() -> ()>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<() -> ()>
+ - affine_map<() -> ()>
+ iterator_types: []
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: value
+
+# ODS-LABEL: def Test5Op : LinalgStructuredBase_Op<"test5"
+# ODS-NEXT: /*extraInterfaces=*/[LinalgFillOpInterface])>
+
+# ODS: let hasCanonicalizer = 1;
diff --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/metadata.py
similarity index 86%
rename from mlir/test/python/dialects/linalg/opdsl/interfaces.py
rename to mlir/test/python/dialects/linalg/opdsl/metadata.py
index ca9bd04cd9671..a7502e9eb1aae 100644
--- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py
+++ b/mlir/test/python/dialects/linalg/opdsl/metadata.py
@@ -7,11 +7,14 @@
# CHECK-LABEL: matmul
# CHECK: implements:
# CHECK-NEXT: - LinalgContractionOpInterface
+# CHECK: defines:
+# CHECK-NEXT: - hasCanonicalizer
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
implements(ContractionOpInterface)
+ defines(Canonicalizer)
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
U, B[D.k, D.n])
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 1cf1247262e09..5cade2a24f430 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -53,6 +53,7 @@ struct LinalgOpMetadata {
std::string cppClassName;
Optional<std::string> doc;
SmallVector<std::string> implements;
+ SmallVector<std::string> defines;
};
struct SerializedAffineMap {
@@ -233,6 +234,7 @@ struct MappingTraits<LinalgOpMetadata> {
io.mapRequired("cpp_class_name", info.cppClassName);
io.mapOptional("doc", info.doc);
io.mapOptional("implements", info.implements);
+ io.mapOptional("defines", info.defines);
}
};
@@ -499,7 +501,8 @@ static const char bannerFormat[] = R"FMT(
// {3}: documentation (summary + description)
// {4}: op attribute list
// {5}: builder methods taking standalone attribute parameters
-// {6}: additional methods for attributes used by indexing maps
+// {6}: additional method defintions
+// {7}: additional methods for attributes used by indexing maps
static const char structuredOpOdsHeaderFormat[] = R"FMT(
//===----------------------------------------------------------------------===//
// Op definition for {0}
@@ -573,6 +576,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
+ {6}
let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
@@ -589,7 +593,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
// Generic methods.
static unsigned getNumRegionArgs();
std::string getLibraryCallName();
- {6}
+ {7}
}];
}
)FMT";
@@ -736,6 +740,12 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
+ std::string definitionList;
+ for (const std::string &definition : opConfig.metadata->defines) {
+ static const char definitionFmt[] = "let {0} = 1;\n";
+ definitionList.append(llvm::formatv(definitionFmt, definition));
+ }
+
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return isAttribute(arg.kind);
})) {
@@ -794,7 +804,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
os << llvm::formatv(structuredOpOdsHeaderFormat,
opConfig.metadata->cppClassName, opConfig.metadata->name,
interfaceNameList, doc, attrList, attrBuilder,
- attrMethods);
+ definitionList, attrMethods);
return success();
}
More information about the Mlir-commits
mailing list