[Mlir-commits] [mlir] 03d1c99 - [mlir][ODS] Add `OptionalTypesMatchWith` and remove a custom assemblyFormat (#68876)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 09:00:06 PDT 2023
Author: Benjamin Maxwell
Date: 2023-10-19T17:00:02+01:00
New Revision: 03d1c99d99d7918115b993f14dcb6fc39cf09f72
URL: https://github.com/llvm/llvm-project/commit/03d1c99d99d7918115b993f14dcb6fc39cf09f72
DIFF: https://github.com/llvm/llvm-project/commit/03d1c99d99d7918115b993f14dcb6fc39cf09f72.diff
LOG: [mlir][ODS] Add `OptionalTypesMatchWith` and remove a custom assemblyFormat (#68876)
This is just a slight specialization of `TypesMatchWith` that returns
success if an optional parameter is missing.
There may be other places this could help e.g.:
https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58-L59
...but I'm leaving those to avoid some churn.
This constraint will be handy for us in some later patches, it's a
formalization of a short circuiting trick with the `comparator` of the
`TypesMatchWith` constraint (devised for #69195).
```
TypesMatchWith<
"padding type matches element type of result (if present)",
"result", "padding",
"::llvm::cast<VectorType>($_self).getElementType()",
// This returns true if no padding is present, or it's present with a type that matches the element type of `result`.
"!getPadding() || std::equal_to<>()">
```
This is a little non-obvious, so after this patch you can instead do:
```
OptionalTypesMatchWith<
"padding type matches element type of result (if present)",
"result", "padding",
"::llvm::cast<VectorType>($_self).getElementType()">
```
Added:
mlir/test/mlir-tblgen/utils.td
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/Utils.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..917b27a40f26f13 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -215,6 +215,8 @@ def Vector_ReductionOp :
Vector_Op<"reduction", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
+ OptionalTypesMatchWith<"dest and acc have the same type",
+ "dest", "acc", "::llvm::cast<Type>($_self)">,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
@@ -263,9 +265,8 @@ def Vector_ReductionOp :
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
];
- // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
- // operands.
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` `` $fastmath^)?"
+ " attr-dict `:` type($vector) `into` type($dest)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 236dd74839dfb04..7866ac24c1ccbad 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -568,6 +568,14 @@ class TypesMatchWith<string summary, string lhsArg, string rhsArg,
string transformer = transform;
}
+// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
+// and not present returns success.
+class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
+ string transform, string comparator = "std::equal_to<>()">
+ : TypesMatchWith<summary, lhsArg, rhsArg, transform,
+ "!get" # snakeCaseToCamelCase<lhsArg>.ret # "()"
+ # " || !get" # snakeCaseToCamelCase<rhsArg>.ret # "() || " # comparator>;
+
// Special variant of `TypesMatchWith` that provides a comparator suitable for
// ranged arguments.
class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
diff --git a/mlir/include/mlir/IR/Utils.td b/mlir/include/mlir/IR/Utils.td
index 651706099d9b1ce..e75b3ae7a50d171 100644
--- a/mlir/include/mlir/IR/Utils.td
+++ b/mlir/include/mlir/IR/Utils.td
@@ -66,4 +66,28 @@ class CArg<string ty, string value = ""> {
string defaultValue = value;
}
+// Helper which makes the first letter of a string uppercase.
+// e.g. cat -> Cat
+class firstCharToUpper<string str>
+{
+ string ret = !if(!gt(!size(str), 0),
+ !toupper(!substr(str, 0, 1)) # !substr(str, 1),
+ "");
+}
+
+class _snakeCaseHelper<string str> {
+ int idx = !find(str, "_");
+ string ret = !if(!ge(idx, 0),
+ !substr(str, 0, idx) # firstCharToUpper<!substr(str, !add(idx, 1))>.ret,
+ str);
+}
+
+// Converts a snake_case string to CamelCase.
+// TODO: Replace with a !tocamelcase bang operator.
+class snakeCaseToCamelCase<string str>
+{
+ string ret = !foldl(firstCharToUpper<str>.ret,
+ !range(0, !size(str)), acc, idx, _snakeCaseHelper<acc>.ret);
+}
+
#endif // UTILS_TD
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 68a5cf209f2fb49..9e7de1d1e11f782 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -524,47 +524,6 @@ LogicalResult ReductionOp::verify() {
return success();
}
-ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
- Type redType;
- Type resType;
- CombiningKindAttr kindAttr;
- arith::FastMathFlagsAttr fastMathAttr;
- if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
- result.attributes) ||
- parser.parseComma() || parser.parseOperandList(operandsInfo) ||
- (succeeded(parser.parseOptionalKeyword("fastmath")) &&
- parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath",
- result.attributes)) ||
- parser.parseColonType(redType) ||
- parser.parseKeywordType("into", resType) ||
- (!operandsInfo.empty() &&
- parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
- (operandsInfo.size() > 1 &&
- parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
- parser.addTypeToList(resType, result.types))
- return failure();
- if (operandsInfo.empty() || operandsInfo.size() > 2)
- return parser.emitError(parser.getNameLoc(),
- "unsupported number of operands");
- return success();
-}
-
-void ReductionOp::print(OpAsmPrinter &p) {
- p << " ";
- getKindAttr().print(p);
- p << ", " << getVector();
- if (getAcc())
- p << ", " << getAcc();
-
- if (getFastmathAttr() &&
- getFastmathAttr().getValue() != arith::FastMathFlags::none) {
- p << ' ' << getFastmathAttrName().getValue();
- p.printStrippedAttrOrType(getFastmathAttr());
- }
- p << " : " << getVector().getType() << " into " << getDest().getType();
-}
-
// MaskableOpInterface methods.
/// Returns the mask type expected by this operation.
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5967a8d69bbfcc0..504ac89659fdb73 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1169,7 +1169,7 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
// -----
func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
- // expected-error at +1 {{'vector.reduction' unsupported number of operands}}
+ // expected-error at +1 {{expected ':'}}
%0 = vector.reduction <add>, %arg0, %arg1, %arg1 : vector<16xf32> into f32
}
diff --git a/mlir/test/mlir-tblgen/utils.td b/mlir/test/mlir-tblgen/utils.td
new file mode 100644
index 000000000000000..28e0fecb2881bdd
--- /dev/null
+++ b/mlir/test/mlir-tblgen/utils.td
@@ -0,0 +1,23 @@
+// RUN: mlir-tblgen -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/Utils.td"
+
+// CHECK-DAG: string value = "CamelCaseTest"
+class already_camel_case {
+ string value = snakeCaseToCamelCase<"CamelCaseTest">.ret;
+}
+
+// CHECK-DAG: string value = "Foo"
+class single_word {
+ string value = snakeCaseToCamelCase<"foo">.ret;
+}
+
+// CHECK-DAG: string value = "ThisIsATest"
+class snake_case {
+ string value = snakeCaseToCamelCase<"this_is_a_test">.ret;
+}
+
+// CHECK-DAG: string value = "ThisIsATestAgain"
+class extra_underscores {
+ string value = snakeCaseToCamelCase<"__this__is_a_test__again__">.ret;
+}
More information about the Mlir-commits
mailing list