[Mlir-commits] [mlir] [mlir] add normal form checked transform interface (PR #192647)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Apr 17 05:58:50 PDT 2026


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/192647

This interface can be implemented by operations that guarantee certain
normal forms for themselves and their regions. The operations provide
the list of normal forms they guarantee. This interface interacts with
the typed transform handles removing the need for them to check normal
forms that are guaranteed (and preserved by transforms).

Provide a simple `transform.payload` operation to carry a list of normal
forms and implement the interface.

This exposes the fact that the transform interpreter may be running the
verifier too much, but this is a pre-existing beavior that is orthogonal
to this patch.

Assisted-by: Claude Opus 4.7 / Cursor


>From 7eeee80fa4f415f8656055e9c2f9214c3fdd8d0e Mon Sep 17 00:00:00 2001
From: Alex Zinenko <git at ozinenko.com>
Date: Thu, 16 Apr 2026 23:19:26 +0200
Subject: [PATCH] [mlir] add normal form checked transform interface

This interface can be implemented by operations that guarantee certain
normal forms for themselves and their regions. The operations provide
the list of normal forms they guarantee. This interface interacts with
the typed transform handles removing the need for them to check normal
forms that are guaranteed (and preserved by transforms).

Provide a simple `transform.payload` operation to carry a list of normal
forms and implement the interface.

This exposes the fact that the transform interpreter may be running the
verifier too much, but this is a pre-existing beavior that is orthogonal
to this patch.

Assisted-by: Claude Opus 4.7 / Cursor
---
 .../mlir/Dialect/Transform/IR/TransformOps.td | 17 ++++
 .../Interfaces/TransformInterfaces.h          | 15 ++++
 .../Interfaces/TransformInterfaces.td         | 29 +++++++
 .../Dialect/Transform/IR/TransformDialect.cpp |  6 +-
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 10 +++
 .../Dialect/Transform/IR/TransformTypes.cpp   | 41 ++++-----
 .../Interfaces/TransformInterfaces.cpp        | 39 +++++++++
 mlir/test/Dialect/Transform/normal-forms.mlir | 83 +++++++++++++++++++
 .../TestTransformDialectExtension.cpp         | 16 ++++
 .../TestTransformDialectExtension.td          | 13 +++
 10 files changed, 242 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index d0de4aaed310c..2daa2ad655e34 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -1138,6 +1138,23 @@ def ParamConstantOp : Op<Transform_Dialect, "param.constant", [
   let assemblyFormat = "$value attr-dict `->` type($param)";
 }
 
+def PayloadOp : Op<Transform_Dialect, "payload",
+                   [NoTerminator, NoRegionArguments,
+                    DeclareOpInterfaceMethods<NormalFormCheckedOpInterface>]> {
+  let summary = "Optional container for transform payloads";
+  let description = [{
+    Contains payload operations on which transforms operate and serves as a
+    storage location for information between multiple transform invocations.
+
+    The operations contained in this payload must satisfy the normal forms specified by the `normal_forms` attribute.
+  }];
+  let regions = (region AnyRegion:$body);
+  let arguments =
+      (ins TypedArrayAttrBase<NormalFormAttrInterface,
+                              "normal form attributes">:$normal_forms);
+  let assemblyFormat = "attr-dict-with-keyword regions";
+}
+
 def PrintOp : TransformDialectOp<"print",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index d5499fa2f3fc0..baf7d01e6811c 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -82,6 +82,21 @@ TransformState makeTransformStateForTesting(Region *region,
 /// Returns all operands that are handles and being consumed by the given op.
 SmallVector<OpOperand *>
 getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
+
+/// Checks that the given payload operations satisfy the normal form
+/// constraints, reports the first encountered definite failure or the first
+/// encountered silenceable failure if there were no definite failures. Normal
+/// forms are checked in order, so trailing normal forms may assume earlier
+/// normal forms did not produce definite failures.
+mlir::DiagnosedSilenceableFailure checkNormalForms(
+    llvm::ArrayRef<mlir::transform::NormalFormAttrInterface> normalForms,
+    llvm::ArrayRef<mlir::Operation *> payload);
+
+/// Checks that the given list does not contain duplicate normal forms of the
+/// same class.
+llvm::LogicalResult verifyNormalFormList(
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+    llvm::ArrayRef<mlir::transform::NormalFormAttrInterface> normalForms);
 } // namespace detail
 } // namespace transform
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.td
index 60474a370be8f..61c673d173b71 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.td
@@ -151,6 +151,35 @@ def NormalFormAttrInterface : AttrInterface<"NormalFormAttrInterface"> {
   ];
 }
 
