[Mlir-commits] [mlir] d629645 - [mlir][OpDSL] Add support for adding canonicalization patterns.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 8 08:00:06 PST 2022


Author: gysit
Date: 2022-03-08T15:56:59Z
New Revision: d629645fcdf30576b1d4dc9ea2639321c4b33eae

URL: https://github.com/llvm/llvm-project/commit/d629645fcdf30576b1d4dc9ea2639321c4b33eae
DIFF: https://github.com/llvm/llvm-project/commit/d629645fcdf30576b1d4dc9ea2639321c4b33eae.diff

LOG: [mlir][OpDSL] Add support for adding canonicalization patterns.

Extend OpDSL with a `defines` method that can set the `hasCanonicalizer` flag for an OpDSL operation. If the flag is set via `defines(Canonicalizer)` the operation needs to implement the `getCanonicalizationPatterns` method. The revision specifies the flag for linalg.fill_tensor and adds an empty `FillTensorOp::getCanonicalizationPatterns` implementation.

This revision is a preparation step to replace linalg.fill by its OpDSL counterpart linalg.fill_tensor. The two are only functionally equivalent if both specify the same canonicalization patterns. The revision is thus a prerequisite for the linalg.fill replacement.

Depends On D120725

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D120726

Added: 
    mlir/test/python/dialects/linalg/opdsl/metadata.py

Modified: 
    mlir/docs/Dialects/Linalg/OpDSL.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    mlir/test/python/dialects/linalg/opdsl/interfaces.py


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index d7526bf9f3bab..99136f1472f18 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -55,6 +55,7 @@ def matmul(A=TensorDef(T1, S.M, S.K),
   them to the same data type as the accumulator/output.
   """
   domain(D.m, D.n, D.k)
+  defines(Canonicalizer)
   implements(ContractionOpInterface)
   C[D.m, D.n] += TypeFn.cast_signed(
       U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
@@ -78,6 +79,9 @@ An explicit iteration domain dimension order can be declared for the op via
 Special identifying op interfaces can be declared for the op via
 `implements(interface1[, interface2...])`.
 
+Extra method definitions can be declared for the op via
+`defines(definition1[, definition2...])`.
+
 ## Parameters
 
 Structured operations take two types of runtime parameters namely scalars and

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 7511e268ae850..21f28cbd84c3b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2877,6 +2877,8 @@ metadata: !LinalgOpMetadata
     the value operand, promoting it to the same data type as the output.
   implements:
   - LinalgFillOpInterface
+  defines:
+  - hasCanonicalizer
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d4e46f7619d72..53ff45a531049 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -509,6 +509,10 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
            FoldInsertPadIntoFill>(context);
 }
 
+// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp.
+void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {}
+
 //===----------------------------------------------------------------------===//
 // GenericOps
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 1de5449e27e31..47083de625def 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -689,6 +689,16 @@ def __init__(self, cpp_name: str):
 FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
 
 
+class OpDefinitionDef:
+  """A method that an op implements."""
+
+  def __init__(self, def_name: str):
+    self.def_name = def_name
+
+
+Canonicalizer = OpDefinitionDef("hasCanonicalizer")
+
+
 class OpMetadataDef(YAMLObject):
   """Metadata about the op (generally not behavior impacting)."""
   yaml_tag = "!LinalgOpMetadata"
@@ -699,6 +709,7 @@ def __init__(self, name: str, cpp_class_name: Optional[str],
     self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
     self.doc = doc
     self.implements = []  # type: List[OpInterfaceDef]
+    self.defines = []  # type: List[OpDefinitionsDef]
 
   def to_yaml_custom_dict(self):
     d = dict(
@@ -708,6 +719,8 @@ def to_yaml_custom_dict(self):
     )
     if self.implements:
       d["implements"] = [intr.cpp_name for intr in self.implements]
+    if self.defines:
+      d["defines"] = [defi.def_name for defi in self.defines]
     return d
 
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index bd9042ac0aacb..45b8d5ccd13d6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -149,13 +149,21 @@ def linalg_structured_op(dsl_func=None,
   return DefinedOpCallable(op_name, op_def)
 
 
+def domain(*dimensions: DimDef):
+  if any(not isinstance(d, DimDef) for d in dimensions):
+    raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
+  current_op_def().domain.extend(dimensions)
+
+
 def implements(*interfaces: OpInterfaceDef):
+  if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
+    raise ValueError(
+        f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
   current_op_def().metadata.implements.extend(interfaces)
 
 
-def domain(*dimensions: DimDef):
-  if current_op_def().domain:
-    raise ValueError(f"Expected only one set of domain dimensions per operator")
-  if any(not isinstance(dim, DimDef) for dim in dimensions):
-    raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
-  current_op_def().domain.extend(dimensions)
+def defines(*definitions: OpDefinitionDef):
+  if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
+    raise ValueError(
+        f"Expected definitions of type OpDefinitionDef but got {definitions}")
+  current_op_def().metadata.defines.extend(definitions)

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 7798d7f9498e3..39934131cb225 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -672,6 +672,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
   the value operand, promoting it to the same data type as the output.
   """
   implements(FillOpInterface)
+  defines(Canonicalizer)
   O[None] = TypeFn.cast_signed(U, value)
 
 

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 3f6c763470146..a31984764ebbf 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -333,3 +333,58 @@ structured_op: !LinalgStructuredOpConfig
 #       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
 #  IMPL-NEXT:  Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
 #  IMPL-NEXT:  yields.push_back([[VAL1]])
+
+# @linalg_structured_op
+# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)):
+#   """Title.
+
+#   Detailed description.
+#   """
+#   implements(FillOpInterface)
+#   defines(Canonicalizer)
+#   O[None] = TypeFn.cast(U, value)
+
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: test5
+  cpp_class_name: Test5Op
+  doc: |-
+    Title.
+
+    Detailed description.
+  implements:
+  - LinalgFillOpInterface
+  defines:
+  - hasCanonicalizer
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: value
+    kind: scalar
+    type_var: T1
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: type
+        fn_name: cast
+        type_var: U
+        operands:
+        - !ScalarExpression
+          scalar_arg: value
+
+# ODS-LABEL:  def Test5Op : LinalgStructuredBase_Op<"test5"
+#  ODS-NEXT:  /*extraInterfaces=*/[LinalgFillOpInterface])>
+
+#       ODS:  let hasCanonicalizer = 1;

diff  --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/metadata.py
similarity index 86%
rename from mlir/test/python/dialects/linalg/opdsl/interfaces.py
rename to mlir/test/python/dialects/linalg/opdsl/metadata.py
index ca9bd04cd9671..a7502e9eb1aae 100644
--- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py
+++ b/mlir/test/python/dialects/linalg/opdsl/metadata.py
@@ -7,11 +7,14 @@
 # CHECK-LABEL: matmul
 # CHECK:      implements:
 # CHECK-NEXT: - LinalgContractionOpInterface
+# CHECK:      defines:
+# CHECK-NEXT: - hasCanonicalizer
 @linalg_structured_op
 def matmul(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True)):
   implements(ContractionOpInterface)
