[Mlir-commits] [mlir] d8b84be - [MLIR][Transform][SMT] Introduce transform.smt.constrain_params (#159450)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Sep 21 13:32:49 PDT 2025
Author: Rolf Morel
Date: 2025-09-21T20:32:45Z
New Revision: d8b84be1078861dd463bae964c6443fbb613f6c8
URL: https://github.com/llvm/llvm-project/commit/d8b84be1078861dd463bae964c6443fbb613f6c8
DIFF: https://github.com/llvm/llvm-project/commit/d8b84be1078861dd463bae964c6443fbb613f6c8.diff
LOG: [MLIR][Transform][SMT] Introduce transform.smt.constrain_params (#159450)
Introduces a Transform-dialect SMT-extension so that we can have an op
to express constrains on Transform-dialect params, in particular when
these params are knobs -- see transform.tune.knob -- and can hence be
seen as symbolic variables. This op allows expressing joint constraints
over multiple params/knobs together.
While the op's semantics are clearly defined, per SMTLIB, the interpreted
semantics -- i.e. the `apply()` method -- for now just defaults to failure. In
the future we should support attaching an implementation so that users
can Bring Your Own Solver and thereby control performance of
interpreting the op. For now the main usage is to walk schedule IR and
collect these constraints so that knobs can be rewritten to constants that
satisfy the constraints.
Added:
mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt
mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h
mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
mlir/python/mlir/dialects/TransformSMTExtensionOps.td
mlir/python/mlir/dialects/transform/smt.py
mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
mlir/test/Dialect/Transform/test-smt-extension.mlir
mlir/test/python/dialects/transform_smt_ext.py
Modified:
mlir/include/mlir/Dialect/Transform/CMakeLists.txt
mlir/lib/Bindings/Python/DialectSMT.cpp
mlir/lib/Dialect/Transform/CMakeLists.txt
mlir/lib/RegisterAllExtensions.cpp
mlir/python/CMakeLists.txt
mlir/python/mlir/dialects/smt.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index e70479b2a39f2..eb91ceccd4ef2 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -4,5 +4,6 @@ add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
+add_subdirectory(SMTExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..da037c1e809de
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS SMTExtensionOps.td)
+mlir_tablegen(SMTExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(SMTExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectSMTExtensionOpsIncGen)
+
+add_mlir_doc(SMTExtensionOps SMTExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h
new file mode 100644
index 0000000000000..7079873cec048
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h
@@ -0,0 +1,27 @@
+//===- SMTExtension.h - SMT extension for Transform dialect -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
+#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the SMT extension of the Transform dialect in the given registry.
+void registerSMTExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
new file mode 100644
index 0000000000000..fc69b039f24ff
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
@@ -0,0 +1,21 @@
+//===- SMTExtensionOps.h - SMT extension for Transform dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
new file mode 100644
index 0000000000000..b987cb31e54bb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
@@ -0,0 +1,52 @@
+//===- SMTExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ NoTerminator
+]> {
+ let cppNamespace = [{ mlir::transform::smt }];
+
+ let summary = "Express contraints on params interpreted as symbolic values";
+ let description = [{
+ Allows expressing constraints on params using the SMT dialect.
+
+ Each Transform dialect param provided as an operand has a corresponding
+ argument of SMT-type in the region. The SMT-Dialect ops in the region use
+ these arguments as operands.
+
+ The semantics of this op is that all the ops in the region together express
+ a constraint on the params-interpreted-as-smt-vars. The op fails in case the
+ expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
+ op succeeds.
+
+ ---
+
+ TODO: currently the operational semantics per the Transform interpreter is
+ to always fail. The intention is build out support for hooking in your own
+ operational semantics so you can invoke your favourite solver to determine
+ satisfiability of the corresponding constraint problem.
+ }];
+
+ let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat =
+ "`(` $params `)` attr-dict `:` type(operands) $body";
+
+ let hasVerifier = 1;
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp
index 3123e3bdda496..0d1d9e89f92f6 100644
--- a/mlir/lib/Bindings/Python/DialectSMT.cpp
+++ b/mlir/lib/Bindings/Python/DialectSMT.cpp
@@ -26,21 +26,26 @@ using namespace mlir::python::nanobind_adaptors;
static void populateDialectSMTSubmodule(nanobind::module_ &m) {
- auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
- .def_classmethod(
- "get",
- [](const nb::object &, MlirContext context) {
- return mlirSMTTypeGetBool(context);
- },
- "cls"_a, "context"_a = nb::none());
+ auto smtBoolType =
+ mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
+ .def_staticmethod(
+ "get",
+ [](MlirContext context) { return mlirSMTTypeGetBool(context); },
+ "context"_a = nb::none());
auto smtBitVectorType =
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
- .def_classmethod(
+ .def_staticmethod(
"get",
- [](const nb::object &, int32_t width, MlirContext context) {
+ [](int32_t width, MlirContext context) {
return mlirSMTTypeGetBitVector(context, width);
},
- "cls"_a, "width"_a, "context"_a = nb::none());
+ "width"_a, "context"_a = nb::none());
+ auto smtIntType =
+ mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
+ .def_staticmethod(
+ "get",
+ [](MlirContext context) { return mlirSMTTypeGetInt(context); },
+ "context"_a = nb::none());
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 6e628353258d6..123c4b92271fe 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
+add_subdirectory(SMTExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..ba1cc464e506d
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRTransformSMTExtension
+ SMTExtension.cpp
+ SMTExtensionOps.cpp
+
+ DEPENDS
+ MLIRTransformDialectSMTExtensionOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRTransformDialect
+ MLIRSMT
+)
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
new file mode 100644
index 0000000000000..228e8d342a1f6
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
@@ -0,0 +1,35 @@
+//===- SMTExtension.cpp - SMT extension for the Transform dialect ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class SMTExtension : public transform::TransformDialectExtension<SMTExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SMTExtension)
+
+ SMTExtension() {
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+void mlir::transform::registerSMTExtension(DialectRegistry &dialectRegistry) {
+ dialectRegistry.addExtensions<SMTExtension>();
+}
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
new file mode 100644
index 0000000000000..8e7af05353de7
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -0,0 +1,55 @@
+//===- SMTExtensionOps.cpp - SMT extension for the Transform dialect ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// ConstrainParamsOp
+//===----------------------------------------------------------------------===//
+
+void transform::smt::ConstrainParamsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getParamsMutable(), effects);
+}
+
+DiagnosedSilenceableFailure
+transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ // TODO: Proper operational semantics are to check the SMT problem in the body
+ // with a SMT solver with the arguments of the body constrained to the
+ // values passed into the op. Success or failure is then determined by
+ // the solver's result.
+ // One way to support this is to just promise the TransformOpInterface
+ // and allow for users to attach their own implementation, which would,
+ // e.g., translate the ops to SMTLIB and hand that over to the user's
+ // favourite solver. This requires changes to the dialect's verifier.
+ return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+}
+
+LogicalResult transform::smt::ConstrainParamsOp::verify() {
+ if (getOperands().size() != getBody().getNumArguments())
+ return emitOpError(
+ "must have the same number of block arguments as operands");
+
+ for (auto &op : getBody().getOps()) {
+ if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
+ return emitOpError(
+ "ops contained in region should belong to SMT-dialect");
+ }
+
+ return success();
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 69a85dbe141ce..3839172fd0b42 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -53,6 +53,7 @@
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
@@ -108,6 +109,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
transform::registerIRDLExtension(registry);
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
+ transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 97f0778071ef9..d6686bb89ce4e 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
+declare_mlir_dialect_extension_python_bindings(
+ADD_TO_PARENT MLIRPythonSources.Dialects
+ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/TransformSMTExtensionOps.td
+ SOURCES
+ dialects/transform/smt.py
+ DIALECT_NAME transform
+ EXTENSION_NAME transform_smt_extension)
+
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/TransformSMTExtensionOps.td b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td
new file mode 100644
index 0000000000000..3e92417a35d13
--- /dev/null
+++ b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td
@@ -0,0 +1,19 @@
+//===-- TransformSMTExtensionOps.td - Binding entry point --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the generated Python bindings for the SMT extension of the
+// Transform dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
+#define PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
+
+include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td"
+
+#endif // PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py
index ae7a4c41cbc3a..38970d17abd47 100644
--- a/mlir/python/mlir/dialects/smt.py
+++ b/mlir/python/mlir/dialects/smt.py
@@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._smt_ops_gen import *
+from ._smt_enum_gen import *
from .._mlir_libs._mlirDialectsSMT import *
from ..extras.meta import region_op
diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py
new file mode 100644
index 0000000000000..1f0b7f066118c
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/smt.py
@@ -0,0 +1,38 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Sequence
+
+from ...ir import Type, Block
+from .._transform_smt_extension_ops_gen import *
+from .._transform_smt_extension_ops_gen import _Dialect
+from ...dialects import transform
+
+try:
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstrainParamsOp(ConstrainParamsOp):
+ def __init__(
+ self,
+ params: Sequence[transform.AnyParamType],
+ arg_types: Sequence[Type],
+ loc=None,
+ ip=None,
+ ):
+ if len(params) != len(arg_types):
+ raise ValueError(f"{params=} not same length as {arg_types=}")
+ super().__init__(
+ params,
+ loc=loc,
+ ip=ip,
+ )
+ self.regions[0].blocks.append(*arg_types)
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
new file mode 100644
index 0000000000000..314b8d493c5d4
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+// CHECK-LABEL: @constraint_not_using_smt_ops
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ // expected-error at below {{ops contained in region should belong to SMT-dialect}}
+ transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ ^bb0(%param_as_smt_var: !smt.int):
+ %c4 = arith.constant 4 : i32
+ // This is the kind of thing one might think works:
+ //arith.remsi %param_as_smt_var, %c4 : i32
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @operands_not_one_to_one_with_vars
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ // expected-error at below {{must have the same number of block arguments as operands}}
+ transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
+ }
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-smt-extension.mlir b/mlir/test/Dialect/Transform/test-smt-extension.mlir
new file mode 100644
index 0000000000000..29d15175ae4ec
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-smt-extension.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-opt %s --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @schedule_with_constrained_param
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @schedule_with_constrained_param(%arg0: !transform.any_op {transform.readonly}) {
+ // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+
+ // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
+ transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
+ ^bb0(%param_as_smt_var: !smt.int):
+ // CHECK: %[[C0:.*]] = smt.int.constant 0
+ %c0 = smt.int.constant 0
+ // CHECK: %[[C43:.*]] = smt.int.constant 43
+ %c43 = smt.int.constant 43
+ // CHECK: %[[LOWER_BOUND:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
+ %lower_bound = smt.int.cmp le %c0, %param_as_smt_var
+ // CHECK: smt.assert %[[LOWER_BOUND]]
+ smt.assert %lower_bound
+ // CHECK: %[[UPPER_BOUND:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
+ %upper_bound = smt.int.cmp le %param_as_smt_var, %c43
+ // CHECK: smt.assert %[[UPPER_BOUND]]
+ smt.assert %upper_bound
+ }
+ // NB: from here can rely on that 0 <= %param_as_param <= 43, even if its
+ // definition changes.
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
+ // CHECK: %[[PARAM_A:.*]] = transform.param.constant
+ %param_a = transform.param.constant 4 -> !transform.param<i64>
+ // CHECK: %[[PARAM_B:.*]] = transform.param.constant
+ %param_b = transform.param.constant 16 -> !transform.param<i64>
+
+ // CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
+ transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
+ // CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
+ ^bb0(%var_a: !smt.int, %var_b: !smt.int):
+ // CHECK: %[[C0:.*]] = smt.int.constant 0
+ %c0 = smt.int.constant 0
+ // CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
+ %remainder = smt.int.mod %var_b, %var_a
+ // CHECK: %[[EQ:.*]] = smt.eq %[[REMAINDER]], %[[C0]]
+ %eq = smt.eq %remainder, %c0 : !smt.int
+ // CHECK: smt.assert %[[EQ]]
+ smt.assert %eq
+ }
+ // NB: from here can rely on that %param_a is a divisor of %param_b
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @schedule_with_param_as_a_bool
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
+ // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
+ %param_as_param = transform.param.constant true -> !transform.any_param
+
+ // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
+ transform.smt.constrain_params(%param_as_param) : !transform.any_param {
+ // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
+ ^bb0(%param_as_smt_var: !smt.bool):
+ // CHECK: %[[C0:.*]] = smt.int.constant 0
+ %c0 = smt.int.constant 0
+ // CHECK: %[[C1:.*]] = smt.int.constant 1
+ %c1 = smt.int.constant 1
+ // CHECK: %[[FALSEHOOD:.*]] = smt.eq %[[C0]], %[[C1]]
+ %falsehood = smt.eq %c0, %c1 : !smt.int
+ // CHECK: %[[TRUE_IFF_PARAM_IS:.*]] = smt.or %[[PARAM_AS_SMT_VAR]], %[[FALSEHOOD]]
+ %true_iff_param_is = smt.or %param_as_smt_var, %falsehood
+ // CHECK: smt.assert %[[TRUE_IFF_PARAM_IS]]
+ smt.assert %true_iff_param_is
+ }
+ // NB: from here can rely on that %param_as_param holds true, even if its
+ // definition changes.
+ transform.yield
+ }
+}
diff --git a/mlir/test/python/dialects/transform_smt_ext.py b/mlir/test/python/dialects/transform_smt_ext.py
new file mode 100644
index 0000000000000..3692fd92344a6
--- /dev/null
+++ b/mlir/test/python/dialects/transform_smt_ext.py
@@ -0,0 +1,50 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir import ir
+from mlir.dialects import transform, smt
+from mlir.dialects.transform import smt as transform_smt
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with ir.InsertionPoint(sequence.body):
+ f(sequence.bodyTarget)
+ transform.YieldOp()
+ print(module)
+ return f
+
+
+# CHECK-LABEL: TEST: testConstrainParamsOp
+ at run
+def testConstrainParamsOp(target):
+ dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
+ # CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
+ symbolic_value = transform.ParamConstantOp(
+ transform.AnyParamType.get(), dummy_value
+ )
+ # CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
+ constrain_params = transform_smt.ConstrainParamsOp(
+ [symbolic_value], [smt.IntType.get()]
+ )
+ # CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
+ with ir.InsertionPoint(constrain_params.body):
+ # CHECK: %[[C0:.*]] = smt.int.constant 0
+ c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
+ # CHECK: %[[C43:.*]] = smt.int.constant 43
+ c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
+ # CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
+ lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
+ # CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
+ ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43)
+ # CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]]
+ bounded = smt.AndOp([lb, ub])
+ # CHECK: smt.assert %[[BOUNDED:.*]]
+ smt.AssertOp(bounded)
More information about the Mlir-commits
mailing list