[Mlir-commits] [mlir] [MLIR][Transform][SMT] Allow for declarative computations in schedules (PR #160895)
Rolf Morel
llvmlistbot at llvm.org
Thu Oct 9 08:33:53 PDT 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/160895
>From 65638acadfc4a87296a036fd4ffa8716762bb781 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 25 Sep 2025 10:23:56 -0700
Subject: [PATCH 1/2] [MLIR][Transform][SMT] Allow for declarative computations
in schedules
By allowing `transform.smt.constrain_params`'s region to yield SMT vars,
we op instances declare relationships, through constraints, on incoming
params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it
possible to declare that computations on params should be performed.
The semantics are that the yielded SMT-vars should be from any valid
satisfying assignment/model of the constraints in the region.
---
mlir/include/mlir/Dialect/SMT/IR/SMTOps.td | 2 -
.../Transform/SMTExtension/SMTExtensionOps.h | 1 +
.../Transform/SMTExtension/SMTExtensionOps.td | 17 +++-
.../SMTExtension/SMTExtensionOps.cpp | 99 ++++++++++++++++++-
mlir/python/mlir/dialects/transform/smt.py | 12 +++
.../Transform/test-smt-extension-invalid.mlir | 49 ++++++++-
.../Dialect/Transform/test-smt-extension.mlir | 21 ++--
.../test/python/dialects/transform_smt_ext.py | 30 ++++--
8 files changed, 205 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
index 3143ab7de1b14..99b22e5609c74 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
Pure,
Terminator,
ReturnLike,
- ParentOneOf<["smt::SolverOp", "smt::CheckOp",
- "smt::ForallOp", "smt::ExistsOp"]>,
]> {
let summary = "terminator operation for various regions of SMT operations";
let arguments = (ins Variadic<AnyType>:$values);
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
index fc69b039f24ff..f6353a995d747 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
index b987cb31e54bb..9d9783aa66ed9 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
@@ -16,7 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- NoTerminator
+ SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
]> {
let cppNamespace = [{ mlir::transform::smt }];
@@ -24,14 +24,20 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
let description = [{
Allows expressing constraints on params using the SMT dialect.
- Each Transform dialect param provided as an operand has a corresponding
+ Each Transform-dialect param provided as an operand has a corresponding
argument of SMT-type in the region. The SMT-Dialect ops in the region use
- these arguments as operands.
+ these params-as-SMT-vars as operands, thereby expressing relevant
+ constraints on their allowed values.
+
+ Computations w.r.t. passed-in params can also be expressed through the
+ region's SMT-ops. Namely, the constraints express relationships to other
+ SMT-variables which can then be yielded from the region (with `smt.yield`).
The semantics of this op is that all the ops in the region together express
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
- op succeeds.
+ op succeeds and any one satisfying assignment is used to map the
+ SMT-variables yielded in the region to `transform.param`s.
---
@@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
}];
let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
+ let results = (outs Variadic<TransformParamTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
- "`(` $params `)` attr-dict `:` type(operands) $body";
+ "`(` $params `)` attr-dict `:` functional-type(operands, results) $body";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
index 8e7af05353de7..d85268da2ad5d 100644
--- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -8,8 +8,8 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformOps.h"
-#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
using namespace mlir;
@@ -23,6 +23,7 @@ using namespace mlir;
void transform::smt::ConstrainParamsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getParamsMutable(), effects);
+ producesHandle(getResults(), effects);
}
DiagnosedSilenceableFailure
@@ -37,19 +38,111 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
// and allow for users to attach their own implementation, which would,
// e.g., translate the ops to SMTLIB and hand that over to the user's
// favourite solver. This requires changes to the dialect's verifier.
- return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+ return emitSilenceableFailure(getLoc())
+ << "op does not have interpreted semantics yet";
}
LogicalResult transform::smt::ConstrainParamsOp::verify() {
+ auto yieldTerminator =
+ llvm::dyn_cast_if_present<mlir::smt::YieldOp>(getRegion().front().back());
+ if (!yieldTerminator)
+ return emitOpError() << "expected '"
+ << mlir::smt::YieldOp::getOperationName()
+ << "' as terminator";
+
if (getOperands().size() != getBody().getNumArguments())
return emitOpError(
"must have the same number of block arguments as operands");
+ for (auto [i, operandType, blockArgType] :
+ llvm::zip_equal(llvm::seq<unsigned>(0, getBody().getNumArguments()),
+ getOperandTypes(), getBody().getArgumentTypes())) {
+ if (isa<transform::AnyParamType>(operandType))
+ continue; // No type checking as operand is of !transform.any_param type.
+ auto paramOperandType = dyn_cast<transform::ParamType>(operandType);
+ if (!paramOperandType)
+ return emitOpError() << "operand type #" << i
+ << " is not a !transform.param";
+ Type wrappedOperandType = paramOperandType.getType();
+
+ if (isa<mlir::smt::IntType>(blockArgType)) {
+ if (!isa<IntegerType>(paramOperandType.getType()))
+ return emitOpError()
+ << "the type of block arg #" << i
+ << " is !smt.int though the corresponding operand type ("
+ << operandType << ") is not wrapping an integer type";
+ } else if (isa<mlir::smt::BoolType>(blockArgType)) {
+ auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+ if (!intOperandType || intOperandType.getWidth() != 1)
+ return emitOpError()
+ << "the type of block arg #" << i
+ << " is !smt.bool though the corresponding operand type ("
+ << operandType << ") is not wrapping i1 (i.e. bool)";
+ } else if (auto bvBlockArgType =
+ dyn_cast<mlir::smt::BitVectorType>(blockArgType)) {
+ auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+ if (!intOperandType ||
+ intOperandType.getWidth() != bvBlockArgType.getWidth())
+ return emitOpError()
+ << "the type of block arg #" << i << " is " << blockArgType
+ << " though the corresponding operand type (" << operandType
+ << ") is not wrapping an integer type of the same bitwidth";
+ }
+ }
+
for (auto &op : getBody().getOps()) {
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
return emitOpError(
"ops contained in region should belong to SMT-dialect");
}
+ if (getOperands().size() != getBody().getNumArguments())
+ return emitOpError(
+ "must have the same number of block arguments as operands");
+
+ if (yieldTerminator->getNumOperands() != getNumResults())
+ return yieldTerminator.emitOpError()
+ << "expected terminator to have as many operands as the parent op "
+ "has results";
+
+ for (auto [i, termOperandType, resultType] : llvm::zip_equal(
+ llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+ yieldTerminator->getOperands().getType(), getResultTypes())) {
+ if (isa<transform::AnyParamType>(resultType))
+ continue; // No type checking as result is of !transform.any_param type.
+ auto paramResultType = dyn_cast<transform::ParamType>(resultType);
+ if (!paramResultType)
+ return emitOpError() << "result type #" << i
+ << " is not a !transform.param";
+ Type wrappedResultType = paramResultType.getType();
+
+ if (isa<mlir::smt::IntType>(termOperandType)) {
+ if (!isa<IntegerType>(wrappedResultType))
+ return yieldTerminator.emitOpError()
+ << "the type of terminator operand #" << i
+ << " is !smt.int though the corresponding result type ("
+ << resultType
+ << ") of the parent op is not wrapping an integer type";
+ } else if (isa<mlir::smt::BoolType>(termOperandType)) {
+ auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
+ if (!intResultType || intResultType.getWidth() != 1)
+ return yieldTerminator.emitOpError()
+ << "the type of terminator operand #" << i
+ << " is !smt.bool though the corresponding result type ("
+ << resultType
+ << ") of the parent op is not wrapping i1 (i.e. bool)";
+ } else if (auto bvOperandType =
+ dyn_cast<mlir::smt::BitVectorType>(termOperandType)) {
+ auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
+ if (!intResultType ||
+ intResultType.getWidth() != bvOperandType.getWidth())
+ return yieldTerminator.emitOpError()
+ << "the type of terminator operand #" << i << " is "
+ << termOperandType << " though the corresponding result type ("
+ << resultType
+ << ") is not wrapping an integer type of the same bitwidth";
+ }
+ }
+
return success();
}
diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py
index 1f0b7f066118c..af88fffcd3bba 100644
--- a/mlir/python/mlir/dialects/transform/smt.py
+++ b/mlir/python/mlir/dialects/transform/smt.py
@@ -19,6 +19,7 @@
class ConstrainParamsOp(ConstrainParamsOp):
def __init__(
self,
+ results: Sequence[Type],
params: Sequence[transform.AnyParamType],
arg_types: Sequence[Type],
loc=None,
@@ -27,6 +28,7 @@ def __init__(
if len(params) != len(arg_types):
raise ValueError(f"{params=} not same length as {arg_types=}")
super().__init__(
+ results,
params,
loc=loc,
ip=ip,
@@ -36,3 +38,13 @@ def __init__(
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
+
+
+def constrain_params(
+ results: Sequence[Type],
+ params: Sequence[transform.AnyParamType],
+ arg_types: Sequence[Type],
+ loc=None,
+ ip=None,
+):
+ return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip)
diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
index 314b8d493c5d4..4e365fa2dbaf9 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
@@ -5,7 +5,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error at below {{ops contained in region should belong to SMT-dialect}}
- transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int):
%c4 = arith.constant 4 : i32
// This is the kind of thing one might think works:
@@ -22,9 +22,54 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error at below {{must have the same number of block arguments as operands}}
- transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
}
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @results_not_one_to_one_with_vars
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param<i64>, !transform.param<i64>) -> () {
+ ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
+ // expected-error at below {{expected terminator to have as many operands as the parent op has results}}
+ smt.yield %param_as_smt_var : !smt.int
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mismatched_type_bool
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @mismatched_type_bool(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ // expected-error at below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1 (i.e. bool)}}
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
+ ^bb0(%param_as_smt_var: !smt.bool):
+ smt.yield %param_as_smt_var : !smt.bool
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mismatched_type_bitvector
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @mismatched_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ // expected-error at below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
+ ^bb0(%param_as_smt_var: !smt.bv<8>):
+ smt.yield %param_as_smt_var : !smt.bv<8>
+ }
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-smt-extension.mlir b/mlir/test/Dialect/Transform/test-smt-extension.mlir
index 29d15175ae4ec..6cc41dd52473e 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension.mlir
@@ -7,7 +7,7 @@ module attributes {transform.with_named_sequence} {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
- transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
^bb0(%param_as_smt_var: !smt.int):
// CHECK: %[[C0:.*]] = smt.int.constant 0
@@ -31,18 +31,20 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
+// CHECK-LABEL: @schedule_with_constraint_on_multiple_params_returning_computed_value
module attributes {transform.with_named_sequence} {
- transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
+ transform.named_sequence @schedule_with_constraint_on_multiple_params_returning_computed_value(%arg0: !transform.any_op {transform.readonly}) {
// CHECK: %[[PARAM_A:.*]] = transform.param.constant
%param_a = transform.param.constant 4 -> !transform.param<i64>
// CHECK: %[[PARAM_B:.*]] = transform.param.constant
- %param_b = transform.param.constant 16 -> !transform.param<i64>
+ %param_b = transform.param.constant 32 -> !transform.param<i64>
// CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
- transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
+ %divisor = transform.smt.constrain_params(%param_a, %param_b) : (!transform.param<i64>, !transform.param<i64>) -> (!transform.param<i64>) {
// CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
^bb0(%var_a: !smt.int, %var_b: !smt.int):
+ // CHECK: %[[DIV:.*]] = smt.int.div %[[VAR_B]], %[[VAR_A]]
+ %divisor = smt.int.div %var_b, %var_a
// CHECK: %[[C0:.*]] = smt.int.constant 0
%c0 = smt.int.constant 0
// CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
@@ -51,8 +53,11 @@ module attributes {transform.with_named_sequence} {
%eq = smt.eq %remainder, %c0 : !smt.int
// CHECK: smt.assert %[[EQ]]
smt.assert %eq
+ // CHECK: smt.yield %[[DIV]]
+ smt.yield %divisor : !smt.int
}
- // NB: from here can rely on that %param_a is a divisor of %param_b
+ // NB: from here can rely on that %param_a is a divisor of %param_b and
+ // that the relevant factor, 8, got associated to %divisor.
transform.yield
}
}
@@ -63,10 +68,10 @@ module attributes {transform.with_named_sequence} {
module attributes {transform.with_named_sequence} {
transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
// CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
- %param_as_param = transform.param.constant true -> !transform.any_param
+ %param_as_param = transform.param.constant true -> !transform.param<i1>
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
- transform.smt.constrain_params(%param_as_param) : !transform.any_param {
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> () {
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
^bb0(%param_as_smt_var: !smt.bool):
// CHECK: %[[C0:.*]] = smt.int.constant 0
diff --git a/mlir/test/python/dialects/transform_smt_ext.py b/mlir/test/python/dialects/transform_smt_ext.py
index 3692fd92344a6..e28c56f277439 100644
--- a/mlir/test/python/dialects/transform_smt_ext.py
+++ b/mlir/test/python/dialects/transform_smt_ext.py
@@ -25,26 +25,44 @@ def run(f):
# CHECK-LABEL: TEST: testConstrainParamsOp
@run
def testConstrainParamsOp(target):
- dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
+ c42_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
# CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
- symbolic_value = transform.ParamConstantOp(
- transform.AnyParamType.get(), dummy_value
+ symbolic_value_as_param = transform.ParamConstantOp(
+ transform.AnyParamType.get(), c42_attr
)
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
constrain_params = transform_smt.ConstrainParamsOp(
- [symbolic_value], [smt.IntType.get()]
+ [], [symbolic_value_as_param], [smt.IntType.get()]
)
# CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
with ir.InsertionPoint(constrain_params.body):
+ symbolic_value_as_smt_var = constrain_params.body.arguments[0]
# CHECK: %[[C0:.*]] = smt.int.constant 0
c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
# CHECK: %[[C43:.*]] = smt.int.constant 43
c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
# CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
- lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
+ lb = smt.IntCmpOp(smt.IntPredicate.le, c0, symbolic_value_as_smt_var)
# CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
- ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43)
+ ub = smt.IntCmpOp(smt.IntPredicate.le, symbolic_value_as_smt_var, c43)
# CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]]
bounded = smt.AndOp([lb, ub])
# CHECK: smt.assert %[[BOUNDED:.*]]
smt.AssertOp(bounded)
+ smt.YieldOp([])
+
+ # CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
+ compute_with_params = transform_smt.ConstrainParamsOp(
+ [transform.ParamType.get(ir.IntegerType.get_signless(32))],
+ [symbolic_value_as_param],
+ [smt.IntType.get()],
+ )
+ # CHECK-NEXT: ^bb{{.*}}(%[[SMT_SYMB:.*]]: !smt.int):
+ with ir.InsertionPoint(compute_with_params.body):
+ symbolic_value_as_smt_var = compute_with_params.body.arguments[0]
+ # CHECK: %[[TWICE:.*]] = smt.int.add %[[SMT_SYMB]], %[[SMT_SYMB]]
+ twice_symb = smt.IntAddOp(
+ [symbolic_value_as_smt_var, symbolic_value_as_smt_var]
+ )
+ # CHECK: smt.yield %[[TWICE]]
+ smt.YieldOp([twice_symb])
>From 778091dff053b4f7046b7d675df6fbf93a618dab Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 9 Oct 2025 08:02:08 -0700
Subject: [PATCH 2/2] Address Alex's comments
---
.../SMTExtension/SMTExtensionOps.cpp | 134 ++++++++----------
.../Transform/test-smt-extension-invalid.mlir | 83 +++++++++--
2 files changed, 131 insertions(+), 86 deletions(-)
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
index d85268da2ad5d..abc131639fb3a 100644
--- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -44,50 +44,67 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
LogicalResult transform::smt::ConstrainParamsOp::verify() {
auto yieldTerminator =
- llvm::dyn_cast_if_present<mlir::smt::YieldOp>(getRegion().front().back());
+ dyn_cast<mlir::smt::YieldOp>(getRegion().front().back());
if (!yieldTerminator)
return emitOpError() << "expected '"
<< mlir::smt::YieldOp::getOperationName()
<< "' as terminator";
+ auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
+ Type paramType, StringRef paramDesc,
+ auto *atOp) -> InFlightDiagnostic {
+ if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
+ smtType))
+ return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx
+ << " is expected to be either a !smt.bool, a "
+ "!smt.int, or a !smt.bv";
+
+ assert(isa<TransformParamTypeInterface>(paramType) &&
+ "ODS specifies params' type should implement param interface");
+ if (isa<transform::AnyParamType>(paramType))
+ return {}; // No further checks can be done.
+
+ // NB: This cast must succeed as long as the only implementors of
+ // TransformParamTypeInterface are AnyParamType and ParamType.
+ Type typeWrappedByParam = cast<ParamType>(paramType).getType();
+
+ if (isa<mlir::smt::IntType>(smtType)) {
+ if (!isa<IntegerType>(typeWrappedByParam))
+ return atOp->emitOpError()
+ << "the type of " << smtDesc << " #" << idx
+ << " is !smt.int though the corresponding " << paramDesc
+ << " type (" << paramType << ") is not wrapping an integer type";
+ } else if (isa<mlir::smt::BoolType>(smtType)) {
+ auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
+ if (!wrappedIntType || wrappedIntType.getWidth() != 1)
+ return atOp->emitOpError()
+ << "the type of " << smtDesc << " #" << idx
+ << " is !smt.bool though the corresponding " << paramDesc
+ << " type (" << paramType << ") is not wrapping i1";
+ } else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
+ auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
+ if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth())
+ return atOp->emitOpError()
+ << "the type of " << smtDesc << " #" << idx << " is " << smtType
+ << " though the corresponding " << paramDesc << " type ("
+ << paramType
+ << ") is not wrapping an integer type of the same bitwidth";
+ }
+
+ return {};
+ };
+
if (getOperands().size() != getBody().getNumArguments())
return emitOpError(
"must have the same number of block arguments as operands");
- for (auto [i, operandType, blockArgType] :
- llvm::zip_equal(llvm::seq<unsigned>(0, getBody().getNumArguments()),
- getOperandTypes(), getBody().getArgumentTypes())) {
- if (isa<transform::AnyParamType>(operandType))
- continue; // No type checking as operand is of !transform.any_param type.
- auto paramOperandType = dyn_cast<transform::ParamType>(operandType);
- if (!paramOperandType)
- return emitOpError() << "operand type #" << i
- << " is not a !transform.param";
- Type wrappedOperandType = paramOperandType.getType();
-
- if (isa<mlir::smt::IntType>(blockArgType)) {
- if (!isa<IntegerType>(paramOperandType.getType()))
- return emitOpError()
- << "the type of block arg #" << i
- << " is !smt.int though the corresponding operand type ("
- << operandType << ") is not wrapping an integer type";
- } else if (isa<mlir::smt::BoolType>(blockArgType)) {
- auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
- if (!intOperandType || intOperandType.getWidth() != 1)
- return emitOpError()
- << "the type of block arg #" << i
- << " is !smt.bool though the corresponding operand type ("
- << operandType << ") is not wrapping i1 (i.e. bool)";
- } else if (auto bvBlockArgType =
- dyn_cast<mlir::smt::BitVectorType>(blockArgType)) {
- auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
- if (!intOperandType ||
- intOperandType.getWidth() != bvBlockArgType.getWidth())
- return emitOpError()
- << "the type of block arg #" << i << " is " << blockArgType
- << " though the corresponding operand type (" << operandType
- << ") is not wrapping an integer type of the same bitwidth";
- }
+ for (auto [idx, operandType, blockArgType] :
+ llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) {
+ InFlightDiagnostic typeCheckResult =
+ checkTypes(idx, blockArgType, "block arg", operandType, "operand",
+ /*atOp=*/this);
+ if (LogicalResult(typeCheckResult).failed())
+ return typeCheckResult;
}
for (auto &op : getBody().getOps()) {
@@ -96,52 +113,19 @@ LogicalResult transform::smt::ConstrainParamsOp::verify() {
"ops contained in region should belong to SMT-dialect");
}
- if (getOperands().size() != getBody().getNumArguments())
- return emitOpError(
- "must have the same number of block arguments as operands");
-
if (yieldTerminator->getNumOperands() != getNumResults())
return yieldTerminator.emitOpError()
<< "expected terminator to have as many operands as the parent op "
"has results";
- for (auto [i, termOperandType, resultType] : llvm::zip_equal(
- llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+ for (auto [idx, termOperandType, resultType] : llvm::enumerate(
yieldTerminator->getOperands().getType(), getResultTypes())) {
- if (isa<transform::AnyParamType>(resultType))
- continue; // No type checking as result is of !transform.any_param type.
- auto paramResultType = dyn_cast<transform::ParamType>(resultType);
- if (!paramResultType)
- return emitOpError() << "result type #" << i
- << " is not a !transform.param";
- Type wrappedResultType = paramResultType.getType();
-
- if (isa<mlir::smt::IntType>(termOperandType)) {
- if (!isa<IntegerType>(wrappedResultType))
- return yieldTerminator.emitOpError()
- << "the type of terminator operand #" << i
- << " is !smt.int though the corresponding result type ("
- << resultType
- << ") of the parent op is not wrapping an integer type";
- } else if (isa<mlir::smt::BoolType>(termOperandType)) {
- auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
- if (!intResultType || intResultType.getWidth() != 1)
- return yieldTerminator.emitOpError()
- << "the type of terminator operand #" << i
- << " is !smt.bool though the corresponding result type ("
- << resultType
- << ") of the parent op is not wrapping i1 (i.e. bool)";
- } else if (auto bvOperandType =
- dyn_cast<mlir::smt::BitVectorType>(termOperandType)) {
- auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
- if (!intResultType ||
- intResultType.getWidth() != bvOperandType.getWidth())
- return yieldTerminator.emitOpError()
- << "the type of terminator operand #" << i << " is "
- << termOperandType << " though the corresponding result type ("
- << resultType
- << ") is not wrapping an integer type of the same bitwidth";
- }
+ InFlightDiagnostic typeCheckResult =
+ checkTypes(idx, termOperandType, "terminator operand",
+ cast<transform::ParamType>(resultType), "result",
+ /*atOp=*/&yieldTerminator);
+ if (LogicalResult(typeCheckResult).failed())
+ return typeCheckResult;
}
return success();
diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
index 4e365fa2dbaf9..d91d69a756458 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
@@ -1,15 +1,13 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
-// CHECK-LABEL: @constraint_not_using_smt_ops
+// CHECK-LABEL: @incorrect terminator
module attributes {transform.with_named_sequence} {
- transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
+ transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
- // expected-error at below {{ops contained in region should belong to SMT-dialect}}
+ // expected-error at below {{op expected 'smt.yield' as terminator}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int):
- %c4 = arith.constant 4 : i32
- // This is the kind of thing one might think works:
- //arith.remsi %param_as_smt_var, %c4 : i32
+ transform.yield
}
transform.yield
}
@@ -31,6 +29,23 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @constraint_not_using_smt_ops
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ // expected-error at below {{ops contained in region should belong to SMT-dialect}}
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
+ ^bb0(%param_as_smt_var: !smt.int):
+ %c4 = arith.constant 4 : i32
+ // This is the kind of thing one might think works:
+ //arith.remsi %param_as_smt_var, %c4 : i32
+ }
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: @results_not_one_to_one_with_vars
module attributes {transform.with_named_sequence} {
transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
@@ -46,11 +61,27 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: @mismatched_type_bool
+// CHECK-LABEL: @non_smt_type_block_args
module attributes {transform.with_named_sequence} {
- transform.named_sequence @mismatched_type_bool(%arg0: !transform.any_op {transform.readonly}) {
+ transform.named_sequence @non_smt_type_block_args(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i8>
+ // expected-error at below {{the type of block arg #0 is expected to be either a !smt.bool, a !smt.int, or a !smt.bv}}
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i8>) {
+ ^bb0(%param_as_smt_var: !transform.param<i8>):
+ smt.yield %param_as_smt_var : !transform.param<i8>
+ }
+ transform.yield
+ }
+}
+
+
+// -----
+
+// CHECK-LABEL: @mismatched_arg_type_bool
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @mismatched_arg_type_bool(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
- // expected-error at below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1 (i.e. bool)}}
+ // expected-error at below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
^bb0(%param_as_smt_var: !smt.bool):
smt.yield %param_as_smt_var : !smt.bool
@@ -61,9 +92,9 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: @mismatched_type_bitvector
+// CHECK-LABEL: @mismatched_arg_type_bitvector
module attributes {transform.with_named_sequence} {
- transform.named_sequence @mismatched_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
+ transform.named_sequence @mismatched_arg_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error at below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
@@ -73,3 +104,33 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @mismatched_result_type_bool
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @mismatched_result_type_bool(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 1 -> !transform.param<i1>
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> (!transform.param<i64>) {
+ ^bb0(%param_as_smt_var: !smt.bool):
+ // expected-error at below {{the type of terminator operand #0 is !smt.bool though the corresponding result type ('!transform.param<i64>') is not wrapping i1}}
+ smt.yield %param_as_smt_var : !smt.bool
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mismatched_result_type_bitvector
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @mismatched_result_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i8>
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i64>) {
+ ^bb0(%param_as_smt_var: !smt.bv<8>):
+ // expected-error at below {{the type of terminator operand #0 is '!smt.bv<8>' though the corresponding result type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
+ smt.yield %param_as_smt_var : !smt.bv<8>
+ }
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list