[Mlir-commits] [mlir] [mlir][SMT] add python bindings (PR #135674)

Maksim Levental llvmlistbot at llvm.org
Wed Apr 16 10:24:07 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/135674

>From 3ba88a0333a158eac9ed6ce93fa738ade4477e89 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Mon, 14 Apr 2025 16:35:40 -0400
Subject: [PATCH 1/3] [mlir][SMT] add python bindings

---
 mlir/python/CMakeLists.txt          |  9 +++++++++
 mlir/python/mlir/dialects/SMTOps.td | 14 ++++++++++++++
 mlir/python/mlir/dialects/smt.py    |  5 +++++
 mlir/test/python/dialects/smt.py    | 16 ++++++++++++++++
 4 files changed, 44 insertions(+)
 create mode 100644 mlir/python/mlir/dialects/SMTOps.td
 create mode 100644 mlir/python/mlir/dialects/smt.py
 create mode 100644 mlir/test/python/dialects/smt.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index fb115a5f43423..3985668486931 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -403,6 +403,15 @@ declare_mlir_dialect_python_bindings(
     "../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
 )
 
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/SMTOps.td
+  GEN_ENUM_BINDINGS
+  SOURCES
+    dialects/smt.py
+  DIALECT_NAME smt)
+
 declare_mlir_dialect_python_bindings(
     ADD_TO_PARENT MLIRPythonSources.Dialects
     ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/SMTOps.td b/mlir/python/mlir/dialects/SMTOps.td
new file mode 100644
index 0000000000000..e143f071eb658
--- /dev/null
+++ b/mlir/python/mlir/dialects/SMTOps.td
@@ -0,0 +1,14 @@
+//===- SMTOps.td - Entry point for SMT bindings ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BINDINGS_PYTHON_SMT_OPS
+#define BINDINGS_PYTHON_SMT_OPS
+
+include "mlir/Dialect/SMT/IR/SMT.td"
+
+#endif // BINDINGS_PYTHON_SMT_OPS
diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py
new file mode 100644
index 0000000000000..7948486988b4c
--- /dev/null
+++ b/mlir/python/mlir/dialects/smt.py
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._smt_ops_gen import *
diff --git a/mlir/test/python/dialects/smt.py b/mlir/test/python/dialects/smt.py
new file mode 100644
index 0000000000000..3e10f3ca35321
--- /dev/null
+++ b/mlir/test/python/dialects/smt.py
@@ -0,0 +1,16 @@
+# REQUIRES: bindings_python
+# RUN: %PYTHON% %s | FileCheck %s
+
+import mlir
+
+from mlir.dialects import smt
+from mlir.ir import Context, Location, Module, InsertionPoint
+
+with Context() as ctx, Location.unknown():
+    m = Module.create()
+    with InsertionPoint(m.body):
+        true = smt.constant(True)
+        false = smt.constant(False)
+    # CHECK: smt.constant true
+    # CHECK: smt.constant false
+    print(m)

>From cc99edeee1edc35308855c531aa497d8d9662f1d Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 16 Apr 2025 12:19:02 -0400
Subject: [PATCH 2/3] rename SMT C APIs

---
 mlir/include/mlir-c/Dialect/SMT.h | 69 ++++++++++++-----------
 mlir/lib/CAPI/Dialect/SMT.cpp     | 52 +++++++++--------
 mlir/test/CAPI/smt.c              | 94 +++++++++++++++----------------
 3 files changed, 110 insertions(+), 105 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SMT.h b/mlir/include/mlir-c/Dialect/SMT.h
index d076dccce1b06..0ad64746f148b 100644
--- a/mlir/include/mlir-c/Dialect/SMT.h
+++ b/mlir/include/mlir-c/Dialect/SMT.h
@@ -26,82 +26,83 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt);
 //===----------------------------------------------------------------------===//
 
 /// Checks if the given type is any non-func SMT value type.
-MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type);
 
 /// Checks if the given type is any SMT value type.
-MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnySMTValueType(MlirType type);
 
 /// Checks if the given type is a smt::ArrayType.
-MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAArray(MlirType type);
 
 /// Creates an array type with the given domain and range types.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx,
-                                            MlirType domainType,
-                                            MlirType rangeType);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetArray(MlirContext ctx,
+                                                MlirType domainType,
+                                                MlirType rangeType);
 
 /// Checks if the given type is a smt::BitVectorType.
-MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type);
 
 /// Creates a smt::BitVectorType with the given width.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx,
+                                                    int32_t width);
 
 /// Checks if the given type is a smt::BoolType.
-MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type);
 
 /// Creates a smt::BoolType.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx);
 
 /// Checks if the given type is a smt::IntType.
-MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type);
 
 /// Creates a smt::IntType.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx);
 
 /// Checks if the given type is a smt::FuncType.
-MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type);
 
 /// Creates a smt::FuncType with the given domain and range types.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx,
-                                              size_t numberOfDomainTypes,
-                                              const MlirType *domainTypes,
-                                              MlirType rangeType);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx,
+                                                  size_t numberOfDomainTypes,
+                                                  const MlirType *domainTypes,
+                                                  MlirType rangeType);
 
 /// Checks if the given type is a smt::SortType.
-MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type);
+MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASort(MlirType type);
 
 /// Creates a smt::SortType with the given identifier and sort parameters.
-MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx,
-                                           MlirIdentifier identifier,
-                                           size_t numberOfSortParams,
-                                           const MlirType *sortParams);
+MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSort(MlirContext ctx,
+                                               MlirIdentifier identifier,
+                                               size_t numberOfSortParams,
+                                               const MlirType *sortParams);
 
 //===----------------------------------------------------------------------===//
 // Attribute API.
 //===----------------------------------------------------------------------===//
 
 /// Checks if the given string is a valid smt::BVCmpPredicate.
-MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx,
-                                                   MlirStringRef str);
+MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx,
+                                                       MlirStringRef str);
 
 /// Checks if the given string is a valid smt::IntPredicate.
-MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx,
-                                                 MlirStringRef str);
+MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckIntPredicate(MlirContext ctx,
+                                                     MlirStringRef str);
 
 /// Checks if the given attribute is a smt::SMTAttribute.
-MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr);
 
 /// Creates a smt::BitVectorAttr with the given value and width.
-MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx,
-                                                     uint64_t value,
-                                                     unsigned width);
+MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx,
+                                                         uint64_t value,
+                                                         unsigned width);
 
 /// Creates a smt::BVCmpPredicateAttr with the given string.
-MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx,
-                                                          MlirStringRef str);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str);
 
 /// Creates a smt::IntPredicateAttr with the given string.
-MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx,
-                                                        MlirStringRef str);
+MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx,
+                                                            MlirStringRef str);
 
 #ifdef __cplusplus
 }
diff --git a/mlir/lib/CAPI/Dialect/SMT.cpp b/mlir/lib/CAPI/Dialect/SMT.cpp
index 3a4620df8ccdf..7e96bbb071533 100644
--- a/mlir/lib/CAPI/Dialect/SMT.cpp
+++ b/mlir/lib/CAPI/Dialect/SMT.cpp
@@ -25,46 +25,49 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect)
 // Type API.
 //===----------------------------------------------------------------------===//
 
-bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) {
+bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) {
   return isAnyNonFuncSMTValueType(unwrap(type));
 }
 
-bool smtTypeIsAnySMTValueType(MlirType type) {
+bool mlirSMTTypeIsAnySMTValueType(MlirType type) {
   return isAnySMTValueType(unwrap(type));
 }
 
-bool smtTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
+bool mlirSMTTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
 
-MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType,
-                         MlirType rangeType) {
+MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType,
+                             MlirType rangeType) {
   return wrap(
       ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType)));
 }
 
