[Mlir-commits] [mlir] 3119677 - [mlir][SCF] Add scf.foreach_thread.parallel_insert_slice canonicalization.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jun 2 05:00:50 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-02T11:53:25Z
New Revision: 311967701a2a33b2753ec5db6977f3c3ef40c46e
URL: https://github.com/llvm/llvm-project/commit/311967701a2a33b2753ec5db6977f3c3ef40c46e
DIFF: https://github.com/llvm/llvm-project/commit/311967701a2a33b2753ec5db6977f3c3ef40c46e.diff
LOG: [mlir][SCF] Add scf.foreach_thread.parallel_insert_slice canonicalization.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D126761
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index b342b6a057a0..8c9a1e3ad1d8 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -544,6 +544,8 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 7e743109b6e0..a160ba028928 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1229,6 +1229,45 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
+namespace {
+/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
+class ParallelInsertSliceOpConstantArgumentFolder final
+ : public OpRewritePattern<ParallelInsertSliceOp> {
+public:
+ using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
+ PatternRewriter &rewriter) const override {
+ // No constant operand, just return.
+ if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
+ return matchPattern(operand, matchConstantIndex());
+ }))
+ return failure();
+
+ // At least one of offsets/sizes/strides is a new constant.
+ // Form the new list of operands and constant attributes from the
+ // existing.
+ SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
+ SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
+ SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
+ canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
+ canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
+ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
+
+ // Create the new op in canonical form.
+ rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
+ insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
+ mixedOffsets, mixedSizes, mixedStrides);
+ return success();
+ }
+};
+} // namespace
+
+void ParallelInsertSliceOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// PerformConcurrentlyOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8e087fc0f38a..ad5afa9c3601 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1458,3 +1458,28 @@ func.func @func_execute_region_elim_multi_yield() {
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
// CHECK: "test.bar"(%[[z]])
// CHECK: return
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
+// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
+func.func @canonicalize_parallel_insert_slice_indices(
+ %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %num_threads : index) -> tensor<?x?xf32>
+{
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ // CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
+ // CHECK-NEXT: scf.foreach_thread.perform_concurrently {
+ // CHECK-NEXT: scf.foreach_thread.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
+ %2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor<?x?xf32>) {
+ scf.foreach_thread.perform_concurrently {
+ scf.foreach_thread.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
+ }
+ }
+ return %2 : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list