[Mlir-commits] [mlir] a489aa7 - [mlir][SCF] Add scf::ForeachThread canonicalization.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jun 21 00:55:14 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-21T00:54:46-07:00
New Revision: a489aa745b621547427602dc4995e1e9ff3fcb57

URL: https://github.com/llvm/llvm-project/commit/a489aa745b621547427602dc4995e1e9ff3fcb57
DIFF: https://github.com/llvm/llvm-project/commit/a489aa745b621547427602dc4995e1e9ff3fcb57.diff

LOG: [mlir][SCF] Add scf::ForeachThread canonicalization.

This revision adds the necessary plumbing for canonicalizing scf::ForeachThread with the
`AffineOpSCFCanonicalizationPattern`.
In the process the `loopMatcher` helper is updated to take OpFoldResult instead of just values.
This allows composing various scenarios without the need for an artificial builder.

Differential Revision: https://reviews.llvm.org/D128244

Added: 
    mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCF.h
    mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
    mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 1efa7ef84ff59..2c0dad6382009 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -49,6 +49,10 @@ ForOp getForInductionVarOwner(Value val);
 /// value is not an induction variable, then return nullptr.
 ParallelOp getParallelForInductionVarOwner(Value val);
 
+/// Returns the ForeachThreadOp parent of an thread index variable.
+/// If the provided value is not a thread index variable, then return nullptr.
+ForeachThreadOp getForeachThreadOpThreadIndexOwner(Value val);
+
 /// Return true if ops a and b (or their ancestors) are in mutually exclusive
 /// regions/blocks of an IfOp.
 // TODO: Consider moving this functionality to RegionBranchOpInterface.

diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
index 7e775c5e90621..462d6b5c42412 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
@@ -20,6 +20,7 @@ namespace mlir {
 class AffineMap;
 struct LogicalResult;
 class Operation;
+class OpFoldResult;
 class RewriterBase;
 class Value;
 class ValueRange;
@@ -32,8 +33,8 @@ class IfOp;
 /// step size via the last parameter. The function should return `success` in
 /// that case. If the first parameter is not an iteration variable, return
 /// `failure`.
-using LoopMatcherFn =
-    function_ref<LogicalResult(Value, Value &, Value &, Value &)>;
+using LoopMatcherFn = function_ref<LogicalResult(
+    Value, OpFoldResult &, OpFoldResult &, OpFoldResult &)>;
 
 /// Try to canonicalize an min/max operations in the context of for `loops` with
 /// a known range.

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 012499f7dad38..878ddc60cee70 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1194,6 +1194,15 @@ PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
   return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
 }
 
+ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
+  auto tidxArg = val.dyn_cast<BlockArgument>();
+  if (!tidxArg)
+    return ForeachThreadOp();
+  assert(tidxArg.getOwner() && "unlinked block argument");
+  auto *containingOp = tidxArg.getOwner()->getParentOp();
+  return dyn_cast<ForeachThreadOp>(containingOp);
+}
+
 //===----------------------------------------------------------------------===//
 // ParallelInsertSliceOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 0f511af14811d..eda6bc6e1cf8b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -138,7 +138,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
     unsigned resultNumber = opResult.getResultNumber();
     if (!isShapePreserving(forOp, resultNumber))
       return failure();
-    rewriter.updateRootInPlace(dimOp, [&](){
+    rewriter.updateRootInPlace(dimOp, [&]() {
       dimOp.sourceMutable().assign(forOp.getIterOperands()[resultNumber]);
     });
     return success();
@@ -153,7 +153,8 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
+    auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub,
+                          OpFoldResult &step) {
       if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
         lb = forOp.getLowerBound();
         ub = forOp.getUpperBound();
@@ -171,6 +172,18 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
         }
         return failure();
       }
+      if (scf::ForeachThreadOp foreachThreadOp =
+              scf::getForeachThreadOpThreadIndexOwner(iv)) {
+        for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) {
+          if (foreachThreadOp.getThreadIndices()[idx] == iv) {
+            lb = OpBuilder(iv.getContext()).getIndexAttr(0);
+            ub = foreachThreadOp.getNumThreads()[idx];
+            step = OpBuilder(iv.getContext()).getIndexAttr(1);
+            return success();
+          }
+        }
+        return failure();
+      }
       return failure();
     };
 

