[Mlir-commits] [mlir] 0d9761d - [mlir][SCF] Add tensor.dim(scf.foreach_thread) folding
Matthias Springer
llvmlistbot at llvm.org
Tue Nov 22 02:32:05 PST 2022
Author: Matthias Springer
Date: 2022-11-22T11:28:27+01:00
New Revision: 0d9761d50e738163c87d84a4328bc0a827ac8f34
URL: https://github.com/llvm/llvm-project/commit/0d9761d50e738163c87d84a4328bc0a827ac8f34
DIFF: https://github.com/llvm/llvm-project/commit/0d9761d50e738163c87d84a4328bc0a827ac8f34.diff
LOG: [mlir][SCF] Add tensor.dim(scf.foreach_thread) folding
Dim sizes of `scf.foreach_thread` op results match the dim sizes of their respective tied shared_outs operands.
Differential Revision: https://reviews.llvm.org/D138484
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 2d880ac52d2c4..af4db68fd7c87 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -487,6 +487,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
+ let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
@@ -510,11 +511,20 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
opOperand->getOperandNumber() - getRank());
}
+ /// Return the num_threads operand that is tied to the given thread id
+ /// block argument.
OpOperand *getTiedOpOperand(BlockArgument bbArg) {
assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg");
return &getOperation()->getOpOperand(bbArg.getArgNumber());
}
+ /// Return the shared_outs operand that is tied to the given OpResult.
+ OpOperand *getTiedOpOperand(OpResult opResult) {
+ assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult");
+ return &getOperation()->getOpOperand(
+ opResult.getResultNumber() + getRank());
+ }
+
BlockArgument getTiedBlockArgument(OpOperand *opOperand) {
assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
return getBody()->getArgument(opOperand->getOperandNumber());
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 118452aae10b3..6924107c1c52e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1323,6 +1323,31 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
return dyn_cast<ForeachThreadOp>(containingOp);
}
+namespace {
+/// Fold tensor.dim(foreach_thread shared_outs(... = %t)) to tensor.dim(%t).
+struct DimOfForeachThreadOp : public OpRewritePattern<tensor::DimOp> {
+ using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+ PatternRewriter &rewriter) const final {
+ auto foreachThreadOp = dimOp.getSource().getDefiningOp<ForeachThreadOp>();
+ if (!foreachThreadOp)
+ return failure();
+ Value sharedOut =
+ foreachThreadOp.getTiedOpOperand(dimOp.getSource().cast<OpResult>())
+ ->get();
+ rewriter.updateRootInPlace(
+ dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
+ return success();
+ }
+};
+} // namespace
+
+void ForeachThreadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<DimOfForeachThreadOp>(context);
+}
+
//===----------------------------------------------------------------------===//
// PerformConcurrentlyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index b6ac36282fc43..e5e2afcc2f735 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1478,3 +1478,25 @@ func.func @func_execute_region_elim_multi_yield() {
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
// CHECK: "test.bar"(%[[z]])
// CHECK: return
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
+// CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
+func.func @canonicalize_parallel_insert_slice_indices(
+ %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>, %num_threads : index) -> index
+{
+ // CHECK: %[[c1:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+
+ %2 = scf.foreach_thread (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+ scf.foreach_thread.perform_concurrently {
+ tensor.parallel_insert_slice %arg0 into %o[%tidx, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<?x?xf32>
+ }
+ }
+
+ // CHECK: %[[dim:.*]] = tensor.dim %[[arg1]], %[[c1]]
+ %dim = tensor.dim %2, %c1 : tensor<?x?xf32>
+ // CHECK: return %[[dim]]
+ return %dim : index
+}
More information about the Mlir-commits
mailing list