[Mlir-commits] [mlir] 88b7e8e - [mlir][SCF] Add an scf.take_assumed_branch transform op.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Apr 12 08:51:03 PDT 2023
Author: Nicolas Vasilache
Date: 2023-04-12T08:47:20-07:00
New Revision: 88b7e8e0f06dd22d228d7fa7eb7e4d112342e3ed
URL: https://github.com/llvm/llvm-project/commit/88b7e8e0f06dd22d228d7fa7eb7e4d112342e3ed
DIFF: https://github.com/llvm/llvm-project/commit/88b7e8e0f06dd22d228d7fa7eb7e4d112342e3ed.diff
LOG: [mlir][SCF] Add an scf.take_assumed_branch transform op.
Given an scf.if conditional, using this transformation is akin to injecting
user-specified information that it is always safe to execute only the specified
`if` or `else` branch.
This is achieved by just replacing the scf.if by the content of one of its
branches.
This is particularly useful for user-controlled rewriting of conditionals
that exist solely to guard against out-of-bounds behavior.
At the moment, no assume or assert operation is emitted as it is not always
desirable. In the future, this may be controlled by a dedicated attribute.
Differential Revision: https://reviews.llvm.org/D148125
Added:
mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir
Modified:
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
index 91e42140dc119..c5cc2da232846 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h
@@ -19,6 +19,7 @@ class FuncOp;
} // namespace func
namespace scf {
class ForOp;
+class IfOp;
} // namespace scf
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index b286850ad9895..0399a5a9afa5e 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -215,4 +215,45 @@ def LoopCoalesceOp : Op<Transform_Dialect, "loop.coalesce", [
}];
}
+def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let description = [{
+ Given an scf.if conditional, inject user-defined information that it is
+ always safe to execute only the if or else branch.
+
+ This is achieved by just replacing the scf.if by the content of one of its
+ branches.
+
+ This is particularly useful for user-controlled rewriting of conditionals
+ that exist solely to guard against out-of-bounds behavior.
+
+ At the moment, no assume or assert operation is emitted as it is not always
+ desirable. In the future, this may be controlled by a dedicated attribute.
+
+ #### Return modes
+
+ The transform only consumes its operand and does not produce any result.
+ The transform definitely fails if `take_else_branch` is specified and the
+ `else` region is empty.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ OptionalAttr<UnitAttr>:$take_else_branch);
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $target
+ (`take_else_branch` $take_else_branch^)?
+ attr-dict
+ `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::scf::IfOp ifOp,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // SCF_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index b35e1045b2ee0..b87fc777cdc4b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
using namespace mlir;
@@ -245,6 +246,46 @@ transform::LoopCoalesceOp::applyToOne(Operation *op,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// TakeAssumedBranchOp
+//===----------------------------------------------------------------------===//
+/// Replaces the given op with the contents of the given single-block region,
+/// using the operands of the block terminator to replace operation results.
+static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
+ Region ®ion) {
+ assert(llvm::hasSingleElement(region) && "expected single-region block");
+ Block *block = ®ion.front();
+ Operation *terminator = block->getTerminator();
+ ValueRange results = terminator->getOperands();
+ rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{});
+ rewriter.replaceOp(op, results);
+ rewriter.eraseOp(terminator);
+}
+
+DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
+ scf::IfOp ifOp, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ TrackingListener listener(state, *this);
+ IRRewriter rewriter(ifOp->getContext(), &listener);
+ rewriter.setInsertionPoint(ifOp);
+
+ Region ®ion =
+ getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
+ if (!llvm::hasSingleElement(region)) {
+ return emitDefiniteFailure()
+ << "requires an scf.if op with a single-block "
+ << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
+ }
+ replaceOpWithRegion(rewriter, ifOp, region);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::TakeAssumedBranchOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir b/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir
new file mode 100644
index 0000000000000..15d9e56ad099a
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics --allow-unregistered-dialect | FileCheck %s
+
+func.func @if_no_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
+ scf.if %cond {
+ "some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
+ scf.yield
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+ %if = transform.structured.match ops{["scf.if"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+
+ // expected-error @+1 {{requires an scf.if op with a single-block `else` region}}
+ transform.scf.take_assumed_branch %if take_else_branch
+ : (!transform.any_op) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: tile_tensor_pad
+func.func @tile_tensor_pad(
+ %arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index)
+ -> tensor<20x40xf32>
+{
+ // CHECK: scf.forall
+ // CHECK-NOT: scf.if
+ // CHECK-NOT: tensor.generate
+ // CHECK-NOT: else
+ // CHECK: tensor.pad {{.*}} nofold
+ %0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
+ ^bb0(%arg9: index, %arg10: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<20x40xf32>
+ return %0 : tensor<20x40xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ : (!transform.any_op) -> !pdl.operation
+ transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
+
+ %if = transform.structured.match ops{["scf.if"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.scf.take_assumed_branch %if take_else_branch
+ : (!transform.any_op) -> ()
+}
More information about the Mlir-commits
mailing list