[Mlir-commits] [mlir] 3948928 - [mlir] verify that transform ops have memory effects

Alex Zinenko llvmlistbot at llvm.org
Tue Jan 10 05:49:45 PST 2023


Author: Alex Zinenko
Date: 2023-01-10T13:49:40Z
New Revision: 394892841aeaf35a75fa626b07637da68be176a9

URL: https://github.com/llvm/llvm-project/commit/394892841aeaf35a75fa626b07637da68be176a9
DIFF: https://github.com/llvm/llvm-project/commit/394892841aeaf35a75fa626b07637da68be176a9.diff

LOG: [mlir] verify that transform ops have memory effects

Add a verifier to the TransformOpInterface ensuring that operations
implementing the interface define memory effects on their operands and
results.

Add the missing effects to TileToForeachThreadOp, specifically for
operands that were added at a later version of the op without modifying
`getEffects` accordingly.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141371

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index b2c3827fdb825..6dbd1210d183b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -494,6 +494,9 @@ mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
 
 /// Verification hook for PossibleTopLevelTransformOpTrait.
 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
+
+/// Verification hook for TransformOpInterface.
+LogicalResult verifyTransformOpInterface(Operation *op);
 } // namespace detail
 
 /// This trait is supposed to be attached to Transform dialect operations that

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index b0b92daa3c855..f4d66c5f7fa67 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -101,6 +101,10 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
       return diag;
     }
   }];
+
+  let verify = [{
+    return ::mlir::transform::detail::verifyTransformOpInterface($_op);
+  }];
 }
 
 class TransformTypeInterfaceBase<string cppClass, string cppObjectType>

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f170d0bd1199b..509739691ab5e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1760,6 +1760,8 @@ void transform::TileToForeachThreadOp::getEffects(
   consumesHandle(getTarget(), effects);
   onlyReadsHandle(getTileSizes(), effects);
   onlyReadsHandle(getNumThreads(), effects);
+  onlyReadsHandle(getPackedNumThreads(), effects);
+  onlyReadsHandle(getPackedTileSizes(), effects);
   producesHandle(getResults(), effects);
 }
 

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index e2e8aa2cea027..b8a4ee7d66d74 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -616,8 +616,8 @@ void transform::consumesHandle(
 
 /// Returns `true` if the given list of effects instances contains an instance
 /// with the effect type specified as template parameter.
-template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource>
-static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
+template <typename EffectTy, typename ResourceTy, typename Range>
+static bool hasEffect(Range &&effects) {
   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
     return isa<EffectTy>(effect.getEffect()) &&
            isa<ResourceTy>(effect.getResource());
@@ -664,6 +664,48 @@ void transform::onlyReadsPayload(
   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
 }
 
+//===----------------------------------------------------------------------===//
+// Utilities for TransformOpInterface.
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
+  auto iface = cast<MemoryEffectOpInterface>(op);
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  iface.getEffects(effects);
+
+  auto effectsOn = [&](Value value) {
+    return llvm::make_filter_range(
+        effects, [value](const MemoryEffects::EffectInstance &instance) {
+          return instance.getValue() == value;
+        });
+  };
+
+  for (OpOperand &operand : op->getOpOperands()) {
+    auto range = effectsOn(operand.get());
+    if (range.empty()) {
+      InFlightDiagnostic diag =
+          op->emitError() << "TransformOpInterface requires memory effects "
+                             "on operands to be specified";
+      diag.attachNote() << "no effects specified for operand #"
+                        << operand.getOperandNumber();
+      return diag;
+    }
+  }
+  for (OpResult result : op->getResults()) {
+    auto range = effectsOn(result);
+    if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
+            range)) {
+      InFlightDiagnostic diag =
+          op->emitError() << "TransformOpInterface requires 'allocate' memory "
+                             "effect to be specified for results";
+      diag.attachNote() << "no 'allocate' effect specified for result #"
+                        << result.getResultNumber();
+      return diag;
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Entry point.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index ec3f5537cf55a..e957d7a26e575 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -210,3 +210,21 @@ transform.sequence failures(propagate) {
   // expected-note @below {{used here as operand #0}}
   transform.test_consume_operand %0
 }
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}}
+  // expected-note @below {{no effects specified for operand #0}}
+  transform.test_required_memory_effects %arg0 : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{TransformOpInterface requires 'allocate' memory effect to be specified for results}}
+  // expected-note @below {{no 'allocate' effect specified for result #0}}
+  transform.test_required_memory_effects %arg0 {has_operand_effect} : (!transform.any_op) -> !transform.any_op
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 338d72e3042db..63d682831114b 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -471,7 +471,9 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
 }
 
 void mlir::test::TestProduceNullParamOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::producesHandle(getOut(), effects);
+}
 
 DiagnosedSilenceableFailure
 mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
@@ -480,6 +482,23 @@ mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
   return DiagnosedSilenceableFailure::success();
 }
 
+void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  if (getHasOperandEffect())
+    transform::consumesHandle(getIn(), effects);
+
+  if (getHasResultEffect())
+    transform::producesHandle(getOut(), effects);
+  else
+    transform::onlyReadsHandle(getOut(), effects);
+}
+
+DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  results.set(getOut().cast<OpResult>(), state.getPayloadOps(getIn()));
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 9ff5e30944e78..02e8a691db3c3 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -352,4 +352,16 @@ def TestProduceNullParamOp
   let cppNamespace = "::mlir::test";
 }
 
+def TestRequiredMemoryEffectsOp
+  : Op<Transform_Dialect, "test_required_memory_effects",
+      [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+       DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins TransformHandleTypeInterface:$in,
+                       UnitAttr:$has_operand_effect,
+                       UnitAttr:$has_result_effect);
+  let results = (outs TransformHandleTypeInterface:$out);
+  let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD


        


More information about the Mlir-commits mailing list