[Mlir-commits] [mlir] 3119677 - [mlir][SCF] Add scf.foreach_thread.parallel_insert_slice canonicalization.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 2 05:00:50 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-02T11:53:25Z
New Revision: 311967701a2a33b2753ec5db6977f3c3ef40c46e

URL: https://github.com/llvm/llvm-project/commit/311967701a2a33b2753ec5db6977f3c3ef40c46e
DIFF: https://github.com/llvm/llvm-project/commit/311967701a2a33b2753ec5db6977f3c3ef40c46e.diff

LOG: [mlir][SCF] Add scf.foreach_thread.parallel_insert_slice canonicalization.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D126761

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index b342b6a057a0..8c9a1e3ad1d8 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -544,6 +544,8 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
       "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
   ];
+  
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 7e743109b6e0..a160ba028928 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1229,6 +1229,45 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
 }
 
+namespace {
+/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
+class ParallelInsertSliceOpConstantArgumentFolder final
+    : public OpRewritePattern<ParallelInsertSliceOp> {
+public:
+  using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
+                                PatternRewriter &rewriter) const override {
+    // No constant operand, just return.
+    if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
+          return matchPattern(operand, matchConstantIndex());
+        }))
+      return failure();
+
+    // At least one of offsets/sizes/strides is a new constant.
+    // Form the new list of operands and constant attributes from the
+    // existing.
+    SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
+    SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
+    SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
+    canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
+    canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
+    canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
+
+    // Create the new op in canonical form.
+    rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
+        insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
+        mixedOffsets, mixedSizes, mixedStrides);
+    return success();
+  }
+};
+} // namespace
+
+void ParallelInsertSliceOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // PerformConcurrentlyOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8e087fc0f38a..ad5afa9c3601 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1458,3 +1458,28 @@ func.func @func_execute_region_elim_multi_yield() {
 // CHECK:   ^[[bb3]](%[[z:.+]]: i64):
 // CHECK:     "test.bar"(%[[z]])
 // CHECK:     return
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
+//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>, 
+//  CHECK-SAME:     %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
+//  CHECK-SAME:     %[[num_threads:[0-9a-z]*]]: index
+func.func @canonicalize_parallel_insert_slice_indices(
+    %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+    %num_threads : index) -> tensor<?x?xf32>
+{
+  %cst = arith.constant 4.200000e+01 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  //      CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
+  // CHECK-NEXT:   scf.foreach_thread.perform_concurrently {
+  // CHECK-NEXT:     scf.foreach_thread.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
+  %2 = scf.foreach_thread (%tidx) in (%num_threads)  -> (tensor<?x?xf32>) {
+    scf.foreach_thread.perform_concurrently {
+      scf.foreach_thread.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
+    }
+  }
+  return %2 : tensor<?x?xf32>
+}


        


More information about the Mlir-commits mailing list