[Mlir-commits] [mlir] [MLIR][Transform][SMT] Allow for declarative computations in schedules (PR #160895)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 26 07:44:25 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

<details>
<summary>Changes</summary>

By allowing `transform.smt.constrain_params`'s region to yield SMT-vars, op instances can declare relationships, through constraints, on incoming params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it possible to declare that computations on params should be performed.

The semantics are that the yielded SMT-vars should be from any valid satisfying assignment/model of the constraints in the region.

---

Patch is 21.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160895.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/SMT/IR/SMTOps.td (-2) 
- (modified) mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h (+1) 
- (modified) mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td (+12-5) 
- (modified) mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp (+96-3) 
- (modified) mlir/python/mlir/dialects/transform/smt.py (+12) 
- (modified) mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir (+47-2) 
- (modified) mlir/test/Dialect/Transform/test-smt-extension.mlir (+13-8) 
- (modified) mlir/test/python/dialects/transform_smt_ext.py (+24-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
index 3143ab7de1b14..99b22e5609c74 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
   Pure,
   Terminator,
   ReturnLike,
-  ParentOneOf<["smt::SolverOp", "smt::CheckOp",
-               "smt::ForallOp", "smt::ExistsOp"]>,
 ]> {
   let summary = "terminator operation for various regions of SMT operations";
   let arguments = (ins Variadic<AnyType>:$values);
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
index fc69b039f24ff..f6353a995d747 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
index b987cb31e54bb..9d9783aa66ed9 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
@@ -16,7 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
   DeclareOpInterfaceMethods<TransformOpInterface>,
   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-  NoTerminator
+  SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
 ]> {
   let cppNamespace = [{ mlir::transform::smt }];
 
@@ -24,14 +24,20 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
   let description = [{
     Allows expressing constraints on params using the SMT dialect.
 
-    Each Transform dialect param provided as an operand has a corresponding
+    Each Transform-dialect param provided as an operand has a corresponding
     argument of SMT-type in the region. The SMT-Dialect ops in the region use
-    these arguments as operands.
+    these params-as-SMT-vars as operands, thereby expressing relevant
+    constraints on their allowed values.
+
+    Computations w.r.t. passed-in params can also be expressed through the
+    region's SMT-ops. Namely, the constraints express relationships to other
+    SMT-variables which can then be yielded from the region (with `smt.yield`).
 
     The semantics of this op is that all the ops in the region together express
     a constraint on the params-interpreted-as-smt-vars. The op fails in case the
     expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
-    op succeeds.
+    op succeeds and any one satisfying assignment is used to map the
+    SMT-variables yielded in the region to `transform.param`s.
 
     ---
 
@@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
   }];
 
   let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
+  let results = (outs Variadic<TransformParamTypeInterface>:$results);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat =
-      "`(` $params `)` attr-dict `:` type(operands) $body";
+      "`(` $params `)` attr-dict `:` functional-type(operands, results) $body";
 
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
index 8e7af05353de7..d85268da2ad5d 100644
--- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -8,8 +8,8 @@
 
 #include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
 #include "mlir/Dialect/SMT/IR/SMTDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformOps.h"
-#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
 
 using namespace mlir;
 
@@ -23,6 +23,7 @@ using namespace mlir;
 void transform::smt::ConstrainParamsOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   onlyReadsHandle(getParamsMutable(), effects);
+  producesHandle(getResults(), effects);
 }
 
 DiagnosedSilenceableFailure
@@ -37,19 +38,111 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
   //       and allow for users to attach their own implementation, which would,
   //       e.g., translate the ops to SMTLIB and hand that over to the user's
   //       favourite solver. This requires changes to the dialect's verifier.
-  return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+  return emitSilenceableFailure(getLoc())
+         << "op does not have interpreted semantics yet";
 }
 
 LogicalResult transform::smt::ConstrainParamsOp::verify() {
+  auto yieldTerminator =
+      llvm::dyn_cast_if_present<mlir::smt::YieldOp>(getRegion().front().back());
+  if (!yieldTerminator)
+    return emitOpError() << "expected '"
+                         << mlir::smt::YieldOp::getOperationName()
+                         << "' as terminator";
+
   if (getOperands().size() != getBody().getNumArguments())
     return emitOpError(
         "must have the same number of block arguments as operands");
 
+  for (auto [i, operandType, blockArgType] :
+       llvm::zip_equal(llvm::seq<unsigned>(0, getBody().getNumArguments()),
+                       getOperandTypes(), getBody().getArgumentTypes())) {
+    if (isa<transform::AnyParamType>(operandType))
+      continue; // No type checking as operand is of !transform.any_param type.
+    auto paramOperandType = dyn_cast<transform::ParamType>(operandType);
+    if (!paramOperandType)
+      return emitOpError() << "operand type #" << i
+                           << " is not a !transform.param";
+    Type wrappedOperandType = paramOperandType.getType();
+
+    if (isa<mlir::smt::IntType>(blockArgType)) {
+      if (!isa<IntegerType>(paramOperandType.getType()))
+        return emitOpError()
+               << "the type of block arg #" << i
+               << " is !smt.int though the corresponding operand type ("
+               << operandType << ") is not wrapping an integer type";
+    } else if (isa<mlir::smt::BoolType>(blockArgType)) {
+      auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+      if (!intOperandType || intOperandType.getWidth() != 1)
+        return emitOpError()
+               << "the type of block arg #" << i
+               << " is !smt.bool though the corresponding operand type ("
+               << operandType << ") is not wrapping i1 (i.e. bool)";
+    } else if (auto bvBlockArgType =
+                   dyn_cast<mlir::smt::BitVectorType>(blockArgType)) {
+      auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+      if (!intOperandType ||
+          intOperandType.getWidth() != bvBlockArgType.getWidth())
+        return emitOpError()
+               << "the type of block arg #" << i << " is " << blockArgType
+               << " though the corresponding operand type (" << operandType
+               << ") is not wrapping an integer type of the same bitwidth";
+    }
+  }
+
   for (auto &op : getBody().getOps()) {
     if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
       return emitOpError(
           "ops contained in region should belong to SMT-dialect");
   }
 
+  if (getOperands().size() != getBody().getNumArguments())
+    return emitOpError(
+        "must have the same number of block arguments as operands");
+
+  if (yieldTerminator->getNumOperands() != getNumResults())
+    return yieldTerminator.emitOpError()
+           << "expected terminator to have as many operands as the parent op "
+              "has results";
+
+  for (auto [i, termOperandType, resultType] : llvm::zip_equal(
+           llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+           yieldTerminator->getOperands().getType(), getResultTypes())) {
+    if (isa<transform::AnyParamType>(resultType))
+      continue; // No type checking as result is of !transform.any_param type.
+    auto paramResultType = dyn_cast<transform::ParamType>(resultType);
+    if (!paramResultType)
+      return emitOpError() << "result type #" << i
+                           << " is not a !transform.param";
+    Type wrappedResultType = paramResultType.getType();
+
+    if (isa<mlir::smt::IntType>(termOperandType)) {
+      if (!isa<IntegerType>(wrappedResultType))
+        return yieldTerminator.emitOpError()
+               << "the type of terminator operand #" << i
+               << " is !smt.int though the corresponding result type ("
+               << resultType
+               << ") of the parent op is not wrapping an integer type";
+    } else if (isa<mlir::smt::BoolType>(termOperandType)) {
+      auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
+      if (!intResultType || intResultType.getWidth() != 1)
+        return yieldTerminator.emitOpError()
+               << "the type of terminator operand #" << i
+               << " is !smt.bool though the corresponding result type ("
+               << resultType
+               << ") of the parent op is not wrapping i1 (i.e. bool)";
+    } else if (auto bvOperandType =
+                   dyn_cast<mlir::smt::BitVectorType>(termOperandType)) {
+      auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
+      if (!intResultType ||
+          intResultType.getWidth() != bvOperandType.getWidth())
+        return yieldTerminator.emitOpError()
+               << "the type of terminator operand #" << i << " is "
+               << termOperandType << " though the corresponding result type ("
+               << resultType
+               << ") is not wrapping an integer type of the same bitwidth";
+    }
+  }
+
   return success();
 }
diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py
index 1f0b7f066118c..af88fffcd3bba 100644
--- a/mlir/python/mlir/dialects/transform/smt.py
+++ b/mlir/python/mlir/dialects/transform/smt.py
@@ -19,6 +19,7 @@
 class ConstrainParamsOp(ConstrainParamsOp):
     def __init__(
         self,
+        results: Sequence[Type],
         params: Sequence[transform.AnyParamType],
         arg_types: Sequence[Type],
         loc=None,
@@ -27,6 +28,7 @@ def __init__(
         if len(params) != len(arg_types):
             raise ValueError(f"{params=} not same length as {arg_types=}")
         super().__init__(
+            results,
             params,
             loc=loc,
             ip=ip,
@@ -36,3 +38,13 @@ def __init__(
     @property
     def body(self) -> Block:
         return self.regions[0].blocks[0]
+
+
+def constrain_params(
+    results: Sequence[Type],
+    params: Sequence[transform.AnyParamType],
+    arg_types: Sequence[Type],
+    loc=None,
+    ip=None,
+):
+    return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip)
diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
index 314b8d493c5d4..4e365fa2dbaf9 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
@@ -5,7 +5,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
     %param_as_param = transform.param.constant 42 -> !transform.param<i64>
     // expected-error at below {{ops contained in region should belong to SMT-dialect}}
-    transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
       ^bb0(%param_as_smt_var: !smt.int):
       %c4 = arith.constant 4 : i32
       // This is the kind of thing one might think works:
@@ -22,9 +22,54 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
     %param_as_param = transform.param.constant 42 -> !transform.param<i64>
     // expected-error at below {{must have the same number of block arguments as operands}}
-    transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
       ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
     }
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @results_not_one_to_one_with_vars
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
+    %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+    transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param<i64>, !transform.param<i64>) -> () {
+      ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
+      // expected-error at below {{expected terminator to have as many operands as the parent op has results}}
+      smt.yield %param_as_smt_var : !smt.int
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @mismatched_type_bool
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @mismatched_type_bool(%arg0: !transform.any_op {transform.readonly}) {
+    %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+    // expected-error at below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1 (i.e. bool)}}
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
+      ^bb0(%param_as_smt_var: !smt.bool):
+      smt.yield %param_as_smt_var : !smt.bool
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @mismatched_type_bitvector
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @mismatched_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
+    %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+    // expected-error at below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
+      ^bb0(%param_as_smt_var: !smt.bv<8>):
+      smt.yield %param_as_smt_var : !smt.bv<8>
+    }
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-smt-extension.mlir b/mlir/test/Dialect/Transform/test-smt-extension.mlir
index 29d15175ae4ec..6cc41dd52473e 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension.mlir
@@ -7,7 +7,7 @@ module attributes {transform.with_named_sequence} {
     %param_as_param = transform.param.constant 42 -> !transform.param<i64>
 
     // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
-    transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
       // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
       ^bb0(%param_as_smt_var: !smt.int):
       // CHECK: %[[C0:.*]] = smt.int.constant 0
@@ -31,18 +31,20 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
+// CHECK-LABEL: @schedule_with_constraint_on_multiple_params_returning_computed_value
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
+  transform.named_sequence @schedule_with_constraint_on_multiple_params_returning_computed_value(%arg0: !transform.any_op {transform.readonly}) {
     // CHECK: %[[PARAM_A:.*]] = transform.param.constant
     %param_a = transform.param.constant 4 -> !transform.param<i64>
     // CHECK: %[[PARAM_B:.*]] = transform.param.constant
-    %param_b = transform.param.constant 16 -> !transform.param<i64>
+    %param_b = transform.param.constant 32 -> !transform.param<i64>
 
     // CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
-    transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
+    %divisor = transform.smt.constrain_params(%param_a, %param_b) : (!transform.param<i64>, !transform.param<i64>) -> (!transform.param<i64>) {
       // CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
       ^bb0(%var_a: !smt.int, %var_b: !smt.int):
+      // CHECK: %[[DIV:.*]] = smt.int.div %[[VAR_B]], %[[VAR_A]]
+      %divisor = smt.int.div %var_b, %var_a
       // CHECK: %[[C0:.*]] = smt.int.constant 0
       %c0 = smt.int.constant 0
       // CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
@@ -51,8 +53,11 @@ module attributes {transform.with_named_sequence} {
       %eq = smt.eq %remainder, %c0 : !smt.int
       // CHECK: smt.assert %[[EQ]]
       smt.assert %eq
+      // CHECK: smt.yield %[[DIV]]
+      smt.yield %divisor : !smt.int
     }
-    // NB: from here can rely on that %param_a is a divisor of %param_b
+    // NB: from here can rely on that %param_a is a divisor of %param_b and
+    //     that the relevant factor, 8, got associated to %divisor.
     transform.yield
   }
 }
@@ -63,10 +68,10 @@ module attributes {transform.with_named_sequence} {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
     // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
-    %param_as_param = transform.param.constant true -> !transform.any_param
+    %param_as_param = transform.param.constant true -> !transform.param<i1>
 
     // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
-    transform.smt.constrain_params(%param_as_param) : !transform.any_param {
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> () {
       // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
       ^bb0(%param_as_smt_var: !smt.bool):
       // CHECK: %[[C0:.*]] = smt.int.constant 0
diff --git a/mlir/test/python/dialects/transform_smt_ext.py b/mlir/test/python/dialects/transform_smt_ext.py
index 3692fd92344a6..e28c56f277439 100644
--- a/mlir/test/python/dialects/transform_smt_ext.py
+++ b/mlir/test/python/dialects/transform_smt_ext.py
@@ -25,26 +25,44 @@ def run(f):
 # CHECK-LABEL: TEST: testConstrainParamsOp
 @run
 def testConstrainParamsOp(target):
-    dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
+    c42_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
     # CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
-    symbolic_value = transform.ParamConstantOp(
-        transform.AnyParamType.get(), dummy_value
+    symbolic_value_as_param = transform.ParamConstantOp(
+        transform.AnyParamType.get(), c42_attr
     )
     # CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
     constrain_params = transform_smt.ConstrainParamsOp(
-        [symbolic_value], [smt.IntType.get()]
+        [], [symbolic_value_as_param], [smt.IntType.get()]
     )
     # CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
     with ir.InsertionPoint(constrain_params.body):
+        symbolic_value_as_smt_var = constrain_params.body.arguments[0]
         # CHECK: %[[C0:.*]] = smt.int.constant 0
         c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
         # CHECK: %[[C43:.*]] = smt.int.constant 43
         c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
         # CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
-        lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
+        lb = smt.IntCmpOp(smt.IntPredicate....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/160895


More information about the Mlir-commits mailing list