[Mlir-commits] [mlir] a489aa7 - [mlir][SCF] Add scf::ForeachThread canonicalization.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jun 21 00:55:14 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-21T00:54:46-07:00
New Revision: a489aa745b621547427602dc4995e1e9ff3fcb57
URL: https://github.com/llvm/llvm-project/commit/a489aa745b621547427602dc4995e1e9ff3fcb57
DIFF: https://github.com/llvm/llvm-project/commit/a489aa745b621547427602dc4995e1e9ff3fcb57.diff
LOG: [mlir][SCF] Add scf::ForeachThread canonicalization.
This revision adds the necessary plumbing for canonicalizing scf::ForeachThread with the
`AffineOpSCFCanonicalizationPattern`.
In the process the `loopMatcher` helper is updated to take OpFoldResult instead of just values.
This allows composing various scenarios without the need for an artificial builder.
Differential Revision: https://reviews.llvm.org/D128244
Added:
mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCF.h
mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 1efa7ef84ff59..2c0dad6382009 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -49,6 +49,10 @@ ForOp getForInductionVarOwner(Value val);
/// value is not an induction variable, then return nullptr.
ParallelOp getParallelForInductionVarOwner(Value val);
+/// Returns the ForeachThreadOp parent of an thread index variable.
+/// If the provided value is not a thread index variable, then return nullptr.
+ForeachThreadOp getForeachThreadOpThreadIndexOwner(Value val);
+
/// Return true if ops a and b (or their ancestors) are in mutually exclusive
/// regions/blocks of an IfOp.
// TODO: Consider moving this functionality to RegionBranchOpInterface.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
index 7e775c5e90621..462d6b5c42412 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
@@ -20,6 +20,7 @@ namespace mlir {
class AffineMap;
struct LogicalResult;
class Operation;
+class OpFoldResult;
class RewriterBase;
class Value;
class ValueRange;
@@ -32,8 +33,8 @@ class IfOp;
/// step size via the last parameter. The function should return `success` in
/// that case. If the first parameter is not an iteration variable, return
/// `failure`.
-using LoopMatcherFn =
- function_ref<LogicalResult(Value, Value &, Value &, Value &)>;
+using LoopMatcherFn = function_ref<LogicalResult(
+ Value, OpFoldResult &, OpFoldResult &, OpFoldResult &)>;
/// Try to canonicalize an min/max operations in the context of for `loops` with
/// a known range.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 012499f7dad38..878ddc60cee70 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1194,6 +1194,15 @@ PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
}
+ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
+ auto tidxArg = val.dyn_cast<BlockArgument>();
+ if (!tidxArg)
+ return ForeachThreadOp();
+ assert(tidxArg.getOwner() && "unlinked block argument");
+ auto *containingOp = tidxArg.getOwner()->getParentOp();
+ return dyn_cast<ForeachThreadOp>(containingOp);
+}
+
//===----------------------------------------------------------------------===//
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 0f511af14811d..eda6bc6e1cf8b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -138,7 +138,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
unsigned resultNumber = opResult.getResultNumber();
if (!isShapePreserving(forOp, resultNumber))
return failure();
- rewriter.updateRootInPlace(dimOp, [&](){
+ rewriter.updateRootInPlace(dimOp, [&]() {
dimOp.sourceMutable().assign(forOp.getIterOperands()[resultNumber]);
});
return success();
@@ -153,7 +153,8 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
+ auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub,
+ OpFoldResult &step) {
if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
lb = forOp.getLowerBound();
ub = forOp.getUpperBound();
@@ -171,6 +172,18 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
}
return failure();
}
+ if (scf::ForeachThreadOp foreachThreadOp =
+ scf::getForeachThreadOpThreadIndexOwner(iv)) {
+ for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) {
+ if (foreachThreadOp.getThreadIndices()[idx] == iv) {
+ lb = OpBuilder(iv.getContext()).getIndexAttr(0);
+ ub = foreachThreadOp.getNumThreads()[idx];
+ step = OpBuilder(iv.getContext()).getIndexAttr(1);
+ return success();
+ }
+ }
+ return failure();
+ }
return failure();
};
diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
index 6c28cc3d83d87..958b5a2757148 100644
--- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
@@ -201,7 +201,7 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
static LogicalResult
addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv,
- Value lb, Value ub, Value step,
+ OpFoldResult lb, OpFoldResult ub, OpFoldResult step,
RewriterBase &rewriter) {
// IntegerPolyhedron does not support semi-affine expressions.
// Therefore, only constant step values are supported.
@@ -210,8 +210,12 @@ addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv,
return failure();
unsigned dimIv = constraints.appendDimId(iv);
- unsigned dimLb = constraints.appendDimId(lb);
- unsigned dimUb = constraints.appendDimId(ub);
+ auto lbv = lb.dyn_cast<Value>();
+ unsigned dimLb =
+ lbv ? constraints.appendDimId(lbv) : constraints.appendDimId(/*num=*/1);
+ auto ubv = ub.dyn_cast<Value>();
+ unsigned dimUb =
+ ubv ? constraints.appendDimId(ubv) : constraints.appendDimId(/*num=*/1);
// If loop lower/upper bounds are constant: Add EQ constraint.
Optional<int64_t> lbInt = getConstantIntValue(lb);
@@ -276,7 +280,7 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter,
// If `operand` is an iteration variable: Find corresponding loop
// bounds and step.
Value iv = operand;
- Value lb, ub, step;
+ OpFoldResult lb, ub, step;
if (failed(loopMatcher(operand, lb, ub, step)))
continue;
allIvs.insert(iv);
diff --git a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
new file mode 100644
index 0000000000000..b65d0c7049ab6
--- /dev/null
+++ b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -scf-for-loop-canonicalization -canonicalize | FileCheck %s
+
+func.func @reduce() -> tensor<128xf32> {
+ %c2 = arith.constant 2 : index
+ %cst = arith.constant dense<1.000000e+00> : tensor<1x128x384xf32>
+ %cst_0 = arith.constant -0.000000e+00 : f32
+ %0 = linalg.init_tensor [128, 384] : tensor<128x384xf32>
+ %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x384xf32>) -> tensor<128x384xf32>
+ %2 = linalg.init_tensor [128] : tensor<128xf32>
+ %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
+ %4 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
+ %7 = affine.min affine_map<(d0) -> (d0 * -64 + 128, 64)>(%arg0)
+ %8 = affine.max affine_map<(d0) -> (0, d0)>(%7)
+ %9 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg0)
+ %10 = affine.min affine_map<(d0, d1) -> (d1 * -64 + 128, d0)>(%8, %arg0)
+
+ // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}, 0] [64, 384] [1, 1] : tensor<128x384xf32> to tensor<64x384xf32>
+ // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [64] [1] : tensor<128xf32> to tensor<64xf32>
+ %11 = tensor.extract_slice %1[%9, 0] [%10, 384] [1, 1] : tensor<128x384xf32> to tensor<?x384xf32>
+ %12 = tensor.extract_slice %3[%9] [%10] [1] : tensor<128xf32> to tensor<?xf32>
+
+ // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<64x384xf32>) outs(%{{.*}} : tensor<64xf32>) {
+ %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%11 : tensor<?x384xf32>) outs(%12 : tensor<?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %14 = arith.addf %arg1, %arg2 : f32
+ linalg.yield %14 : f32
+ } -> tensor<?xf32>
+
+ // TODO: canonicalize this cast away.
+ // CHECK: %[[dyn_casted:.*]] = tensor.cast %{{.*}} : tensor<64xf32> to tensor<?xf32>
+ // CHECK: scf.foreach_thread.parallel_insert_slice %[[dyn_casted:.*]] into %{{.*}}[%{{.*}}] [64] [1] : tensor<?xf32> into tensor<128xf32>
+ scf.foreach_thread.perform_concurrently {
+ scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
+ }
+ }
+ return %4 : tensor<128xf32>
+}
More information about the Mlir-commits
mailing list