+  defines(Canonicalizer)
   C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
       U, B[D.k, D.n])

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 1cf1247262e09..5cade2a24f430 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -53,6 +53,7 @@ struct LinalgOpMetadata {
   std::string cppClassName;
   Optional<std::string> doc;
   SmallVector<std::string> implements;
+  SmallVector<std::string> defines;
 };
 
 struct SerializedAffineMap {
@@ -233,6 +234,7 @@ struct MappingTraits<LinalgOpMetadata> {
     io.mapRequired("cpp_class_name", info.cppClassName);
     io.mapOptional("doc", info.doc);
     io.mapOptional("implements", info.implements);
+    io.mapOptional("defines", info.defines);
   }
 };
 
@@ -499,7 +501,8 @@ static const char bannerFormat[] = R"FMT(
 // {3}: documentation (summary + description)
 // {4}: op attribute list
 // {5}: builder methods taking standalone attribute parameters
-// {6}: additional methods for attributes used by indexing maps
+// {6}: additional method defintions
+// {7}: additional methods for attributes used by indexing maps
 static const char structuredOpOdsHeaderFormat[] = R"FMT(
 //===----------------------------------------------------------------------===//
 // Op definition for {0}
@@ -573,6 +576,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
     ];
     let hasCustomAssemblyFormat = 1;
     let hasFolder = 1;
+    {6}
 
     let extraClassDeclaration = structuredOpsBaseDecls # [{{
       // Auto-generated.
@@ -589,7 +593,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
       // Generic methods.
       static unsigned getNumRegionArgs();
       std::string getLibraryCallName();
-      {6}
+      {7}
     }];
 }
 )FMT";
@@ -736,6 +740,12 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
 
   interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
 
+  std::string definitionList;
+  for (const std::string &definition : opConfig.metadata->defines) {
+    static const char definitionFmt[] = "let {0} = 1;\n";
+    definitionList.append(llvm::formatv(definitionFmt, definition));
+  }
+
   if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
         return isAttribute(arg.kind);
       })) {
@@ -794,7 +804,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
   os << llvm::formatv(structuredOpOdsHeaderFormat,
                       opConfig.metadata->cppClassName, opConfig.metadata->name,
                       interfaceNameList, doc, attrList, attrBuilder,
-                      attrMethods);
+                      definitionList, attrMethods);
 
   return success();
 }


        


More information about the Mlir-commits mailing list