[Mlir-commits] [mlir] 98dbaed - [mlir][SCF] Fold tensor.cast feeding into scf.foreach_thread.parallel_insert_slice
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jun 21 01:19:44 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-21T01:19:18-07:00
New Revision: 98dbaed1e6317de54140bedfd787f443a61600d5
URL: https://github.com/llvm/llvm-project/commit/98dbaed1e6317de54140bedfd787f443a61600d5
DIFF: https://github.com/llvm/llvm-project/commit/98dbaed1e6317de54140bedfd787f443a61600d5.diff
LOG: [mlir][SCF] Fold tensor.cast feeding into scf.foreach_thread.parallel_insert_slice
Differential Revision: https://reviews.llvm.org/D128247
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index d15c51d4251e..b36f6b7a1dba 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -564,6 +564,7 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 878ddc60cee7..293238e82b16 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1276,6 +1276,40 @@ class ParallelInsertSliceOpConstantArgumentFolder final
};
} // namespace
+/// Fold a parallel_insert_slice source coming from a tensor.cast op.
+///
+/// Example:
+/// ```
+/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
+/// %1 = compute_some_tensor() : tensor<64xf32>
+/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
+/// scf.foreach_thread.perform_concurrently {
+/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
+/// tensor<?xf32> into tensor<128xf32>
+/// }
+/// }
+/// ```
+///
+/// is folded into:
+/// ```
+/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
+/// %1 = compute_some_tensor() : tensor<64xf32>
+/// scf.foreach_thread.perform_concurrently {
+/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
+/// tensor<64xf32> into tensor<128xf32>
+/// }
+/// }
+/// ```
+LogicalResult
+ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
+ if (!sourceCast)
+ return failure();
+ getSourceMutable().assign(sourceCast.getSource());
+ return success();
+}
+
void ParallelInsertSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
diff --git a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
index b65d0c7049ab..688e8738e842 100644
--- a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
+++ b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir
@@ -26,9 +26,8 @@ func.func @reduce() -> tensor<128xf32> {
linalg.yield %14 : f32
} -> tensor<?xf32>
- // TODO: canonicalize this cast away.
- // CHECK: %[[dyn_casted:.*]] = tensor.cast %{{.*}} : tensor<64xf32> to tensor<?xf32>
- // CHECK: scf.foreach_thread.parallel_insert_slice %[[dyn_casted:.*]] into %{{.*}}[%{{.*}}] [64] [1] : tensor<?xf32> into tensor<128xf32>
+ // CHECK-NOT: tensor.cast
+ // CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32>
scf.foreach_thread.perform_concurrently {
scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor<?xf32> into tensor<128xf32>
}
More information about the Mlir-commits
mailing list