[Mlir-commits] [mlir] [mlir] Implement DestinationStyleOpInterface for scf::ForallOp (PR #66981)

Felix Schneider llvmlistbot at llvm.org
Sun Sep 24 23:31:41 PDT 2023


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/66981

>From e847677603fadb26b1772eddbd84371f7b21c8a7 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Thu, 21 Sep 2023 06:13:57 +0000
Subject: [PATCH 1/4] [mlir] Implement DestinationStyleOpInterface for
 scf::ForallOp

`scf::ForallOp` has `shared_outs` tensor operands which are used to
insert partial results into in the parallel terminator. The
`scf::ForallOp` returns one tensor for each `shared_out` which then
contains the combined result from all threads. Since the parallel
terminator cannot change the shape of the `shared_out`, ForallOp is a
`DestinationStyleOp` and this patch implements the interface by declaring
the `outputs` operands as `inits` in the language of the DPS interface.

For this change to work, we need to add an exception to the Pattern
that folds `tensor.cast` Ops into DPS Ops because `scf::Forall` needs
special handling of it's `BlockArgument` Type during this folding.
---
 mlir/include/mlir/Dialect/SCF/IR/SCF.h     |  1 +
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 10 ++++++++++
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   |  4 ++++
 3 files changed, 15 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 915ab3016b688e7..644118ca884c6b1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/RegionKindInterface.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0c93989ca99a4eb..a0894c81ad77955 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -17,6 +17,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/IR/RegionKindInterface.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -333,6 +334,7 @@ def ForallOp : SCF_Op<"forall", [
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       DestinationStyleOpInterface
      ]> {
   let summary = "evaluate a block multiple times in parallel";
   let description = [{
@@ -630,6 +632,14 @@ def ForallOp : SCF_Op<"forall", [
                                  Location loc);
 
     InParallelOp getTerminator();
+
+    // Implement this to declare all shared_outs as inits/outs to 
+    // DestinationStyleOpInterface
+    std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+      int64_t numOperands = getNumOperands();
+      int64_t numOuts = getOutputs().size();
+      return {numOperands - numOuts, numOperands};
+    }
   }];
 }
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3e30e320bee8f83..fa91471f33d4bd3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
@@ -3970,6 +3971,9 @@ struct FoldTensorCastProducerOp
     if (isa<InsertSliceOp>(op.getOperation()))
       return failure();
 
+    if (isa<scf::ForallOp>(op.getOperation()))
+      return failure();
+
     // If no operand comes from a tensor::CastOp and can be folded then fail.
     bool hasTensorCastOperand =
         llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {

>From 614e91ab4ea1b585bb622a7ec482a625051c9d6d Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Thu, 21 Sep 2023 07:52:23 +0000
Subject: [PATCH 2/4] add comment clarifying the exception for forallOp folding
 with tensor.cast

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa91471f33d4bd3..65a834acad8acfb 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3971,6 +3971,8 @@ struct FoldTensorCastProducerOp
     if (isa<InsertSliceOp>(op.getOperation()))
       return failure();
 
+    // scf::ForallOp also has its own folding logic taking the Types of its
+    // BlockArguments into consideration.
     if (isa<scf::ForallOp>(op.getOperation()))
       return failure();
 

>From d511c715621c5beaf7b67d5cfc61692e7fc0ed1d Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Fri, 22 Sep 2023 06:45:52 +0000
Subject: [PATCH 3/4] Exclude LoopLikeInterfaceOps instead of scf::ForallOp
 specifically

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 65a834acad8acfb..f719cfed6b6dd30 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
@@ -23,6 +22,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
@@ -3971,9 +3971,9 @@ struct FoldTensorCastProducerOp
     if (isa<InsertSliceOp>(op.getOperation()))
       return failure();
 
-    // scf::ForallOp also has its own folding logic taking the Types of its
-    // BlockArguments into consideration.
-    if (isa<scf::ForallOp>(op.getOperation()))
+    // Exclude DPS ops that are also LoopLike from this interface as they
+    // might need special handling of attached regions.
+    if (isa<LoopLikeOpInterface>(op.getOperation()))
       return failure();
 
     // If no operand comes from a tensor::CastOp and can be folded then fail.

>From 6420451da304272c06cb5f0f1f88b2fca6b35dd3 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Fri, 22 Sep 2023 11:52:31 +0000
Subject: [PATCH 4/4] Update DPS interface implementation

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 9 ++-------
 1 file changed, 2 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a0894c81ad77955..6130f031ca6ab2d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -633,13 +633,8 @@ def ForallOp : SCF_Op<"forall", [
 
     InParallelOp getTerminator();
 
-    // Implement this to declare all shared_outs as inits/outs to 
-    // DestinationStyleOpInterface
-    std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
-      int64_t numOperands = getNumOperands();
-      int64_t numOuts = getOutputs().size();
-      return {numOperands - numOuts, numOperands};
-    }
+    // Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
+    MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
   }];
 }
 



More information about the Mlir-commits mailing list