[Mlir-commits] [mlir] 8b9f8db - [mlir][matchers] Add m_Op(StringRef) and m_Attr matchers
Jacques Pienaar
llvmlistbot at llvm.org
Tue Apr 11 14:16:34 PDT 2023
Author: Devajith V S
Date: 2023-04-11T14:16:14-07:00
New Revision: 8b9f8db501cf9299f3bab5fb669d5f231b3a1aaa
URL: https://github.com/llvm/llvm-project/commit/8b9f8db501cf9299f3bab5fb669d5f231b3a1aaa
DIFF: https://github.com/llvm/llvm-project/commit/8b9f8db501cf9299f3bab5fb669d5f231b3a1aaa.diff
LOG: [mlir][matchers] Add m_Op(StringRef) and m_Attr matchers
This patch introduces support for m_Op with a StringRef argument and m_Attr matchers. These matchers will be very useful for mlir-query that is being developed currently.
Submitting this patch separately to reduce the final patch size and make it easier to upstream mlir-query.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D147262
Added:
Modified:
mlir/include/mlir/IR/Matchers.h
mlir/test/IR/test-matchers.mlir
mlir/test/lib/IR/TestMatchers.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 374f05ab49c5f..4dbc623916acf 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -52,6 +52,22 @@ struct constant_op_matcher {
bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
};
+/// The matcher that matches operations that have the specified op name.
+struct NameOpMatcher {
+ NameOpMatcher(StringRef name) : name(name) {}
+ bool match(Operation *op) { return op->getName().getStringRef() == name; }
+
+ StringRef name;
+};
+
+/// The matcher that matches operations that have the specified attribute name.
+struct AttrOpMatcher {
+ AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
+ bool match(Operation *op) { return op->hasAttr(attrName); }
+
+ StringRef attrName;
+};
+
/// The matcher that matches operations that have the `ConstantLike` trait, and
/// binds the folded attribute value.
template <typename AttrT>
@@ -83,6 +99,29 @@ struct constant_op_binder {
}
};
+/// The matcher that matches operations that have the specified attribute
+/// name, and binds the attribute value.
+template <typename AttrT>
+struct AttrOpBinder {
+ /// Creates a matcher instance that binds the attribute value to
+ /// bind_value if match succeeds.
+ AttrOpBinder(StringRef attrName, AttrT *bindValue)
+ : attrName(attrName), bindValue(bindValue) {}
+ /// Creates a matcher instance that doesn't bind if match succeeds.
+ AttrOpBinder(StringRef attrName) : attrName(attrName), bindValue(nullptr) {}
+
+ bool match(Operation *op) {
+ if (auto attr = op->getAttrOfType<AttrT>(attrName)) {
+ if (bindValue)
+ *bindValue = attr;
+ return true;
+ }
+ return false;
+ }
+ StringRef attrName;
+ AttrT *bindValue;
+};
+
/// The matcher that matches a constant scalar / vector splat / tensor splat
/// float operation and binds the constant float value.
struct constant_float_op_binder {
@@ -249,6 +288,16 @@ inline detail::constant_op_matcher m_Constant() {
return detail::constant_op_matcher();
}
+/// Matches a named attribute operation.
+inline detail::AttrOpMatcher m_Attr(StringRef attrName) {
+ return detail::AttrOpMatcher(attrName);
+}
+
+/// Matches a named operation.
+inline detail::NameOpMatcher m_Op(StringRef opName) {
+ return detail::NameOpMatcher(opName);
+}
+
/// Matches a value from a constant foldable operation and writes the value to
/// bind_value.
template <typename AttrT>
@@ -256,6 +305,13 @@ inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
return detail::constant_op_binder<AttrT>(bind_value);
}
+/// Matches a named attribute operation and writes the value to bind_value.
+template <typename AttrT>
+inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName,
+ AttrT *bindValue) {
+ return detail::AttrOpBinder<AttrT>(attrName, bindValue);
+}
+
/// Matches a constant scalar / vector splat / tensor splat float (both positive
/// and negative) zero.
inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {
diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir
index 87c7bf9e7ebc8..31f1b6de18c4e 100644
--- a/mlir/test/IR/test-matchers.mlir
+++ b/mlir/test/IR/test-matchers.mlir
@@ -41,3 +41,14 @@ func.func @test2(%a: f32) -> f32 {
// CHECK-LABEL: test2
// CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00
// CHECK: Pattern add(add(a, constant), a) matched
+
+func.func @test3(%a: f32) -> f32 {
+ %0 = "test.name"() {value = 1.0 : f32} : () -> f32
+ %1 = arith.addf %a, %0: f32
+ %2 = arith.mulf %a, %1 fastmath<fast>: f32
+ return %2: f32
+}
+
+// CHECK-LABEL: test3
+// CHECK: Pattern mul(*, add(*, m_Op("test.name"))) matched
+// CHECK: Pattern m_Attr("fastmath") matched and bound value to: fast
diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp
index 4f87517235e2d..d075d7a818859 100644
--- a/mlir/test/lib/IR/TestMatchers.cpp
+++ b/mlir/test/lib/IR/TestMatchers.cpp
@@ -148,6 +148,21 @@ void test2(FunctionOpInterface f) {
llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
}
+void test3(FunctionOpInterface f) {
+ arith::FastMathFlagsAttr fastMathAttr;
+ auto p = m_Op<arith::MulFOp>(m_Any(),
+ m_Op<arith::AddFOp>(m_Any(), m_Op("test.name")));
+ auto p1 = m_Attr("fastmath", &fastMathAttr);
+
+ // Last operation that is not the terminator.
+ Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
+ if (p.match(lastOp))
+ llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n";
+ if (p1.match(lastOp))
+ llvm::outs() << "Pattern m_Attr(\"fastmath\") matched and bound value to: "
+ << fastMathAttr.getValue() << "\n";
+}
+
void TestMatchers::runOnOperation() {
auto f = getOperation();
llvm::outs() << f.getName() << "\n";
@@ -155,6 +170,8 @@ void TestMatchers::runOnOperation() {
test1(f);
if (f.getName() == "test2")
test2(f);
+ if (f.getName() == "test3")
+ test3(f);
}
namespace mlir {
More information about the Mlir-commits
mailing list