[Mlir-commits] [mlir] [MLIR][Transform][SMT] Introduce transform.smt.constrain_params (PR #159450)
Rolf Morel
llvmlistbot at llvm.org
Sun Sep 21 13:18:59 PDT 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/159450
>From 7de3873dea46a458e2d2bc6a168d42981e57cca7 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 17 Sep 2025 13:11:12 -0700
Subject: [PATCH 1/3] [MLIR][Transform][SMT] Introduce
transform.smt.constrain_params
Introduces a SMT Transform-dialect extension so that we can have an op
to express constrains on Transform-dialect params, in particular when
these params are knobs -- see transform.tune.knob -- and can hence can
be seen as symbolic variables. This op allows expressing joint
constraints over multiple params/knobs together.
While the op's semantics are clearly defined, the operational semantics
- i.e. the `apply()` method - for now just defaults to failure. In the
future we should support attaching an implementation so that users can
Bring Your Own Solver and thereby control performance of interpreting
the op. For now the main usage is to walk schedule IR and collect these
constraints so that knobs can be rewritten to constants that satify the
constraints.
---
.../mlir/Dialect/Transform/CMakeLists.txt | 1 +
.../Transform/SMTExtension/CMakeLists.txt | 6 ++
.../Transform/SMTExtension/SMTExtension.h | 27 ++++++
.../Transform/SMTExtension/SMTExtensionOps.h | 22 +++++
.../Transform/SMTExtension/SMTExtensionOps.td | 52 +++++++++++
mlir/lib/Bindings/Python/DialectSMT.cpp | 7 ++
mlir/lib/Dialect/Transform/CMakeLists.txt | 1 +
.../Transform/SMTExtension/CMakeLists.txt | 12 +++
.../Transform/SMTExtension/SMTExtension.cpp | 35 ++++++++
.../SMTExtension/SMTExtensionOps.cpp | 55 ++++++++++++
mlir/lib/RegisterAllExtensions.cpp | 2 +
mlir/python/CMakeLists.txt | 9 ++
.../mlir/dialects/TransformSMTExtensionOps.td | 19 ++++
mlir/python/mlir/dialects/smt.py | 1 +
mlir/python/mlir/dialects/transform/smt.py | 36 ++++++++
.../Transform/test-smt-extension-invalid.mlir | 32 +++++++
.../Dialect/Transform/test-smt-extension.mlir | 87 +++++++++++++++++++
.../test/python/dialects/transform_smt_ext.py | 50 +++++++++++
18 files changed, 454 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h
create mode 100644 mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
create mode 100644 mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
create mode 100644 mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
create mode 100644 mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
create mode 100644 mlir/python/mlir/dialects/TransformSMTExtensionOps.td
create mode 100644 mlir/python/mlir/dialects/transform/smt.py
create mode 100644 mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
create mode 100644 mlir/test/Dialect/Transform/test-smt-extension.mlir
create mode 100644 mlir/test/python/dialects/transform_smt_ext.py
diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index e70479b2a39f2..eb91ceccd4ef2 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -4,5 +4,6 @@ add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
+add_subdirectory(SMTExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..da037c1e809de
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS SMTExtensionOps.td)
+mlir_tablegen(SMTExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(SMTExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectSMTExtensionOpsIncGen)
+
+add_mlir_doc(SMTExtensionOps SMTExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h
new file mode 100644
index 0000000000000..7079873cec048
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h
@@ -0,0 +1,27 @@
+//===- SMTExtension.h - SMT extension for Transform dialect -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
+#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the SMT extension of the Transform dialect in the given registry.
+void registerSMTExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
new file mode 100644
index 0000000000000..dfea2039a16c3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
@@ -0,0 +1,22 @@
+//===- SMTExtensionOps.h - SMT extension for Transform dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h.inc"
+
+
+#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
new file mode 100644
index 0000000000000..b987cb31e54bb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
@@ -0,0 +1,52 @@
+//===- SMTExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ NoTerminator
+]> {
+ let cppNamespace = [{ mlir::transform::smt }];
+
+ let summary = "Express contraints on params interpreted as symbolic values";
+ let description = [{
+ Allows expressing constraints on params using the SMT dialect.
+
+ Each Transform dialect param provided as an operand has a corresponding
+ argument of SMT-type in the region. The SMT-Dialect ops in the region use
+ these arguments as operands.
+
+ The semantics of this op is that all the ops in the region together express
+ a constraint on the params-interpreted-as-smt-vars. The op fails in case the
+ expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
+ op succeeds.
+
+ ---
+
+ TODO: currently the operational semantics per the Transform interpreter is
+ to always fail. The intention is build out support for hooking in your own
+ operational semantics so you can invoke your favourite solver to determine
+ satisfiability of the corresponding constraint problem.
+ }];
+
+ let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat =
+ "`(` $params `)` attr-dict `:` type(operands) $body";
+
+ let hasVerifier = 1;
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 3123e3bdda496..6e28d96ca58a7 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -41,6 +41,13 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
return mlirSMTTypeGetBitVector(context, width);
},
"cls"_a, "width"_a, "context"_a = nb::none());
+ auto smtIntType = mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
+ .def_classmethod(
+ "get",
+ [](const nb::object &, MlirContext context) {
+ return mlirSMTTypeGetInt(context);
+ },
+ "cls"_a, "context"_a.none() = nb::none());
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 6e628353258d6..123c4b92271fe 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
+add_subdirectory(SMTExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..ba1cc464e506d
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRTransformSMTExtension
+ SMTExtension.cpp
+ SMTExtensionOps.cpp
+
+ DEPENDS
+ MLIRTransformDialectSMTExtensionOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRTransformDialect
+ MLIRSMT
+)
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
new file mode 100644
index 0000000000000..228e8d342a1f6
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
@@ -0,0 +1,35 @@
+//===- SMTExtension.cpp - SMT extension for the Transform dialect ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class SMTExtension : public transform::TransformDialectExtension<SMTExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SMTExtension)
+
+ SMTExtension() {
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+void mlir::transform::registerSMTExtension(DialectRegistry &dialectRegistry) {
+ dialectRegistry.addExtensions<SMTExtension>();
+}
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
new file mode 100644
index 0000000000000..8e7d0b18b7311
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -0,0 +1,55 @@
+//===- SMTExtensionOps.cpp - SMT extension for the Transform dialect ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ConstrainParamsOp
+//===----------------------------------------------------------------------===//
+
+void transform::smt::ConstrainParamsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getParamsMutable(), effects);
+}
+
+DiagnosedSilenceableFailure
+transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ // TODO: Proper operational semantics are to chuck the SMT problem in the body
+ // to a SMT solver with the arguments of the body constrained to the
+ // values passed into the op. Success or failure is then determined by
+ // the solver's result.
+ // One way to support this is to just promise the TransformOpInterface
+ // and allow for users to attach their own implementation, which would,
+ // e.g., translate the ops to SMTLIB and hand that over to the user's
+ // favourite solver. This requires changes to the dialect's verifier.
+ return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+}
+
+LogicalResult transform::smt::ConstrainParamsOp::verify() {
+ if (getOperands().size() != getBody().getNumArguments())
+ return emitOpError(
+ "must have the same number of block arguments as operands");
+
+ for (auto &op : getBody().getOps()) {
+ if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
+ return emitOpError(
+ "ops contained in region should belong to SMT-dialect");
+ }
+
+ return success();
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 69a85dbe141ce..3839172fd0b42 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -53,6 +53,7 @@
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
@@ -108,6 +109,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
transform::registerIRDLExtension(registry);
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
+ transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 97f0778071ef9..d6686bb89ce4e 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
+declare_mlir_dialect_extension_python_bindings(
+ADD_TO_PARENT MLIRPythonSources.Dialects
+ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/TransformSMTExtensionOps.td
+ SOURCES
+ dialects/transform/smt.py
+ DIALECT_NAME transform
+ EXTENSION_NAME transform_smt_extension)
+
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/TransformSMTExtensionOps.td b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td
new file mode 100644
index 0000000000000..3e92417a35d13
--- /dev/null
+++ b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td
@@ -0,0 +1,19 @@
+//===-- TransformSMTExtensionOps.td - Binding entry point --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the generated Python bindings for the SMT extension of the
+// Transform dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
+#define PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
+
+include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td"
+
+#endif // PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py
index ae7a4c41cbc3a..38970d17abd47 100644
--- a/mlir/python/mlir/dialects/smt.py
+++ b/mlir/python/mlir/dialects/smt.py
@@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._smt_ops_gen import *
+from ._smt_enum_gen import *
from .._mlir_libs._mlirDialectsSMT import *
from ..extras.meta import region_op
diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py
new file mode 100644
index 0000000000000..7cb06e8bfed54
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/smt.py
@@ -0,0 +1,36 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Sequence
+
+from ...ir import Type, Block
+from .._transform_smt_extension_ops_gen import *
+from .._transform_smt_extension_ops_gen import _Dialect
+from ...dialects import transform
+
+try:
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstrainParamsOp(ConstrainParamsOp):
+ def __init__(
+ self,
+ params: Sequence[transform.AnyParamType],
+ arg_types: Sequence[Type],
+ loc=None,
+ ip=None,
+ ):
+ assert len(params) == len(arg_types)
+ super().__init__(
+ params,
+ loc=loc,
+ ip=ip,
+ )
+ self.regions[0].blocks.append(*arg_types)
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
new file mode 100644
index 0000000000000..3961d7c5ba72b
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+// CHECK-LABEL: @constraint_not_using_smt_ops
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ // expected-error at below {{ops contained in region should belong to SMT-dialect}}
+ transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ ^bb0(%param_as_smt_var: !smt.int):
+ %c4 = arith.constant 4 : i32
+ // This is the kind of thing one might think works:
+ //arith.remsi %param_as_smt_var, %c4 : i32
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @operands_not_one_to_one_with_vars
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ // expected-error at below {{must have the same number of block arguments as operands}}
+ transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
+ // This is the kind of thing one might think works:
+ //arith.remsi %param_as_smt_var, %c4 : i32
+ }
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-smt-extension.mlir b/mlir/test/Dialect/Transform/test-smt-extension.mlir
new file mode 100644
index 0000000000000..29d15175ae4ec
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-smt-extension.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-opt %s --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @schedule_with_constrained_param
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @schedule_with_constrained_param(%arg0: !transform.any_op {transform.readonly}) {
+ // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+
+ // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
+ transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
+ ^bb0(%param_as_smt_var: !smt.int):
+ // CHECK: %[[C0:.*]] = smt.int.constant 0
+ %c0 = smt.int.constant 0
+ // CHECK: %[[C43:.*]] = smt.int.constant 43
+ %c43 = smt.int.constant 43
+ // CHECK: %[[LOWER_BOUND:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
+ %lower_bound = smt.int.cmp le %c0, %param_as_smt_var
+ // CHECK: smt.assert %[[LOWER_BOUND]]
+ smt.assert %lower_bound
+ // CHECK: %[[UPPER_BOUND:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
+ %upper_bound = smt.int.cmp le %param_as_smt_var, %c43
+ // CHECK: smt.assert %[[UPPER_BOUND]]
+ smt.assert %upper_bound
+ }
+ // NB: from here can rely on that 0 <= %param_as_param <= 43, even if its
+ // definition changes.
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
+ // CHECK: %[[PARAM_A:.*]] = transform.param.constant
+ %param_a = transform.param.constant 4 -> !transform.param<i64>
+ // CHECK: %[[PARAM_B:.*]] = transform.param.constant
+ %param_b = transform.param.constant 16 -> !transform.param<i64>
+
+ // CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
+ transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
+ // CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
+ ^bb0(%var_a: !smt.int, %var_b: !smt.int):
+ // CHECK: %[[C0:.*]] = smt.int.constant 0
+ %c0 = smt.int.constant 0
+ // CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
+ %remainder = smt.int.mod %var_b, %var_a
+ // CHECK: %[[EQ:.*]] = smt.eq %[[REMAINDER]], %[[C0]]
+ %eq = smt.eq %remainder, %c0 : !smt.int
+ // CHECK: smt.assert %[[EQ]]
+ smt.assert %eq
+ }
+ // NB: from here can rely on that %param_a is a divisor of %param_b
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @schedule_with_param_as_a_bool
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
+ // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
+ %param_as_param = transform.param.constant true -> !transform.any_param
+
+ // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
+ transform.smt.constrain_params(%param_as_param) : !transform.any_param {
+ // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
+ ^bb0(%param_as_smt_var: !smt.bool):
+ // CHECK: %[[C0:.*]] = smt.int.constant 0
+ %c0 = smt.int.constant 0
+ // CHECK: %[[C1:.*]] = smt.int.constant 1
+ %c1 = smt.int.constant 1
+ // CHECK: %[[FALSEHOOD:.*]] = smt.eq %[[C0]], %[[C1]]
+ %falsehood = smt.eq %c0, %c1 : !smt.int
+ // CHECK: %[[TRUE_IFF_PARAM_IS:.*]] = smt.or %[[PARAM_AS_SMT_VAR]], %[[FALSEHOOD]]
+ %true_iff_param_is = smt.or %param_as_smt_var, %falsehood
+ // CHECK: smt.assert %[[TRUE_IFF_PARAM_IS]]
+ smt.assert %true_iff_param_is
+ }
+ // NB: from here can rely on that %param_as_param holds true, even if its
+ // definition changes.
+ transform.yield
+ }
+}
diff --git a/mlir/test/python/dialects/transform_smt_ext.py b/mlir/test/python/dialects/transform_smt_ext.py
new file mode 100644
index 0000000000000..8e354e0512fad
--- /dev/null
+++ b/mlir/test/python/dialects/transform_smt_ext.py
@@ -0,0 +1,50 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform, smt
+from mlir.dialects.transform import smt as transform_smt
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ f(sequence.bodyTarget)
+ transform.YieldOp()
+ print(module)
+ return f
+
+
+# CHECK-LABEL: TEST: testConstrainParamsOp
+ at run
+def testConstrainParamsOp(target):
+ dummy_value = IntegerAttr.get(IntegerType.get_signless(32), 42)
+ # CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
+ symbolic_value = transform.ParamConstantOp(
+ transform.AnyParamType.get(), dummy_value
+ )
+ # CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
+ constrain_params = transform_smt.ConstrainParamsOp(
+ [symbolic_value], [smt.IntType.get()]
+ )
+ # CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
+ with InsertionPoint(constrain_params.body):
+ # CHECK: %[[C0:.*]] = smt.int.constant 0
+ c0 = smt.IntConstantOp(IntegerAttr.get(IntegerType.get_signless(32), 0))
+ # CHECK: %[[C43:.*]] = smt.int.constant 43
+ c43 = smt.IntConstantOp(IntegerAttr.get(IntegerType.get_signless(32), 43))
+ # CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
+ lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
+ # CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
+ ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43)
+ # CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]]
+ bounded = smt.AndOp([lb, ub])
+ # CHECK: smt.assert %[[BOUNDED:.*]]
+ smt.AssertOp(bounded)
>From 838a41a9958cea742896cfd9be4c9f5b1bae24f1 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 17 Sep 2025 14:02:39 -0700
Subject: [PATCH 2/3] Fix formatting
---
.../Dialect/Transform/SMTExtension/SMTExtensionOps.h | 1 -
mlir/lib/Bindings/Python/DialectSMT.cpp | 12 ++++++------
mlir/python/mlir/dialects/transform/smt.py | 1 +
3 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
index dfea2039a16c3..fc69b039f24ff 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
@@ -18,5 +18,4 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h.inc"
-
#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 6e28d96ca58a7..593e7505ce6e8 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -42,12 +42,12 @@ static void populateDialectSMTSubmodule(nanobind::module_ &m) {
},
"cls"_a, "width"_a, "context"_a = nb::none());
auto smtIntType = mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
- .def_classmethod(
- "get",
- [](const nb::object &, MlirContext context) {
- return mlirSMTTypeGetInt(context);
- },
- "cls"_a, "context"_a.none() = nb::none());
+ .def_classmethod(
+ "get",
+ [](const nb::object &, MlirContext context) {
+ return mlirSMTTypeGetInt(context);
+ },
+ "cls"_a, "context"_a.none() = nb::none());
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py
index 7cb06e8bfed54..c1fdf3ea0ca6c 100644
--- a/mlir/python/mlir/dialects/transform/smt.py
+++ b/mlir/python/mlir/dialects/transform/smt.py
@@ -14,6 +14,7 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
+
@_ods_cext.register_operation(_Dialect, replace=True)
class ConstrainParamsOp(ConstrainParamsOp):
def __init__(
>From 93b6e0aa5d85c7eef0e2e650ca0f83e9ca900a48 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 18 Sep 2025 12:54:55 -0700
Subject: [PATCH 3/3] Address reviewer comments
---
mlir/lib/Bindings/Python/DialectSMT.cpp | 32 +++++++++----------
.../SMTExtension/SMTExtensionOps.cpp | 4 +--
mlir/python/mlir/dialects/transform/smt.py | 3 +-
.../Transform/test-smt-extension-invalid.mlir | 2 --
.../test/python/dialects/transform_smt_ext.py | 18 +++++------
5 files changed, 28 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 593e7505ce6e8..0d1d9e89f92f6 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -26,28 +26,26 @@ using namespace mlir::python::nanobind_adaptors;
static void populateDialectSMTSubmodule(nanobind::module_ &m) {
- auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
- .def_classmethod(
- "get",
- [](const nb::object &, MlirContext context) {
- return mlirSMTTypeGetBool(context);
- },
- "cls"_a, "context"_a = nb::none());
+ auto smtBoolType =
+ mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
+ .def_staticmethod(
+ "get",
+ [](MlirContext context) { return mlirSMTTypeGetBool(context); },
+ "context"_a = nb::none());
auto smtBitVectorType =
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
- .def_classmethod(
+ .def_staticmethod(
"get",
- [](const nb::object &, int32_t width, MlirContext context) {
+ [](int32_t width, MlirContext context) {
return mlirSMTTypeGetBitVector(context, width);
},
- "cls"_a, "width"_a, "context"_a = nb::none());
- auto smtIntType = mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
- .def_classmethod(
- "get",
- [](const nb::object &, MlirContext context) {
- return mlirSMTTypeGetInt(context);
- },
- "cls"_a, "context"_a.none() = nb::none());
+ "width"_a, "context"_a = nb::none());
+ auto smtIntType =
+ mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
+ .def_staticmethod(
+ "get",
+ [](MlirContext context) { return mlirSMTTypeGetInt(context); },
+ "context"_a = nb::none());
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
index 8e7d0b18b7311..8e7af05353de7 100644
--- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -29,8 +29,8 @@ DiagnosedSilenceableFailure
transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- // TODO: Proper operational semantics are to chuck the SMT problem in the body
- // to a SMT solver with the arguments of the body constrained to the
+ // TODO: Proper operational semantics are to check the SMT problem in the body
+ // with a SMT solver with the arguments of the body constrained to the
// values passed into the op. Success or failure is then determined by
// the solver's result.
// One way to support this is to just promise the TransformOpInterface
diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py
index c1fdf3ea0ca6c..1f0b7f066118c 100644
--- a/mlir/python/mlir/dialects/transform/smt.py
+++ b/mlir/python/mlir/dialects/transform/smt.py
@@ -24,7 +24,8 @@ def __init__(
loc=None,
ip=None,
):
- assert len(params) == len(arg_types)
+ if len(params) != len(arg_types):
+ raise ValueError(f"{params=} not same length as {arg_types=}")
super().__init__(
params,
loc=loc,
diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
index 3961d7c5ba72b..314b8d493c5d4 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
@@ -24,8 +24,6 @@ module attributes {transform.with_named_sequence} {
// expected-error at below {{must have the same number of block arguments as operands}}
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
- // This is the kind of thing one might think works:
- //arith.remsi %param_as_smt_var, %c4 : i32
}
transform.yield
}
diff --git a/mlir/test/python/dialects/transform_smt_ext.py b/mlir/test/python/dialects/transform_smt_ext.py
index 8e354e0512fad..3692fd92344a6 100644
--- a/mlir/test/python/dialects/transform_smt_ext.py
+++ b/mlir/test/python/dialects/transform_smt_ext.py
@@ -1,21 +1,21 @@
# RUN: %PYTHON %s | FileCheck %s
-from mlir.ir import *
+from mlir import ir
from mlir.dialects import transform, smt
from mlir.dialects.transform import smt as transform_smt
def run(f):
print("\nTEST:", f.__name__)
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
- with InsertionPoint(sequence.body):
+ with ir.InsertionPoint(sequence.body):
f(sequence.bodyTarget)
transform.YieldOp()
print(module)
@@ -25,7 +25,7 @@ def run(f):
# CHECK-LABEL: TEST: testConstrainParamsOp
@run
def testConstrainParamsOp(target):
- dummy_value = IntegerAttr.get(IntegerType.get_signless(32), 42)
+ dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
# CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
symbolic_value = transform.ParamConstantOp(
transform.AnyParamType.get(), dummy_value
@@ -35,11 +35,11 @@ def testConstrainParamsOp(target):
[symbolic_value], [smt.IntType.get()]
)
# CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
- with InsertionPoint(constrain_params.body):
+ with ir.InsertionPoint(constrain_params.body):
# CHECK: %[[C0:.*]] = smt.int.constant 0
- c0 = smt.IntConstantOp(IntegerAttr.get(IntegerType.get_signless(32), 0))
+ c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
# CHECK: %[[C43:.*]] = smt.int.constant 43
- c43 = smt.IntConstantOp(IntegerAttr.get(IntegerType.get_signless(32), 43))
+ c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
# CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
# CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
More information about the Mlir-commits
mailing list