[Mlir-commits] [mlir] 98dbaed - [mlir][SCF] Fold tensor.cast feeding into scf.foreach_thread.parallel_insert_slice

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jun 21 01:19:44 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-21T01:19:18-07:00
New Revision: 98dbaed1e6317de54140bedfd787f443a61600d5

URL: https://github.com/llvm/llvm-project/commit/98dbaed1e6317de54140bedfd787f443a61600d5
DIFF: https://github.com/llvm/llvm-project/commit/98dbaed1e6317de54140bedfd787f443a61600d5.diff

LOG: [mlir][SCF] Fold tensor.cast feeding into scf.foreach_thread.parallel_insert_slice

Differential Revision: https://reviews.llvm.org/D128247

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index d15c51d4251e..b36f6b7a1dba 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -564,6 +564,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
   ];
   
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 878ddc60cee7..293238e82b16 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1276,6 +1276,40 @@ class ParallelInsertSliceOpConstantArgumentFolder final
 };
 } // namespace
 
+/// Fold a parallel_insert_slice source coming from a tensor.cast op.
+///
+/// Example:
+/// ```
+/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
+///   %1 = compute_some_tensor() : tensor<64xf32>
+///   %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
+///   scf.foreach_thread.perform_concurrently {
+///     scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
+///        tensor<?xf32> into tensor<128xf32>
+///   }
+/// }
+/// ```
+///
+/// is folded into:
+/// ```
+/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
+///   %1 = compute_some_tensor() : tensor<64xf32>
+///   scf.foreach_thread.perform_concurrently {
+///     scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
+///        tensor<64xf32> into tensor<128xf32>
+///   }
+/// }
+/// ```
+LogicalResult
+ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
+                            SmallVectorImpl<OpFoldResult> &results) {
+  auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
+  if (!sourceCast)
+    return failure();
+  getSourceMutable().assign(sourceCast.getSource());
+  return success();
+}
+
 void ParallelInsertSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);

diff  --git a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
index b65d0c7049ab..688e8738e842 100644
--- a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
+++ b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
@@ -26,9 +26,8 @@ func.func @reduce() -> tensor<128xf32> {
       linalg.yield %14 : f32
     } -> tensor<?xf32>
 
-    // TODO: canonicalize this cast away.
-    // CHECK: %[[dyn_casted:.*]] = tensor.cast %{{.*}} : tensor<64xf32> to tensor<?xf32>
-    // CHECK: scf.foreach_thread.parallel_insert_slice %[[dyn_casted:.*]] into %{{.*}}[%{{.*}}] [64] [1] : tensor<?xf32> into tensor<128xf32>
+    // CHECK-NOT: tensor.cast
+    // CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32>
     scf.foreach_thread.perform_concurrently {
       scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
     }


        


More information about the Mlir-commits mailing list