[Mlir-commits] [mlir] 0d9761d - [mlir][SCF] Add tensor.dim(scf.foreach_thread) folding

Matthias Springer llvmlistbot at llvm.org
Tue Nov 22 02:32:05 PST 2022


Author: Matthias Springer
Date: 2022-11-22T11:28:27+01:00
New Revision: 0d9761d50e738163c87d84a4328bc0a827ac8f34

URL: https://github.com/llvm/llvm-project/commit/0d9761d50e738163c87d84a4328bc0a827ac8f34
DIFF: https://github.com/llvm/llvm-project/commit/0d9761d50e738163c87d84a4328bc0a827ac8f34.diff

LOG: [mlir][SCF] Add tensor.dim(scf.foreach_thread) folding

Dim sizes of `scf.foreach_thread` op results match the dim sizes of their respective tied shared_outs operands.

Differential Revision: https://reviews.llvm.org/D138484

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 2d880ac52d2c4..af4db68fd7c87 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -487,6 +487,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
+  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 
@@ -510,11 +511,20 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
           opOperand->getOperandNumber() - getRank());
     }
 
+    /// Return the num_threads operand that is tied to the given thread id
+    /// block argument.
     OpOperand *getTiedOpOperand(BlockArgument bbArg) {
       assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg");
       return &getOperation()->getOpOperand(bbArg.getArgNumber());
     }
 
+    /// Return the shared_outs operand that is tied to the given OpResult.
+    OpOperand *getTiedOpOperand(OpResult opResult) {
+      assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult");
+      return &getOperation()->getOpOperand(
+          opResult.getResultNumber() + getRank());
+    }
+
     BlockArgument getTiedBlockArgument(OpOperand *opOperand) {
       assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
       return getBody()->getArgument(opOperand->getOperandNumber());

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 118452aae10b3..6924107c1c52e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1323,6 +1323,31 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
   return dyn_cast<ForeachThreadOp>(containingOp);
 }
 
+namespace {
+/// Fold tensor.dim(foreach_thread shared_outs(... = %t)) to tensor.dim(%t).
+struct DimOfForeachThreadOp : public OpRewritePattern<tensor::DimOp> {
+  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+                                PatternRewriter &rewriter) const final {
+    auto foreachThreadOp = dimOp.getSource().getDefiningOp<ForeachThreadOp>();
+    if (!foreachThreadOp)
+      return failure();
+    Value sharedOut =
+        foreachThreadOp.getTiedOpOperand(dimOp.getSource().cast<OpResult>())
+            ->get();
+    rewriter.updateRootInPlace(
+        dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
+    return success();
+  }
+};
+} // namespace
+
+void ForeachThreadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                  MLIRContext *context) {
+  results.add<DimOfForeachThreadOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // PerformConcurrentlyOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index b6ac36282fc43..e5e2afcc2f735 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1478,3 +1478,25 @@ func.func @func_execute_region_elim_multi_yield() {
 // CHECK:   ^[[bb3]](%[[z:.+]]: i64):
 // CHECK:     "test.bar"(%[[z]])
 // CHECK:     return
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
+func.func @canonicalize_parallel_insert_slice_indices(
+    %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>, %num_threads : index) -> index
+{
+  // CHECK: %[[c1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+
+  %2 = scf.foreach_thread (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+    scf.foreach_thread.perform_concurrently {
+      tensor.parallel_insert_slice %arg0 into %o[%tidx, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<?x?xf32>
+    }
+  }
+
+  // CHECK: %[[dim:.*]] = tensor.dim %[[arg1]], %[[c1]]
+  %dim = tensor.dim %2, %c1 : tensor<?x?xf32>
+  // CHECK: return %[[dim]]
+  return %dim : index
+}


        


More information about the Mlir-commits mailing list