+def NormalFormCheckedOpInterface : OpInterface<"NormalFormCheckedOpInterface"> {
+  let description = [{
+    Interface for operations that satisfy normal forms. Such an operation
+    declares a list of normal forms it satisfies and the interface-level
+    verifier ensures this is the case by calling the corresponding checkers from
+    the normal form interfaces.
+  }];
+  let cppNamespace = "::mlir::transform";
+  let methods = [InterfaceMethod<
+      /*desc=*/[{Returns the list of normal forms the operation satisfies.}],
+      /*returnType=*/"void",
+      /*name=*/"getCheckedNormalForms",
+      /*arguments=*/
+      (ins "::llvm::SmallVectorImpl<::mlir::transform::NormalFormAttrInterface>"
+           " &":$normalForms)>];
+
+  let verify = [{
+    ::llvm::SmallVector<::mlir::transform::NormalFormAttrInterface> normalForms;
+    cast<::mlir::transform::NormalFormCheckedOpInterface>($_op)
+        .getCheckedNormalForms(normalForms);
+    if (::llvm::failed(::mlir::transform::detail::verifyNormalFormList(
+        [&] { return $_op->emitError(); }, normalForms)))
+      return ::llvm::failure();
+    return ::mlir::transform::detail::checkNormalForms(normalForms, $_op)
+        .checkAndReport();
+  }];
+  let verifyWithRegions = 1;
+}
+
 class TransformTypeInterfaceBase<string cppClass, string cppObjectType>
     : TypeInterface<cppClass> {
   let cppNamespace = "::mlir::transform";
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 0448ed194217d..778303d8f2baa 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -33,13 +33,15 @@ void transform::detail::checkImplementsTransformOpInterface(
           opName.hasInterface<PatternDescriptorOpInterface>() ||
           opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
           opName.hasInterface<TypeConverterBuilderOpInterface>() ||
-          opName.hasTrait<OpTrait::IsTerminator>()) &&
+          opName.hasTrait<OpTrait::IsTerminator>() ||
+          opName.hasInterface<NormalFormCheckedOpInterface>()) &&
          "non-terminator ops injected into the transform dialect must "
          "implement TransformOpInterface or PatternDescriptorOpInterface or "
          "ConversionPatternDescriptorOpInterface");
   if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
       !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
