[Mlir-commits] [mlir] [MLIR][Transform][SMT] Allow for declarative computations in schedules (PR #160895)
    Oleksandr Alex Zinenko 
    llvmlistbot at llvm.org
       
    Thu Oct  9 03:07:22 PDT 2025
    
    
  
================
@@ -37,19 +38,111 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
   //       and allow for users to attach their own implementation, which would,
   //       e.g., translate the ops to SMTLIB and hand that over to the user's
   //       favourite solver. This requires changes to the dialect's verifier.
-  return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+  return emitSilenceableFailure(getLoc())
+         << "op does not have interpreted semantics yet";
 }
 
 LogicalResult transform::smt::ConstrainParamsOp::verify() {
+  auto yieldTerminator =
+      llvm::dyn_cast_if_present<mlir::smt::YieldOp>(getRegion().front().back());
+  if (!yieldTerminator)
+    return emitOpError() << "expected '"
+                         << mlir::smt::YieldOp::getOperationName()
+                         << "' as terminator";
+
   if (getOperands().size() != getBody().getNumArguments())
     return emitOpError(
         "must have the same number of block arguments as operands");
 
+  for (auto [i, operandType, blockArgType] :
+       llvm::zip_equal(llvm::seq<unsigned>(0, getBody().getNumArguments()),
+                       getOperandTypes(), getBody().getArgumentTypes())) {
+    if (isa<transform::AnyParamType>(operandType))
+      continue; // No type checking as operand is of !transform.any_param type.
+    auto paramOperandType = dyn_cast<transform::ParamType>(operandType);
+    if (!paramOperandType)
+      return emitOpError() << "operand type #" << i
+                           << " is not a !transform.param";
+    Type wrappedOperandType = paramOperandType.getType();
+
+    if (isa<mlir::smt::IntType>(blockArgType)) {
+      if (!isa<IntegerType>(paramOperandType.getType()))
+        return emitOpError()
+               << "the type of block arg #" << i
+               << " is !smt.int though the corresponding operand type ("
+               << operandType << ") is not wrapping an integer type";
+    } else if (isa<mlir::smt::BoolType>(blockArgType)) {
+      auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+      if (!intOperandType || intOperandType.getWidth() != 1)
+        return emitOpError()
+               << "the type of block arg #" << i
+               << " is !smt.bool though the corresponding operand type ("
+               << operandType << ") is not wrapping i1 (i.e. bool)";
+    } else if (auto bvBlockArgType =
+                   dyn_cast<mlir::smt::BitVectorType>(blockArgType)) {
+      auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+      if (!intOperandType ||
+          intOperandType.getWidth() != bvBlockArgType.getWidth())
+        return emitOpError()
+               << "the type of block arg #" << i << " is " << blockArgType
+               << " though the corresponding operand type (" << operandType
+               << ") is not wrapping an integer type of the same bitwidth";
+    }
+  }
+
   for (auto &op : getBody().getOps()) {
     if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
       return emitOpError(
           "ops contained in region should belong to SMT-dialect");
   }
 
+  if (getOperands().size() != getBody().getNumArguments())
+    return emitOpError(
+        "must have the same number of block arguments as operands");
+
+  if (yieldTerminator->getNumOperands() != getNumResults())
+    return yieldTerminator.emitOpError()
+           << "expected terminator to have as many operands as the parent op "
+              "has results";
+
+  for (auto [i, termOperandType, resultType] : llvm::zip_equal(
+           llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+           yieldTerminator->getOperands().getType(), getResultTypes())) {
+    if (isa<transform::AnyParamType>(resultType))
+      continue; // No type checking as result is of !transform.any_param type.
+    auto paramResultType = dyn_cast<transform::ParamType>(resultType);
+    if (!paramResultType)
+      return emitOpError() << "result type #" << i
+                           << " is not a !transform.param";
+    Type wrappedResultType = paramResultType.getType();
+
+    if (isa<mlir::smt::IntType>(termOperandType)) {
+      if (!isa<IntegerType>(wrappedResultType))
+        return yieldTerminator.emitOpError()
+               << "the type of terminator operand #" << i
+               << " is !smt.int though the corresponding result type ("
+               << resultType
+               << ") of the parent op is not wrapping an integer type";
+    } else if (isa<mlir::smt::BoolType>(termOperandType)) {
+      auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
+      if (!intResultType || intResultType.getWidth() != 1)
+        return yieldTerminator.emitOpError()
+               << "the type of terminator operand #" << i
+               << " is !smt.bool though the corresponding result type ("
+               << resultType
+               << ") of the parent op is not wrapping i1 (i.e. bool)";
----------------
ftynse wrote:
Can this be factored out with similar code above?
https://github.com/llvm/llvm-project/pull/160895
    
    
More information about the Mlir-commits
mailing list