[Mlir-commits] [mlir] [mlir][smt] add arith-to-smt (PR #131484)
Maksim Levental
llvmlistbot at llvm.org
Sat Mar 15 18:54:31 PDT 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/131484
None
>From 25f19c2208f3d5134346fad2effbd949c3807b89 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Sat, 15 Mar 2025 19:14:37 -0400
Subject: [PATCH 1/2] [mlir][smt] upstream SMT dialect
---
mlir/include/mlir/Dialect/CMakeLists.txt | 1 +
mlir/include/mlir/Dialect/SMT/CMakeLists.txt | 1 +
.../mlir/Dialect/SMT/IR/CMakeLists.txt | 16 +
mlir/include/mlir/Dialect/SMT/IR/SMT.td | 22 +
.../mlir/Dialect/SMT/IR/SMTArrayOps.td | 99 ++++
.../mlir/Dialect/SMT/IR/SMTAttributes.h | 29 +
.../mlir/Dialect/SMT/IR/SMTAttributes.td | 74 +++
.../mlir/Dialect/SMT/IR/SMTBitVectorOps.td | 255 +++++++++
mlir/include/mlir/Dialect/SMT/IR/SMTDialect.h | 20 +
.../include/mlir/Dialect/SMT/IR/SMTDialect.td | 30 ++
mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td | 137 +++++
mlir/include/mlir/Dialect/SMT/IR/SMTOps.h | 25 +
mlir/include/mlir/Dialect/SMT/IR/SMTOps.td | 477 +++++++++++++++++
mlir/include/mlir/Dialect/SMT/IR/SMTTypes.h | 30 ++
mlir/include/mlir/Dialect/SMT/IR/SMTTypes.td | 145 +++++
.../include/mlir/Dialect/SMT/IR/SMTVisitors.h | 201 +++++++
mlir/include/mlir/InitAllDialects.h | 2 +
mlir/lib/Dialect/CMakeLists.txt | 1 +
mlir/lib/Dialect/SMT/CMakeLists.txt | 1 +
mlir/lib/Dialect/SMT/IR/CMakeLists.txt | 27 +
mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp | 201 +++++++
mlir/lib/Dialect/SMT/IR/SMTDialect.cpp | 47 ++
mlir/lib/Dialect/SMT/IR/SMTOps.cpp | 472 +++++++++++++++++
mlir/lib/Dialect/SMT/IR/SMTTypes.cpp | 92 ++++
mlir/test/Dialect/SMT/array-errors.mlir | 13 +
mlir/test/Dialect/SMT/array.mlir | 14 +
mlir/test/Dialect/SMT/basic.mlir | 200 +++++++
mlir/test/Dialect/SMT/bitvector-errors.mlir | 112 ++++
mlir/test/Dialect/SMT/bitvectors.mlir | 81 +++
mlir/test/Dialect/SMT/core-errors.mlir | 497 ++++++++++++++++++
mlir/test/Dialect/SMT/cse-test.mlir | 12 +
mlir/test/Dialect/SMT/integers.mlir | 36 ++
32 files changed, 3370 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/SMT/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMT.td
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTArrayOps.td
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.h
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTBitVectorOps.td
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTDialect.h
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTDialect.td
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTOps.h
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTTypes.h
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTTypes.td
create mode 100644 mlir/include/mlir/Dialect/SMT/IR/SMTVisitors.h
create mode 100644 mlir/lib/Dialect/SMT/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/SMT/IR/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp
create mode 100644 mlir/lib/Dialect/SMT/IR/SMTDialect.cpp
create mode 100644 mlir/lib/Dialect/SMT/IR/SMTOps.cpp
create mode 100644 mlir/lib/Dialect/SMT/IR/SMTTypes.cpp
create mode 100644 mlir/test/Dialect/SMT/array-errors.mlir
create mode 100644 mlir/test/Dialect/SMT/array.mlir
create mode 100644 mlir/test/Dialect/SMT/basic.mlir
create mode 100644 mlir/test/Dialect/SMT/bitvector-errors.mlir
create mode 100644 mlir/test/Dialect/SMT/bitvectors.mlir
create mode 100644 mlir/test/Dialect/SMT/core-errors.mlir
create mode 100644 mlir/test/Dialect/SMT/cse-test.mlir
create mode 100644 mlir/test/Dialect/SMT/integers.mlir
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index f710235197334..9d1a840d6644b 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -33,6 +33,7 @@ add_subdirectory(Ptr)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(Shape)
+add_subdirectory(SMT)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
add_subdirectory(Tensor)
diff --git a/mlir/include/mlir/Dialect/SMT/CMakeLists.txt b/mlir/include/mlir/Dialect/SMT/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..bd743ed510a9e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect(SMT smt)
+add_mlir_doc(SMT SMT Dialects/SMTOps -gen-op-doc)
+# TODO(maX)
+#add_mlir_doc(SMT SMT Dialects/SMTTypes -gen-typedef-doc -dialect smt)
+
+set(LLVM_TARGET_DEFINITIONS SMT.td)
+
+mlir_tablegen(SMTAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(SMTAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRSMTAttrIncGen)
+add_dependencies(mlir-headers MLIRSMTAttrIncGen)
+
+mlir_tablegen(SMTEnums.h.inc -gen-enum-decls)
+mlir_tablegen(SMTEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRSMTEnumsIncGen)
+add_dependencies(mlir-headers MLIRSMTEnumsIncGen)
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMT.td b/mlir/include/mlir/Dialect/SMT/IR/SMT.td
new file mode 100644
index 0000000000000..dd7bd033c9fa5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMT.td
@@ -0,0 +1,22 @@
+//===- SMT.td - SMT dialect definition ---------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMT_TD
+#define MLIR_DIALECT_SMT_SMT_TD
+
+include "mlir/IR/OpBase.td"
+
+include "mlir/Dialect/SMT/IR/SMTAttributes.td"
+include "mlir/Dialect/SMT/IR/SMTDialect.td"
+include "mlir/Dialect/SMT/IR/SMTTypes.td"
+include "mlir/Dialect/SMT/IR/SMTOps.td"
+include "mlir/Dialect/SMT/IR/SMTArrayOps.td"
+include "mlir/Dialect/SMT/IR/SMTBitVectorOps.td"
+include "mlir/Dialect/SMT/IR/SMTIntOps.td"
+
+#endif // MLIR_DIALECT_SMT_SMT_TD
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTArrayOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTArrayOps.td
new file mode 100644
index 0000000000000..05b5398b6a7f9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTArrayOps.td
@@ -0,0 +1,99 @@
+//===- SMTArrayOps.td - SMT array operations ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTARRAYOPS_TD
+#define MLIR_DIALECT_SMT_SMTARRAYOPS_TD
+
+include "mlir/Dialect/SMT/IR/SMTDialect.td"
+include "mlir/Dialect/SMT/IR/SMTAttributes.td"
+include "mlir/Dialect/SMT/IR/SMTTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+class SMTArrayOp<string mnemonic, list<Trait> traits = []> :
+ SMTOp<"array." # mnemonic, traits>;
+
+def ArrayStoreOp : SMTArrayOp<"store", [
+ Pure,
+ TypesMatchWith<"summary", "array", "index",
+ "cast<ArrayType>($_self).getDomainType()">,
+ TypesMatchWith<"summary", "array", "value",
+ "cast<ArrayType>($_self).getRangeType()">,
+ AllTypesMatch<["array", "result"]>,
+]> {
+ let summary = "stores a value at a given index and returns the new array";
+ let description = [{
+ This operation returns a new array which is the same as the 'array' operand
+ except that the value at the given 'index' is changed to the given 'value'.
+ The semantics are equivalent to the 'store' operator described in the
+ [SMT ArrayEx theory](https://smtlib.cs.uiowa.edu/Theories/ArraysEx.smt2) of
+ the SMT-LIB standard 2.6.
+ }];
+
+ let arguments = (ins ArrayType:$array, AnySMTType:$index, AnySMTType:$value);
+ let results = (outs ArrayType:$result);
+
+ let assemblyFormat = [{
+ $array `[` $index `]` `,` $value attr-dict `:` qualified(type($array))
+ }];
+}
+
+def ArraySelectOp : SMTArrayOp<"select", [
+ Pure,
+ TypesMatchWith<"summary", "array", "index",
+ "cast<ArrayType>($_self).getDomainType()">,
+ TypesMatchWith<"summary", "array", "result",
+ "cast<ArrayType>($_self).getRangeType()">,
+]> {
+ let summary = "get the value stored in the array at the given index";
+ let description = [{
+ This operation is retuns the value stored in the given array at the given
+ index. The semantics are equivalent to the `select` operator defined in the
+ [SMT ArrayEx theory](https://smtlib.cs.uiowa.edu/Theories/ArraysEx.smt2) of
+ the SMT-LIB standard 2.6.
+ }];
+
+ let arguments = (ins ArrayType:$array, AnySMTType:$index);
+ let results = (outs AnySMTType:$result);
+
+ let assemblyFormat = [{
+ $array `[` $index `]` attr-dict `:` qualified(type($array))
+ }];
+}
+
+def ArrayBroadcastOp : SMTArrayOp<"broadcast", [
+ Pure,
+ TypesMatchWith<"summary", "result", "value",
+ "cast<ArrayType>($_self).getRangeType()">,
+]> {
+ let summary = "construct an array with the given value stored at every index";
+ let description = [{
+ This operation represents a broadcast of the 'value' operand to all indices
+ of the array. It is equivalent to
+ ```
+ %0 = smt.declare "array" : !smt.array<[!smt.int -> !smt.bool]>
+ %1 = smt.forall ["idx"] {
+ ^bb0(%idx: !smt.int):
+ %2 = smt.array.select %0[%idx] : !smt.array<[!smt.int -> !smt.bool]>
+ %3 = smt.eq %value, %2 : !smt.bool
+ smt.yield %3 : !smt.bool
+ }
+ smt.assert %1
+ // return %0
+ ```
+
+ In SMT-LIB, this is frequently written as
+ `((as const (Array Int Bool)) value)`.
+ }];
+
+ let arguments = (ins AnySMTType:$value);
+ let results = (outs ArrayType:$result);
+
+ let assemblyFormat = "$value attr-dict `:` qualified(type($result))";
+}
+
+#endif // MLIR_DIALECT_SMT_SMTARRAYOPS_TD
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.h b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.h
new file mode 100644
index 0000000000000..590364d572699
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.h
@@ -0,0 +1,29 @@
+//===- SMTAttributes.h - Declare SMT dialect attributes ----------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTATTRIBUTES_H
+#define MLIR_DIALECT_SMT_SMTATTRIBUTES_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir {
+namespace smt {
+namespace detail {
+
+struct BitVectorAttrStorage;
+
+} // namespace detail
+} // namespace smt
+} // namespace mlir
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/SMT/IR/SMTAttributes.h.inc"
+
+#endif // MLIR_DIALECT_SMT_SMTATTRIBUTES_H
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td
new file mode 100644
index 0000000000000..4231363fdf05b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td
@@ -0,0 +1,74 @@
+//===- SMTAttributes.td - Attributes for SMT dialect -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines SMT dialect specific attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTATTRIBUTES_TD
+#define MLIR_DIALECT_SMT_SMTATTRIBUTES_TD
+
+include "mlir/Dialect/SMT/IR/SMTDialect.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
+
+def BitVectorAttr : AttrDef<SMTDialect, "BitVector", [
+ DeclareAttrInterfaceMethods<TypedAttrInterface>
+]> {
+ let mnemonic = "bv";
+ let description = [{
+ This attribute represents a constant value of the `(_ BitVec width)` sort as
+ described in the [SMT bit-vector
+ theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml).
+
+ The constant is as #bX (binary) or #xX (hexadecimal) in SMT-LIB
+ where X is the value in the corresponding format without any further
+ prefixing. Here, the bit-vector constant is given as a regular integer
+ literal and the associated bit-vector type indicating the bit-width.
+
+ Examples:
+ ```mlir
+ #smt.bv<5> : !smt.bv<4>
+ #smt.bv<92> : !smt.bv<8>
+ ```
+
+ The explicit type-suffix is mandatory to uniquely represent the attribute,
+ i.e., this attribute should always be used in the extended form (using the
+ `quantified` keyword in the operation assembly format string).
+
+ The bit-width must be greater than zero (i.e., at least one digit has to be
+ present).
+ }];
+
+ let parameters = (ins "llvm::APInt":$value);
+
+ let hasCustomAssemblyFormat = true;
+ let genVerifyDecl = true;
+
+ // We need to manually define the storage class because the generated one is
+ // buggy (because the APInt asserts matching bitwidth in the `==` operator and
+ // the generated storage uses that directly.
+ // Alternatively: add a type parameter to redundantly store the bitwidth of
+ // of the attribute type, it it's in the order before the 'value' it will be
+ // checked before the APInt equality (this is the reason it works for the
+ // builtin integer attribute), but would be more fragile (and we'd store
+ // duplicate data).
+ let genStorageClass = false;
+
+ let builders = [
+ AttrBuilder<(ins "llvm::StringRef":$value)>,
+ AttrBuilder<(ins "uint64_t":$value, "unsigned":$width)>,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return the bit-vector constant as a SMT-LIB formatted string.
+ std::string getValueAsString(bool prefix = true) const;
+ }];
+}
+
+#endif // MLIR_DIALECT_SMT_SMTATTRIBUTES_TD
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTBitVectorOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTBitVectorOps.td
new file mode 100644
index 0000000000000..b6ca34e142d82
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTBitVectorOps.td
@@ -0,0 +1,255 @@
+//===- SMTBitVectorOps.td - SMT bit-vector dialect ops -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTBITVECTOROPS_TD
+#define MLIR_DIALECT_SMT_SMTBITVECTOROPS_TD
+
+include "mlir/Dialect/SMT/IR/SMTDialect.td"
+include "mlir/Dialect/SMT/IR/SMTAttributes.td"
+include "mlir/Dialect/SMT/IR/SMTTypes.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+class SMTBVOp<string mnemonic, list<Trait> traits = []> :
+ Op<SMTDialect, "bv." # mnemonic, traits>;
+
+def BVConstantOp : SMTBVOp<"constant", [
+ Pure,
+ ConstantLike,
+ FirstAttrDerivedResultType,
+ DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "produce a constant bit-vector";
+ let description = [{
+ This operation produces an SSA value equal to the bit-vector constant
+ specified by the 'value' attribute.
+ Refer to the `BitVectorAttr` documentation for more information about
+ the semantics of bit-vector constants, their format, and associated sort.
+ The result type always matches the attribute's type.
+
+ Examples:
+ ```mlir
+ %c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8>
+ %c5_bv4 = smt.bv.constant #smt.bv<5> : !smt.bv<4>
+ ```
+ }];
+
+ let arguments = (ins BitVectorAttr:$value);
+ let results = (outs BitVectorType:$result);
+
+ let assemblyFormat = "qualified($value) attr-dict";
+
+ let builders = [
+ OpBuilder<(ins "const llvm::APInt &":$value), [{
+ build($_builder, $_state,
+ BitVectorAttr::get($_builder.getContext(), value));
+ }]>,
+ OpBuilder<(ins "uint64_t":$value, "unsigned":$width), [{
+ build($_builder, $_state,
+ BitVectorAttr::get($_builder.getContext(), value, width));
+ }]>,
+ ];
+
+ let hasFolder = true;
+}
+
+class BVArithmeticOrBitwiseOp<string mnemonic, string desc> :
+ SMTBVOp<mnemonic, [Pure, SameOperandsAndResultType]> {
+ let summary = "equivalent to bv" # mnemonic # " in SMT-LIB";
+ let description = "This operation performs " # desc # [{. The semantics are
+ equivalent to the `bv}] # mnemonic # [{` operator defined in the SMT-LIB 2.6
+ standard. More precisely in the [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2)
+ and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2)
+ describing closed quantifier-free formulas over the theory of fixed-size
+ bit-vectors.
+ }];
+
+ let results = (outs BitVectorType:$result);
+}
+
+class BinaryBVOp<string mnemonic, string desc> :
+ BVArithmeticOrBitwiseOp<mnemonic, desc> {
+ let arguments = (ins BitVectorType:$lhs, BitVectorType:$rhs);
+ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($result))";
+}
+
+class UnaryBVOp<string mnemonic, string desc> :
+ BVArithmeticOrBitwiseOp<mnemonic, desc> {
+ let arguments = (ins BitVectorType:$input);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($result))";
+}
+
+def BVNotOp : UnaryBVOp<"not", "bitwise negation">;
+def BVNegOp : UnaryBVOp<"neg", "two's complement unary minus">;
+
+def BVAndOp : BinaryBVOp<"and", "bitwise AND">;
+def BVOrOp : BinaryBVOp<"or", "bitwise OR">;
+def BVXOrOp : BinaryBVOp<"xor", "bitwise exclusive OR">;
+
+def BVAddOp : BinaryBVOp<"add", "addition">;
+def BVMulOp : BinaryBVOp<"mul", "multiplication">;
+def BVUDivOp : BinaryBVOp<"udiv", "unsigned division (rounded towards zero)">;
+def BVSDivOp : BinaryBVOp<"sdiv", "two's complement signed division">;
+def BVURemOp : BinaryBVOp<"urem", "unsigned remainder">;
+def BVSRemOp : BinaryBVOp<"srem",
+ "two's complement signed remainder (sign follows dividend)">;
+def BVSModOp : BinaryBVOp<"smod",
+ "two's complement signed remainder (sign follows divisor)">;
+def BVShlOp : BinaryBVOp<"shl", "shift left">;
+def BVLShrOp : BinaryBVOp<"lshr", "logical shift right">;
+def BVAShrOp : BinaryBVOp<"ashr", "arithmetic shift right">;
+
+def PredicateSLT : I64EnumAttrCase<"slt", 0>;
+def PredicateSLE : I64EnumAttrCase<"sle", 1>;
+def PredicateSGT : I64EnumAttrCase<"sgt", 2>;
+def PredicateSGE : I64EnumAttrCase<"sge", 3>;
+def PredicateULT : I64EnumAttrCase<"ult", 4>;
+def PredicateULE : I64EnumAttrCase<"ule", 5>;
+def PredicateUGT : I64EnumAttrCase<"ugt", 6>;
+def PredicateUGE : I64EnumAttrCase<"uge", 7>;
+let cppNamespace = "mlir::smt" in
+def BVCmpPredicate : I64EnumAttr<
+ "BVCmpPredicate",
+ "smt bit-vector comparison predicate",
+ [PredicateSLT, PredicateSLE, PredicateSGT, PredicateSGE,
+ PredicateULT, PredicateULE, PredicateUGT, PredicateUGE]>;
+
+def BVCmpOp : SMTBVOp<"cmp", [Pure, SameTypeOperands]> {
+ let summary = "compare bit-vectors interpreted as signed or unsigned";
+ let description = [{
+ This operation compares bit-vector values, interpreting them as signed or
+ unsigned values depending on the predicate. The semantics are equivalent to
+ the `bvslt`, `bvsle`, `bvsgt`, `bvsge`, `bvult`, `bvule`, `bvugt`, or
+ `bvuge` operator defined in the SMT-LIB 2.6 standard depending on the
+ specified predicate. More precisely in the
+ [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2)
+ and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2)
+ describing closed quantifier-free formulas over the theory of fixed-size
+ bit-vectors.
+ }];
+
+ let arguments = (ins BVCmpPredicate:$pred,
+ BitVectorType:$lhs,
+ BitVectorType:$rhs);
+ let results = (outs BoolType:$result);
+
+ let assemblyFormat = [{
+ $pred $lhs `,` $rhs attr-dict `:` qualified(type($lhs))
+ }];
+}
+
+def ConcatOp : SMTBVOp<"concat", [
+ Pure,
+ DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>
+]> {
+ let summary = "bit-vector concatenation";
+ let description = [{
+ This operation concatenates bit-vector values with semantics equivalent to
+ the `concat` operator defined in the SMT-LIB 2.6 standard. More precisely in
+ the [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2)
+ and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2)
+ describing closed quantifier-free formulas over the theory of fixed-size
+ bit-vectors.
+
+ Note that the following equivalences hold:
+ * `smt.bv.concat %a, %b : !smt.bv<4>, !smt.bv<4>` is equivalent to
+ `(concat a b)` in SMT-LIB
+ * `(= (concat #xf #x0) #xf0)`
+ }];
+
+ let arguments = (ins BitVectorType:$lhs, BitVectorType:$rhs);
+ let results = (outs BitVectorType:$result);
+
+ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type(operands))";
+}
+
+def ExtractOp : SMTBVOp<"extract", [Pure]> {
+ let summary = "bit-vector extraction";
+ let description = [{
+ This operation extracts the range of bits starting at the 'lowBit' index
+ (inclusive) up to the 'lowBit' + result-width index (exclusive). The
+ semantics are equivalent to the `extract` operator defined in the SMT-LIB
+ 2.6 standard. More precisely in the
+ [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2)
+ and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2)
+ describing closed quantifier-free formulas over the theory of fixed-size
+ bit-vectors.
+
+ Note that `smt.bv.extract %bv from 2 : (!smt.bv<32>) -> !smt.bv<16>` is
+ equivalent to `((_ extract 17 2) bv)`, i.e., the SMT-LIB operator takes the
+ low and high indices where both are inclusive. The following equivalence
+ holds: `(= ((_ extract 3 0) #x0f) #xf)`
+ }];
+
+ let arguments = (ins I32Attr:$lowBit, BitVectorType:$input);
+ let results = (outs BitVectorType:$result);
+
+ let assemblyFormat = [{
+ $input `from` $lowBit attr-dict `:` functional-type($input, $result)
+ }];
+
+ let hasVerifier = true;
+}
+
+def RepeatOp : SMTBVOp<"repeat", [Pure]> {
+ let summary = "repeated bit-vector concatenation of one value";
+ let description = [{
+ This operation is a shorthand for repeated concatenation of the same
+ bit-vector value, i.e.,
+ ```mlir
+ smt.bv.repeat 5 times %a : !smt.bv<4>
+ // is the same as
+ %0 = smt.bv.repeat 4 times %a : !smt.bv<4>
+ smt.bv.concat %a, %0 : !smt.bv<4>, !smt.bv<16>
+ // or also
+ %0 = smt.bv.repeat 4 times %a : !smt.bv<4>
+ smt.bv.concat %0, %a : !smt.bv<16>, !smt.bv<4>
+ ```
+
+ The semantics are equivalent to the `repeat` operator defined in the SMT-LIB
+ 2.6 standard. More precisely in the
+ [theory of FixedSizeBitVectors](https://smtlib.cs.uiowa.edu/Theories/FixedSizeBitVectors.smt2)
+ and the [QF_BV logic](https://smtlib.cs.uiowa.edu/Logics/QF_BV.smt2)
+ describing closed quantifier-free formulas over the theory of fixed-size
+ bit-vectors.
+ }];
+
+ let arguments = (ins BitVectorType:$input);
+ let results = (outs BitVectorType:$result);
+
+ let hasCustomAssemblyFormat = true;
+ let hasVerifier = true;
+
+ let builders = [
+ OpBuilder<(ins "unsigned":$count, "mlir::Value":$input)>,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Get the number of times the input operand is repeated.
+ unsigned getCount();
+ }];
+}
+
+def BV2IntOp : SMTOp<"bv2int", [Pure]> {
+ let summary = "Convert an SMT bit-vector to an SMT integer.";
+ let description = [{
+ Create an integer from the bit-vector argument `input`. If `is_signed` is
+ present, the bit-vector is treated as two's complement signed. Otherwise,
+ it is treated as an unsigned integer in the range [0..2^N-1], where N is
+ the number of bits in `input`.
+ }];
+ let arguments = (ins BitVectorType:$input, UnitAttr:$is_signed);
+ let results = (outs IntType:$result);
+ let assemblyFormat = [{$input (`signed` $is_signed^)? attr-dict `:`
+ qualified(type($input))}];
+}
+
+#endif // MLIR_DIALECT_SMT_SMTBITVECTOROPS_TD
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.h b/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.h
new file mode 100644
index 0000000000000..e808583a9e593
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.h
@@ -0,0 +1,20 @@
+//===- SMTDialect.h - SMT dialect definition --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTDIALECT_H
+#define MLIR_DIALECT_SMT_SMTDIALECT_H
+
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Support/LLVM.h"
+
+// Pull in the dialect definition.
+#include "mlir/Dialect/SMT/IR/SMTDialect.h.inc"
+#include "mlir/Dialect/SMT/IR/SMTEnums.h.inc"
+
+#endif // MLIR_DIALECT_SMT_SMTDIALECT_H
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.td b/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.td
new file mode 100644
index 0000000000000..4b74187e85b87
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTDialect.td
@@ -0,0 +1,30 @@
+//===- SMTDialect.td - SMT dialect definition --------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTDIALECT_TD
+#define MLIR_DIALECT_SMT_SMTDIALECT_TD
+
+include "mlir/IR/DialectBase.td"
+
+def SMTDialect : Dialect {
+ let name = "smt";
+ let summary = "a dialect that models satisfiability modulo theories";
+ let cppNamespace = "mlir::smt";
+
+ let useDefaultAttributePrinterParser = 1;
+ let useDefaultTypePrinterParser = 1;
+
+ let hasConstantMaterializer = 1;
+
+ let extraClassDeclaration = [{
+ void registerAttributes();
+ void registerTypes();
+ }];
+}
+
+#endif // MLIR_DIALECT_SMT_SMTDIALECT_TD
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td
new file mode 100644
index 0000000000000..6606c9608ef55
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTIntOps.td
@@ -0,0 +1,137 @@
+//===- SMTIntOps.td - SMT dialect int theory operations ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTINTOPS_TD
+#define MLIR_DIALECT_SMT_SMTINTOPS_TD
+
+include "mlir/Dialect/SMT/IR/SMTDialect.td"
+include "mlir/Dialect/SMT/IR/SMTAttributes.td"
+include "mlir/Dialect/SMT/IR/SMTTypes.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+class SMTIntOp<string mnemonic, list<Trait> traits = []> :
+ SMTOp<"int." # mnemonic, traits>;
+
+def IntConstantOp : SMTIntOp<"constant", [
+ Pure,
+ ConstantLike,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = "produce a constant (infinite-precision) integer";
+ let description = [{
+ This operation represents (infinite-precision) integer literals of the `Int`
+ sort. The set of values for the sort `Int` consists of all numerals and
+ all terms of the form `-n`where n is a numeral other than 0. For more
+ information refer to the
+ [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the
+ SMT-LIB 2.6 standard.
+ }];
+
+ let arguments = (ins APIntAttr:$value);
+ let results = (outs IntType:$result);
+
+ let hasCustomAssemblyFormat = true;
+ let hasFolder = true;
+}
+
+class VariadicIntOp<string mnemonic> : SMTIntOp<mnemonic, [Pure, Commutative]> {
+ let description = [{
+ This operation represents (infinite-precision) }] # summary # [{.
+ The semantics are equivalent to the corresponding operator described in
+ the [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the
+ SMT-LIB 2.6 standard.
+ }];
+
+ let arguments = (ins Variadic<IntType>:$inputs);
+ let results = (outs IntType:$result);
+ let assemblyFormat = "$inputs attr-dict";
+
+ let builders = [
+ OpBuilder<(ins "mlir::ValueRange":$inputs), [{
+ build($_builder, $_state, $_builder.getType<smt::IntType>(), inputs);
+ }]>,
+ ];
+}
+
+class BinaryIntOp<string mnemonic> : SMTIntOp<mnemonic, [Pure]> {
+ let description = [{
+ This operation represents (infinite-precision) }] # summary # [{.
+ The semantics are equivalent to the corresponding operator described in
+ the [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the
+ SMT-LIB 2.6 standard.
+ }];
+
+ let arguments = (ins IntType:$lhs, IntType:$rhs);
+ let results = (outs IntType:$result);
+ let assemblyFormat = "$lhs `,` $rhs attr-dict";
+}
+
+def IntAbsOp : SMTIntOp<"abs", [Pure]> {
+ let summary = "the absolute value of an Int";
+ let description = [{
+ This operation represents the absolute value function for the `Int` sort.
+ The semantics are equivalent to the `abs` operator as described in the
+ [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the
+ SMT-LIB 2.6 standard.
+ }];
+
+ let arguments = (ins IntType:$input);
+ let results = (outs IntType:$result);
+ let assemblyFormat = "$input attr-dict";
+}
+
+def IntAddOp : VariadicIntOp<"add"> { let summary = "integer addition"; }
+def IntMulOp : VariadicIntOp<"mul"> { let summary = "integer multiplication"; }
+def IntSubOp : BinaryIntOp<"sub"> { let summary = "integer subtraction"; }
+def IntDivOp : BinaryIntOp<"div"> { let summary = "integer division"; }
+def IntModOp : BinaryIntOp<"mod"> { let summary = "integer remainder"; }
+
+def IntPredicateLT : I64EnumAttrCase<"lt", 0>;
+def IntPredicateLE : I64EnumAttrCase<"le", 1>;
+def IntPredicateGT : I64EnumAttrCase<"gt", 2>;
+def IntPredicateGE : I64EnumAttrCase<"ge", 3>;
+let cppNamespace = "mlir::smt" in
+def IntPredicate : I64EnumAttr<
+ "IntPredicate",
+ "smt comparison predicate for integers",
+ [IntPredicateLT, IntPredicateLE, IntPredicateGT, IntPredicateGE]>;
+
+def IntCmpOp : SMTIntOp<"cmp", [Pure]> {
+ let summary = "integer comparison";
+ let description = [{
+ This operation represents the comparison of (infinite-precision) integers.
+ The semantics are equivalent to the `<= (le)`, `< (lt)`, `>= (ge)`, or
+ `> (gt)` operator depending on the predicate (indicated in parentheses) as
+ described in the
+ [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the
+ SMT-LIB 2.6 standard.
+ }];
+
+ let arguments = (ins IntPredicate:$pred, IntType:$lhs, IntType:$rhs);
+ let results = (outs BoolType:$result);
+ let assemblyFormat = "$pred $lhs `,` $rhs attr-dict";
+}
+
+def Int2BVOp : SMTOp<"int2bv", [Pure]> {
+ let summary = "Convert an integer to an inferred-width bitvector.";
+ let description = [{
+ Designed to lower directly to an operation of the same name in Z3. The Z3
+ C API describes the semantics as follows:
+ Create an n bit bit-vector from the integer argument t1.
+ The resulting bit-vector has n bits, where the i'th bit (counting from 0
+ to n-1) is 1 if (t1 div 2^i) mod 2 is 1.
+ The node t1 must have integer sort.
+ }];
+ let arguments = (ins IntType:$input);
+ let results = (outs BitVectorType:$result);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($result))";
+}
+
+#endif // MLIR_DIALECT_SMT_SMTINTOPS_TD
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.h b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.h
new file mode 100644
index 0000000000000..859566ec6dbdb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.h
@@ -0,0 +1,25 @@
+//===- SMTOps.h - SMT dialect operations ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTOPS_H
+#define MLIR_DIALECT_SMT_SMTOPS_H
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/SMT/IR/SMTAttributes.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/SMT/IR/SMTTypes.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/SMT/IR/SMT.h.inc"
+
+#endif // MLIR_DIALECT_SMT_SMTOPS_H
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
new file mode 100644
index 0000000000000..18a1483f1dab1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -0,0 +1,477 @@
+//===- SMTOps.td - SMT dialect operations ------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTOPS_TD
+#define MLIR_DIALECT_SMT_SMTOPS_TD
+
+include "mlir/Dialect/SMT/IR/SMTDialect.td"
+include "mlir/Dialect/SMT/IR/SMTAttributes.td"
+include "mlir/Dialect/SMT/IR/SMTTypes.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+
+class SMTOp<string mnemonic, list<Trait> traits = []> :
+ Op<SMTDialect, mnemonic, traits>;
+
+def DeclareFunOp : SMTOp<"declare_fun", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "declare a symbolic value of a given sort";
+ let description = [{
+ This operation declares a symbolic value just as the `declare-const` and
+ `declare-func` statements in SMT-LIB 2.6. The result type determines the SMT
+ sort of the symbolic value. The returned value can then be used to refer to
+ the symbolic value instead of using the identifier like in SMT-LIB.
+
+ The optionally provided string will be used as a prefix for the newly
+ generated identifier (useful for easier readability when exporting to
+ SMT-LIB). Each `declare` will always provide a unique new symbolic value
+ even if the identifier strings are the same.
+
+ Note that there does not exist a separate operation equivalent to
+ SMT-LIBs `define-fun` since
+ ```
+ (define-fun f (a Int) Int (-a))
+ ```
+ is only syntactic sugar for
+ ```
+ %f = smt.declare_fun : !smt.func<(!smt.int) !smt.int>
+ %0 = smt.forall {
+ ^bb0(%arg0: !smt.int):
+ %1 = smt.apply_func %f(%arg0) : !smt.func<(!smt.int) !smt.int>
+ %2 = smt.int.neg %arg0
+ %3 = smt.eq %1, %2 : !smt.int
+ smt.yield %3 : !smt.bool
+ }
+ smt.assert %0
+ ```
+
+ Note that this operation cannot be marked as Pure since two operations (even
+ with the same identifier string) could then be CSEd, leading to incorrect
+ behavior.
+ }];
+
+ let arguments = (ins OptionalAttr<StrAttr>:$namePrefix);
+ let results = (outs Res<AnySMTType, "a symbolic value", [MemAlloc]>:$result);
+
+ let assemblyFormat = [{
+ ($namePrefix^)? attr-dict `:` qualified(type($result))
+ }];
+
+ let builders = [
+ OpBuilder<(ins "mlir::Type":$type), [{
+ build($_builder, $_state, type, nullptr);
+ }]>
+ ];
+}
+
+def BoolConstantOp : SMTOp<"constant", [
+ Pure,
+ ConstantLike,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = "Produce a constant boolean";
+ let description = [{
+ Produces the constant expressions 'true' and 'false' as described in the
+ [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2) of the SMT-LIB
+ Standard 2.6.
+ }];
+
+ let arguments = (ins BoolAttr:$value);
+ let results = (outs BoolType:$result);
+ let assemblyFormat = "$value attr-dict";
+
+ let hasFolder = true;
+}
+
+def SolverOp : SMTOp<"solver", [
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"smt::YieldOp">,
+]> {
+ let summary = "create a solver instance within a lifespan";
+ let description = [{
+ This operation defines an SMT context with a solver instance. SMT operations
+ are only valid when being executed between the start and end of the region
+ of this operation. Any invocation outside is undefined. However, they do not
+ have to be direct children of this operation. For example, it is allowed to
+ have SMT operations in a `func.func` which is only called from within this
+ region. No SMT value may enter or exit the lifespan of this region (such
+ that no value created from another SMT context can be used in this scope and
+ the solver can deallocate all state required to keep track of SMT values at
+ the end).
+
+ As a result, the region is comparable to an entire SMT-LIB script, but
+ allows for concrete operations and control-flow. Concrete values may be
+ passed in and returned to influence the computations after the `smt.solver`
+ operation.
+
+ Example:
+ ```mlir
+ %0:2 = smt.solver (%in) {smt.some_attr} : (i8) -> (i8, i32) {
+ ^bb0(%arg0: i8):
+ %c = smt.declare_fun "c" : !smt.bool
+ smt.assert %c
+ %1 = smt.check sat {
+ %c1_i32 = arith.constant 1 : i32
+ smt.yield %c1_i32 : i32
+ } unknown {
+ %c0_i32 = arith.constant 0 : i32
+ smt.yield %c0_i32 : i32
+ } unsat {
+ %c-1_i32 = arith.constant -1 : i32
+ smt.yield %c-1_i32 : i32
+ } -> i32
+ smt.yield %arg0, %1 : i8, i32
+ }
+ ```
+
+ TODO: solver configuration attributes
+ }];
+
+ let arguments = (ins Variadic<AnyNonSMTType>:$inputs);
+ let regions = (region SizedRegion<1>:$bodyRegion);
+ let results = (outs Variadic<AnyNonSMTType>:$results);
+
+ let assemblyFormat = [{
+ `(` $inputs `)` attr-dict `:` functional-type($inputs, $results) $bodyRegion
+ }];
+
+ let hasRegionVerifier = true;
+}
+
+def SetLogicOp : SMTOp<"set_logic", [
+ HasParent<"smt::SolverOp">,
+]> {
+ let summary = "set the logic for the SMT solver";
+ let arguments = (ins StrAttr:$logic);
+ let assemblyFormat = "$logic attr-dict";
+}
+
+def AssertOp : SMTOp<"assert", []> {
+ let summary = "assert that a boolean expression holds";
+ let arguments = (ins BoolType:$input);
+ let assemblyFormat = "$input attr-dict";
+}
+
+def ResetOp : SMTOp<"reset", []> {
+ let summary = "reset the solver";
+ let assemblyFormat = "attr-dict";
+}
+
+def PushOp : SMTOp<"push", []> {
+ let summary = "push a given number of levels onto the assertion stack";
+ let arguments = (ins ConfinedAttr<I32Attr, [IntNonNegative]>:$count);
+ let assemblyFormat = "$count attr-dict";
+}
+
+def PopOp : SMTOp<"pop", []> {
+ let summary = "pop a given number of levels from the assertion stack";
+ let arguments = (ins ConfinedAttr<I32Attr, [IntNonNegative]>:$count);
+ let assemblyFormat = "$count attr-dict";
+}
+
+def CheckOp : SMTOp<"check", [
+ NoRegionArguments,
+ SingleBlockImplicitTerminator<"smt::YieldOp">,
+]> {
+ let summary = "check if the current set of assertions is satisfiable";
+ let description = [{
+ This operation checks if all the assertions in the solver defined by the
+ nearest ancestor operation of type `smt.solver` are consistent. The outcome
+ an be 'satisfiable', 'unknown', or 'unsatisfiable' and the corresponding
+ region will be executed. It is the corresponding construct to the
+ `check-sat` in SMT-LIB.
+
+ Example:
+ ```mlir
+ %0 = smt.check sat {
+ %c1_i32 = arith.constant 1 : i32
+ smt.yield %c1_i32 : i32
+ } unknown {
+ %c0_i32 = arith.constant 0 : i32
+ smt.yield %c0_i32 : i32
+ } unsat {
+ %c-1_i32 = arith.constant -1 : i32
+ smt.yield %c-1_i32 : i32
+ } -> i32
+ ```
+ }];
+
+ let regions = (region SizedRegion<1>:$satRegion,
+ SizedRegion<1>:$unknownRegion,
+ SizedRegion<1>:$unsatRegion);
+ let results = (outs Variadic<AnyType>:$results);
+
+ let assemblyFormat = [{
+ attr-dict `sat` $satRegion `unknown` $unknownRegion `unsat` $unsatRegion
+ (`->` qualified(type($results))^ )?
+ }];
+
+ let hasRegionVerifier = true;
+}
+
+def YieldOp : SMTOp<"yield", [
+ Pure,
+ Terminator,
+ ReturnLike,
+ ParentOneOf<["smt::SolverOp", "smt::CheckOp",
+ "smt::ForallOp", "smt::ExistsOp"]>,
+]> {
+ let summary = "terminator operation for various regions of SMT operations";
+ let arguments = (ins Variadic<AnyType>:$values);
+ let assemblyFormat = "($values^ `:` qualified(type($values)))? attr-dict";
+ let builders = [OpBuilder<(ins), [{
+ build($_builder, $_state, std::nullopt);
+ }]>];
+}
+
+def ApplyFuncOp : SMTOp<"apply_func", [
+ Pure,
+ TypesMatchWith<"summary", "func", "result",
+ "cast<SMTFuncType>($_self).getRangeType()">,
+ RangedTypesMatchWith<"summary", "func", "args",
+ "cast<SMTFuncType>($_self).getDomainTypes()">
+]> {
+ let summary = "apply a function";
+ let description = [{
+ This operation performs a function application as described in the
+ [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf).
+ It is part of the language itself rather than a theory or logic.
+ }];
+
+ let arguments = (ins SMTFuncType:$func,
+ Variadic<AnyNonFuncSMTType>:$args);
+ let results = (outs AnyNonFuncSMTType:$result);
+
+ let assemblyFormat = [{
+ $func `(` $args `)` attr-dict `:` qualified(type($func))
+ }];
+}
+
+def EqOp : SMTOp<"eq", [Pure, SameTypeOperands]> {
+ let summary = "returns true iff all operands are identical";
+ let description = [{
+ This operation compares the operands and returns true iff all operands are
+ identical. The semantics are equivalent to the `=` operator defined in the
+ SMT-LIB Standard 2.6 in the
+ [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2).
+
+ Any SMT sort/type is allowed for the operands and it supports a variadic
+ number of operands, but requires at least two. This is because the `=`
+ operator is annotated with `:chainable` which means that `= a b c d` is
+ equivalent to `and (= a b) (= b c) (= c d)` where `and` is annotated
+ `:left-assoc`, i.e., it can be further rewritten to
+ `and (and (= a b) (= b c)) (= c d)`.
+ }];
+
+ let arguments = (ins Variadic<AnyNonFuncSMTType>:$inputs);
+ let results = (outs BoolType:$result);
+
+ let builders = [
+ OpBuilder<(ins "mlir::Value":$lhs, "mlir::Value":$rhs), [{
+ build($_builder, $_state, ValueRange{lhs, rhs});
+ }]>
+ ];
+
+ let hasCustomAssemblyFormat = true;
+ let hasVerifier = true;
+}
+
+def DistinctOp : SMTOp<"distinct", [Pure, SameTypeOperands]> {
+ let summary = "returns true iff all operands are not identical to any other";
+ let description = [{
+ This operation compares the operands and returns true iff all operands are
+ not identical to any of the other operands. The semantics are equivalent to
+ the `distinct` operator defined in the SMT-LIB Standard 2.6 in the
+ [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2).
+
+ Any SMT sort/type is allowed for the operands and it supports a variadic
+ number of operands, but requires at least two. This is because the
+ `distinct` operator is annotated with `:pairwise` which means that
+ `distinct a b c d` is equivalent to
+ ```
+ and (distinct a b) (distinct a c) (distinct a d)
+ (distinct b c) (distinct b d)
+ (distinct c d)
+ ```
+ where `and` is annotated `:left-assoc`, i.e., it can be further rewritten to
+ ```
+ (and (and (and (and (and (distinct a b)
+ (distinct a c))
+ (distinct a d))
+ (distinct b c))
+ (distinct b d))
+ (distinct c d)
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyNonFuncSMTType>:$inputs);
+ let results = (outs BoolType:$result);
+
+ let builders = [
+ OpBuilder<(ins "mlir::Value":$lhs, "mlir::Value":$rhs), [{
+ build($_builder, $_state, ValueRange{lhs, rhs});
+ }]>
+ ];
+
+ let hasCustomAssemblyFormat = true;
+ let hasVerifier = true;
+}
+
+def IteOp : SMTOp<"ite", [
+ Pure,
+ AllTypesMatch<["thenValue", "elseValue", "result"]>
+]> {
+ let summary = "an if-then-else function";
+ let description = [{
+ This operation returns its second operand or its third operand depending on
+ whether its first operand is true or not. The semantics are equivalent to
+ the `ite` operator defined in the
+ [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2) of the SMT-LIB
+ 2.6 standard.
+ }];
+
+ let arguments = (ins BoolType:$cond,
+ AnySMTType:$thenValue,
+ AnySMTType:$elseValue);
+ let results = (outs AnySMTType:$result);
+
+ let assemblyFormat = [{
+ $cond `,` $thenValue `,` $elseValue attr-dict `:` qualified(type($result))
+ }];
+}
+
+def NotOp : SMTOp<"not", [Pure]> {
+ let summary = "a boolean negation";
+ let description = [{
+ This operation performs a boolean negation. The semantics are equivalent to
+ the 'not' operator in the
+ [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2) of the SMT-LIB
+ Standard 2.6.
+ }];
+
+ let arguments = (ins BoolType:$input);
+ let results = (outs BoolType:$result);
+ let assemblyFormat = "$input attr-dict";
+}
+
+class VariadicBoolOp<string mnemonic, string desc> : SMTOp<mnemonic, [Pure]> {
+ let summary = desc;
+ let description = "This operation performs " # desc # [{.
+ The semantics are equivalent to the '}] # mnemonic # [{' operator in the
+ [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2).
+ of the SMT-LIB Standard 2.6.
+
+ It supports a variadic number of operands, but requires at least two.
+ This is because the operator is annotated with the `:left-assoc` attribute
+ which means that `op a b c` is equivalent to `(op (op a b) c)`.
+ }];
+
+ let arguments = (ins Variadic<BoolType>:$inputs);
+ let results = (outs BoolType:$result);
+ let assemblyFormat = "$inputs attr-dict";
+
+ let builders = [
+ OpBuilder<(ins "mlir::Value":$lhs, "mlir::Value":$rhs), [{
+ build($_builder, $_state, ValueRange{lhs, rhs});
+ }]>
+ ];
+}
+
+def AndOp : VariadicBoolOp<"and", "a boolean conjunction">;
+def OrOp : VariadicBoolOp<"or", "a boolean disjunction">;
+def XOrOp : VariadicBoolOp<"xor", "a boolean exclusive OR">;
+
+def ImpliesOp : SMTOp<"implies", [Pure]> {
+ let summary = "boolean implication";
+ let description = [{
+ This operation performs a boolean implication. The semantics are equivalent
+ to the '=>' operator in the
+ [Core theory](https://smtlib.cs.uiowa.edu/Theories/Core.smt2) of the SMT-LIB
+ Standard 2.6.
+ }];
+
+ let arguments = (ins BoolType:$lhs, BoolType:$rhs);
+ let results = (outs BoolType:$result);
+ let assemblyFormat = "$lhs `,` $rhs attr-dict";
+}
+
+class QuantifierOp<string mnemonic> : SMTOp<mnemonic, [
+ RecursivelySpeculatable,
+ RecursiveMemoryEffects,
+ SingleBlockImplicitTerminator<"smt::YieldOp">,
+]> {
+ let description = [{
+ This operation represents the }] # summary # [{ as described in the
+ [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf).
+ It is part of the language itself rather than a theory or logic.
+
+ The operation specifies the name prefixes (as an optional attribute) and
+ types (as the types of the block arguments of the regions) of bound
+ variables that may be used in the 'body' of the operation. If a 'patterns'
+ region is specified, the block arguments must match the ones of the 'body'
+ region and (other than there) must be used at least once in the 'patterns'
+ region. It may also not contain any operations that bind variables, such as
+ quantifiers. While the 'body' region must always yield exactly one
+ `!smt.bool`-typed value, the 'patterns' region can yield an arbitrary number
+ (but at least one) of SMT values.
+
+ The bound variables can be any SMT type except of functions, since SMT only
+ supports first-order logic.
+
+ The 'no_patterns' attribute is only allowed when no 'patterns' region is
+ specified and forbids the solver to generate and use patterns for this
+ quantifier.
+
+ The 'weight' attribute indicates the importance of this quantifier being
+ instantiated compared to other quantifiers that may be present. The default
+ value is zero.
+
+ Both the 'no_patterns' and 'weight' attributes are annotations to the
+ quantifiers body term. Annotations and attributes are described in the
+ standard in sections 3.4, and 3.6 (specifically 3.6.5). SMT-LIB allows
+ adding custom attributes to provide solvers with additional metadata, e.g.,
+ hints such as above mentioned attributes. They are not part of the standard
+ themselves, but supported by common SMT solvers (e.g., Z3).
+ }];
+
+ let arguments = (ins DefaultValuedAttr<I32Attr, "0">:$weight,
+ UnitAttr:$noPattern,
+ OptionalAttr<StrArrayAttr>:$boundVarNames);
+ let regions = (region SizedRegion<1>:$body,
+ VariadicRegion<SizedRegion<1>>:$patterns);
+ let results = (outs BoolType:$result);
+
+ let builders = [
+ OpBuilder<(ins
+ "TypeRange":$boundVarTypes,
+ "function_ref<Value(OpBuilder &, Location, ValueRange)>":$bodyBuilder,
+ CArg<"std::optional<ArrayRef<StringRef>>", "std::nullopt">:$boundVarNames,
+ CArg<"function_ref<ValueRange(OpBuilder &, Location, ValueRange)>",
+ "{}">:$patternBuilder,
+ CArg<"uint32_t", "0">:$weight,
+ CArg<"bool", "false">:$noPattern)>
+ ];
+ let skipDefaultBuilders = true;
+
+ let assemblyFormat = [{
+ ($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)?
+ attr-dict-with-keyword $body (`patterns` $patterns^)?
+ }];
+
+ let hasVerifier = true;
+ let hasRegionVerifier = true;
+}
+
+def ForallOp : QuantifierOp<"forall"> { let summary = "forall quantifier"; }
+def ExistsOp : QuantifierOp<"exists"> { let summary = "exists quantifier"; }
+
+#endif // MLIR_DIALECT_SMT_SMTOPS_TD
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.h b/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.h
new file mode 100644
index 0000000000000..4db28f7a07a41
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.h
@@ -0,0 +1,30 @@
+//===- SMTTypes.h - SMT dialect types ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTTYPES_H
+#define MLIR_DIALECT_SMT_SMTTYPES_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/SMT/IR/SMTTypes.h.inc"
+
+namespace mlir {
+namespace smt {
+
+/// Returns whether the given type is an SMT value type.
+bool isAnySMTValueType(mlir::Type type);
+
+/// Returns whether the given type is an SMT value type (excluding functions).
+bool isAnyNonFuncSMTValueType(mlir::Type type);
+
+} // namespace smt
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SMT_SMTTYPES_H
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.td b/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.td
new file mode 100644
index 0000000000000..3032900b52178
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTTypes.td
@@ -0,0 +1,145 @@
+//===- SMTTypes.td - SMT dialect types ---------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTTYPES_TD
+#define MLIR_DIALECT_SMT_SMTTYPES_TD
+
+include "mlir/Dialect/SMT/IR/SMTDialect.td"
+include "mlir/IR/AttrTypeBase.td"
+
+class SMTTypeDef<string name> : TypeDef<SMTDialect, name> { }
+
+def BoolType : SMTTypeDef<"Bool"> {
+ let mnemonic = "bool";
+ let assemblyFormat = "";
+}
+
+def IntType : SMTTypeDef<"Int"> {
+ let mnemonic = "int";
+ let description = [{
+ This type represents the `Int` sort as described in the
+ [SMT Ints theory](https://smtlib.cs.uiowa.edu/Theories/Ints.smt2) of the
+ SMT-LIB 2.6 standard.
+ }];
+ let assemblyFormat = "";
+}
+
+def BitVectorType : SMTTypeDef<"BitVector"> {
+ let mnemonic = "bv";
+ let description = [{
+ This type represents the `(_ BitVec width)` sort as described in the
+ [SMT bit-vector
+ theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml).
+
+ The bit-width must be strictly greater than zero.
+ }];
+
+ let parameters = (ins "int64_t":$width);
+ let assemblyFormat = "`<` $width `>`";
+
+ let genVerifyDecl = true;
+}
+
+def ArrayType : SMTTypeDef<"Array"> {
+ let mnemonic = "array";
+ let description = [{
+ This type represents the `(Array X Y)` sort, where X and Y are any
+ sort/type, as described in the
+ [SMT ArrayEx theory](https://smtlib.cs.uiowa.edu/Theories/ArraysEx.smt2) of
+ the SMT-LIB standard 2.6.
+ }];
+
+ let parameters = (ins "mlir::Type":$domainType, "mlir::Type":$rangeType);
+ let assemblyFormat = "`<` `[` $domainType `->` $rangeType `]` `>`";
+
+ let genVerifyDecl = true;
+}
+
+def SMTFuncType : SMTTypeDef<"SMTFunc"> {
+ let mnemonic = "func";
+ let description = [{
+ This type represents the SMT function sort as described in the
+ [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf).
+ It is part of the language itself rather than a theory or logic.
+
+ A function in SMT can have an arbitrary domain size, but always has exactly
+ one range sort.
+
+ Since SMT only supports first-order logic, it is not possible to nest
+ function types.
+
+ Example: `!smt.func<(!smt.bool, !smt.int) !smt.bool>` is equivalent to
+ `((Bool Int) Bool)` in SMT-LIB.
+ }];
+
+ let parameters = (ins
+ ArrayRefParameter<"mlir::Type", "domain types">:$domainTypes,
+ "mlir::Type":$rangeType
+ );
+
+ // Note: We are not printing the parentheses when no domain type is present
+ // because the default MLIR parser thinks it is a builtin function type
+ // otherwise.
+ let assemblyFormat = "`<` `(` $domainTypes `)` ` ` $rangeType `>`";
+
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "llvm::ArrayRef<mlir::Type>":$domainTypes,
+ "mlir::Type":$rangeType), [{
+ return $_get(rangeType.getContext(), domainTypes, rangeType);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "mlir::Type":$rangeType), [{
+ return $_get(rangeType.getContext(),
+ llvm::ArrayRef<mlir::Type>{}, rangeType);
+ }]>
+ ];
+
+ let genVerifyDecl = true;
+}
+
+def SortType : SMTTypeDef<"Sort"> {
+ let mnemonic = "sort";
+ let description = [{
+ This type represents uninterpreted sorts. The usage of a type like
+ `!smt.sort<"sort_name"[!smt.bool, !smt.sort<"other_sort">]>` implies a
+ `declare-sort sort_name 2` and a `declare-sort other_sort 0` in SMT-LIB.
+ This type represents concrete use-sites of such declared sorts, in this
+ particular case it would be equivalent to `(sort_name Bool other_sort)` in
+ SMT-LIB. More details about the semantics can be found in the
+ [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf).
+ }];
+
+ let parameters = (ins
+ "mlir::StringAttr":$identifier,
+ OptionalArrayRefParameter<"mlir::Type", "sort parameters">:$sortParams
+ );
+
+ let assemblyFormat = "`<` $identifier (`[` $sortParams^ `]`)? `>`";
+
+ let builders = [
+ TypeBuilder<(ins "llvm::StringRef":$identifier,
+ "llvm::ArrayRef<mlir::Type>":$sortParams), [{
+ return $_get($_ctxt, mlir::StringAttr::get($_ctxt, identifier),
+ sortParams);
+ }]>,
+ TypeBuilder<(ins "llvm::StringRef":$identifier), [{
+ return $_get($_ctxt, mlir::StringAttr::get($_ctxt, identifier),
+ llvm::ArrayRef<mlir::Type>{});
+ }]>,
+ ];
+
+ let genVerifyDecl = true;
+}
+
+def AnySMTType : Type<CPred<"smt::isAnySMTValueType($_self)">,
+ "any SMT value type">;
+def AnyNonFuncSMTType : Type<CPred<"smt::isAnyNonFuncSMTValueType($_self)">,
+ "any non-function SMT value type">;
+def AnyNonSMTType : Type<Neg<AnySMTType.predicate>, "any non-smt type">;
+
+#endif // MLIR_DIALECT_SMT_SMTTYPES_TD
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTVisitors.h b/mlir/include/mlir/Dialect/SMT/IR/SMTVisitors.h
new file mode 100644
index 0000000000000..38fad21019158
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTVisitors.h
@@ -0,0 +1,201 @@
+//===- SMTVisitors.h - SMT Dialect Visitors ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines visitors that make it easier to work with the SMT IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SMT_SMTVISITORS_H
+#define MLIR_DIALECT_SMT_SMTVISITORS_H
+
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+namespace smt {
+
+/// This helps visit SMT nodes.
+template <typename ConcreteType, typename ResultType = void,
+ typename... ExtraArgs>
+class SMTOpVisitor {
+public:
+ ResultType dispatchSMTOpVisitor(Operation *op, ExtraArgs... args) {
+ auto *thisCast = static_cast<ConcreteType *>(this);
+ return TypeSwitch<Operation *, ResultType>(op)
+ .template Case<
+ // Constants
+ BoolConstantOp, IntConstantOp, BVConstantOp,
+ // Bit-vector arithmetic
+ BVNegOp, BVAddOp, BVMulOp, BVURemOp, BVSRemOp, BVSModOp, BVShlOp,
+ BVLShrOp, BVAShrOp, BVUDivOp, BVSDivOp,
+ // Bit-vector bitwise
+ BVNotOp, BVAndOp, BVOrOp, BVXOrOp,
+ // Other bit-vector ops
+ ConcatOp, ExtractOp, RepeatOp, BVCmpOp, BV2IntOp,
+ // Int arithmetic
+ IntAddOp, IntMulOp, IntSubOp, IntDivOp, IntModOp, IntCmpOp,
+ Int2BVOp,
+ // Core Ops
+ EqOp, DistinctOp, IteOp,
+ // Variable/symbol declaration
+ DeclareFunOp, ApplyFuncOp,
+ // solver interaction
+ SolverOp, AssertOp, ResetOp, PushOp, PopOp, CheckOp, SetLogicOp,
+ // Boolean logic
+ NotOp, AndOp, OrOp, XOrOp, ImpliesOp,
+ // Arrays
+ ArrayStoreOp, ArraySelectOp, ArrayBroadcastOp,
+ // Quantifiers
+ ForallOp, ExistsOp, YieldOp>([&](auto expr) -> ResultType {
+ return thisCast->visitSMTOp(expr, args...);
+ })
+ .Default([&](auto expr) -> ResultType {
+ return thisCast->visitInvalidSMTOp(op, args...);
+ });
+ }
+
+ /// This callback is invoked on any non-expression operations.
+ ResultType visitInvalidSMTOp(Operation *op, ExtraArgs... args) {
+ op->emitOpError("unknown SMT node");
+ abort();
+ }
+
+ /// This callback is invoked on any SMT operations that are not
+ /// handled by the concrete visitor.
+ ResultType visitUnhandledSMTOp(Operation *op, ExtraArgs... args) {
+ return ResultType();
+ }
+
+#define HANDLE(OPTYPE, OPKIND) \
+ ResultType visitSMTOp(OPTYPE op, ExtraArgs... args) { \
+ return static_cast<ConcreteType *>(this)->visit##OPKIND##SMTOp(op, \
+ args...); \
+ }
+
+ // Constants
+ HANDLE(BoolConstantOp, Unhandled);
+ HANDLE(IntConstantOp, Unhandled);
+ HANDLE(BVConstantOp, Unhandled);
+
+ // Bit-vector arithmetic
+ HANDLE(BVNegOp, Unhandled);
+ HANDLE(BVAddOp, Unhandled);
+ HANDLE(BVMulOp, Unhandled);
+ HANDLE(BVURemOp, Unhandled);
+ HANDLE(BVSRemOp, Unhandled);
+ HANDLE(BVSModOp, Unhandled);
+ HANDLE(BVShlOp, Unhandled);
+ HANDLE(BVLShrOp, Unhandled);
+ HANDLE(BVAShrOp, Unhandled);
+ HANDLE(BVUDivOp, Unhandled);
+ HANDLE(BVSDivOp, Unhandled);
+
+ // Bit-vector bitwise operations
+ HANDLE(BVNotOp, Unhandled);
+ HANDLE(BVAndOp, Unhandled);
+ HANDLE(BVOrOp, Unhandled);
+ HANDLE(BVXOrOp, Unhandled);
+
+ // Other bit-vector operations
+ HANDLE(ConcatOp, Unhandled);
+ HANDLE(ExtractOp, Unhandled);
+ HANDLE(RepeatOp, Unhandled);
+ HANDLE(BVCmpOp, Unhandled);
+ HANDLE(BV2IntOp, Unhandled);
+
+ // Int arithmetic
+ HANDLE(IntAddOp, Unhandled);
+ HANDLE(IntMulOp, Unhandled);
+ HANDLE(IntSubOp, Unhandled);
+ HANDLE(IntDivOp, Unhandled);
+ HANDLE(IntModOp, Unhandled);
+
+ HANDLE(IntCmpOp, Unhandled);
+ HANDLE(Int2BVOp, Unhandled);
+
+ HANDLE(EqOp, Unhandled);
+ HANDLE(DistinctOp, Unhandled);
+ HANDLE(IteOp, Unhandled);
+
+ HANDLE(DeclareFunOp, Unhandled);
+ HANDLE(ApplyFuncOp, Unhandled);
+
+ HANDLE(SolverOp, Unhandled);
+ HANDLE(AssertOp, Unhandled);
+ HANDLE(ResetOp, Unhandled);
+ HANDLE(PushOp, Unhandled);
+ HANDLE(PopOp, Unhandled);
+ HANDLE(CheckOp, Unhandled);
+ HANDLE(SetLogicOp, Unhandled);
+
+ // Boolean logic operations
+ HANDLE(NotOp, Unhandled);
+ HANDLE(AndOp, Unhandled);
+ HANDLE(OrOp, Unhandled);
+ HANDLE(XOrOp, Unhandled);
+ HANDLE(ImpliesOp, Unhandled);
+
+ // Array operations
+ HANDLE(ArrayStoreOp, Unhandled);
+ HANDLE(ArraySelectOp, Unhandled);
+ HANDLE(ArrayBroadcastOp, Unhandled);
+
+ // Quantifier operations
+ HANDLE(ForallOp, Unhandled);
+ HANDLE(ExistsOp, Unhandled);
+ HANDLE(YieldOp, Unhandled);
+
+#undef HANDLE
+};
+
+/// This helps visit SMT types.
+template <typename ConcreteType, typename ResultType = void,
+ typename... ExtraArgs>
+class SMTTypeVisitor {
+public:
+ ResultType dispatchSMTTypeVisitor(Type type, ExtraArgs... args) {
+ auto *thisCast = static_cast<ConcreteType *>(this);
+ return TypeSwitch<Type, ResultType>(type)
+ .template Case<BoolType, IntType, BitVectorType, ArrayType, SMTFuncType,
+ SortType>([&](auto expr) -> ResultType {
+ return thisCast->visitSMTType(expr, args...);
+ })
+ .Default([&](auto expr) -> ResultType {
+ return thisCast->visitInvalidSMTType(type, args...);
+ });
+ }
+
+ /// This callback is invoked on any non-expression types.
+ ResultType visitInvalidSMTType(Type type, ExtraArgs... args) { abort(); }
+
+ /// This callback is invoked on any SMT type that are not
+ /// handled by the concrete visitor.
+ ResultType visitUnhandledSMTType(Type type, ExtraArgs... args) {
+ return ResultType();
+ }
+
+#define HANDLE(TYPE, KIND) \
+ ResultType visitSMTType(TYPE op, ExtraArgs... args) { \
+ return static_cast<ConcreteType *>(this)->visit##KIND##SMTType(op, \
+ args...); \
+ }
+
+ HANDLE(BoolType, Unhandled);
+ HANDLE(IntegerType, Unhandled);
+ HANDLE(BitVectorType, Unhandled);
+ HANDLE(ArrayType, Unhandled);
+ HANDLE(SMTFuncType, Unhandled);
+ HANDLE(SortType, Unhandled);
+
+#undef HANDLE
+};
+
+} // namespace smt
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SMT_SMTVISITORS_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 33bc89279c08c..e83be7b40eded 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -73,6 +73,7 @@
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
@@ -143,6 +144,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
ROCDL::ROCDLDialect,
scf::SCFDialect,
shape::ShapeDialect,
+ smt::SMTDialect,
sparse_tensor::SparseTensorDialect,
spirv::SPIRVDialect,
tensor::TensorDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 80b0ef068d96d..a473f2ff317c9 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -33,6 +33,7 @@ add_subdirectory(Ptr)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(Shape)
+add_subdirectory(SMT)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
add_subdirectory(Tensor)
diff --git a/mlir/lib/Dialect/SMT/CMakeLists.txt b/mlir/lib/Dialect/SMT/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/SMT/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/SMT/IR/CMakeLists.txt b/mlir/lib/Dialect/SMT/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..e287613da9fd0
--- /dev/null
+++ b/mlir/lib/Dialect/SMT/IR/CMakeLists.txt
@@ -0,0 +1,27 @@
+add_mlir_dialect_library(MLIRSMT
+ SMTAttributes.cpp
+ SMTDialect.cpp
+ SMTOps.cpp
+ SMTTypes.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SMT
+
+ DEPENDS
+ MLIRSMTAttrIncGen
+ MLIRSMTEnumsIncGen
+ MLIRSMTIncGen
+
+ LINK_COMPONENTS
+ Support
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRInferTypeOpInterface
+ MLIRSideEffectInterfaces
+ MLIRControlFlowInterfaces
+)
+
+add_dependencies(mlir-headers
+ MLIRSMTIncGen
+)
diff --git a/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp b/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp
new file mode 100644
index 0000000000000..c28f3558a02d2
--- /dev/null
+++ b/mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp
@@ -0,0 +1,201 @@
+//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SMT/IR/SMTAttributes.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/SMT/IR/SMTTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Format.h"
+
+using namespace mlir;
+using namespace mlir::smt;
+
+//===----------------------------------------------------------------------===//
+// BitVectorAttr
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace smt {
+namespace detail {
+struct BitVectorAttrStorage : public mlir::AttributeStorage {
+ using KeyTy = APInt;
+ BitVectorAttrStorage(APInt value) : value(std::move(value)) {}
+
+ KeyTy getAsKey() const { return value; }
+
+ // NOTE: the implementation of this operator is the reason we need to define
+ // the storage manually. The auto-generated version would just do the direct
+ // equality check of the APInt, but that asserts the bitwidth of both to be
+ // the same, leading to a crash. This implementation, therefore, checks for
+ // matching bit-width beforehand.
+ bool operator==(const KeyTy &key) const {
+ return (value.getBitWidth() == key.getBitWidth() && value == key);
+ }
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_value(key);
+ }
+
+ static BitVectorAttrStorage *
+ construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key) {
+ return new (allocator.allocate<BitVectorAttrStorage>())
+ BitVectorAttrStorage(std::move(key));
+ }
+
+ APInt value;
+};
+} // namespace detail
+} // namespace smt
+} // namespace mlir
+
+APInt BitVectorAttr::getValue() const { return getImpl()->value; }
+
+LogicalResult BitVectorAttr::verify(
+ function_ref<InFlightDiagnostic()> emitError,
+ APInt value) { // NOLINT(performance-unnecessary-value-param)
+ if (value.getBitWidth() < 1)
+ return emitError() << "bit-width must be at least 1, but got "
+ << value.getBitWidth();
+ return success();
+}
+
+std::string BitVectorAttr::getValueAsString(bool prefix) const {
+ unsigned width = getValue().getBitWidth();
+ SmallVector<char> toPrint;
+ StringRef pref = prefix ? "#" : "";
+ if (width % 4 == 0) {
+ getValue().toString(toPrint, 16, false, false, false);
+ // APInt's 'toString' omits leading zeros. However, those are critical here
+ // because they determine the bit-width of the bit-vector.
+ SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
+ return (pref + "x" + Twine(leadingZeros) + toPrint).str();
+ }
+
+ getValue().toString(toPrint, 2, false, false, false);
+ // APInt's 'toString' omits leading zeros
+ SmallVector<char> leadingZeros(width - toPrint.size(), '0');
+ return (pref + "b" + Twine(leadingZeros) + toPrint).str();
+}
+
+/// Parse an SMT-LIB formatted bit-vector string.
+static FailureOr<APInt>
+parseBitVectorString(function_ref<InFlightDiagnostic()> emitError,
+ StringRef value) {
+ if (value[0] != '#')
+ return emitError() << "expected '#'";
+
+ if (value.size() < 3)
+ return emitError() << "expected at least one digit";
+
+ if (value[1] == 'b')
+ return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()),
+ 2);
+
+ if (value[1] == 'x')
+ return APInt((value.size() - 2) * 4,
+ std::string(value.begin() + 2, value.end()), 16);
+
+ return emitError() << "expected either 'b' or 'x'";
+}
+
+BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
+ auto maybeValue = parseBitVectorString(nullptr, value);
+
+ assert(succeeded(maybeValue) && "string must have SMT-LIB format");
+ return Base::get(context, *maybeValue);
+}
+
+BitVectorAttr
+BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+ MLIRContext *context, StringRef value) {
+ auto maybeValue = parseBitVectorString(emitError, value);
+ if (failed(maybeValue))
+ return {};
+
+ return Base::getChecked(emitError, context, *maybeValue);
+}
+
+BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value,
+ unsigned width) {
+ return Base::get(context, APInt(width, value));
+}
+
+BitVectorAttr
+BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+ MLIRContext *context, uint64_t value,
+ unsigned width) {
+ if (width < 64 && value >= (UINT64_C(1) << width)) {
+ emitError() << "value does not fit in a bit-vector of desired width";
+ return {};
+ }
+ return Base::getChecked(emitError, context, APInt(width, value));
+}
+
+Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
+ llvm::SMLoc loc = odsParser.getCurrentLocation();
+
+ APInt val;
+ if (odsParser.parseLess() || odsParser.parseInteger(val) ||
+ odsParser.parseGreater())
+ return {};
+
+ // Requires the use of `quantified(<attr>)` in operation assembly formats.
+ if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
+ odsParser.emitError(loc) << "explicit bit-vector type required";
+ return {};
+ }
+
+ unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
+
+ if (width > val.getBitWidth()) {
+ // sext is always safe here, even for unsigned values, because the
+ // parseOptionalInteger method will return something with a zero in the
+ // top bits if it is a positive number.
+ val = val.sext(width);
+ } else if (width < val.getBitWidth()) {
+ // The parser can return an unnecessarily wide result.
+ // This isn't a problem, but truncating off bits is bad.
+ unsigned neededBits =
+ val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
+ if (width < neededBits) {
+ odsParser.emitError(loc)
+ << "integer value out of range for given bit-vector type " << odsType;
+ return {};
+ }
+ val = val.trunc(width);
+ }
+
+ return BitVectorAttr::get(odsParser.getContext(), val);
+}
+
+void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
+ // This printer only works for the extended format where the MLIR
+ // infrastructure prints the type for us. This means, the attribute should
+ // never be used without `quantified` in an assembly format.
+ odsPrinter << "<" << getValue() << ">";
+}
+
+Type BitVectorAttr::getType() const {
+ return BitVectorType::get(getContext(), getValue().getBitWidth());
+}
+
+//===----------------------------------------------------------------------===//
+// ODS Boilerplate
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
+
+void SMTDialect::registerAttributes() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
+ >();
+}
diff --git a/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp
new file mode 100644
index 0000000000000..66eed861b2bb7
--- /dev/null
+++ b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp
@@ -0,0 +1,47 @@
+//===- SMTDialect.cpp - SMT dialect implementation ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/SMT/IR/SMTAttributes.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "mlir/Dialect/SMT/IR/SMTTypes.h"
+
+using namespace mlir;
+using namespace smt;
+
+void SMTDialect::initialize() {
+ registerAttributes();
+ registerTypes();
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/SMT/IR/SMT.cpp.inc"
+ >();
+}
+
+Operation *SMTDialect::materializeConstant(OpBuilder &builder, Attribute value,
+ Type type, Location loc) {
+ // BitVectorType constants can materialize into smt.bv.constant
+ if (auto bvType = dyn_cast<BitVectorType>(type)) {
+ if (auto attrValue = dyn_cast<BitVectorAttr>(value)) {
+ assert(bvType == attrValue.getType() &&
+ "attribute and desired result types have to match");
+ return builder.create<BVConstantOp>(loc, attrValue);
+ }
+ }
+
+ // BoolType constants can materialize into smt.constant
+ if (auto boolType = dyn_cast<BoolType>(type)) {
+ if (auto attrValue = dyn_cast<BoolAttr>(value))
+ return builder.create<BoolConstantOp>(loc, attrValue);
+ }
+
+ return nullptr;
+}
+
+#include "mlir/Dialect/SMT/IR/SMTDialect.cpp.inc"
+#include "mlir/Dialect/SMT/IR/SMTEnums.cpp.inc"
diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
new file mode 100644
index 0000000000000..8977a3abc125d
--- /dev/null
+++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
@@ -0,0 +1,472 @@
+//===- SMTOps.cpp ---------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/APSInt.h"
+
+using namespace mlir;
+using namespace smt;
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// BVConstantOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult BVConstantOp::inferReturnTypes(
+ mlir::MLIRContext *context, std::optional<mlir::Location> location,
+ ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+ ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(
+ properties.as<Properties *>()->getValue().getType());
+ return success();
+}
+
+void BVConstantOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ SmallVector<char, 128> specialNameBuffer;
+ llvm::raw_svector_ostream specialName(specialNameBuffer);
+ specialName << "c" << getValue().getValue() << "_bv"
+ << getValue().getValue().getBitWidth();
+ setNameFn(getResult(), specialName.str());
+}
+
+OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) {
+ assert(adaptor.getOperands().empty() && "constant has no operands");
+ return getValueAttr();
+}
+
+//===----------------------------------------------------------------------===//
+// DeclareFunOp
+//===----------------------------------------------------------------------===//
+
+void DeclareFunOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() : "");
+}
+
+//===----------------------------------------------------------------------===//
+// SolverOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SolverOp::verifyRegions() {
+ if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes())
+ return emitOpError() << "types of yielded values must match return values";
+ if (getBody()->getArgumentTypes() != getInputs().getTypes())
+ return emitOpError()
+ << "block argument types must match the types of the 'inputs'";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CheckOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult CheckOp::verifyRegions() {
+ if (getSatRegion().front().getTerminator()->getOperands().getTypes() !=
+ getResultTypes())
+ return emitOpError() << "types of yielded values in 'sat' region must "
+ "match return values";
+ if (getUnknownRegion().front().getTerminator()->getOperands().getTypes() !=
+ getResultTypes())
+ return emitOpError() << "types of yielded values in 'unknown' region must "
+ "match return values";
+ if (getUnsatRegion().front().getTerminator()->getOperands().getTypes() !=
+ getResultTypes())
+ return emitOpError() << "types of yielded values in 'unsat' region must "
+ "match return values";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// EqOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs;
+ SMLoc loc = parser.getCurrentLocation();
+ Type type;
+
+ if (parser.parseOperandList(inputs) ||
+ parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
+ parser.parseType(type))
+ return failure();
+
+ result.addTypes(BoolType::get(parser.getContext()));
+ if (parser.resolveOperands(inputs, SmallVector<Type>(inputs.size(), type),
+ loc, result.operands))
+ return failure();
+
+ return success();
+}
+
+ParseResult EqOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseSameOperandTypeVariadicToBoolOp(parser, result);
+}
+
+void EqOp::print(OpAsmPrinter &printer) {
+ printer << ' ' << getInputs();
+ printer.printOptionalAttrDict(getOperation()->getAttrs());
+ printer << " : " << getInputs().front().getType();
+}
+
+LogicalResult EqOp::verify() {
+ if (getInputs().size() < 2)
+ return emitOpError() << "'inputs' must have at least size 2, but got "
+ << getInputs().size();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// DistinctOp
+//===----------------------------------------------------------------------===//
+
+ParseResult DistinctOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseSameOperandTypeVariadicToBoolOp(parser, result);
+}
+
+void DistinctOp::print(OpAsmPrinter &printer) {
+ printer << ' ' << getInputs();
+ printer.printOptionalAttrDict(getOperation()->getAttrs());
+ printer << " : " << getInputs().front().getType();
+}
+
+LogicalResult DistinctOp::verify() {
+ if (getInputs().size() < 2)
+ return emitOpError() << "'inputs' must have at least size 2, but got "
+ << getInputs().size();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ExtractOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExtractOp::verify() {
+ unsigned rangeWidth = getType().getWidth();
+ unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
+ if (getLowBit() + rangeWidth > inputWidth)
+ return emitOpError("range to be extracted is too big, expected range "
+ "starting at index ")
+ << getLowBit() << " of length " << rangeWidth
+ << " requires input width of at least " << (getLowBit() + rangeWidth)
+ << ", but the input width is only " << inputWidth;
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConcatOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(BitVectorType::get(
+ context, cast<BitVectorType>(operands[0].getType()).getWidth() +
+ cast<BitVectorType>(operands[1].getType()).getWidth()));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// RepeatOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult RepeatOp::verify() {
+ unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
+ unsigned resultWidth = getType().getWidth();
+ if (resultWidth % inputWidth != 0)
+ return emitOpError() << "result bit-vector width must be a multiple of the "
+ "input bit-vector width";
+
+ return success();
+}
+
+unsigned RepeatOp::getCount() {
+ unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
+ unsigned resultWidth = getType().getWidth();
+ return resultWidth / inputWidth;
+}
+
+void RepeatOp::build(OpBuilder &builder, OperationState &state, unsigned count,
+ Value input) {
+ unsigned inputWidth = cast<BitVectorType>(input.getType()).getWidth();
+ Type resultTy = BitVectorType::get(builder.getContext(), inputWidth * count);
+ build(builder, state, resultTy, input);
+}
+
+ParseResult RepeatOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::UnresolvedOperand input;
+ Type inputType;
+ llvm::SMLoc countLoc = parser.getCurrentLocation();
+
+ APInt count;
+ if (parser.parseInteger(count) || parser.parseKeyword("times"))
+ return failure();
+
+ if (count.isNonPositive())
+ return parser.emitError(countLoc) << "integer must be positive";
+
+ llvm::SMLoc inputLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(input) ||
+ parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
+ parser.parseType(inputType))
+ return failure();
+
+ if (parser.resolveOperand(input, inputType, result.operands))
+ return failure();
+
+ auto bvInputTy = dyn_cast<BitVectorType>(inputType);
+ if (!bvInputTy)
+ return parser.emitError(inputLoc) << "input must have bit-vector type";
+
+ // Make sure no assertions can trigger and no silent overflows can happen
+ // Bit-width is stored as 'int64_t' parameter in 'BitVectorType'
+ const unsigned maxBw = 63;
+ if (count.getActiveBits() > maxBw)
+ return parser.emitError(countLoc)
+ << "integer must fit into " << maxBw << " bits";
+
+ // Store multiplication in an APInt twice the size to not have any overflow
+ // and check if it can be truncated to 'maxBw' bits without cutting of
+ // important bits.
+ APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw);
+ if (resultBw.getActiveBits() > maxBw)
+ return parser.emitError(countLoc)
+ << "result bit-width (provided integer times bit-width of the input "
+ "type) must fit into "
+ << maxBw << " bits";
+
+ Type resultTy =
+ BitVectorType::get(parser.getContext(), resultBw.getZExtValue());
+ result.addTypes(resultTy);
+ return success();
+}
+
+void RepeatOp::print(OpAsmPrinter &printer) {
+ printer << " " << getCount() << " times " << getInput();
+ printer.printOptionalAttrDict((*this)->getAttrs());
+ printer << " : " << getInput().getType();
+}
+
+//===----------------------------------------------------------------------===//
+// BoolConstantOp
+//===----------------------------------------------------------------------===//
+
+void BoolConstantOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), getValue() ? "true" : "false");
+}
+
+OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
+ assert(adaptor.getOperands().empty() && "constant has no operands");
+ return getValueAttr();
+}
+
+//===----------------------------------------------------------------------===//
+// IntConstantOp
+//===----------------------------------------------------------------------===//
+
+void IntConstantOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ SmallVector<char, 32> specialNameBuffer;
+ llvm::raw_svector_ostream specialName(specialNameBuffer);
+ specialName << "c" << getValue();
+ setNameFn(getResult(), specialName.str());
+}
+
+OpFoldResult IntConstantOp::fold(FoldAdaptor adaptor) {
+ assert(adaptor.getOperands().empty() && "constant has no operands");
+ return getValueAttr();
+}
+
+void IntConstantOp::print(OpAsmPrinter &p) {
+ p << " " << getValue();
+ p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
+}
+
+ParseResult IntConstantOp::parse(OpAsmParser &parser, OperationState &result) {
+ APInt value;
+ if (parser.parseInteger(value))
+ return failure();
+
+ result.getOrAddProperties<Properties>().setValue(
+ IntegerAttr::get(parser.getContext(), APSInt(value)));
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ result.addTypes(smt::IntType::get(parser.getContext()));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ForallOp
+//===----------------------------------------------------------------------===//
+
+template <typename QuantifierOp>
+static LogicalResult verifyQuantifierRegions(QuantifierOp op) {
+ if (op.getBoundVarNames() &&
+ op.getBody().getNumArguments() != op.getBoundVarNames()->size())
+ return op.emitOpError(
+ "number of bound variable names must match number of block arguments");
+ if (!llvm::all_of(op.getBody().getArgumentTypes(), isAnyNonFuncSMTValueType))
+ return op.emitOpError()
+ << "bound variables must by any non-function SMT value";
+
+ if (op.getBody().front().getTerminator()->getNumOperands() != 1)
+ return op.emitOpError("must have exactly one yielded value");
+ if (!isa<BoolType>(
+ op.getBody().front().getTerminator()->getOperand(0).getType()))
+ return op.emitOpError("yielded value must be of '!smt.bool' type");
+
+ for (auto regionWithIndex : llvm::enumerate(op.getPatterns())) {
+ unsigned i = regionWithIndex.index();
+ Region ®ion = regionWithIndex.value();
+
+ if (op.getBody().getArgumentTypes() != region.getArgumentTypes())
+ return op.emitOpError()
+ << "block argument number and types of the 'body' "
+ "and 'patterns' region #"
+ << i << " must match";
+ if (region.front().getTerminator()->getNumOperands() < 1)
+ return op.emitOpError() << "'patterns' region #" << i
+ << " must have at least one yielded value";
+
+ // All operations in the 'patterns' region must be SMT operations.
+ auto result = region.walk([&](Operation *childOp) {
+ if (!isa<SMTDialect>(childOp->getDialect())) {
+ auto diag = op.emitOpError()
+ << "the 'patterns' region #" << i
+ << " may only contain SMT dialect operations";
+ diag.attachNote(childOp->getLoc()) << "first non-SMT operation here";
+ return WalkResult::interrupt();
+ }
+
+ // There may be no quantifier (or other variable binding) operations in
+ // the 'patterns' region.
+ if (isa<ForallOp, ExistsOp>(childOp)) {
+ auto diag = op.emitOpError() << "the 'patterns' region #" << i
+ << " must not contain "
+ "any variable binding operations";
+ diag.attachNote(childOp->getLoc()) << "first violating operation here";
+ return WalkResult::interrupt();
+ }
+
+ return WalkResult::advance();
+ });
+ if (result.wasInterrupted())
+ return failure();
+ }
+
+ return success();
+}
+
+template <typename Properties>
+static void buildQuantifier(
+ OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+ function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+ std::optional<ArrayRef<StringRef>> boundVarNames,
+ function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+ uint32_t weight, bool noPattern) {
+ odsState.addTypes(BoolType::get(odsBuilder.getContext()));
+ if (weight != 0)
+ odsState.getOrAddProperties<Properties>().weight =
+ odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight);
+ if (noPattern)
+ odsState.getOrAddProperties<Properties>().noPattern =
+ odsBuilder.getUnitAttr();
+ if (boundVarNames.has_value()) {
+ SmallVector<Attribute> boundVarNamesList;
+ for (StringRef str : *boundVarNames)
+ boundVarNamesList.emplace_back(odsBuilder.getStringAttr(str));
+ odsState.getOrAddProperties<Properties>().boundVarNames =
+ odsBuilder.getArrayAttr(boundVarNamesList);
+ }
+ {
+ OpBuilder::InsertionGuard guard(odsBuilder);
+ Region *region = odsState.addRegion();
+ Block *block = odsBuilder.createBlock(region);
+ block->addArguments(
+ boundVarTypes,
+ SmallVector<Location>(boundVarTypes.size(), odsState.location));
+ Value returnVal =
+ bodyBuilder(odsBuilder, odsState.location, block->getArguments());
+ odsBuilder.create<smt::YieldOp>(odsState.location, returnVal);
+ }
+ if (patternBuilder) {
+ Region *region = odsState.addRegion();
+ OpBuilder::InsertionGuard guard(odsBuilder);
+ Block *block = odsBuilder.createBlock(region);
+ block->addArguments(
+ boundVarTypes,
+ SmallVector<Location>(boundVarTypes.size(), odsState.location));
+ ValueRange returnVals =
+ patternBuilder(odsBuilder, odsState.location, block->getArguments());
+ odsBuilder.create<smt::YieldOp>(odsState.location, returnVals);
+ }
+}
+
+LogicalResult ForallOp::verify() {
+ if (!getPatterns().empty() && getNoPattern())
+ return emitOpError() << "patterns and the no_pattern attribute must not be "
+ "specified at the same time";
+
+ return success();
+}
+
+LogicalResult ForallOp::verifyRegions() {
+ return verifyQuantifierRegions(*this);
+}
+
+void ForallOp::build(
+ OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+ function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+ std::optional<ArrayRef<StringRef>> boundVarNames,
+ function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+ uint32_t weight, bool noPattern) {
+ buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+ boundVarNames, patternBuilder, weight, noPattern);
+}
+
+//===----------------------------------------------------------------------===//
+// ExistsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExistsOp::verify() {
+ if (!getPatterns().empty() && getNoPattern())
+ return emitOpError() << "patterns and the no_pattern attribute must not be "
+ "specified at the same time";
+
+ return success();
+}
+
+LogicalResult ExistsOp::verifyRegions() {
+ return verifyQuantifierRegions(*this);
+}
+
+void ExistsOp::build(
+ OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+ function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+ std::optional<ArrayRef<StringRef>> boundVarNames,
+ function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+ uint32_t weight, bool noPattern) {
+ buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+ boundVarNames, patternBuilder, weight, noPattern);
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/SMT/IR/SMT.cpp.inc"
diff --git a/mlir/lib/Dialect/SMT/IR/SMTTypes.cpp b/mlir/lib/Dialect/SMT/IR/SMTTypes.cpp
new file mode 100644
index 0000000000000..6188719bb1ab5
--- /dev/null
+++ b/mlir/lib/Dialect/SMT/IR/SMTTypes.cpp
@@ -0,0 +1,92 @@
+//===- SMTTypes.cpp -------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SMT/IR/SMTTypes.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace smt;
+using namespace mlir;
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/SMT/IR/SMTTypes.cpp.inc"
+
+void SMTDialect::registerTypes() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/SMT/IR/SMTTypes.cpp.inc"
+ >();
+}
+
+bool smt::isAnyNonFuncSMTValueType(Type type) {
+ return isAnySMTValueType(type) && !isa<SMTFuncType>(type);
+}
+
+bool smt::isAnySMTValueType(Type type) {
+ return isa<BoolType, BitVectorType, ArrayType, IntType, SortType,
+ SMTFuncType>(type);
+}
+
+//===----------------------------------------------------------------------===//
+// BitVectorType
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+BitVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
+ int64_t width) {
+ if (width <= 0U)
+ return emitError() << "bit-vector must have at least a width of one";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ArrayType
+//===----------------------------------------------------------------------===//
+
+LogicalResult ArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
+ Type domainType, Type rangeType) {
+ if (!isAnySMTValueType(domainType))
+ return emitError() << "domain must be any SMT value type";
+ if (!isAnySMTValueType(rangeType))
+ return emitError() << "range must be any SMT value type";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SMTFuncType
+//===----------------------------------------------------------------------===//
+
+LogicalResult SMTFuncType::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Type> domainTypes, Type rangeType) {
+ if (domainTypes.empty())
+ return emitError() << "domain must not be empty";
+ if (!llvm::all_of(domainTypes, isAnyNonFuncSMTValueType))
+ return emitError() << "domain types must be any non-function SMT type";
+ if (!isAnyNonFuncSMTValueType(rangeType))
+ return emitError() << "range type must be any non-function SMT type";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SortType
+//===----------------------------------------------------------------------===//
+
+LogicalResult SortType::verify(function_ref<InFlightDiagnostic()> emitError,
+ StringAttr identifier,
+ ArrayRef<Type> sortParams) {
+ if (!llvm::all_of(sortParams, isAnyNonFuncSMTValueType))
+ return emitError()
+ << "sort parameter types must be any non-function SMT type";
+
+ return success();
+}
diff --git a/mlir/test/Dialect/SMT/array-errors.mlir b/mlir/test/Dialect/SMT/array-errors.mlir
new file mode 100644
index 0000000000000..4e90948eed848
--- /dev/null
+++ b/mlir/test/Dialect/SMT/array-errors.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+// expected-error @below {{domain must be any SMT value type}}
+func.func @array_domain_no_smt_type(%arg0: !smt.array<[i32 -> !smt.bool]>) {
+ return
+}
+
+// -----
+
+// expected-error @below {{range must be any SMT value type}}
+func.func @array_range_no_smt_type(%arg0: !smt.array<[!smt.bool -> i32]>) {
+ return
+}
diff --git a/mlir/test/Dialect/SMT/array.mlir b/mlir/test/Dialect/SMT/array.mlir
new file mode 100644
index 0000000000000..89cb45c5e878a
--- /dev/null
+++ b/mlir/test/Dialect/SMT/array.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @arrayOperations
+// CHECK-SAME: ([[A0:%.+]]: !smt.bool)
+func.func @arrayOperations(%arg0: !smt.bool) {
+ // CHECK-NEXT: [[V0:%.+]] = smt.array.broadcast [[A0]] {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]>
+ %0 = smt.array.broadcast %arg0 {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]>
+ // CHECK-NEXT: [[V1:%.+]] = smt.array.select [[V0]][[[A0]]] {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]>
+ %1 = smt.array.select %0[%arg0] {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]>
+ // CHECK-NEXT: [[V2:%.+]] = smt.array.store [[V0]][[[A0]]], [[A0]] {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]>
+ %2 = smt.array.store %0[%arg0], %arg0 {smt.some_attr} : !smt.array<[!smt.bool -> !smt.bool]>
+
+ return
+}
diff --git a/mlir/test/Dialect/SMT/basic.mlir b/mlir/test/Dialect/SMT/basic.mlir
new file mode 100644
index 0000000000000..a4975d66e9769
--- /dev/null
+++ b/mlir/test/Dialect/SMT/basic.mlir
@@ -0,0 +1,200 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @types
+// CHECK-SAME: (%{{.*}}: !smt.bool, %{{.*}}: !smt.bv<32>, %{{.*}}: !smt.int, %{{.*}}: !smt.sort<"uninterpreted_sort">, %{{.*}}: !smt.sort<"uninterpreted_sort"[!smt.bool, !smt.int]>, %{{.*}}: !smt.func<(!smt.bool, !smt.bool) !smt.bool>)
+func.func @types(%arg0: !smt.bool, %arg1: !smt.bv<32>, %arg2: !smt.int, %arg3: !smt.sort<"uninterpreted_sort">, %arg4: !smt.sort<"uninterpreted_sort"[!smt.bool, !smt.int]>, %arg5: !smt.func<(!smt.bool, !smt.bool) !smt.bool>) {
+ return
+}
+
+func.func @core(%in: i8) {
+ // CHECK: %a = smt.declare_fun "a" {smt.some_attr} : !smt.bool
+ %a = smt.declare_fun "a" {smt.some_attr} : !smt.bool
+ // CHECK: smt.declare_fun {smt.some_attr} : !smt.bv<32>
+ %b = smt.declare_fun {smt.some_attr} : !smt.bv<32>
+ // CHECK: smt.declare_fun {smt.some_attr} : !smt.int
+ %c = smt.declare_fun {smt.some_attr} : !smt.int
+ // CHECK: smt.declare_fun {smt.some_attr} : !smt.sort<"uninterpreted_sort">
+ %d = smt.declare_fun {smt.some_attr} : !smt.sort<"uninterpreted_sort">
+ // CHECK: smt.declare_fun {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool>
+ %e = smt.declare_fun {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool>
+
+ // CHECK: smt.constant true {smt.some_attr}
+ %true = smt.constant true {smt.some_attr}
+ // CHECK: smt.constant false {smt.some_attr}
+ %false = smt.constant false {smt.some_attr}
+
+ // CHECK: smt.assert %a {smt.some_attr}
+ smt.assert %a {smt.some_attr}
+
+ // CHECK: smt.reset {smt.some_attr}
+ smt.reset {smt.some_attr}
+
+ // CHECK: smt.push 1 {smt.some_attr}
+ smt.push 1 {smt.some_attr}
+
+ // CHECK: smt.pop 1 {smt.some_attr}
+ smt.pop 1 {smt.some_attr}
+
+ // CHECK: %{{.*}} = smt.solver(%{{.*}}) {smt.some_attr} : (i8) -> (i8, i32) {
+ // CHECK: ^bb0(%{{.*}}: i8)
+ // CHECK: %{{.*}} = smt.check {smt.some_attr} sat {
+ // CHECK: smt.yield %{{.*}} : i32
+ // CHECK: } unknown {
+ // CHECK: smt.yield %{{.*}} : i32
+ // CHECK: } unsat {
+ // CHECK: smt.yield %{{.*}} : i32
+ // CHECK: } -> i32
+ // CHECK: smt.yield %{{.*}}, %{{.*}} : i8, i32
+ // CHECK: }
+ %0:2 = smt.solver(%in) {smt.some_attr} : (i8) -> (i8, i32) {
+ ^bb0(%arg0: i8):
+ %1 = smt.check {smt.some_attr} sat {
+ %c1_i32 = arith.constant 1 : i32
+ smt.yield %c1_i32 : i32
+ } unknown {
+ %c0_i32 = arith.constant 0 : i32
+ smt.yield %c0_i32 : i32
+ } unsat {
+ %c-1_i32 = arith.constant -1 : i32
+ smt.yield %c-1_i32 : i32
+ } -> i32
+ smt.yield %arg0, %1 : i8, i32
+ }
+
+ // CHECK: smt.solver() : () -> () {
+ // CHECK-NEXT: }
+ smt.solver() : () -> () { }
+
+ // CHECK: smt.solver() : () -> () {
+ // CHECK-NEXT: smt.set_logic "AUFLIA"
+ // CHECK-NEXT: }
+ smt.solver() : () -> () {
+ smt.set_logic "AUFLIA"
+ }
+
+ // CHECK: smt.check sat {
+ // CHECK-NEXT: } unknown {
+ // CHECK-NEXT: } unsat {
+ // CHECK-NEXT: }
+ smt.check sat { } unknown { } unsat { }
+
+ // CHECK: %{{.*}} = smt.eq %{{.*}}, %{{.*}} {smt.some_attr} : !smt.bv<32>
+ %1 = smt.eq %b, %b {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.distinct %{{.*}}, %{{.*}} {smt.some_attr} : !smt.bv<32>
+ %2 = smt.distinct %b, %b {smt.some_attr} : !smt.bv<32>
+
+ // CHECK: %{{.*}} = smt.eq %{{.*}}, %{{.*}}, %{{.*}} : !smt.bool
+ %3 = smt.eq %a, %a, %a : !smt.bool
+ // CHECK: %{{.*}} = smt.distinct %{{.*}}, %{{.*}}, %{{.*}} : !smt.bool
+ %4 = smt.distinct %a, %a, %a : !smt.bool
+
+ // CHECK: %{{.*}} = smt.ite %{{.*}}, %{{.*}}, %{{.*}} {smt.some_attr} : !smt.bv<32>
+ %5 = smt.ite %a, %b, %b {smt.some_attr} : !smt.bv<32>
+
+ // CHECK: %{{.*}} = smt.not %{{.*}} {smt.some_attr}
+ %6 = smt.not %a {smt.some_attr}
+ // CHECK: %{{.*}} = smt.and %{{.*}}, %{{.*}}, %{{.*}} {smt.some_attr}
+ %7 = smt.and %a, %a, %a {smt.some_attr}
+ // CHECK: %{{.*}} = smt.or %{{.*}}, %{{.*}}, %{{.*}} {smt.some_attr}
+ %8 = smt.or %a, %a, %a {smt.some_attr}
+ // CHECK: %{{.*}} = smt.xor %{{.*}}, %{{.*}}, %{{.*}} {smt.some_attr}
+ %9 = smt.xor %a, %a, %a {smt.some_attr}
+ // CHECK: %{{.*}} = smt.implies %{{.*}}, %{{.*}} {smt.some_attr}
+ %10 = smt.implies %a, %a {smt.some_attr}
+
+ // CHECK: smt.apply_func %{{.*}}(%{{.*}}, %{{.*}}) {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool>
+ %11 = smt.apply_func %e(%c, %a) {smt.some_attr} : !smt.func<(!smt.int, !smt.bool) !smt.bool>
+
+ return
+}
+
+// CHECK-LABEL: func @quantifiers
+func.func @quantifiers() {
+ // CHECK-NEXT: smt.forall ["a", "b"] weight 2 attributes {smt.some_attr} {
+ // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool):
+ // CHECK-NEXT: smt.eq
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: } patterns {
+ // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool, %{{.*}}: !smt.bool):
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: }, {
+ // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool):
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: }
+ %0 = smt.forall ["a", "b"] weight 2 attributes {smt.some_attr} {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ %1 = smt.eq %arg2, %arg3 : !smt.bool
+ smt.yield %1 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
+ }, {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
+ }
+
+ // CHECK-NEXT: smt.forall ["a", "b"] no_pattern attributes {smt.some_attr} {
+ // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool):
+ // CHECK-NEXT: smt.eq
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: }
+ %1 = smt.forall ["a", "b"] no_pattern attributes {smt.some_attr} {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ %2 = smt.eq %arg2, %arg3 : !smt.bool
+ smt.yield %2 : !smt.bool
+ }
+
+ // CHECK-NEXT: smt.forall {
+ // CHECK-NEXT: smt.constant
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: }
+ %2 = smt.forall {
+ %3 = smt.constant true
+ smt.yield %3 : !smt.bool
+ }
+
+ // CHECK-NEXT: smt.exists ["a", "b"] weight 2 attributes {smt.some_attr} {
+ // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool):
+ // CHECK-NEXT: smt.eq
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: } patterns {
+ // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool, %{{.*}}: !smt.bool):
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: }, {
+ // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool):
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: }
+ %3 = smt.exists ["a", "b"] weight 2 attributes {smt.some_attr} {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ %4 = smt.eq %arg2, %arg3 : !smt.bool
+ smt.yield %4 : !smt.bool {smt.some_attr}
+ } patterns {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
+ }, {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
+ }
+
+ // CHECK-NEXT: smt.exists no_pattern attributes {smt.some_attr} {
+ // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool):
+ // CHECK-NEXT: smt.eq
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: }
+ %4 = smt.exists no_pattern attributes {smt.some_attr} {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ %5 = smt.eq %arg2, %arg3 : !smt.bool
+ smt.yield %5 : !smt.bool {smt.some_attr}
+ }
+
+ // CHECK-NEXT: smt.exists [] {
+ // CHECK-NEXT: smt.constant
+ // CHECK-NEXT: smt.yield %{{.*}}
+ // CHECK-NEXT: }
+ %5 = smt.exists [] {
+ %6 = smt.constant true
+ smt.yield %6 : !smt.bool
+ }
+
+ return
+}
diff --git a/mlir/test/Dialect/SMT/bitvector-errors.mlir b/mlir/test/Dialect/SMT/bitvector-errors.mlir
new file mode 100644
index 0000000000000..58226f4d55f62
--- /dev/null
+++ b/mlir/test/Dialect/SMT/bitvector-errors.mlir
@@ -0,0 +1,112 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+// expected-error @below {{bit-vector must have at least a width of one}}
+func.func @at_least_size_one(%arg0: !smt.bv<0>) {
+ return
+}
+
+// -----
+
+// expected-error @below {{bit-vector must have at least a width of one}}
+func.func @positive_width(%arg0: !smt.bv<-1>) {
+ return
+}
+
+// -----
+
+func.func @attr_type_and_return_type_match() {
+ // expected-error @below {{inferred type(s) '!smt.bv<1>' are incompatible with return type(s) of operation '!smt.bv<32>'}}
+ // expected-error @below {{failed to infer returned types}}
+ %c0_bv32 = "smt.bv.constant"() <{value = #smt.bv<0> : !smt.bv<1>}> : () -> !smt.bv<32>
+ return
+}
+
+// -----
+
+func.func @invalid_bitvector_attr() {
+ // expected-error @below {{explicit bit-vector type required}}
+ smt.bv.constant #smt.bv<5>
+}
+
+// -----
+
+func.func @invalid_bitvector_attr() {
+ // expected-error @below {{integer value out of range for given bit-vector type}}
+ smt.bv.constant #smt.bv<32> : !smt.bv<2>
+}
+
+// -----
+
+func.func @invalid_bitvector_attr() {
+ // expected-error @below {{integer value out of range for given bit-vector type}}
+ smt.bv.constant #smt.bv<-4> : !smt.bv<2>
+}
+
+// -----
+
+func.func @extraction(%arg0: !smt.bv<32>) {
+ // expected-error @below {{range to be extracted is too big, expected range starting at index 20 of length 16 requires input width of at least 36, but the input width is only 32}}
+ smt.bv.extract %arg0 from 20 : (!smt.bv<32>) -> !smt.bv<16>
+ return
+}
+
+// -----
+
+func.func @concat(%arg0: !smt.bv<32>) {
+ // expected-error @below {{inferred type(s) '!smt.bv<64>' are incompatible with return type(s) of operation '!smt.bv<33>'}}
+ // expected-error @below {{failed to infer returned types}}
+ "smt.bv.concat"(%arg0, %arg0) {} : (!smt.bv<32>, !smt.bv<32>) -> !smt.bv<33>
+ return
+}
+
+// -----
+
+func.func @repeat_result_type_no_multiple_of_input_type(%arg0: !smt.bv<32>) {
+ // expected-error @below {{result bit-vector width must be a multiple of the input bit-vector width}}
+ "smt.bv.repeat"(%arg0) : (!smt.bv<32>) -> !smt.bv<65>
+ return
+}
+
+// -----
+
+func.func @repeat_negative_count(%arg0: !smt.bv<32>) {
+ // expected-error @below {{integer must be positive}}
+ smt.bv.repeat -2 times %arg0 : !smt.bv<32>
+ return
+}
+
+// -----
+
+// The parser has to extract the bit-width of the input and thus we need to
+// test that this is handled correctly in the parser, we cannot just rely on the
+// verifier.
+func.func @repeat_wrong_input_type(%arg0: !smt.bool) {
+ // expected-error @below {{input must have bit-vector type}}
+ smt.bv.repeat 2 times %arg0 : !smt.bool
+ return
+}
+
+// -----
+
+func.func @repeat_count_too_large(%arg0: !smt.bv<32>) {
+ // expected-error @below {{integer must fit into 63 bits}}
+ smt.bv.repeat 18446744073709551617 times %arg0 : !smt.bv<32>
+ return
+}
+
+// -----
+
+func.func @repeat_result_type_bitwidth_too_large(%arg0: !smt.bv<9223372036854775807>) {
+ // expected-error @below {{result bit-width (provided integer times bit-width of the input type) must fit into 63 bits}}
+ smt.bv.repeat 2 times %arg0 : !smt.bv<9223372036854775807>
+ return
+}
+
+// -----
+
+func.func @invalid_bv2int_signedness() {
+ %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32>
+ // expected-error @below {{expected ':'}}
+ %bv2int = smt.bv2int %c5_bv32 unsigned : !smt.bv<32>
+ return
+}
diff --git a/mlir/test/Dialect/SMT/bitvectors.mlir b/mlir/test/Dialect/SMT/bitvectors.mlir
new file mode 100644
index 0000000000000..2482f55b5ed31
--- /dev/null
+++ b/mlir/test/Dialect/SMT/bitvectors.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @bitvectors
+func.func @bitvectors() {
+ // CHECK: %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr}
+ %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr}
+ // CHECK: %c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8> {smt.some_attr}
+ %c92_bv8 = smt.bv.constant #smt.bv<0x5c> : !smt.bv<8> {smt.some_attr}
+ // CHECK: %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8>
+ %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8>
+ // CHECK: %c-1_bv1{{(_[0-9]+)?}} = smt.bv.constant #smt.bv<-1> : !smt.bv<1>
+ %c-1_bv1_neg = smt.bv.constant #smt.bv<-1> : !smt.bv<1>
+ // CHECK: %c-1_bv1{{(_[0-9]+)?}} = smt.bv.constant #smt.bv<-1> : !smt.bv<1>
+ %c-1_bv1_pos = smt.bv.constant #smt.bv<1> : !smt.bv<1>
+
+ // CHECK: [[C0:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
+ %c = smt.bv.constant #smt.bv<0> : !smt.bv<32>
+
+ // CHECK: %{{.*}} = smt.bv.neg [[C0]] {smt.some_attr} : !smt.bv<32>
+ %0 = smt.bv.neg %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.add [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %1 = smt.bv.add %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.mul [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %3 = smt.bv.mul %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.urem [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %4 = smt.bv.urem %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.srem [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %5 = smt.bv.srem %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.smod [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %7 = smt.bv.smod %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.shl [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %8 = smt.bv.shl %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.lshr [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %9 = smt.bv.lshr %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.ashr [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %10 = smt.bv.ashr %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.udiv [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %11 = smt.bv.udiv %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.sdiv [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %12 = smt.bv.sdiv %c, %c {smt.some_attr} : !smt.bv<32>
+
+ // CHECK: %{{.*}} = smt.bv.not [[C0]] {smt.some_attr} : !smt.bv<32>
+ %13 = smt.bv.not %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.and [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %14 = smt.bv.and %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.or [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %15 = smt.bv.or %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.xor [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %16 = smt.bv.xor %c, %c {smt.some_attr} : !smt.bv<32>
+
+ // CHECK: %{{.*}} = smt.bv.cmp slt [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %17 = smt.bv.cmp slt %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.cmp sle [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %18 = smt.bv.cmp sle %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.cmp sgt [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %19 = smt.bv.cmp sgt %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.cmp sge [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %20 = smt.bv.cmp sge %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.cmp ult [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %21 = smt.bv.cmp ult %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.cmp ule [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %22 = smt.bv.cmp ule %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.cmp ugt [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %23 = smt.bv.cmp ugt %c, %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.cmp uge [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>
+ %24 = smt.bv.cmp uge %c, %c {smt.some_attr} : !smt.bv<32>
+
+ // CHECK: %{{.*}} = smt.bv.concat [[C0]], [[C0]] {smt.some_attr} : !smt.bv<32>, !smt.bv<32>
+ %25 = smt.bv.concat %c, %c {smt.some_attr} : !smt.bv<32>, !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv.extract [[C0]] from 8 {smt.some_attr} : (!smt.bv<32>) -> !smt.bv<16>
+ %26 = smt.bv.extract %c from 8 {smt.some_attr} : (!smt.bv<32>) -> !smt.bv<16>
+ // CHECK: %{{.*}} = smt.bv.repeat 2 times [[C0]] {smt.some_attr} : !smt.bv<32>
+ %27 = smt.bv.repeat 2 times %c {smt.some_attr} : !smt.bv<32>
+
+ // CHECK: %{{.*}} = smt.bv2int [[C0]] {smt.some_attr} : !smt.bv<32>
+ %29 = smt.bv2int %c {smt.some_attr} : !smt.bv<32>
+ // CHECK: %{{.*}} = smt.bv2int [[C0]] signed {smt.some_attr} : !smt.bv<32>
+ %28 = smt.bv2int %c signed {smt.some_attr} : !smt.bv<32>
+
+ return
+}
diff --git a/mlir/test/Dialect/SMT/core-errors.mlir b/mlir/test/Dialect/SMT/core-errors.mlir
new file mode 100644
index 0000000000000..67bebda56b68e
--- /dev/null
+++ b/mlir/test/Dialect/SMT/core-errors.mlir
@@ -0,0 +1,497 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+func.func @solver_isolated_from_above(%arg0: !smt.bool) {
+ // expected-note @below {{required by region isolation constraints}}
+ smt.solver() : () -> () {
+ // expected-error @below {{using value defined outside the region}}
+ smt.assert %arg0
+ }
+ return
+}
+
+// -----
+
+func.func @no_smt_value_enters_solver(%arg0: !smt.bool) {
+ // expected-error @below {{operand #0 must be variadic of any non-smt type, but got '!smt.bool'}}
+ smt.solver(%arg0) : (!smt.bool) -> () {
+ ^bb0(%arg1: !smt.bool):
+ smt.assert %arg1
+ }
+ return
+}
+
+// -----
+
+func.func @no_smt_value_exits_solver() {
+ // expected-error @below {{result #0 must be variadic of any non-smt type, but got '!smt.bool'}}
+ %0 = smt.solver() : () -> !smt.bool {
+ %a = smt.declare_fun "a" : !smt.bool
+ smt.yield %a : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @block_args_and_inputs_match() {
+ // expected-error @below {{block argument types must match the types of the 'inputs'}}
+ smt.solver() : () -> () {
+ ^bb0(%arg0: i32):
+ }
+ return
+}
+
+// -----
+
+func.func @solver_yield_operands_and_results_match() {
+ // expected-error @below {{types of yielded values must match return values}}
+ smt.solver() : () -> () {
+ %1 = arith.constant 0 : i32
+ smt.yield %1 : i32
+ }
+ return
+}
+
+// -----
+
+func.func @check_yield_operands_and_results_match() {
+ // expected-error @below {{types of yielded values in 'unsat' region must match return values}}
+ %0 = smt.check sat {
+ %1 = arith.constant 0 : i32
+ smt.yield %1 : i32
+ } unknown {
+ %1 = arith.constant 0 : i32
+ smt.yield %1 : i32
+ } unsat { } -> i32
+ return
+}
+
+// -----
+
+func.func @check_yield_operands_and_results_match() {
+ // expected-error @below {{types of yielded values in 'unknown' region must match return values}}
+ %0 = smt.check sat {
+ %1 = arith.constant 0 : i32
+ smt.yield %1 : i32
+ } unknown {
+ } unsat {
+ %1 = arith.constant 0 : i32
+ smt.yield %1 : i32
+ } -> i32
+ return
+}
+
+// -----
+
+func.func @check_yield_operands_and_results_match() {
+ // expected-error @below {{types of yielded values in 'sat' region must match return values}}
+ %0 = smt.check sat {
+ } unknown {
+ %1 = arith.constant 0 : i32
+ smt.yield %1 : i32
+ } unsat {
+ %1 = arith.constant 0 : i32
+ smt.yield %1 : i32
+ } -> i32
+ return
+}
+
+// -----
+
+func.func @check_no_block_arguments() {
+ // expected-error @below {{region #0 should have no arguments}}
+ smt.check sat {
+ ^bb0(%arg0: i32):
+ } unknown {
+ } unsat {
+ }
+ return
+}
+
+// -----
+
+func.func @check_no_block_arguments() {
+ // expected-error @below {{region #1 should have no arguments}}
+ smt.check sat {
+ } unknown {
+ ^bb0(%arg0: i32):
+ } unsat {
+ }
+ return
+}
+
+// -----
+
+func.func @check_no_block_arguments() {
+ // expected-error @below {{region #2 should have no arguments}}
+ smt.check sat {
+ } unknown {
+ } unsat {
+ ^bb0(%arg0: i32):
+ }
+ return
+}
+
+// -----
+
+func.func @too_few_operands() {
+ // expected-error @below {{'inputs' must have at least size 2, but got 0}}
+ smt.eq : !smt.bool
+ return
+}
+
+// -----
+
+func.func @too_few_operands(%a: !smt.bool) {
+ // expected-error @below {{'inputs' must have at least size 2, but got 1}}
+ smt.distinct %a : !smt.bool
+ return
+}
+
+// -----
+
+func.func @ite_type_mismatch(%a: !smt.bool, %b: !smt.bv<32>) {
+ // expected-error @below {{failed to verify that all of {thenValue, elseValue, result} have same type}}
+ "smt.ite"(%a, %a, %b) {} : (!smt.bool, !smt.bool, !smt.bv<32>) -> !smt.bool
+ return
+}
+
+// -----
+
+func.func @forall_number_of_decl_names_must_match_num_args() {
+ // expected-error @below {{number of bound variable names must match number of block arguments}}
+ %1 = smt.forall ["a"] {
+ ^bb0(%arg2: !smt.int, %arg3: !smt.int):
+ %2 = smt.eq %arg2, %arg3 : !smt.int
+ smt.yield %2 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @exists_number_of_decl_names_must_match_num_args() {
+ // expected-error @below {{number of bound variable names must match number of block arguments}}
+ %1 = smt.exists ["a"] {
+ ^bb0(%arg2: !smt.int, %arg3: !smt.int):
+ %2 = smt.eq %arg2, %arg3 : !smt.int
+ smt.yield %2 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @forall_yield_must_have_exactly_one_bool_value() {
+ // expected-error @below {{yielded value must be of '!smt.bool' type}}
+ %1 = smt.forall ["a", "b"] {
+ ^bb0(%arg2: !smt.int, %arg3: !smt.int):
+ %2 = smt.int.add %arg2, %arg3
+ smt.yield %2 : !smt.int
+ }
+ return
+}
+
+// -----
+
+func.func @forall_yield_must_have_exactly_one_bool_value() {
+ // expected-error @below {{must have exactly one yielded value}}
+ %1 = smt.forall ["a", "b"] {
+ ^bb0(%arg2: !smt.int, %arg3: !smt.int):
+ smt.yield
+ }
+ return
+}
+
+// -----
+
+func.func @exists_yield_must_have_exactly_one_bool_value() {
+ // expected-error @below {{yielded value must be of '!smt.bool' type}}
+ %1 = smt.exists ["a", "b"] {
+ ^bb0(%arg2: !smt.int, %arg3: !smt.int):
+ %2 = smt.int.add %arg2, %arg3
+ smt.yield %2 : !smt.int
+ }
+ return
+}
+
+// -----
+
+func.func @exists_yield_must_have_exactly_one_bool_value() {
+ // expected-error @below {{must have exactly one yielded value}}
+ %1 = smt.exists ["a", "b"] {
+ ^bb0(%arg2: !smt.int, %arg3: !smt.int):
+ smt.yield
+ }
+ return
+}
+
+// -----
+
+func.func @exists_patterns_region_and_no_patterns_attr_are_mutually_exclusive() {
+ // expected-error @below {{patterns and the no_pattern attribute must not be specified at the same time}}
+ %1 = smt.exists ["a"] no_pattern {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @forall_patterns_region_and_no_patterns_attr_are_mutually_exclusive() {
+ // expected-error @below {{patterns and the no_pattern attribute must not be specified at the same time}}
+ %1 = smt.forall ["a"] no_pattern {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @exists_patterns_region_num_args() {
+ // expected-error @below {{block argument number and types of the 'body' and 'patterns' region #0 must match}}
+ %1 = smt.exists ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @forall_patterns_region_num_args() {
+ // expected-error @below {{block argument number and types of the 'body' and 'patterns' region #0 must match}}
+ %1 = smt.forall ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
+ smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @exists_patterns_region_at_least_one_yielded_value() {
+ // expected-error @below {{'patterns' region #0 must have at least one yielded value}}
+ %1 = smt.exists ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield
+ }
+ return
+}
+
+// -----
+
+func.func @forall_patterns_region_at_least_one_yielded_value() {
+ // expected-error @below {{'patterns' region #0 must have at least one yielded value}}
+ %1 = smt.forall ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield
+ }
+ return
+}
+
+// -----
+
+func.func @exists_all_pattern_regions_tested() {
+ // expected-error @below {{'patterns' region #1 must have at least one yielded value}}
+ %1 = smt.exists ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ }, {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield
+ }
+ return
+}
+
+// -----
+
+func.func @forall_all_pattern_regions_tested() {
+ // expected-error @below {{'patterns' region #1 must have at least one yielded value}}
+ %1 = smt.forall ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ }, {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield
+ }
+ return
+}
+
+// -----
+
+func.func @exists_patterns_region_no_non_smt_operations() {
+ // expected-error @below {{'patterns' region #0 may only contain SMT dialect operations}}
+ %1 = smt.exists ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ // expected-note @below {{first non-SMT operation here}}
+ arith.constant 0 : i32
+ smt.yield %arg2 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @forall_patterns_region_no_non_smt_operations() {
+ // expected-error @below {{'patterns' region #0 may only contain SMT dialect operations}}
+ %1 = smt.forall ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ // expected-note @below {{first non-SMT operation here}}
+ arith.constant 0 : i32
+ smt.yield %arg2 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @exists_patterns_region_no_var_binding_operations() {
+ // expected-error @below {{'patterns' region #0 must not contain any variable binding operations}}
+ %1 = smt.exists ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ // expected-note @below {{first violating operation here}}
+ smt.exists ["b"] {
+ ^bb0(%arg3: !smt.bool):
+ smt.yield %arg3 : !smt.bool
+ }
+ smt.yield %arg2 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @forall_patterns_region_no_var_binding_operations() {
+ // expected-error @below {{'patterns' region #0 must not contain any variable binding operations}}
+ %1 = smt.forall ["a"] {
+ ^bb0(%arg2: !smt.bool):
+ smt.yield %arg2 : !smt.bool
+ } patterns {
+ ^bb0(%arg2: !smt.bool):
+ // expected-note @below {{first violating operation here}}
+ smt.forall ["b"] {
+ ^bb0(%arg3: !smt.bool):
+ smt.yield %arg3 : !smt.bool
+ }
+ smt.yield %arg2 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @exists_bound_variable_type_invalid() {
+ // expected-error @below {{bound variables must by any non-function SMT value}}
+ %1 = smt.exists ["a", "b"] {
+ ^bb0(%arg2: !smt.func<(!smt.int) !smt.int>, %arg3: !smt.bool):
+ smt.yield %arg3 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+func.func @forall_bound_variable_type_invalid() {
+ // expected-error @below {{bound variables must by any non-function SMT value}}
+ %1 = smt.forall ["a", "b"] {
+ ^bb0(%arg2: !smt.func<(!smt.int) !smt.int>, %arg3: !smt.bool):
+ smt.yield %arg3 : !smt.bool
+ }
+ return
+}
+
+// -----
+
+// expected-error @below {{domain types must be any non-function SMT type}}
+func.func @func_domain_no_smt_type(%arg0: !smt.func<(i32) !smt.bool>) {
+ return
+}
+
+// -----
+
+// expected-error @below {{range type must be any non-function SMT type}}
+func.func @func_range_no_smt_type(%arg0: !smt.func<(!smt.bool) i32>) {
+ return
+}
+
+// -----
+
+// expected-error @below {{range type must be any non-function SMT type}}
+func.func @func_range_no_smt_type(%arg0: !smt.func<(!smt.bool) !smt.func<(!smt.bool) !smt.bool>>) {
+ return
+}
+
+// -----
+
+func.func @func_range_no_smt_type(%arg0: !smt.func<(!smt.bool) !smt.bool>) {
+ // expected-error @below {{got 0 operands and 1 types}}
+ smt.apply_func %arg0() : !smt.func<(!smt.bool) !smt.bool>
+ return
+}
+
+// -----
+
+// expected-error @below {{sort parameter types must be any non-function SMT type}}
+func.func @sort_type_no_smt_type(%arg0: !smt.sort<"sortname"[i32]>) {
+ return
+}
+
+// -----
+
+func.func @negative_push() {
+ // expected-error @below {{smt.push' op attribute 'count' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}}
+ smt.push -1
+ return
+}
+
+// -----
+
+func.func @negative_pop() {
+ // expected-error @below {{smt.pop' op attribute 'count' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}}
+ smt.pop -1
+ return
+}
+
+// -----
+
+func.func @set_logic_outside_solver() {
+ // expected-error @below {{'smt.set_logic' op expects parent op 'smt.solver'}}
+ smt.set_logic "AUFLIA"
+ return
+}
diff --git a/mlir/test/Dialect/SMT/cse-test.mlir b/mlir/test/Dialect/SMT/cse-test.mlir
new file mode 100644
index 0000000000000..ff254857f3b33
--- /dev/null
+++ b/mlir/test/Dialect/SMT/cse-test.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s --cse | FileCheck %s
+
+func.func @declare_const_cse(%in: i8) -> (!smt.bool, !smt.bool){
+ // CHECK: smt.declare_fun "a" : !smt.bool
+ %a = smt.declare_fun "a" : !smt.bool
+ // CHECK-NEXT: smt.declare_fun "a" : !smt.bool
+ %b = smt.declare_fun "a" : !smt.bool
+ // CHECK-NEXT: return
+ %c = smt.declare_fun "a" : !smt.bool
+
+ return %a, %b : !smt.bool, !smt.bool
+}
diff --git a/mlir/test/Dialect/SMT/integers.mlir b/mlir/test/Dialect/SMT/integers.mlir
new file mode 100644
index 0000000000000..f5133c8c72b5d
--- /dev/null
+++ b/mlir/test/Dialect/SMT/integers.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @integer_operations
+func.func @integer_operations() {
+ // CHECK-NEXT: [[V0:%.+]] = smt.int.constant -123 {smt.some_attr}
+ %0 = smt.int.constant -123 {smt.some_attr}
+ // CHECK-NEXT: %c184467440737095516152 = smt.int.constant 184467440737095516152 {smt.some_attr}
+ %1 = smt.int.constant 184467440737095516152 {smt.some_attr}
+
+
+ // CHECK-NEXT: smt.int.add [[V0]], [[V0]], [[V0]] {smt.some_attr}
+ %2 = smt.int.add %0, %0, %0 {smt.some_attr}
+ // CHECK-NEXT: smt.int.mul [[V0]], [[V0]], [[V0]] {smt.some_attr}
+ %3 = smt.int.mul %0, %0, %0 {smt.some_attr}
+ // CHECK-NEXT: smt.int.sub [[V0]], [[V0]] {smt.some_attr}
+ %4 = smt.int.sub %0, %0 {smt.some_attr}
+ // CHECK-NEXT: smt.int.div [[V0]], [[V0]] {smt.some_attr}
+ %5 = smt.int.div %0, %0 {smt.some_attr}
+ // CHECK-NEXT: smt.int.mod [[V0]], [[V0]] {smt.some_attr}
+ %6 = smt.int.mod %0, %0 {smt.some_attr}
+ // CHECK-NEXT: smt.int.abs [[V0]] {smt.some_attr}
+ %7 = smt.int.abs %0 {smt.some_attr}
+
+ // CHECK-NEXT: smt.int.cmp le [[V0]], [[V0]] {smt.some_attr}
+ %9 = smt.int.cmp le %0, %0 {smt.some_attr}
+ // CHECK-NEXT: smt.int.cmp lt [[V0]], [[V0]] {smt.some_attr}
+ %10 = smt.int.cmp lt %0, %0 {smt.some_attr}
+ // CHECK-NEXT: smt.int.cmp ge [[V0]], [[V0]] {smt.some_attr}
+ %11 = smt.int.cmp ge %0, %0 {smt.some_attr}
+ // CHECK-NEXT: smt.int.cmp gt [[V0]], [[V0]] {smt.some_attr}
+ %12 = smt.int.cmp gt %0, %0 {smt.some_attr}
+ // CHECK-NEXT: smt.int2bv [[V0]] {smt.some_attr} : !smt.bv<4>
+ %13 = smt.int2bv %0 {smt.some_attr} : !smt.bv<4>
+
+ return
+}
>From 7150eb70d594189f81f99d0ea570bcb4c1416c37 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Sat, 15 Mar 2025 21:53:52 -0400
Subject: [PATCH 2/2] [mlir][smt] add arith-to-smt
---
.../mlir/Conversion/ArithToSMT/ArithToSMT.h | 30 ++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 13 +
mlir/include/mlir/InitAllPasses.h | 1 +
mlir/lib/Conversion/ArithToSMT/ArithToSMT.cpp | 351 ++++++++++++++++++
mlir/lib/Conversion/ArithToSMT/CMakeLists.txt | 14 +
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../Conversion/ArithToSMT/arith-to-smt.mlir | 87 +++++
8 files changed, 498 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/ArithToSMT/ArithToSMT.h
create mode 100644 mlir/lib/Conversion/ArithToSMT/ArithToSMT.cpp
create mode 100644 mlir/lib/Conversion/ArithToSMT/CMakeLists.txt
create mode 100644 mlir/test/Conversion/ArithToSMT/arith-to-smt.mlir
diff --git a/mlir/include/mlir/Conversion/ArithToSMT/ArithToSMT.h b/mlir/include/mlir/Conversion/ArithToSMT/ArithToSMT.h
new file mode 100644
index 0000000000000..5bb76321199ee
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToSMT/ArithToSMT.h
@@ -0,0 +1,30 @@
+//===- ArithToSMT.h - Arith to SMT dialect conversion ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOSMT_H
+#define MLIR_CONVERSION_ARITHTOSMT_H
+
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+
+class TypeConverter;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTARITHTOSMT
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+/// Get the Arith to SMT conversion patterns.
+void populateArithToSMTConversionPatterns(TypeConverter &converter,
+ RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOSMT_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ccd862f67c068..b3a65d611a0d5 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -15,6 +15,7 @@
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ArithToSMT/ArithToSMT.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..89d4c3c0b35b7 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1464,4 +1464,17 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
];
}
+//===----------------------------------------------------------------------===//
+// ConvertArithToSMT
+//===----------------------------------------------------------------------===//
+
+def ConvertArithToSMT : Pass<"convert-arith-to-smt"> {
+ let summary = "Convert arith ops and constants to SMT ops";
+ let dependentDialects = [
+ "smt::SMTDialect",
+ "arith::ArithDialect",
+ "mlir::func::FuncDialect"
+ ];
+}
+
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index dd8b292a87344..bb8dff47ab480 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -95,6 +95,7 @@ inline void registerAllPasses() {
arm_sve::registerArmSVEPasses();
emitc::registerEmitCPasses();
xegpu::registerXeGPUPasses();
+ registerConvertArithToSMTPass();
// Dialect pipelines
bufferization::registerBufferizationPipelines();
diff --git a/mlir/lib/Conversion/ArithToSMT/ArithToSMT.cpp b/mlir/lib/Conversion/ArithToSMT/ArithToSMT.cpp
new file mode 100644
index 0000000000000..6b8714a5a1c44
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToSMT/ArithToSMT.cpp
@@ -0,0 +1,351 @@
+//===- ArithToSMT.cpp
+//------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToSMT/ArithToSMT.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARITHTOSMT
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Lower a arith::CmpIOp operation to a smt::BVCmpOp, smt::EqOp or
+/// smt::DistinctOp
+///
+struct CmpIOpConversion : OpConversionPattern<arith::CmpIOp> {
+ using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (adaptor.getPredicate() == arith::CmpIPredicate::eq) {
+ rewriter.replaceOpWithNewOp<smt::EqOp>(op, adaptor.getLhs(),
+ adaptor.getRhs());
+ return success();
+ }
+
+ if (adaptor.getPredicate() == arith::CmpIPredicate::ne) {
+ rewriter.replaceOpWithNewOp<smt::DistinctOp>(op, adaptor.getLhs(),
+ adaptor.getRhs());
+ return success();
+ }
+
+ smt::BVCmpPredicate pred;
+ switch (adaptor.getPredicate()) {
+ case arith::CmpIPredicate::sge:
+ pred = smt::BVCmpPredicate::sge;
+ break;
+ case arith::CmpIPredicate::sgt:
+ pred = smt::BVCmpPredicate::sgt;
+ break;
+ case arith::CmpIPredicate::sle:
+ pred = smt::BVCmpPredicate::sle;
+ break;
+ case arith::CmpIPredicate::slt:
+ pred = smt::BVCmpPredicate::slt;
+ break;
+ case arith::CmpIPredicate::uge:
+ pred = smt::BVCmpPredicate::uge;
+ break;
+ case arith::CmpIPredicate::ugt:
+ pred = smt::BVCmpPredicate::ugt;
+ break;
+ case arith::CmpIPredicate::ule:
+ pred = smt::BVCmpPredicate::ule;
+ break;
+ case arith::CmpIPredicate::ult:
+ pred = smt::BVCmpPredicate::ult;
+ break;
+ default:
+ llvm_unreachable("all cases handled above");
+ }
+
+ rewriter.replaceOpWithNewOp<smt::BVCmpOp>(op, pred, adaptor.getLhs(),
+ adaptor.getRhs());
+ return success();
+ }
+};
+
+/// Lower a arith::SubOp operation to an smt::BVNegOp + smt::BVAddOp
+struct SubOpConversion : OpConversionPattern<arith::SubIOp> {
+ using OpConversionPattern<arith::SubIOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value negRhs = rewriter.create<smt::BVNegOp>(op.getLoc(), adaptor.getRhs());
+ rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
+ return success();
+ }
+};
+
+/// Lower the SourceOp to the TargetOp one-to-one.
+template <typename SourceOp, typename TargetOp>
+struct OneToOneOpConversion : OpConversionPattern<SourceOp> {
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
+ using OpAdaptor = typename SourceOp::Adaptor;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ rewriter.replaceOpWithNewOp<TargetOp>(
+ op,
+ OpConversionPattern<SourceOp>::typeConverter->convertType(
+ op.getResult().getType()),
+ adaptor.getOperands());
+ return success();
+ }
+};
+
+struct CeilDivSIOpConversion : OpConversionPattern<arith::CeilDivSIOp> {
+ using OpConversionPattern<arith::CeilDivSIOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto numPlusDenom = rewriter.createOrFold<arith::AddIOp>(
+ op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
+ auto bitWidth =
+ llvm::cast<IntegerType>(getElementTypeOrSelf(adaptor.getLhs()))
+ .getWidth();
+ auto one = rewriter.create<arith::ConstantIntOp>(op.getLoc(), 1, bitWidth);
+ auto numPlusDenomMinusOne =
+ rewriter.createOrFold<arith::SubIOp>(op.getLoc(), numPlusDenom, one);
+ rewriter.replaceOpWithNewOp<arith::DivSIOp>(op, numPlusDenomMinusOne,
+ adaptor.getRhs());
+ return success();
+ }
+};
+
+/// Lower the SourceOp to the TargetOp special-casing if the second operand is
+/// zero to return a new symbolic value.
+template <typename SourceOp, typename TargetOp>
+struct DivisionOpConversion : OpConversionPattern<SourceOp> {
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
+ using OpAdaptor = typename SourceOp::Adaptor;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
+ if (!type)
+ return failure();
+
+ auto resultType = OpConversionPattern<SourceOp>::typeConverter->convertType(
+ op.getResult().getType());
+ Value zero =
+ rewriter.create<smt::BVConstantOp>(loc, APInt(type.getWidth(), 0));
+ Value isZero = rewriter.create<smt::EqOp>(loc, adaptor.getRhs(), zero);
+ Value symbolicVal = rewriter.create<smt::DeclareFunOp>(loc, resultType);
+ Value division =
+ rewriter.create<TargetOp>(loc, resultType, adaptor.getOperands());
+ rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
+ return success();
+ }
+};
+
+/// Converts an operation with a variadic number of operands to a chain of
+/// binary operations assuming left-associativity of the operation.
+template <typename SourceOp, typename TargetOp>
+struct VariadicToBinaryOpConversion : OpConversionPattern<SourceOp> {
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
+ using OpAdaptor = typename SourceOp::Adaptor;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ ValueRange operands = adaptor.getOperands();
+ if (operands.size() < 2)
+ return failure();
+
+ Value runner = operands[0];
+ for (Value operand : operands.drop_front())
+ runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
+
+ rewriter.replaceOp(op, runner);
+ return success();
+ }
+};
+
+/// Lower a arith::ConstantOp operation to smt::BVConstantOp
+struct ArithConstantIntOpConversion
+ : OpConversionPattern<arith::ConstantIntOp> {
+ using OpConversionPattern<arith::ConstantIntOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantIntOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto v = llvm::cast<IntegerAttr>(adaptor.getValue());
+ if (v.getValue().getBitWidth() < 1)
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "0-bit constants not supported");
+ // TODO(max): signed/unsigned/signless semenatics
+ rewriter.replaceOpWithNewOp<smt::BVConstantOp>(op, v.getValue());
+ return success();
+ }
+};
+
+} // namespace
+
+void populateArithToSMTTypeConverter(TypeConverter &converter) {
+ // The semantics of the builtin integer at the MLIR core level is currently
+ // not very well defined. It is used for two-valued, four-valued, and possible
+ // other multi-valued logic. Here, we interpret it as two-valued for now.
+ // From a formal perspective, MLIR would ideally define its own types for
+ // two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream
+ // the integer type also carries poison information (which we don't have in
+ // MLIR?).
+ converter.addConversion([](IntegerType type) -> std::optional<Type> {
+ if (type.getWidth() <= 0)
+ return std::nullopt;
+ return smt::BitVectorType::get(type.getContext(), type.getWidth());
+ });
+
+ // Default target materialization to convert from illegal types to legal
+ // types, e.g., at the boundary of an inlined child block.
+ converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> Value {
+ return builder
+ .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
+ ->getResult(0);
+ });
+
+ // Convert a 'smt.bool'-typed value to a 'smt.bv<N>'-typed value
+ converter.addTargetMaterialization(
+ [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return Value();
+
+ if (!isa<smt::BoolType>(inputs[0].getType()))
+ return Value();
+
+ unsigned width = resultType.getWidth();
+ Value constZero = builder.create<smt::BVConstantOp>(loc, 0, width);
+ Value constOne = builder.create<smt::BVConstantOp>(loc, 1, width);
+ return builder.create<smt::IteOp>(loc, inputs[0], constOne, constZero);
+ });
+
+ // Convert an unrealized conversion cast from 'smt.bool' to i1
+ // into a direct conversion from 'smt.bool' to 'smt.bv<1>'.
+ converter.addTargetMaterialization(
+ [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1 || resultType.getWidth() != 1)
+ return Value();
+
+ auto intType = dyn_cast<IntegerType>(inputs[0].getType());
+ if (!intType || intType.getWidth() != 1)
+ return Value();
+
+ auto castOp =
+ inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
+ if (!castOp || castOp.getInputs().size() != 1)
+ return Value();
+
+ if (!isa<smt::BoolType>(castOp.getInputs()[0].getType()))
+ return Value();
+
+ Value constZero = builder.create<smt::BVConstantOp>(loc, 0, 1);
+ Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
+ return builder.create<smt::IteOp>(loc, castOp.getInputs()[0], constOne,
+ constZero);
+ });
+
+ // Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value
+ converter.addTargetMaterialization(
+ [&](OpBuilder &builder, smt::BoolType resultType, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return Value();
+
+ auto bvType = dyn_cast<smt::BitVectorType>(inputs[0].getType());
+ if (!bvType || bvType.getWidth() != 1)
+ return Value();
+
+ Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
+ return builder.create<smt::EqOp>(loc, inputs[0], constOne);
+ });
+
+ // Default source materialization to convert from illegal types to legal
+ // types, e.g., at the boundary of an inlined child block.
+ converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> Value {
+ return builder
+ .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
+ ->getResult(0);
+ });
+}
+
+namespace {
+struct ConvertArithToSMT
+ : public impl::ConvertArithToSMTBase<ConvertArithToSMT> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<CeilDivSIOpConversion>(&getContext());
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+
+ ConversionTarget target(getContext());
+ target.addIllegalDialect<arith::ArithDialect>();
+ target.addLegalDialect<smt::SMTDialect>();
+
+ TypeConverter converter;
+ populateArithToSMTTypeConverter(converter);
+ patterns.clear();
+ arith::populateArithToSMTConversionPatterns(converter, patterns);
+
+ if (failed(mlir::applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+namespace mlir::arith {
+void populateArithToSMTConversionPatterns(TypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<ArithConstantIntOpConversion, CmpIOpConversion, SubOpConversion,
+ OneToOneOpConversion<arith::ShLIOp, smt::BVShlOp>,
+ OneToOneOpConversion<arith::ShRUIOp, smt::BVLShrOp>,
+ OneToOneOpConversion<arith::ShRSIOp, smt::BVAShrOp>,
+ DivisionOpConversion<arith::DivSIOp, smt::BVSDivOp>,
+ DivisionOpConversion<arith::DivUIOp, smt::BVUDivOp>,
+ DivisionOpConversion<arith::RemSIOp, smt::BVSRemOp>,
+ DivisionOpConversion<arith::RemUIOp, smt::BVURemOp>,
+ VariadicToBinaryOpConversion<arith::AddIOp, smt::BVAddOp>,
+ VariadicToBinaryOpConversion<arith::MulIOp, smt::BVMulOp>,
+ VariadicToBinaryOpConversion<arith::AndIOp, smt::BVAndOp>,
+ VariadicToBinaryOpConversion<arith::OrIOp, smt::BVOrOp>,
+ VariadicToBinaryOpConversion<arith::XOrIOp, smt::BVXOrOp>>(
+ converter, patterns.getContext());
+}
+} // namespace mlir::arith
\ No newline at end of file
diff --git a/mlir/lib/Conversion/ArithToSMT/CMakeLists.txt b/mlir/lib/Conversion/ArithToSMT/CMakeLists.txt
new file mode 100644
index 0000000000000..ef9df95568cb4
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToSMT/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_conversion_library(MLIRCombToSMT
+ ArithToSMT.cpp
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRSMT
+ MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index b6c21440c571c..78d0ffd382cce 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -5,6 +5,7 @@ add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
+add_subdirectory(ArithToSMT)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
add_subdirectory(ArmSMEToSCF)
diff --git a/mlir/test/Conversion/ArithToSMT/arith-to-smt.mlir b/mlir/test/Conversion/ArithToSMT/arith-to-smt.mlir
new file mode 100644
index 0000000000000..a1cf033a461c5
--- /dev/null
+++ b/mlir/test/Conversion/ArithToSMT/arith-to-smt.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-opt %s --convert-arith-to-smt | FileCheck %s
+
+// CHECK-LABEL: func @test
+// CHECK-SAME: ([[A0:%.+]]: !smt.bv<32>, [[A1:%.+]]: !smt.bv<32>, [[A2:%.+]]: !smt.bv<32>, [[A3:%.+]]: !smt.bv<32>, [[A4:%.+]]: !smt.bv<1>, [[ARG5:%.+]]: !smt.bv<4>)
+func.func @test(%a0: !smt.bv<32>, %a1: !smt.bv<32>, %a2: !smt.bv<32>, %a3: !smt.bv<32>, %a4: !smt.bv<1>, %a5: !smt.bv<4>) {
+ %arg0 = builtin.unrealized_conversion_cast %a0 : !smt.bv<32> to i32
+ %arg1 = builtin.unrealized_conversion_cast %a1 : !smt.bv<32> to i32
+ %arg2 = builtin.unrealized_conversion_cast %a2 : !smt.bv<32> to i32
+ %arg3 = builtin.unrealized_conversion_cast %a3 : !smt.bv<32> to i32
+ %arg4 = builtin.unrealized_conversion_cast %a4 : !smt.bv<1> to i1
+ %arg5 = builtin.unrealized_conversion_cast %a5 : !smt.bv<4> to i4
+
+ // CHECK: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
+ // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
+ // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
+ // CHECK-NEXT: [[DIV:%.+]] = smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32>
+ // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
+ %0 = arith.divsi %arg0, %arg1 : i32
+ // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
+ // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
+ // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
+ // CHECK-NEXT: [[DIV:%.+]] = smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32>
+ // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
+ %1 = arith.divui %arg0, %arg1 : i32
+ // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
+ // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
+ // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
+ // CHECK-NEXT: [[DIV:%.+]] = smt.bv.srem [[A0]], [[A1]] : !smt.bv<32>
+ // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
+ %2 = arith.remsi %arg0, %arg1 : i32
+ // CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
+ // CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
+ // CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
+ // CHECK-NEXT: [[DIV:%.+]] = smt.bv.urem [[A0]], [[A1]] : !smt.bv<32>
+ // CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
+ %3 = arith.remui %arg0, %arg1 : i32
+
+ // CHECK-NEXT: [[NEG:%.+]] = smt.bv.neg [[A1]] : !smt.bv<32>
+ // CHECK-NEXT: smt.bv.add [[A0]], [[NEG]] : !smt.bv<32>
+ %7 = arith.subi %arg0, %arg1 : i32
+
+ // CHECK-NEXT: [[A5:%.+]] = smt.bv.add [[A0]], [[A1]] : !smt.bv<32>
+ %8 = arith.addi %arg0, %arg1 : i32
+ // CHECK-NEXT: [[B1:%.+]] = smt.bv.mul [[A0]], [[A1]] : !smt.bv<32>
+ %9 = arith.muli %arg0, %arg1 : i32
+ // CHECK-NEXT: [[C1:%.+]] = smt.bv.and [[A0]], [[A1]] : !smt.bv<32>
+ %10 = arith.andi %arg0, %arg1 : i32
+ // CHECK-NEXT: [[D1:%.+]] = smt.bv.or [[A0]], [[A1]] : !smt.bv<32>
+ %11 = arith.ori %arg0, %arg1 : i32
+ // CHECK-NEXT: [[E1:%.+]] = smt.bv.xor [[A0]], [[A1]] : !smt.bv<32>
+ %12 = arith.xori %arg0, %arg1 : i32
+
+ // CHECK-NEXT: smt.eq [[A0]], [[A1]] : !smt.bv<32>
+ %14 = arith.cmpi eq, %arg0, %arg1 : i32
+ // CHECK-NEXT: smt.distinct [[A0]], [[A1]] : !smt.bv<32>
+ %15 = arith.cmpi ne, %arg0, %arg1 : i32
+ // CHECK-NEXT: smt.bv.cmp sle [[A0]], [[A1]] : !smt.bv<32>
+ %20 = arith.cmpi sle, %arg0, %arg1 : i32
+ // CHECK-NEXT: smt.bv.cmp slt [[A0]], [[A1]] : !smt.bv<32>
+ %21 = arith.cmpi slt, %arg0, %arg1 : i32
+ // CHECK-NEXT: smt.bv.cmp ule [[A0]], [[A1]] : !smt.bv<32>
+ %22 = arith.cmpi ule, %arg0, %arg1 : i32
+ // CHECK-NEXT: smt.bv.cmp ult [[A0]], [[A1]] : !smt.bv<32>
+ %23 = arith.cmpi ult, %arg0, %arg1 : i32
+ // CHECK-NEXT: smt.bv.cmp sge [[A0]], [[A1]] : !smt.bv<32>
+ %24 = arith.cmpi sge, %arg0, %arg1 : i32
+ // CHECK-NEXT: smt.bv.cmp sgt [[A0]], [[A1]] : !smt.bv<32>
+ %25 = arith.cmpi sgt, %arg0, %arg1 : i32
+ // CHECK-NEXT: smt.bv.cmp uge [[A0]], [[A1]] : !smt.bv<32>
+ %26 = arith.cmpi uge, %arg0, %arg1 : i32
+ // CHECK-NEXT: smt.bv.cmp ugt [[A0]], [[A1]] : !smt.bv<32>
+ %27 = arith.cmpi ugt, %arg0, %arg1 : i32
+
+ // CHECK-NEXT: %{{.*}} = smt.bv.shl [[A0]], [[A1]] : !smt.bv<32>
+ %32 = arith.shli %arg0, %arg1 : i32
+ // CHECK-NEXT: %{{.*}} = smt.bv.ashr [[A0]], [[A1]] : !smt.bv<32>
+ %33 = arith.shrsi %arg0, %arg1 : i32
+ // CHECK-NEXT: %{{.*}} = smt.bv.lshr [[A0]], [[A1]] : !smt.bv<32>
+ %34 = arith.shrui %arg0, %arg1 : i32
+
+ // The arith.cmpi folder is called before the conversion patterns and produces
+ // a `arith.constant` operation.
+ // CHECK-NEXT: smt.bv.constant #smt.bv<-1> : !smt.bv<1>
+ %35 = arith.cmpi eq, %arg0, %arg0 : i32
+
+ return
+}
More information about the Mlir-commits
mailing list