[Mlir-commits] [mlir] 24e33b5 - [mlir] Implement DestinationStyleOpInterface for scf::ForallOp (#66981)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 25 00:06:29 PDT 2023
Author: Felix Schneider
Date: 2023-09-25T09:06:25+02:00
New Revision: 24e33b594503f0283d31b89e2694e3d798def5ac
URL: https://github.com/llvm/llvm-project/commit/24e33b594503f0283d31b89e2694e3d798def5ac
DIFF: https://github.com/llvm/llvm-project/commit/24e33b594503f0283d31b89e2694e3d798def5ac.diff
LOG: [mlir] Implement DestinationStyleOpInterface for scf::ForallOp (#66981)
`scf::ForallOp` has `shared_outs` tensor operands which are used to
insert partial results into in the parallel terminator. The
`scf::ForallOp` returns one tensor for each `shared_out` which then
contains the combined result from all threads. Since the parallel
terminator cannot change the shape of the `shared_out`, ForallOp is a
`DestinationStyleOp` and this patch implements the interface by
declaring the `outputs` operands as `inits` in the language of the DPS
interface.
For this change to work, we need to add an exception to the Pattern that
folds `tensor.cast` Ops into DPS Ops because `scf::Forall` needs special
handling of its `BlockArgument` Type during this folding.
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCF.h
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 915ab3016b688e7..644118ca884c6b1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0c93989ca99a4eb..6130f031ca6ab2d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -17,6 +17,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -333,6 +334,7 @@ def ForallOp : SCF_Op<"forall", [
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ DestinationStyleOpInterface
]> {
let summary = "evaluate a block multiple times in parallel";
let description = [{
@@ -630,6 +632,9 @@ def ForallOp : SCF_Op<"forall", [
Location loc);
InParallelOp getTerminator();
+
+ // Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
+ MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
}];
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3e30e320bee8f83..f719cfed6b6dd30 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
@@ -3970,6 +3971,11 @@ struct FoldTensorCastProducerOp
if (isa<InsertSliceOp>(op.getOperation()))
return failure();
+ // Exclude DPS ops that are also LoopLike from this interface as they
+ // might need special handling of attached regions.
+ if (isa<LoopLikeOpInterface>(op.getOperation()))
+ return failure();
+
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
More information about the Mlir-commits
mailing list