diff  --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
index 6c28cc3d83d87..958b5a2757148 100644
--- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
@@ -201,7 +201,7 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
 
 static LogicalResult
 addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv,
-                        Value lb, Value ub, Value step,
+                        OpFoldResult lb, OpFoldResult ub, OpFoldResult step,
                         RewriterBase &rewriter) {
   // IntegerPolyhedron does not support semi-affine expressions.
   // Therefore, only constant step values are supported.
@@ -210,8 +210,12 @@ addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv,
     return failure();
 
   unsigned dimIv = constraints.appendDimId(iv);
-  unsigned dimLb = constraints.appendDimId(lb);
-  unsigned dimUb = constraints.appendDimId(ub);
+  auto lbv = lb.dyn_cast<Value>();
+  unsigned dimLb =
+      lbv ? constraints.appendDimId(lbv) : constraints.appendDimId(/*num=*/1);
+  auto ubv = ub.dyn_cast<Value>();
+  unsigned dimUb =
+      ubv ? constraints.appendDimId(ubv) : constraints.appendDimId(/*num=*/1);
 
   // If loop lower/upper bounds are constant: Add EQ constraint.
   Optional<int64_t> lbInt = getConstantIntValue(lb);
@@ -276,7 +280,7 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter,
     // If `operand` is an iteration variable: Find corresponding loop
     // bounds and step.
     Value iv = operand;
-    Value lb, ub, step;
+    OpFoldResult lb, ub, step;
     if (failed(loopMatcher(operand, lb, ub, step)))
       continue;
     allIvs.insert(iv);

diff  --git a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
new file mode 100644
index 0000000000000..b65d0c7049ab6
--- /dev/null
+++ b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -scf-for-loop-canonicalization -canonicalize | FileCheck %s
+
+func.func @reduce() -> tensor<128xf32> {
+  %c2 = arith.constant 2 : index
+  %cst = arith.constant dense<1.000000e+00> : tensor<1x128x384xf32>
+  %cst_0 = arith.constant -0.000000e+00 : f32
+  %0 = linalg.init_tensor [128, 384] : tensor<128x384xf32>
+  %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x384xf32>) -> tensor<128x384xf32>
+  %2 = linalg.init_tensor [128] : tensor<128xf32>
+  %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
+  %4 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
+    %7 = affine.min affine_map<(d0) -> (d0 * -64 + 128, 64)>(%arg0)
+    %8 = affine.max affine_map<(d0) -> (0, d0)>(%7)
+    %9 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg0)
+    %10 = affine.min affine_map<(d0, d1) -> (d1 * -64 + 128, d0)>(%8, %arg0)
+
+    // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}, 0] [64, 384] [1, 1] : tensor<128x384xf32> to tensor<64x384xf32>
+    // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [64] [1] : tensor<128xf32> to tensor<64xf32>
+    %11 = tensor.extract_slice %1[%9, 0] [%10, 384] [1, 1] : tensor<128x384xf32> to tensor<?x384xf32>
+    %12 = tensor.extract_slice %3[%9] [%10] [1] : tensor<128xf32> to tensor<?xf32>
+
+    // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<64x384xf32>) outs(%{{.*}} : tensor<64xf32>) {
+    %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%11 : tensor<?x384xf32>) outs(%12 : tensor<?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %14 = arith.addf %arg1, %arg2 : f32
+      linalg.yield %14 : f32
+    } -> tensor<?xf32>
+
+    // TODO: canonicalize this cast away.
+    // CHECK: %[[dyn_casted:.*]] = tensor.cast %{{.*}} : tensor<64xf32> to tensor<?xf32>
+    // CHECK: scf.foreach_thread.parallel_insert_slice %[[dyn_casted:.*]] into %{{.*}}[%{{.*}}] [64] [1] : tensor<?xf32> into tensor<128xf32>
+    scf.foreach_thread.perform_concurrently {
+      scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
+    }
+  }
+  return %4 : tensor<128xf32>
+}


        


More information about the Mlir-commits mailing list