[Mlir-commits] [mlir] [mlir] Implement DestinationStyleOpInterface for scf::ForallOp (PR #66981)
Felix Schneider
llvmlistbot at llvm.org
Thu Sep 21 00:50:48 PDT 2023
https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/66981
>From a84378d0f4e67cd7f68a27c87e12939fa40b65b4 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Thu, 21 Sep 2023 06:13:57 +0000
Subject: [PATCH 1/2] [mlir] Implement DestinationStyleOpInterface for
scf::ForallOp
`scf::ForallOp` has `shared_outs` tensor operands which are used to
insert partial results into in the parallel terminator. The
`scf::ForallOp` returns one tensor for each `shared_out` which then
contains the combined result from all threads. Since the parallel
terminator cannot change the shape of the `shared_out`, ForallOp is a
`DestinationStyleOp` and this patch implements the interface by declaring
the `outputs` operands as `inits` in the language of the DPS interface.
For this change to work, we need to add an exception to the Pattern
that folds `tensor.cast` Ops into DPS Ops because `scf::Forall` needs
special handling of it's `BlockArgument` Type during this folding.
---
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 1 +
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 10 ++++++++++
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++++
3 files changed, 15 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 915ab3016b688e7..644118ca884c6b1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 08b71e20a2bc079..adc7b2e4170cb89 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -17,6 +17,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -333,6 +334,7 @@ def ForallOp : SCF_Op<"forall", [
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ DestinationStyleOpInterface
]> {
let summary = "evaluate a block multiple times in parallel";
let description = [{
@@ -630,6 +632,14 @@ def ForallOp : SCF_Op<"forall", [
Location loc);
InParallelOp getTerminator();
+
+ // Implement this to declare all shared_outs as inits/outs to
+ // DestinationStyleOpInterface
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ int64_t numOperands = getNumOperands();
+ int64_t numOuts = getOutputs().size();
+ return {numOperands - numOuts, numOperands};
+ }
}];
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3e30e320bee8f83..fa91471f33d4bd3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
@@ -3970,6 +3971,9 @@ struct FoldTensorCastProducerOp
if (isa<InsertSliceOp>(op.getOperation()))
return failure();
+ if (isa<scf::ForallOp>(op.getOperation()))
+ return failure();
+
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
>From 2822e2e9401e0b0a807cb7bc0d5d4a91ddc24beb Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Thu, 21 Sep 2023 07:50:29 +0000
Subject: [PATCH 2/2] add comment clarifying the exception for forallOp folding
with tensor.cast
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa91471f33d4bd3..e28be55a28add75 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3719,8 +3719,7 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
}
/// Returns true if the tiles and the tiled dims are constant.
-template <typename OpTy>
-bool areTilesAndTiledDimsAllConstant(OpTy op) {
+template <typename OpTy> bool areTilesAndTiledDimsAllConstant(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
@@ -3971,6 +3970,8 @@ struct FoldTensorCastProducerOp
if (isa<InsertSliceOp>(op.getOperation()))
return failure();
+ // scf::ForallOp also has its own folding logic taking the Types of its
+ // BlockArguments into consideration.
if (isa<scf::ForallOp>(op.getOperation()))
return failure();
More information about the Mlir-commits
mailing list