-      !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
+      !opName.hasInterface<TypeConverterBuilderOpInterface>() &&
+      !opName.hasInterface<NormalFormCheckedOpInterface>()) {
     assert(opName.hasInterface<MemoryEffectOpInterface>() &&
            "ops injected into the transform dialect must implement "
            "MemoryEffectsOpInterface");
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index b53c36a51038a..03b28ad0acfa2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2759,6 +2759,16 @@ LogicalResult transform::SplitHandleOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// PayloadOp
+//===----------------------------------------------------------------------===//
+
+void transform::PayloadOp::getCheckedNormalForms(
+    SmallVectorImpl<NormalFormAttrInterface> &normalForms) {
+  llvm::append_range(normalForms,
+                     getNormalForms().getAsRange<NormalFormAttrInterface>());
+}
+
 //===----------------------------------------------------------------------===//
 // ReplicateOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
index c0b379cfbb063..20b33b80c99e5 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
@@ -111,36 +111,27 @@ transform::AnyParamType::checkPayload(Location loc,
 DiagnosedSilenceableFailure
 transform::NormalizedOpType::checkPayload(Location loc,
                                           ArrayRef<Operation *> payload) const {
-  // Return any definite failure or the first silenceable failure.
-  auto overallResult = DiagnosedSilenceableFailure::success();
-  for (Operation *op : payload) {
-    for (NormalFormAttrInterface normalForm : getNormalForms()) {
-      DiagnosedSilenceableFailure result = normalForm.checkOperation(op);
-      if (result.isDefiniteFailure())
-        return result;
-      if (result.isSilenceableFailure() && overallResult.succeeded())
-        overallResult = std::move(result);
-    }
-  }
-  return overallResult;
+  // Only check payloads that are not already guaranteeing the required forms.
+  SmallVector<Operation *> payloadsToCheck =
+      llvm::filter_to_vector(payload, [this](Operation *op) {
+        auto normalFormCheckedOp = dyn_cast<NormalFormCheckedOpInterface>(op);
+        if (!normalFormCheckedOp)
+          return true;
+
+        SmallVector<NormalFormAttrInterface> checkedNormalForms;
+        normalFormCheckedOp.getCheckedNormalForms(checkedNormalForms);
+        return !llvm::all_of(
+            this->getNormalForms(), [&](NormalFormAttrInterface form) {
+              return llvm::is_contained(checkedNormalForms, form);
+            });
+      });
+  return detail::checkNormalForms(getNormalForms(), payloadsToCheck);
 }
 
 LogicalResult transform::NormalizedOpType::verify(
     function_ref<InFlightDiagnostic()> emitError,
     ArrayRef<NormalFormAttrInterface> normalForms) {
-  llvm::DenseMap<TypeID, NormalFormAttrInterface> seen;
-  for (NormalFormAttrInterface normalForm : normalForms) {
-    auto [previous, inserted] =
-        seen.try_emplace(normalForm.getTypeID(), normalForm);
-    if (!inserted) {
-      InFlightDiagnostic diag = emitError()
-                                << "duplicate normal form: " << normalForm;
-      diag.attachNote() << "previous instance: " << previous->second;
-      return diag;
-    }
-  }
-
-  return success();
+  return detail::verifyNormalFormList(emitError, normalForms);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 09225f14dea1d..0084327817075 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1987,6 +1987,45 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Normal form utilities.
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::detail::checkNormalForms(
+    ArrayRef<NormalFormAttrInterface> normalForms,
+    ArrayRef<Operation *> payload) {
+  // Return any definite failure or the first silenceable failure.
+  auto overallResult = DiagnosedSilenceableFailure::success();
+  for (Operation *op : payload) {
+    for (NormalFormAttrInterface normalForm : normalForms) {
+      DiagnosedSilenceableFailure result = normalForm.checkOperation(op);
+      if (result.isDefiniteFailure())
+        return result;
+      if (result.isSilenceableFailure() && overallResult.succeeded())
+        overallResult = std::move(result);
+    }
+  }
+  return overallResult;
+}
+
+LogicalResult transform::detail::verifyNormalFormList(
+    function_ref<InFlightDiagnostic()> emitError,
+    ArrayRef<NormalFormAttrInterface> normalForms) {
+  llvm::DenseMap<TypeID, NormalFormAttrInterface> seen;
+  for (NormalFormAttrInterface normalForm : normalForms) {
+    auto [previous, inserted] =
+        seen.try_emplace(normalForm.getTypeID(), normalForm);
+    if (!inserted) {
+      InFlightDiagnostic diag = emitError()
+                                << "duplicate normal form: " << normalForm;
+      diag.attachNote() << "previous instance: " << previous->second;
+      return diag;
+    }
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Entry point.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/normal-forms.mlir b/mlir/test/Dialect/Transform/normal-forms.mlir
index a1367960b424b..9c2984ec1039e 100644
--- a/mlir/test/Dialect/Transform/normal-forms.mlir
+++ b/mlir/test/Dialect/Transform/normal-forms.mlir
@@ -80,3 +80,86 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+module attributes {transform.with_named_sequence} {
+  // expected-remark @below {{matched}}
+  transform.payload attributes {
+      normal_forms = [#transform.test_single_block_normal_form<nested true>]} {
+    transform.test_dummy_payload_op : () -> ()
+  }
+
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %0 = transform.structured.match ops{["transform.payload"]} in %arg0
+        : (!transform.any_op) ->
+            !transform.normalized_op<#transform.test_single_block_normal_form<nested true>>
+    transform.debug.emit_remark_at %0, "matched"
+        : !transform.normalized_op<#transform.test_single_block_normal_form<nested true>>
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-note @below {{previous instance}}
+// expected-error @below {{duplicate normal form}}
+transform.payload attributes {normal_forms = [
+    #transform.test_single_block_normal_form<nested false>,
+    #transform.test_single_block_normal_form<nested true>]} {
+}
+
+// -----
+
+transform.payload attributes {
+    normal_forms = [#transform.test_single_block_normal_form<nested true>]} {
+  // expected-error @below {{normal form test_single_block_normal_form requires payload operations to have a single region}}
+  "test.foo"() ({
+    cf.br ^bb1
+  ^bb1:
+    "test.bar"() : () -> ()
+  }) : () -> ()
+}
+
+// -----
+
+transform.payload attributes {
+    normal_forms = [#transform.test_single_block_normal_form<nested true>]} {
+  // We should see the diagnostic from the inner op verifier, and never hit
+  // the normal form check.
+  // expected-error @below {{fail_to_verify is set}}
+  transform.test_dummy_payload_op {fail_to_verify} : () -> ()
+  "test.foo"() ({
+    cf.br ^bb1
+  ^bb1:
+    "test.bar"() : () -> ()
+  }) : () -> ()
+}
+
+// -----
+
+// We have surprisingly many invocations of the verifier here:
+//  1. after the initial parsing (reasonable)
+//  2. in transform::detail::mergeSymbolsInto (looks excessive)
+//  3. also in transform::detail::mergeSymbolsInto (has a TODO to be removed)
+//  4. after the transform interpreter pass (reasonable)
+//  5. before printing (generally reasonable, but would be nice to avoid if 
+//     the IR is known-verified after by the pass manager).
+// Notably this doesn't include an extra run from checkPayload, which is
+// what we intend to test here.
+
+// CHECK: transform.payload
+// CHECK-SAME: test.counting_normal_form_count = 5 : i64
+
+module attributes {transform.with_named_sequence} {
+  transform.payload attributes {
+      normal_forms = [#transform.test_counting_normal_form]} {
+    transform.test_dummy_payload_op : () -> ()
+  }
+
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %0 = transform.structured.match ops{["transform.payload"]} in %arg0
+        : (!transform.any_op) ->
+            !transform.normalized_op<#transform.test_counting_normal_form>
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 4b9008c398f6d..faf87316f1466 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -929,6 +929,22 @@ mlir::transform::TestSingleBlockNormalFormAttr::checkOperation(
   return wrapResult(loc, failure(walkResult.wasInterrupted()));
 }
 
+DiagnosedSilenceableFailure
+mlir::transform::TestCountingNormalFormAttr::checkOperation(
+    Operation *op) const {
+  // Record the number of invocations of this check on `op` as a discardable
+  // integer attribute. Tests that need to detect redundant checks can simply
+  // `FileCheck` the printed IR for the expected count.
+  Builder builder(op->getContext());
+  StringAttr counterName =
+      builder.getStringAttr("test.counting_normal_form_count");
+  unsigned count = 0;
+  if (auto prev = op->getAttrOfType<IntegerAttr>(counterName))
+    count = prev.getValue().getZExtValue();
+  op->setDiscardableAttr(counterName, builder.getI64IntegerAttr(count + 1));
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index eabb460869b10..d55311eb89651 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -35,6 +35,19 @@ def TestTransformSingleBlockNormalForm
   let assemblyFormat = "`<` `nested` $check_nested `>`";
 }
 
+def TestTransformCountingNormalForm
+    : AttrDef<Transform_Dialect, "TestCountingNormalForm",
+              [DeclareAttrInterfaceMethods<
+                  NormalFormAttrInterface, ["checkOperation"]>]> {
+  let description = [{
+    Normal form that always succeeds and updates a counter attribute on an
+    operation every time the check runs.
+  }];
+  let parameters = (ins);
+  let mnemonic = "test_counting_normal_form";
+  let assemblyFormat = "";
+}
+
 def TestTransformTestDialectHandleType
   : TypeDef<Transform_Dialect, "TestDialectOp",
       [DeclareTypeInterfaceMethods<TransformHandleTypeInterface>]> {



More information about the Mlir-commits mailing list