-bool smtTypeIsABitVector(MlirType type) {
+bool mlirSMTTypeIsABitVector(MlirType type) {
   return isa<BitVectorType>(unwrap(type));
 }
 
-MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) {
+MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) {
   return wrap(BitVectorType::get(unwrap(ctx), width));
 }
 
-bool smtTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
+bool mlirSMTTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
 
-MlirType smtTypeGetBool(MlirContext ctx) {
+MlirType mlirSMTTypeGetBool(MlirContext ctx) {
   return wrap(BoolType::get(unwrap(ctx)));
 }
 
-bool smtTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
+bool mlirSMTTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
 
-MlirType smtTypeGetInt(MlirContext ctx) {
+MlirType mlirSMTTypeGetInt(MlirContext ctx) {
   return wrap(IntType::get(unwrap(ctx)));
 }
 
-bool smtTypeIsASMTFunc(MlirType type) { return isa<SMTFuncType>(unwrap(type)); }
+bool mlirSMTTypeIsASMTFunc(MlirType type) {
+  return isa<SMTFuncType>(unwrap(type));
+}
 
-MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
-                           const MlirType *domainTypes, MlirType rangeType) {
+MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
+                               const MlirType *domainTypes,
+                               MlirType rangeType) {
   SmallVector<Type> domainTypesVec;
   domainTypesVec.reserve(numberOfDomainTypes);
 
@@ -74,10 +77,11 @@ MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
   return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType)));
 }
 
-bool smtTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
+bool mlirSMTTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
 
-MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
-                        size_t numberOfSortParams, const MlirType *sortParams) {
+MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
+                            size_t numberOfSortParams,
+                            const MlirType *sortParams) {
   SmallVector<Type> sortParamsVec;
   sortParamsVec.reserve(numberOfSortParams);
 
@@ -91,31 +95,31 @@ MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
 // Attribute API.
 //===----------------------------------------------------------------------===//
 
-bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
+bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
   return symbolizeBVCmpPredicate(unwrap(str)).has_value();
 }
 
-bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
+bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
   return symbolizeIntPredicate(unwrap(str)).has_value();
 }
 
-bool smtAttrIsASMTAttribute(MlirAttribute attr) {
+bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) {
   return isa<BitVectorAttr, BVCmpPredicateAttr, IntPredicateAttr>(unwrap(attr));
 }
 
-MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value,
-                                  unsigned width) {
+MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value,
+                                      unsigned width) {
   return wrap(BitVectorAttr::get(unwrap(ctx), value, width));
 }
 
-MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
+MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
   auto predicate = symbolizeBVCmpPredicate(unwrap(str));
   assert(predicate.has_value() && "invalid predicate");
 
   return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value()));
 }
 
-MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
+MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
   auto predicate = symbolizeIntPredicate(unwrap(str));
   assert(predicate.has_value() && "invalid predicate");
 
diff --git a/mlir/test/CAPI/smt.c b/mlir/test/CAPI/smt.c
index 77815d4f79657..d3810a24d929c 100644
--- a/mlir/test/CAPI/smt.c
+++ b/mlir/test/CAPI/smt.c
@@ -44,13 +44,13 @@ void testExportSMTLIB(MlirContext ctx) {
 }
 
 void testSMTType(MlirContext ctx) {
-  MlirType boolType = smtTypeGetBool(ctx);
-  MlirType intType = smtTypeGetInt(ctx);
-  MlirType arrayType = smtTypeGetArray(ctx, intType, boolType);
-  MlirType bvType = smtTypeGetBitVector(ctx, 32);
+  MlirType boolType = mlirSMTTypeGetBool(ctx);
+  MlirType intType = mlirSMTTypeGetInt(ctx);
+  MlirType arrayType = mlirSMTTypeGetArray(ctx, intType, boolType);
+  MlirType bvType = mlirSMTTypeGetBitVector(ctx, 32);
   MlirType funcType =
-      smtTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType);
-  MlirType sortType = smtTypeGetSort(
+      mlirSMTTypeGetSMTFunc(ctx, 2, (MlirType[]){intType, boolType}, boolType);
+  MlirType sortType = mlirSMTTypeGetSort(
       ctx, mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sort")), 0,
       NULL);
 
@@ -68,107 +68,107 @@ void testSMTType(MlirContext ctx) {
   mlirTypeDump(sortType);
 
   // CHECK: bool_is_any_non_func_smt_value_type
-  fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(boolType)
+  fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(boolType)
                       ? "bool_is_any_non_func_smt_value_type\n"
                       : "bool_is_func_smt_value_type\n");
   // CHECK: int_is_any_non_func_smt_value_type
-  fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(intType)
+  fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(intType)
                       ? "int_is_any_non_func_smt_value_type\n"
                       : "int_is_func_smt_value_type\n");
   // CHECK: array_is_any_non_func_smt_value_type
