[Mlir-commits] [mlir] [mlir][ODS] Add `OptionalTypesMatchWith` and remove a custom assemblyFormat (PR #68876)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Oct 12 07:29:46 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/68876

>From c01f78edbd4b47f8cd82cb6a8267ee2ee80f7de2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 12 Oct 2023 10:51:39 +0000
Subject: [PATCH 1/3] [mlir][ODS] Add `OptionalTypesMatchWith` and remove a
 custom assemblyFormat

This is just a slight specialization of `TypesMatchWith` that returns
success if an optional parameter is missing.

There may be other places this could help e.g.:
https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58C5-L58C5
But I'm leaving those to avoid some churn.

(This constraint will be handy for us in some later patches)
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  7 ++--
 mlir/include/mlir/IR/OpBase.td                | 15 +++++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 41 -------------------
 mlir/test/Dialect/Vector/invalid.mlir         |  7 ----
 mlir/test/Dialect/Vector/ops.mlir             |  2 +-
 5 files changed, 20 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..a22a082fb60ffb4 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -215,6 +215,8 @@ def Vector_ReductionOp :
   Vector_Op<"reduction", [Pure,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
+     OptionalTypesMatchWith<"dest and acc have the same type",
+                            "dest", "acc", "::llvm::cast<Type>($_self)">,
      DeclareOpInterfaceMethods<ArithFastMathInterface>,
      DeclareOpInterfaceMethods<MaskableOpInterface>,
      DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
@@ -263,9 +265,8 @@ def Vector_ReductionOp :
                          "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
   ];
 
-  // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
-  // operands.
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` $fastmath^)?"
+                       " attr-dict `:` type($vector) `into` type($dest)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 236dd74839dfb04..61babc93d49875b 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -568,6 +568,21 @@ class TypesMatchWith<string summary, string lhsArg, string rhsArg,
   string transformer = transform;
 }
 
+// Helper which makes the first letter of a string uppercase.
+// e.g. cat -> Cat
+class first_char_to_upper<string str>
+{
+  string ret = !toupper(!substr(str, 0, 1)) # !substr(str, 1);
+}
+
+// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
+// and not present returns success.
+class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
+                     string transform, string comparator = "std::equal_to<>()">
+  : TypesMatchWith<summary, lhsArg, rhsArg, transform,
+     "!get" # first_char_to_upper<lhsArg>.ret # "()"
+     # " || !get" # first_char_to_upper<rhsArg>.ret # "() || " # comparator>;
+
 // Special variant of `TypesMatchWith` that provides a comparator suitable for
 // ranged arguments.
 class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 044b6cc07d3d629..b63018dbd5d6aaa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -485,47 +485,6 @@ LogicalResult ReductionOp::verify() {
   return success();
 }
 
-ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
-  Type redType;
-  Type resType;
-  CombiningKindAttr kindAttr;
-  arith::FastMathFlagsAttr fastMathAttr;
-  if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
-                                              result.attributes) ||
-      parser.parseComma() || parser.parseOperandList(operandsInfo) ||
-      (succeeded(parser.parseOptionalKeyword("fastmath")) &&
-       parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath",
-                                               result.attributes)) ||
-      parser.parseColonType(redType) ||
-      parser.parseKeywordType("into", resType) ||
-      (!operandsInfo.empty() &&
-       parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
-      (operandsInfo.size() > 1 &&
-       parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
-      parser.addTypeToList(resType, result.types))
-    return failure();
-  if (operandsInfo.empty() || operandsInfo.size() > 2)
-    return parser.emitError(parser.getNameLoc(),
-                            "unsupported number of operands");
-  return success();
-}
-
-void ReductionOp::print(OpAsmPrinter &p) {
-  p << " ";
-  getKindAttr().print(p);
-  p << ", " << getVector();
-  if (getAcc())
-    p << ", " << getAcc();
-
-  if (getFastmathAttr() &&
-      getFastmathAttr().getValue() != arith::FastMathFlags::none) {
-    p << ' ' << getFastmathAttrName().getValue();
-    p.printStrippedAttrOrType(getFastmathAttr());
-  }
-  p << " : " << getVector().getType() << " into " << getDest().getType();
-}
-
 // MaskableOpInterface methods.
 
 /// Returns the mask type expected by this operation.
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5967a8d69bbfcc0..ce8b56a5d57a2b6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1168,13 +1168,6 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
 
 // -----
 
-func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
-  // expected-error at +1 {{'vector.reduction' unsupported number of operands}}
-  %0 = vector.reduction <add>, %arg0, %arg1, %arg1 : vector<16xf32> into f32
-}
-
-// -----
-
 func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
   // expected-error at +1 {{'vector.reduction' op unsupported reduction rank: 2}}
   %0 = vector.reduction <add>, %arg0 : vector<4x16xf32> into f32
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 6cfddac94efd850..fbbb61959d12666 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1042,7 +1042,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
 
 // CHECK-LABEL:   func.func @fastmath(
 func.func @fastmath(%x: vector<42xf32>) -> f32 {
-  // CHECK: vector.reduction <minf>, %{{.*}} fastmath<reassoc,nnan,ninf>
+  // CHECK: vector.reduction <minf>, %{{.*}} fastmath <reassoc,nnan,ninf>
   %min = vector.reduction <minf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
   return %min: f32
 }

>From 48e06537690e801a9a3e6f3a2c2f412a7574143e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 12 Oct 2023 14:19:39 +0000
Subject: [PATCH 2/3] Use snakeCaseToCamelCase() helper for generating getter
 names.

A !tocamelcase bang would nice in future :)
---
 mlir/include/mlir/IR/OpBase.td | 11 ++---------
 mlir/include/mlir/IR/Utils.td  | 24 ++++++++++++++++++++++++
 mlir/test/mlir-tblgen/utils.td | 23 +++++++++++++++++++++++
 3 files changed, 49 insertions(+), 9 deletions(-)
 create mode 100644 mlir/test/mlir-tblgen/utils.td

diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 61babc93d49875b..7866ac24c1ccbad 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -568,20 +568,13 @@ class TypesMatchWith<string summary, string lhsArg, string rhsArg,
   string transformer = transform;
 }
 
-// Helper which makes the first letter of a string uppercase.
-// e.g. cat -> Cat
-class first_char_to_upper<string str>
-{
-  string ret = !toupper(!substr(str, 0, 1)) # !substr(str, 1);
-}
-
 // The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
 // and not present returns success.
 class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
                      string transform, string comparator = "std::equal_to<>()">
   : TypesMatchWith<summary, lhsArg, rhsArg, transform,
-     "!get" # first_char_to_upper<lhsArg>.ret # "()"
-     # " || !get" # first_char_to_upper<rhsArg>.ret # "() || " # comparator>;
+     "!get" # snakeCaseToCamelCase<lhsArg>.ret # "()"
+     # " || !get" # snakeCaseToCamelCase<rhsArg>.ret # "() || " # comparator>;
 
 // Special variant of `TypesMatchWith` that provides a comparator suitable for
 // ranged arguments.
diff --git a/mlir/include/mlir/IR/Utils.td b/mlir/include/mlir/IR/Utils.td
index 651706099d9b1ce..d4a51add77ac644 100644
--- a/mlir/include/mlir/IR/Utils.td
+++ b/mlir/include/mlir/IR/Utils.td
@@ -66,4 +66,28 @@ class CArg<string ty, string value = ""> {
   string defaultValue = value;
 }
 
+// Helper which makes the first letter of a string uppercase.
+// e.g. cat -> Cat
+class firstCharToUpper<string str>
+{
+  string ret = !if(!gt(!size(str), 0),
+    !toupper(!substr(str, 0, 1)) # !substr(str, 1),
+    "");
+}
+
+class _snakeCaseHeper<string str> {
+  int idx = !find(str, "_");
+  string ret = !if(!ge(idx, 0),
+    !substr(str, 0, idx) # firstCharToUpper<!substr(str, !add(idx, 1))>.ret,
+    str);
+}
+
+// Converts a snake_case string to CamelCase.
+// TODO: Replace with a !tosnakecase bang operator.
+class snakeCaseToCamelCase<string str>
+{
+  string ret = !foldl(firstCharToUpper<str>.ret,
+    !range(0, !size(str)), acc, idx, _snakeCaseHeper<acc>.ret);
+}
+
 #endif // UTILS_TD
diff --git a/mlir/test/mlir-tblgen/utils.td b/mlir/test/mlir-tblgen/utils.td
new file mode 100644
index 000000000000000..28e0fecb2881bdd
--- /dev/null
+++ b/mlir/test/mlir-tblgen/utils.td
@@ -0,0 +1,23 @@
+// RUN: mlir-tblgen -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/Utils.td"
+
+// CHECK-DAG: string value = "CamelCaseTest"
+class already_camel_case {
+  string value = snakeCaseToCamelCase<"CamelCaseTest">.ret;
+}
+
+// CHECK-DAG: string value = "Foo"
+class single_word {
+  string value = snakeCaseToCamelCase<"foo">.ret;
+}
+
+// CHECK-DAG: string value = "ThisIsATest"
+class snake_case {
+  string value = snakeCaseToCamelCase<"this_is_a_test">.ret;
+}
+
+// CHECK-DAG: string value = "ThisIsATestAgain"
+class extra_underscores {
+  string value = snakeCaseToCamelCase<"__this__is_a_test__again__">.ret;
+}

>From 4925ff1fb83867d0c553788b2f8a42c3235b786d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 12 Oct 2023 14:24:50 +0000
Subject: [PATCH 3/3] Fixups

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 mlir/test/Dialect/Vector/invalid.mlir            | 7 +++++++
 mlir/test/Dialect/Vector/ops.mlir                | 2 +-
 3 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index a22a082fb60ffb4..917b27a40f26f13 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -265,7 +265,7 @@ def Vector_ReductionOp :
                          "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
   ];
 
-  let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` $fastmath^)?"
+  let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` `` $fastmath^)?"
                        " attr-dict `:` type($vector) `into` type($dest)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ce8b56a5d57a2b6..504ac89659fdb73 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1168,6 +1168,13 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
 
 // -----
 
+func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
+  // expected-error at +1 {{expected ':'}}
+  %0 = vector.reduction <add>, %arg0, %arg1, %arg1 : vector<16xf32> into f32
+}
+
+// -----
+
 func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
   // expected-error at +1 {{'vector.reduction' op unsupported reduction rank: 2}}
   %0 = vector.reduction <add>, %arg0 : vector<4x16xf32> into f32
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index fbbb61959d12666..6cfddac94efd850 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1042,7 +1042,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
 
 // CHECK-LABEL:   func.func @fastmath(
 func.func @fastmath(%x: vector<42xf32>) -> f32 {
-  // CHECK: vector.reduction <minf>, %{{.*}} fastmath <reassoc,nnan,ninf>
+  // CHECK: vector.reduction <minf>, %{{.*}} fastmath<reassoc,nnan,ninf>
   %min = vector.reduction <minf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
   return %min: f32
 }



More information about the Mlir-commits mailing list