[Mlir-commits] [mlir] [mlir][SMT] add python bindings (PR #135674)
Maksim Levental
llvmlistbot at llvm.org
Wed Apr 16 10:24:07 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/135674
>From 3ba88a0333a158eac9ed6ce93fa738ade4477e89 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Mon, 14 Apr 2025 16:35:40 -0400
Subject: [PATCH 1/3] [mlir][SMT] add python bindings
---
mlir/python/CMakeLists.txt | 9 +++++++++
mlir/python/mlir/dialects/SMTOps.td | 14 ++++++++++++++
mlir/python/mlir/dialects/smt.py | 5 +++++
mlir/test/python/dialects/smt.py | 16 ++++++++++++++++
4 files changed, 44 insertions(+)
create mode 100644 mlir/python/mlir/dialects/SMTOps.td
create mode 100644 mlir/python/mlir/dialects/smt.py
create mode 100644 mlir/test/python/dialects/smt.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index fb115a5f43423..3985668486931 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -403,6 +403,15 @@ declare_mlir_dialect_python_bindings(
"../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
)
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/SMTOps.td
+ GEN_ENUM_BINDINGS
+ SOURCES
+ dialects/smt.py
+ DIALECT_NAME smt)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/SMTOps.td b/mlir/python/mlir/dialects/SMTOps.td
new file mode 100644
index 0000000000000..e143f071eb658
--- /dev/null
+++ b/mlir/python/mlir/dialects/SMTOps.td
@@ -0,0 +1,14 @@
+//===- SMTOps.td - Entry point for SMT bindings ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BINDINGS_PYTHON_SMT_OPS
+#define BINDINGS_PYTHON_SMT_OPS
+
+include "mlir/Dialect/SMT/IR/SMT.td"
+
+#endif // BINDINGS_PYTHON_SMT_OPS
diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py
new file mode 100644
index 0000000000000..7948486988b4c
--- /dev/null
+++ b/mlir/python/mlir/dialects/smt.py
@@ -0,0 +1,5 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._smt_ops_gen import *
diff --git a/mlir/test/python/dialects/smt.py b/mlir/test/python/dialects/smt.py
new file mode 100644
index 0000000000000..3e10f3ca35321
--- /dev/null
+++ b/mlir/test/python/dialects/smt.py
@@ -0,0 +1,16 @@
+# REQUIRES: bindings_python
+# RUN: %PYTHON% %s | FileCheck %s
+
+import mlir
+
+from mlir.dialects import smt
+from mlir.ir import Context, Location, Module, InsertionPoint
+
+with Context() as ctx, Location.unknown():
+ m = Module.create()
+ with InsertionPoint(m.body):
+ true = smt.constant(True)
+ false = smt.constant(False)
+ # CHECK: smt.constant true
+ # CHECK: smt.constant false
+ print(m)
>From cc99edeee1edc35308855c531aa497d8d9662f1d Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 16 Apr 2025 12:19:02 -0400
Subject: [PATCH 2/3] rename SMT C APIs
---
mlir/include/mlir-c/Dialect/SMT.h | 69 ++++++++++++-----------
mlir/lib/CAPI/Dialect/SMT.cpp | 52 +++++++++--------
mlir/test/CAPI/smt.c | 94 +++++++++++++++----------------
3 files changed, 110 insertions(+), 105 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SMT.h b/mlir/include/mlir-c/Dialect/SMT.h
index d076dccce1b06..0ad64746f148b 100644
--- a/mlir/include/mlir-c/Dialect/SMT.h
+++ b/mlir/include/mlir-c/Dialect/SMT.h
@@ -26,82 +26,83 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt);
//===----------------------------------------------------------------------===//
/// Checks if the given type is any non-func SMT value type.
-MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type);
/// Checks if the given type is any SMT value type.
-MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnySMTValueType(MlirType type);
/// Checks if the given type is a smt::ArrayType.
-MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAArray(MlirType type);
/// Creates an array type with the given domain and range types.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx,
- MlirType domainType,
- MlirType rangeType);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetArray(MlirContext ctx,
+ MlirType domainType,
+ MlirType rangeType);
/// Checks if the given type is a smt::BitVectorType.
-MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type);
/// Creates a smt::BitVectorType with the given width.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx,
+ int32_t width);
/// Checks if the given type is a smt::BoolType.
-MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type);
/// Creates a smt::BoolType.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx);
/// Checks if the given type is a smt::IntType.
-MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type);
/// Creates a smt::IntType.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx);
/// Checks if the given type is a smt::FuncType.
-MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type);
/// Creates a smt::FuncType with the given domain and range types.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx,
- size_t numberOfDomainTypes,
- const MlirType *domainTypes,
- MlirType rangeType);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx,
+ size_t numberOfDomainTypes,
+ const MlirType *domainTypes,
+ MlirType rangeType);
/// Checks if the given type is a smt::SortType.
-MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASort(MlirType type);
/// Creates a smt::SortType with the given identifier and sort parameters.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx,
- MlirIdentifier identifier,
- size_t numberOfSortParams,
- const MlirType *sortParams);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSort(MlirContext ctx,
+ MlirIdentifier identifier,
+ size_t numberOfSortParams,
+ const MlirType *sortParams);
//===----------------------------------------------------------------------===//
// Attribute API.
//===----------------------------------------------------------------------===//
/// Checks if the given string is a valid smt::BVCmpPredicate.
-MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx,
- MlirStringRef str);
+MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx,
+ MlirStringRef str);
/// Checks if the given string is a valid smt::IntPredicate.
-MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx,
- MlirStringRef str);
+MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckIntPredicate(MlirContext ctx,
+ MlirStringRef str);
/// Checks if the given attribute is a smt::SMTAttribute.
-MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr);
/// Creates a smt::BitVectorAttr with the given value and width.
-MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx,
- uint64_t value,
- unsigned width);
+MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx,
+ uint64_t value,
+ unsigned width);
/// Creates a smt::BVCmpPredicateAttr with the given string.
-MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx,
- MlirStringRef str);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str);
/// Creates a smt::IntPredicateAttr with the given string.
-MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx,
- MlirStringRef str);
+MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx,
+ MlirStringRef str);
#ifdef __cplusplus
}
diff --git a/mlir/lib/CAPI/Dialect/SMT.cpp b/mlir/lib/CAPI/Dialect/SMT.cpp
index 3a4620df8ccdf..7e96bbb071533 100644
--- a/mlir/lib/CAPI/Dialect/SMT.cpp
+++ b/mlir/lib/CAPI/Dialect/SMT.cpp
@@ -25,46 +25,49 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect)
// Type API.
//===----------------------------------------------------------------------===//
-bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) {
+bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) {
return isAnyNonFuncSMTValueType(unwrap(type));
}
-bool smtTypeIsAnySMTValueType(MlirType type) {
+bool mlirSMTTypeIsAnySMTValueType(MlirType type) {
return isAnySMTValueType(unwrap(type));
}
-bool smtTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
+bool mlirSMTTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
-MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType,
- MlirType rangeType) {
+MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType,
+ MlirType rangeType) {
return wrap(
ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType)));
}
-bool smtTypeIsABitVector(MlirType type) {
+bool mlirSMTTypeIsABitVector(MlirType type) {
return isa<BitVectorType>(unwrap(type));
}
-MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) {
+MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) {
return wrap(BitVectorType::get(unwrap(ctx), width));
}
-bool smtTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
+bool mlirSMTTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
-MlirType smtTypeGetBool(MlirContext ctx) {
+MlirType mlirSMTTypeGetBool(MlirContext ctx) {
return wrap(BoolType::get(unwrap(ctx)));
}
-bool smtTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
+bool mlirSMTTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
-MlirType smtTypeGetInt(MlirContext ctx) {
+MlirType mlirSMTTypeGetInt(MlirContext ctx) {
return wrap(IntType::get(unwrap(ctx)));
}
-bool smtTypeIsASMTFunc(MlirType type) { return isa<SMTFuncType>(unwrap(type)); }
+bool mlirSMTTypeIsASMTFunc(MlirType type) {
+ return isa<SMTFuncType>(unwrap(type));
+}
-MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
- const MlirType *domainTypes, MlirType rangeType) {
+MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
+ const MlirType *domainTypes,
+ MlirType rangeType) {
SmallVector<Type> domainTypesVec;
domainTypesVec.reserve(numberOfDomainTypes);
@@ -74,10 +77,11 @@ MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType)));
}
-bool smtTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
+bool mlirSMTTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
-MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
- size_t numberOfSortParams, const MlirType *sortParams) {
+MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
+ size_t numberOfSortParams,
+ const MlirType *sortParams) {
SmallVector<Type> sortParamsVec;
sortParamsVec.reserve(numberOfSortParams);
@@ -91,31 +95,31 @@ MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
// Attribute API.
//===----------------------------------------------------------------------===//
-bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
+bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
return symbolizeBVCmpPredicate(unwrap(str)).has_value();
}
-bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
+bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
return symbolizeIntPredicate(unwrap(str)).has_value();
}
-bool smtAttrIsASMTAttribute(MlirAttribute attr) {
+bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) {
return isa<BitVectorAttr, BVCmpPredicateAttr, IntPredicateAttr>(unwrap(attr));
}
-MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value,
- unsigned width) {
+MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value,
+ unsigned width) {
return wrap(BitVectorAttr::get(unwrap(ctx), value, width));
}
-MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
+MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
auto predicate = symbolizeBVCmpPredicate(unwrap(str));
assert(predicate.has_value() && "invalid predicate");
return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value()));
}
-MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
+MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
auto predicate = symbolizeIntPredicate(unwrap(str));
assert(predicate.has_value() && "invalid predicate");
diff --git a/mlir/test/CAPI/smt.c b/mlir/test/CAPI/smt.c
index 77815d4f79657..d3810a24d929c 100644
--- a/mlir/test/CAPI/smt.c
+++ b/mlir/test/CAPI/smt.c
@@ -44,13 +44,13 @@ void testExportSMTLIB(MlirContext ctx) {
}
void testSMTType(MlirContext ctx) {
- MlirType boolType = smtTypeGetBool(ctx);
- MlirType intType = smtTypeGetInt(ctx);
- MlirType arrayType = smtTypeGetArray(ctx, intType, boolType);
- MlirType bvType = smtTypeGetBitVector(ctx, 32);
+ MlirType boolType = mlirSMTTypeGetBool(ctx);
+ MlirType intType = mlirSMTTypeGetInt(ctx);
+ MlirType arrayType = mlirSMTTypeGetArray(ctx, intType, boolType);
+ MlirType bvType = mlirSMTTypeGetBitVector(ctx, 32);
MlirType funcType =
- smtTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType);
- MlirType sortType = smtTypeGetSort(
+ mlirSMTTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType);
+ MlirType sortType = mlirSMTTypeGetSort(
ctx, mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sort")), 0,
NULL);
@@ -68,107 +68,107 @@ void testSMTType(MlirContext ctx) {
mlirTypeDump(sortType);
// CHECK: bool_is_any_non_func_smt_value_type
- fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(boolType)
+ fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(boolType)
? "bool_is_any_non_func_smt_value_type\n"
: "bool_is_func_smt_value_type\n");
// CHECK: int_is_any_non_func_smt_value_type
- fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(intType)
+ fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(intType)
? "int_is_any_non_func_smt_value_type\n"
: "int_is_func_smt_value_type\n");
// CHECK: array_is_any_non_func_smt_value_type
- fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(arrayType)
+ fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(arrayType)
? "array_is_any_non_func_smt_value_type\n"
: "array_is_func_smt_value_type\n");
// CHECK: bit_vector_is_any_non_func_smt_value_type
- fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(bvType)
+ fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(bvType)
? "bit_vector_is_any_non_func_smt_value_type\n"
: "bit_vector_is_func_smt_value_type\n");
// CHECK: sort_is_any_non_func_smt_value_type
- fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(sortType)
+ fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(sortType)
? "sort_is_any_non_func_smt_value_type\n"
: "sort_is_func_smt_value_type\n");
// CHECK: smt_func_is_func_smt_value_type
- fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(funcType)
+ fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(funcType)
? "smt_func_is_any_non_func_smt_value_type\n"
: "smt_func_is_func_smt_value_type\n");
// CHECK: bool_is_any_smt_value_type
- fprintf(stderr, smtTypeIsAnySMTValueType(boolType)
+ fprintf(stderr, mlirSMTTypeIsAnySMTValueType(boolType)
? "bool_is_any_smt_value_type\n"
: "bool_is_not_any_smt_value_type\n");
// CHECK: int_is_any_smt_value_type
- fprintf(stderr, smtTypeIsAnySMTValueType(intType)
+ fprintf(stderr, mlirSMTTypeIsAnySMTValueType(intType)
? "int_is_any_smt_value_type\n"
: "int_is_not_any_smt_value_type\n");
// CHECK: array_is_any_smt_value_type
- fprintf(stderr, smtTypeIsAnySMTValueType(arrayType)
+ fprintf(stderr, mlirSMTTypeIsAnySMTValueType(arrayType)
? "array_is_any_smt_value_type\n"
: "array_is_not_any_smt_value_type\n");
// CHECK: array_is_any_smt_value_type
- fprintf(stderr, smtTypeIsAnySMTValueType(bvType)
+ fprintf(stderr, mlirSMTTypeIsAnySMTValueType(bvType)
? "array_is_any_smt_value_type\n"
: "array_is_not_any_smt_value_type\n");
// CHECK: smt_func_is_any_smt_value_type
- fprintf(stderr, smtTypeIsAnySMTValueType(funcType)
+ fprintf(stderr, mlirSMTTypeIsAnySMTValueType(funcType)
? "smt_func_is_any_smt_value_type\n"
: "smt_func_is_not_any_smt_value_type\n");
// CHECK: sort_is_any_smt_value_type
- fprintf(stderr, smtTypeIsAnySMTValueType(sortType)
+ fprintf(stderr, mlirSMTTypeIsAnySMTValueType(sortType)
? "sort_is_any_smt_value_type\n"
: "sort_is_not_any_smt_value_type\n");
// CHECK: int_type_is_not_a_bool
- fprintf(stderr, smtTypeIsABool(intType) ? "int_type_is_a_bool\n"
- : "int_type_is_not_a_bool\n");
+ fprintf(stderr, mlirSMTTypeIsABool(intType) ? "int_type_is_a_bool\n"
+ : "int_type_is_not_a_bool\n");
// CHECK: bool_type_is_not_a_int
- fprintf(stderr, smtTypeIsAInt(boolType) ? "bool_type_is_a_int\n"
- : "bool_type_is_not_a_int\n");
+ fprintf(stderr, mlirSMTTypeIsAInt(boolType) ? "bool_type_is_a_int\n"
+ : "bool_type_is_not_a_int\n");
// CHECK: bv_type_is_not_a_array
- fprintf(stderr, smtTypeIsAArray(bvType) ? "bv_type_is_a_array\n"
- : "bv_type_is_not_a_array\n");
+ fprintf(stderr, mlirSMTTypeIsAArray(bvType) ? "bv_type_is_a_array\n"
+ : "bv_type_is_not_a_array\n");
// CHECK: array_type_is_not_a_bit_vector
- fprintf(stderr, smtTypeIsABitVector(arrayType)
+ fprintf(stderr, mlirSMTTypeIsABitVector(arrayType)
? "array_type_is_a_bit_vector\n"
: "array_type_is_not_a_bit_vector\n");
// CHECK: sort_type_is_not_a_smt_func
- fprintf(stderr, smtTypeIsASMTFunc(sortType)
+ fprintf(stderr, mlirSMTTypeIsASMTFunc(sortType)
? "sort_type_is_a_smt_func\n"
: "sort_type_is_not_a_smt_func\n");
// CHECK: func_type_is_not_a_sort
- fprintf(stderr, smtTypeIsASort(funcType) ? "func_type_is_a_sort\n"
- : "func_type_is_not_a_sort\n");
+ fprintf(stderr, mlirSMTTypeIsASort(funcType) ? "func_type_is_a_sort\n"
+ : "func_type_is_not_a_sort\n");
}
void testSMTAttribute(MlirContext ctx) {
// CHECK: slt_is_BVCmpPredicate
- fprintf(stderr,
- smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt"))
- ? "slt_is_BVCmpPredicate\n"
- : "slt_is_not_BVCmpPredicate\n");
+ fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate(
+ ctx, mlirStringRefCreateFromCString("slt"))
+ ? "slt_is_BVCmpPredicate\n"
+ : "slt_is_not_BVCmpPredicate\n");
// CHECK: lt_is_not_BVCmpPredicate
- fprintf(stderr,
- smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("lt"))
- ? "lt_is_BVCmpPredicate\n"
- : "lt_is_not_BVCmpPredicate\n");
+ fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate(
+ ctx, mlirStringRefCreateFromCString("lt"))
+ ? "lt_is_BVCmpPredicate\n"
+ : "lt_is_not_BVCmpPredicate\n");
// CHECK: slt_is_not_IntPredicate
- fprintf(stderr,
- smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("slt"))
- ? "slt_is_IntPredicate\n"
- : "slt_is_not_IntPredicate\n");
+ fprintf(stderr, mlirSMTAttrCheckIntPredicate(
+ ctx, mlirStringRefCreateFromCString("slt"))
+ ? "slt_is_IntPredicate\n"
+ : "slt_is_not_IntPredicate\n");
// CHECK: lt_is_IntPredicate
- fprintf(stderr,
- smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("lt"))
- ? "lt_is_IntPredicate\n"
- : "lt_is_not_IntPredicate\n");
+ fprintf(stderr, mlirSMTAttrCheckIntPredicate(
+ ctx, mlirStringRefCreateFromCString("lt"))
+ ? "lt_is_IntPredicate\n"
+ : "lt_is_not_IntPredicate\n");
// CHECK: #smt.bv<5> : !smt.bv<32>
- mlirAttributeDump(smtAttrGetBitVector(ctx, 5, 32));
+ mlirAttributeDump(mlirSMTAttrGetBitVector(ctx, 5, 32));
// CHECK: 0 : i64
mlirAttributeDump(
- smtAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt")));
+ mlirSMTAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt")));
// CHECK: 0 : i64
mlirAttributeDump(
- smtAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt")));
+ mlirSMTAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt")));
}
int main(void) {
>From d864f80cdd5512cbfc052339a49c8a83657364c8 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 16 Apr 2025 13:23:56 -0400
Subject: [PATCH 3/3] bind more stuff
---
mlir/include/mlir-c/Target/ExportSMTLIB.h | 10 ++-
mlir/lib/Bindings/Python/DialectSMT.cpp | 83 ++++++++++++++++++++
mlir/lib/CAPI/Target/ExportSMTLIB.cpp | 21 ++++-
mlir/python/CMakeLists.txt | 15 ++++
mlir/python/mlir/dialects/smt.py | 28 +++++++
mlir/test/CAPI/smt.c | 3 +-
mlir/test/python/dialects/smt.py | 93 ++++++++++++++++++++---
7 files changed, 235 insertions(+), 18 deletions(-)
create mode 100644 mlir/lib/Bindings/Python/DialectSMT.cpp
diff --git a/mlir/include/mlir-c/Target/ExportSMTLIB.h b/mlir/include/mlir-c/Target/ExportSMTLIB.h
index 31f411c4a89c2..c4e746a4f4540 100644
--- a/mlir/include/mlir-c/Target/ExportSMTLIB.h
+++ b/mlir/include/mlir-c/Target/ExportSMTLIB.h
@@ -21,9 +21,13 @@ extern "C" {
/// Emits SMTLIB for the specified module using the provided callback and user
/// data
-MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule,
- MlirStringCallback,
- void *userData);
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirTranslateModuleSMTLIB(MlirModule, MlirStringCallback, void *userData,
+ bool inlineSingleUseValues, bool indentLetBody);
+
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB(
+ MlirOperation, MlirStringCallback, void *userData,
+ bool inlineSingleUseValues, bool indentLetBody);
#ifdef __cplusplus
}
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
new file mode 100644
index 0000000000000..4e7647729fb0a
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -0,0 +1,83 @@
+//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "NanobindUtils.h"
+
+#include "mlir-c/Dialect/SMT.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir-c/Target/ExportSMTLIB.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+
+using namespace nanobind::literals;
+
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+void populateDialectSMTSubmodule(nanobind::module_ &m) {
+
+ auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
+ .def_classmethod(
+ "get",
+ [](const nb::object &, MlirContext context) {
+ return mlirSMTTypeGetBool(context);
+ },
+ "cls"_a, "context"_a.none() = nb::none());
+ auto smtBitVectorType =
+ mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
+ .def_classmethod(
+ "get",
+ [](const nb::object &, int32_t width, MlirContext context) {
+ return mlirSMTTypeGetBitVector(context, width);
+ },
+ "cls"_a, "width"_a, "context"_a.none() = nb::none());
+
+ auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
+ bool indentLetBody) {
+ mlir::python::CollectDiagnosticsToStringScope scope(
+ mlirOperationGetContext(module));
+ PyPrintAccumulator printAccum;
+ MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
+ module, printAccum.getCallback(), printAccum.getUserData(),
+ inlineSingleUseValues, indentLetBody);
+ if (mlirLogicalResultIsSuccess(result))
+ return printAccum.join();
+ throw nb::value_error(
+ ("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage())
+ .c_str());
+ };
+
+ m.def(
+ "export_smtlib",
+ [&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues,
+ bool indentLetBody) {
+ return exportSMTLIB(module, inlineSingleUseValues, indentLetBody);
+ },
+ "module"_a, "inline_single_use_values"_a = false,
+ "indent_let_body"_a = false);
+ m.def(
+ "export_smtlib",
+ [&exportSMTLIB](MlirModule module, bool inlineSingleUseValues,
+ bool indentLetBody) {
+ return exportSMTLIB(mlirModuleGetOperation(module),
+ inlineSingleUseValues, indentLetBody);
+ },
+ "module"_a, "inline_single_use_values"_a = false,
+ "indent_let_body"_a = false);
+}
+
+NB_MODULE(_mlirDialectsSMT, m) {
+ m.doc() = "MLIR SMT Dialect";
+
+ populateDialectSMTSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp
index c9ac7ce704af8..0d5c60ad779d5 100644
--- a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp
+++ b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp
@@ -19,9 +19,24 @@
using namespace mlir;
-MlirLogicalResult mlirExportSMTLIB(MlirModule module,
- MlirStringCallback callback,
- void *userData) {
+MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module,
+ MlirStringCallback callback,
+ void *userData,
+ bool inlineSingleUseValues,
+ bool indentLetBody) {
mlir::detail::CallbackOstream stream(callback, userData);
+ smt::SMTEmissionOptions options;
+ options.inlineSingleUseValues = inlineSingleUseValues;
+ options.indentLetBody = indentLetBody;
return wrap(smt::exportSMTLIB(unwrap(module), stream));
}
+
+MlirLogicalResult mlirTranslateModuleSMTLIB(MlirModule module,
+ MlirStringCallback callback,
+ void *userData,
+ bool inlineSingleUseValues,
+ bool indentLetBody) {
+ return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module),
+ callback, userData,
+ inlineSingleUseValues, indentLetBody);
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 3985668486931..bbf6819608bb9 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -673,6 +673,21 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
MLIRCAPILinalg
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
+ MODULE_NAME _mlirDialectsSMT
+ ADD_TO_PARENT MLIRPythonSources.Dialects.smt
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
+ SOURCES
+ DialectSMT.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPISMT
+ MLIRCAPIExportSMTLIB
+)
+
declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
MODULE_NAME _mlirSparseTensorPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py
index 7948486988b4c..ae7a4c41cbc3a 100644
--- a/mlir/python/mlir/dialects/smt.py
+++ b/mlir/python/mlir/dialects/smt.py
@@ -3,3 +3,31 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._smt_ops_gen import *
+
+from .._mlir_libs._mlirDialectsSMT import *
+from ..extras.meta import region_op
+
+
+def bool_t():
+ return BoolType.get()
+
+
+def bv_t(width):
+ return BitVectorType.get(width)
+
+
+def _solver(
+ inputs=None,
+ results=None,
+ loc=None,
+ ip=None,
+):
+ if inputs is None:
+ inputs = []
+ if results is None:
+ results = []
+
+ return SolverOp(results, inputs, loc=loc, ip=ip)
+
+
+solver = region_op(_solver, terminator=YieldOp)
diff --git a/mlir/test/CAPI/smt.c b/mlir/test/CAPI/smt.c
index d3810a24d929c..6c177e27e71ee 100644
--- a/mlir/test/CAPI/smt.c
+++ b/mlir/test/CAPI/smt.c
@@ -34,7 +34,8 @@ void testExportSMTLIB(MlirContext ctx) {
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(testSMT));
- MlirLogicalResult result = mlirExportSMTLIB(module, dumpCallback, NULL);
+ MlirLogicalResult result =
+ mlirTranslateModuleSMTLIB(module, dumpCallback, NULL, false, false);
(void)result;
assert(mlirLogicalResultIsSuccess(result));
diff --git a/mlir/test/python/dialects/smt.py b/mlir/test/python/dialects/smt.py
index 3e10f3ca35321..6f0cd8835b65b 100644
--- a/mlir/test/python/dialects/smt.py
+++ b/mlir/test/python/dialects/smt.py
@@ -1,16 +1,87 @@
-# REQUIRES: bindings_python
-# RUN: %PYTHON% %s | FileCheck %s
+# RUN: %PYTHON %s | FileCheck %s
-import mlir
+from mlir.dialects import smt, arith
+from mlir.ir import Context, Location, Module, InsertionPoint, F32Type
-from mlir.dialects import smt
-from mlir.ir import Context, Location, Module, InsertionPoint
-with Context() as ctx, Location.unknown():
- m = Module.create()
- with InsertionPoint(m.body):
- true = smt.constant(True)
- false = smt.constant(False)
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f(module)
+ print(module)
+ assert module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_smoke
+ at run
+def test_smoke(_module):
+ true = smt.constant(True)
+ false = smt.constant(False)
# CHECK: smt.constant true
# CHECK: smt.constant false
- print(m)
+
+
+# CHECK-LABEL: TEST: test_types
+ at run
+def test_types(_module):
+ bool_t = smt.bool_t()
+ bitvector_t = smt.bv_t(5)
+ # CHECK: !smt.bool
+ print(bool_t)
+ # CHECK: !smt.bv<5>
+ print(bitvector_t)
+
+
+# CHECK-LABEL: TEST: test_solver_op
+ at run
+def test_solver_op(_module):
+ @smt.solver
+ def foo1():
+ true = smt.constant(True)
+ false = smt.constant(False)
+
+ # CHECK: smt.solver() : () -> () {
+ # CHECK: %true = smt.constant true
+ # CHECK: %false = smt.constant false
+ # CHECK: }
+
+ f32 = F32Type.get()
+
+ @smt.solver(results=[f32])
+ def foo2():
+ return arith.ConstantOp(f32, 1.0)
+
+ # CHECK: %{{.*}} = smt.solver() : () -> f32 {
+ # CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
+ # CHECK: smt.yield %[[CST1]] : f32
+ # CHECK: }
+
+ two = arith.ConstantOp(f32, 2.0)
+ # CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+ print(two)
+
+ @smt.solver(inputs=[two], results=[f32])
+ def foo3(z: f32):
+ return z
+
+ # CHECK: %{{.*}} = smt.solver(%[[CST2]]) : (f32) -> f32 {
+ # CHECK: ^bb0(%[[ARG0:.*]]: f32):
+ # CHECK: smt.yield %[[ARG0]] : f32
+ # CHECK: }
+
+
+# CHECK-LABEL: TEST: test_export_smtlib
+ at run
+def test_export_smtlib(module):
+ @smt.solver
+ def foo1():
+ true = smt.constant(True)
+ smt.assert_(true)
+
+ query = smt.export_smtlib(module.operation)
+ # CHECK: ; solver scope 0
+ # CHECK: (assert true)
+ # CHECK: (reset)
+ print(query)
More information about the Mlir-commits
mailing list