-  fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(arrayType)
+  fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(arrayType)
                       ? "array_is_any_non_func_smt_value_type\n"
                       : "array_is_func_smt_value_type\n");
   // CHECK: bit_vector_is_any_non_func_smt_value_type
-  fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(bvType)
+  fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(bvType)
                       ? "bit_vector_is_any_non_func_smt_value_type\n"
                       : "bit_vector_is_func_smt_value_type\n");
   // CHECK: sort_is_any_non_func_smt_value_type
-  fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(sortType)
+  fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(sortType)
                       ? "sort_is_any_non_func_smt_value_type\n"
                       : "sort_is_func_smt_value_type\n");
   // CHECK: smt_func_is_func_smt_value_type
-  fprintf(stderr, smtTypeIsAnyNonFuncSMTValueType(funcType)
+  fprintf(stderr, mlirSMTTypeIsAnyNonFuncSMTValueType(funcType)
                       ? "smt_func_is_any_non_func_smt_value_type\n"
                       : "smt_func_is_func_smt_value_type\n");
 
   // CHECK: bool_is_any_smt_value_type
-  fprintf(stderr, smtTypeIsAnySMTValueType(boolType)
+  fprintf(stderr, mlirSMTTypeIsAnySMTValueType(boolType)
                       ? "bool_is_any_smt_value_type\n"
                       : "bool_is_not_any_smt_value_type\n");
   // CHECK: int_is_any_smt_value_type
-  fprintf(stderr, smtTypeIsAnySMTValueType(intType)
+  fprintf(stderr, mlirSMTTypeIsAnySMTValueType(intType)
                       ? "int_is_any_smt_value_type\n"
                       : "int_is_not_any_smt_value_type\n");
   // CHECK: array_is_any_smt_value_type
-  fprintf(stderr, smtTypeIsAnySMTValueType(arrayType)
+  fprintf(stderr, mlirSMTTypeIsAnySMTValueType(arrayType)
                       ? "array_is_any_smt_value_type\n"
                       : "array_is_not_any_smt_value_type\n");
   // CHECK: array_is_any_smt_value_type
-  fprintf(stderr, smtTypeIsAnySMTValueType(bvType)
+  fprintf(stderr, mlirSMTTypeIsAnySMTValueType(bvType)
                       ? "array_is_any_smt_value_type\n"
                       : "array_is_not_any_smt_value_type\n");
   // CHECK: smt_func_is_any_smt_value_type
-  fprintf(stderr, smtTypeIsAnySMTValueType(funcType)
+  fprintf(stderr, mlirSMTTypeIsAnySMTValueType(funcType)
                       ? "smt_func_is_any_smt_value_type\n"
                       : "smt_func_is_not_any_smt_value_type\n");
   // CHECK: sort_is_any_smt_value_type
-  fprintf(stderr, smtTypeIsAnySMTValueType(sortType)
+  fprintf(stderr, mlirSMTTypeIsAnySMTValueType(sortType)
                       ? "sort_is_any_smt_value_type\n"
                       : "sort_is_not_any_smt_value_type\n");
 
   // CHECK: int_type_is_not_a_bool
-  fprintf(stderr, smtTypeIsABool(intType) ? "int_type_is_a_bool\n"
-                                          : "int_type_is_not_a_bool\n");
+  fprintf(stderr, mlirSMTTypeIsABool(intType) ? "int_type_is_a_bool\n"
+                                              : "int_type_is_not_a_bool\n");
   // CHECK: bool_type_is_not_a_int
-  fprintf(stderr, smtTypeIsAInt(boolType) ? "bool_type_is_a_int\n"
-                                          : "bool_type_is_not_a_int\n");
+  fprintf(stderr, mlirSMTTypeIsAInt(boolType) ? "bool_type_is_a_int\n"
+                                              : "bool_type_is_not_a_int\n");
   // CHECK: bv_type_is_not_a_array
-  fprintf(stderr, smtTypeIsAArray(bvType) ? "bv_type_is_a_array\n"
-                                          : "bv_type_is_not_a_array\n");
+  fprintf(stderr, mlirSMTTypeIsAArray(bvType) ? "bv_type_is_a_array\n"
+                                              : "bv_type_is_not_a_array\n");
   // CHECK: array_type_is_not_a_bit_vector
-  fprintf(stderr, smtTypeIsABitVector(arrayType)
+  fprintf(stderr, mlirSMTTypeIsABitVector(arrayType)
                       ? "array_type_is_a_bit_vector\n"
                       : "array_type_is_not_a_bit_vector\n");
   // CHECK: sort_type_is_not_a_smt_func
-  fprintf(stderr, smtTypeIsASMTFunc(sortType)
+  fprintf(stderr, mlirSMTTypeIsASMTFunc(sortType)
                       ? "sort_type_is_a_smt_func\n"
                       : "sort_type_is_not_a_smt_func\n");
   // CHECK: func_type_is_not_a_sort
-  fprintf(stderr, smtTypeIsASort(funcType) ? "func_type_is_a_sort\n"
-                                           : "func_type_is_not_a_sort\n");
+  fprintf(stderr, mlirSMTTypeIsASort(funcType) ? "func_type_is_a_sort\n"
+                                               : "func_type_is_not_a_sort\n");
 }
 
 void testSMTAttribute(MlirContext ctx) {
   // CHECK: slt_is_BVCmpPredicate
-  fprintf(stderr,
-          smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt"))
-              ? "slt_is_BVCmpPredicate\n"
-              : "slt_is_not_BVCmpPredicate\n");
+  fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate(
+                      ctx, mlirStringRefCreateFromCString("slt"))
+                      ? "slt_is_BVCmpPredicate\n"
+                      : "slt_is_not_BVCmpPredicate\n");
   // CHECK: lt_is_not_BVCmpPredicate
-  fprintf(stderr,
-          smtAttrCheckBVCmpPredicate(ctx, mlirStringRefCreateFromCString("lt"))
-              ? "lt_is_BVCmpPredicate\n"
-              : "lt_is_not_BVCmpPredicate\n");
+  fprintf(stderr, mlirSMTAttrCheckBVCmpPredicate(
+                      ctx, mlirStringRefCreateFromCString("lt"))
+                      ? "lt_is_BVCmpPredicate\n"
+                      : "lt_is_not_BVCmpPredicate\n");
   // CHECK: slt_is_not_IntPredicate
-  fprintf(stderr,
-          smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("slt"))
-              ? "slt_is_IntPredicate\n"
-              : "slt_is_not_IntPredicate\n");
+  fprintf(stderr, mlirSMTAttrCheckIntPredicate(
+                      ctx, mlirStringRefCreateFromCString("slt"))
+                      ? "slt_is_IntPredicate\n"
+                      : "slt_is_not_IntPredicate\n");
   // CHECK: lt_is_IntPredicate
-  fprintf(stderr,
-          smtAttrCheckIntPredicate(ctx, mlirStringRefCreateFromCString("lt"))
-              ? "lt_is_IntPredicate\n"
-              : "lt_is_not_IntPredicate\n");
+  fprintf(stderr, mlirSMTAttrCheckIntPredicate(
+                      ctx, mlirStringRefCreateFromCString("lt"))
+                      ? "lt_is_IntPredicate\n"
+                      : "lt_is_not_IntPredicate\n");
 
   // CHECK: #smt.bv<5> : !smt.bv<32>
-  mlirAttributeDump(smtAttrGetBitVector(ctx, 5, 32));
+  mlirAttributeDump(mlirSMTAttrGetBitVector(ctx, 5, 32));
   // CHECK: 0 : i64
   mlirAttributeDump(
-      smtAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt")));
+      mlirSMTAttrGetBVCmpPredicate(ctx, mlirStringRefCreateFromCString("slt")));
   // CHECK: 0 : i64
   mlirAttributeDump(
-      smtAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt")));
+      mlirSMTAttrGetIntPredicate(ctx, mlirStringRefCreateFromCString("lt")));
 }
 
 int main(void) {

>From d864f80cdd5512cbfc052339a49c8a83657364c8 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 16 Apr 2025 13:23:56 -0400
Subject: [PATCH 3/3] bind more stuff

---
 mlir/include/mlir-c/Target/ExportSMTLIB.h | 10 ++-
 mlir/lib/Bindings/Python/DialectSMT.cpp   | 83 ++++++++++++++++++++
 mlir/lib/CAPI/Target/ExportSMTLIB.cpp     | 21 ++++-
 mlir/python/CMakeLists.txt                | 15 ++++
 mlir/python/mlir/dialects/smt.py          | 28 +++++++
 mlir/test/CAPI/smt.c                      |  3 +-
 mlir/test/python/dialects/smt.py          | 93 ++++++++++++++++++++---
 7 files changed, 235 insertions(+), 18 deletions(-)
 create mode 100644 mlir/lib/Bindings/Python/DialectSMT.cpp

diff --git a/mlir/include/mlir-c/Target/ExportSMTLIB.h b/mlir/include/mlir-c/Target/ExportSMTLIB.h
index 31f411c4a89c2..c4e746a4f4540 100644
--- a/mlir/include/mlir-c/Target/ExportSMTLIB.h
+++ b/mlir/include/mlir-c/Target/ExportSMTLIB.h
@@ -21,9 +21,13 @@ extern "C" {
 
 /// Emits SMTLIB for the specified module using the provided callback and user
 /// data
-MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule,
-                                                      MlirStringCallback,
-                                                      void *userData);
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirTranslateModuleSMTLIB(MlirModule, MlirStringCallback, void *userData,
+                          bool inlineSingleUseValues, bool indentLetBody);
+
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB(
+    MlirOperation, MlirStringCallback, void *userData,
+    bool inlineSingleUseValues, bool indentLetBody);
 
 #ifdef __cplusplus
 }
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
new file mode 100644
index 0000000000000..4e7647729fb0a
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -0,0 +1,83 @@
+//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "NanobindUtils.h"
+
+#include "mlir-c/Dialect/SMT.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir-c/Target/ExportSMTLIB.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+
+using namespace nanobind::literals;
+
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+void populateDialectSMTSubmodule(nanobind::module_ &m) {
+
+  auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
+                         .def_classmethod(
+                             "get",
+                             [](const nb::object &, MlirContext context) {
+                               return mlirSMTTypeGetBool(context);
+                             },
+                             "cls"_a, "context"_a.none() = nb::none());
+  auto smtBitVectorType =
+      mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
+          .def_classmethod(
+              "get",
+              [](const nb::object &, int32_t width, MlirContext context) {
+                return mlirSMTTypeGetBitVector(context, width);
+              },
+              "cls"_a, "width"_a, "context"_a.none() = nb::none());
+
+  auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
+                         bool indentLetBody) {
+    mlir::python::CollectDiagnosticsToStringScope scope(
+        mlirOperationGetContext(module));
+    PyPrintAccumulator printAccum;
+    MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
+        module, printAccum.getCallback(), printAccum.getUserData(),
+        inlineSingleUseValues, indentLetBody);
+    if (mlirLogicalResultIsSuccess(result))
+      return printAccum.join();
+    throw nb::value_error(
+        ("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage())
+            .c_str());
+  };
+
+  m.def(
+      "export_smtlib",
+      [&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues,
+                      bool indentLetBody) {
+        return exportSMTLIB(module, inlineSingleUseValues, indentLetBody);
+      },
+      "module"_a, "inline_single_use_values"_a = false,
+      "indent_let_body"_a = false);
+  m.def(
+      "export_smtlib",
+      [&exportSMTLIB](MlirModule module, bool inlineSingleUseValues,
+                      bool indentLetBody) {
+        return exportSMTLIB(mlirModuleGetOperation(module),
+                            inlineSingleUseValues, indentLetBody);
+      },
+      "module"_a, "inline_single_use_values"_a = false,
+      "indent_let_body"_a = false);
+}
+
+NB_MODULE(_mlirDialectsSMT, m) {
+  m.doc() = "MLIR SMT Dialect";
+
+  populateDialectSMTSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp
index c9ac7ce704af8..0d5c60ad779d5 100644
--- a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp
+++ b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp
@@ -19,9 +19,24 @@
 
 using namespace mlir;
 
-MlirLogicalResult mlirExportSMTLIB(MlirModule module,
-                                   MlirStringCallback callback,
-                                   void *userData) {
+MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module,
+                                                 MlirStringCallback callback,
+                                                 void *userData,
+                                                 bool inlineSingleUseValues,
+                                                 bool indentLetBody) {
   mlir::detail::CallbackOstream stream(callback, userData);
+  smt::SMTEmissionOptions options;
+  options.inlineSingleUseValues = inlineSingleUseValues;
+  options.indentLetBody = indentLetBody;
   return wrap(smt::exportSMTLIB(unwrap(module), stream));
 }
+
+MlirLogicalResult mlirTranslateModuleSMTLIB(MlirModule module,
+                                            MlirStringCallback callback,
+                                            void *userData,
+                                            bool inlineSingleUseValues,
+                                            bool indentLetBody) {
+  return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module),
+                                        callback, userData,
+                                        inlineSingleUseValues, indentLetBody);
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 3985668486931..bbf6819608bb9 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -673,6 +673,21 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
     MLIRCAPILinalg
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
+  MODULE_NAME _mlirDialectsSMT
+  ADD_TO_PARENT MLIRPythonSources.Dialects.smt
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
+  SOURCES
+    DialectSMT.cpp
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIIR
+    MLIRCAPISMT
+    MLIRCAPIExportSMTLIB
+)
+
 declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
   MODULE_NAME _mlirSparseTensorPasses
   ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py
index 7948486988b4c..ae7a4c41cbc3a 100644
--- a/mlir/python/mlir/dialects/smt.py
+++ b/mlir/python/mlir/dialects/smt.py
@@ -3,3 +3,31 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._smt_ops_gen import *
+
+from .._mlir_libs._mlirDialectsSMT import *
+from ..extras.meta import region_op
+
+
+def bool_t():
+    return BoolType.get()
+
+
+def bv_t(width):
+    return BitVectorType.get(width)
+
+
+def _solver(
+    inputs=None,
+    results=None,
+    loc=None,
+    ip=None,
+):
+    if inputs is None:
+        inputs = []
+    if results is None:
+        results = []
+
+    return SolverOp(results, inputs, loc=loc, ip=ip)
+
+
+solver = region_op(_solver, terminator=YieldOp)
diff --git a/mlir/test/CAPI/smt.c b/mlir/test/CAPI/smt.c
index d3810a24d929c..6c177e27e71ee 100644
--- a/mlir/test/CAPI/smt.c
+++ b/mlir/test/CAPI/smt.c
@@ -34,7 +34,8 @@ void testExportSMTLIB(MlirContext ctx) {
   MlirModule module =
       mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(testSMT));
 
-  MlirLogicalResult result = mlirExportSMTLIB(module, dumpCallback, NULL);
+  MlirLogicalResult result =
+      mlirTranslateModuleSMTLIB(module, dumpCallback, NULL, false, false);
   (void)result;
   assert(mlirLogicalResultIsSuccess(result));
 
diff --git a/mlir/test/python/dialects/smt.py b/mlir/test/python/dialects/smt.py
index 3e10f3ca35321..6f0cd8835b65b 100644
--- a/mlir/test/python/dialects/smt.py
+++ b/mlir/test/python/dialects/smt.py
@@ -1,16 +1,87 @@
-# REQUIRES: bindings_python
-# RUN: %PYTHON% %s | FileCheck %s
+# RUN: %PYTHON %s | FileCheck %s
 
-import mlir
+from mlir.dialects import smt, arith
+from mlir.ir import Context, Location, Module, InsertionPoint, F32Type
 
-from mlir.dialects import smt
-from mlir.ir import Context, Location, Module, InsertionPoint
 
-with Context() as ctx, Location.unknown():
-    m = Module.create()
-    with InsertionPoint(m.body):
-        true = smt.constant(True)
-        false = smt.constant(False)
+def run(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f(module)
+        print(module)
+        assert module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_smoke
+ at run
+def test_smoke(_module):
+    true = smt.constant(True)
+    false = smt.constant(False)
     # CHECK: smt.constant true
     # CHECK: smt.constant false
-    print(m)
+
+
+# CHECK-LABEL: TEST: test_types
+ at run
+def test_types(_module):
+    bool_t = smt.bool_t()
+    bitvector_t = smt.bv_t(5)
+    # CHECK: !smt.bool
+    print(bool_t)
+    # CHECK: !smt.bv<5>
+    print(bitvector_t)
+
+
+# CHECK-LABEL: TEST: test_solver_op
+ at run
+def test_solver_op(_module):
+    @smt.solver
+    def foo1():
+        true = smt.constant(True)
+        false = smt.constant(False)
+
+    # CHECK: smt.solver() : () -> () {
+    # CHECK:   %true = smt.constant true
+    # CHECK:   %false = smt.constant false
+    # CHECK: }
+
+    f32 = F32Type.get()
+
+    @smt.solver(results=[f32])
+    def foo2():
+        return arith.ConstantOp(f32, 1.0)
+
+    # CHECK: %{{.*}} = smt.solver() : () -> f32 {
+    # CHECK:   %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
+    # CHECK:   smt.yield %[[CST1]] : f32
+    # CHECK: }
+
+    two = arith.ConstantOp(f32, 2.0)
+    # CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+    print(two)
+
+    @smt.solver(inputs=[two], results=[f32])
+    def foo3(z: f32):
+        return z
+
+    # CHECK: %{{.*}} = smt.solver(%[[CST2]]) : (f32) -> f32 {
+    # CHECK: ^bb0(%[[ARG0:.*]]: f32):
+    # CHECK:   smt.yield %[[ARG0]] : f32
+    # CHECK: }
+
+
+# CHECK-LABEL: TEST: test_export_smtlib
+ at run
+def test_export_smtlib(module):
+    @smt.solver
+    def foo1():
+        true = smt.constant(True)
+        smt.assert_(true)
+
+    query = smt.export_smtlib(module.operation)
+    # CHECK: ; solver scope 0
+    # CHECK: (assert true)
+    # CHECK: (reset)
+    print(query)



More information about the Mlir-commits mailing list