[Mlir-commits] [mlir] [mlir][SMT] restore custom builder for forall/exists (PR #135470)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 11 20:29:26 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
This reverts commit 54e70ac7650f1c22f687937d1a082e4152f97b22.
The necessary change was to explicitly `->destroy()` the ops at the end of the tests. I believe this is because the rewriter used in the tests doesn't actually insert them into a module and so without an explicit `->destroy()` no bookkeeping process is able to take care of them.
---
Full diff: https://github.com/llvm/llvm-project/pull/135470.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/SMT/IR/SMTOps.td (+12)
- (modified) mlir/lib/Dialect/SMT/IR/SMTOps.cpp (+20)
- (modified) mlir/unittests/Dialect/SMT/CMakeLists.txt (+1)
- (added) mlir/unittests/Dialect/SMT/QuantifierTest.cpp (+199)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
index af73955caee54..1872c00b74f1a 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -448,6 +448,18 @@ class QuantifierOp<string mnemonic> : SMTOp<mnemonic, [
VariadicRegion<SizedRegion<1>>:$patterns);
let results = (outs BoolType:$result);
+ let builders = [
+ OpBuilder<(ins
+ "TypeRange":$boundVarTypes,
+ "function_ref<Value(OpBuilder &, Location, ValueRange)>":$bodyBuilder,
+ CArg<"std::optional<ArrayRef<StringRef>>", "std::nullopt">:$boundVarNames,
+ CArg<"function_ref<ValueRange(OpBuilder &, Location, ValueRange)>",
+ "{}">:$patternBuilder,
+ CArg<"uint32_t", "0">:$weight,
+ CArg<"bool", "false">:$noPattern)>
+ ];
+ let skipDefaultBuilders = true;
+
let assemblyFormat = [{
($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)?
attr-dict-with-keyword $body (`patterns` $patterns^)?
diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
index 604dd26da1982..8977a3abc125d 100644
--- a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
+++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
@@ -432,6 +432,16 @@ LogicalResult ForallOp::verifyRegions() {
return verifyQuantifierRegions(*this);
}
+void ForallOp::build(
+ OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+ function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+ std::optional<ArrayRef<StringRef>> boundVarNames,
+ function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+ uint32_t weight, bool noPattern) {
+ buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+ boundVarNames, patternBuilder, weight, noPattern);
+}
+
//===----------------------------------------------------------------------===//
// ExistsOp
//===----------------------------------------------------------------------===//
@@ -448,5 +458,15 @@ LogicalResult ExistsOp::verifyRegions() {
return verifyQuantifierRegions(*this);
}
+void ExistsOp::build(
+ OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+ function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+ std::optional<ArrayRef<StringRef>> boundVarNames,
+ function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+ uint32_t weight, bool noPattern) {
+ buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+ boundVarNames, patternBuilder, weight, noPattern);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/SMT/IR/SMT.cpp.inc"
diff --git a/mlir/unittests/Dialect/SMT/CMakeLists.txt b/mlir/unittests/Dialect/SMT/CMakeLists.txt
index 86e16d6194ea9..a1331467febaa 100644
--- a/mlir/unittests/Dialect/SMT/CMakeLists.txt
+++ b/mlir/unittests/Dialect/SMT/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRSMTTests
AttributeTest.cpp
+ QuantifierTest.cpp
TypeTest.cpp
)
diff --git a/mlir/unittests/Dialect/SMT/QuantifierTest.cpp b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp
new file mode 100644
index 0000000000000..328dba75d8655
--- /dev/null
+++ b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp
@@ -0,0 +1,199 @@
+//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace smt;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Test custom builders of ExistsOp
+//===----------------------------------------------------------------------===//
+
+TEST(QuantifierTest, ExistsBuilderWithPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ExistsOp existsOp = builder.create<ExistsOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ std::nullopt,
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return boundVars;
+ },
+ /*weight=*/2);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ existsOp.print(stream);
+
+ ASSERT_STREQ(
+ stream.str().str().c_str(),
+ "%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
+ "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
+ "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
+
+ existsOp->destroy();
+}
+
+TEST(QuantifierTest, ExistsBuilderNoPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ExistsOp existsOp = builder.create<ExistsOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ existsOp.print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
+ "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
+ "smt.yield %0 : !smt.bool\n}\n");
+
+ existsOp->destroy();
+}
+
+TEST(QuantifierTest, ExistsBuilderDefault) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ExistsOp existsOp = builder.create<ExistsOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"});
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ existsOp.print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
+ "%0 : !smt.bool\n}\n");
+
+ existsOp->destroy();
+}
+
+//===----------------------------------------------------------------------===//
+// Test custom builders of ForallOp
+//===----------------------------------------------------------------------===//
+
+TEST(QuantifierTest, ForallBuilderWithPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ForallOp forallOp = builder.create<ForallOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return boundVars;
+ },
+ /*weight=*/2);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ forallOp.print(stream);
+
+ ASSERT_STREQ(
+ stream.str().str().c_str(),
+ "%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
+ "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
+ "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
+
+ forallOp->destroy();
+}
+
+TEST(QuantifierTest, ForallBuilderNoPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ForallOp forallOp = builder.create<ForallOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ forallOp.print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
+ "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
+ "smt.yield %0 : !smt.bool\n}\n");
+
+ forallOp->destroy();
+}
+
+TEST(QuantifierTest, ForallBuilderDefault) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ ForallOp forallOp = builder.create<ForallOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ std::nullopt);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ forallOp.print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.forall {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
+ "%0 : !smt.bool\n}\n");
+
+ forallOp->destroy();
+}
+
+} // namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/135470
More information about the Mlir-commits
mailing list