[Mlir-commits] [mlir] 06e972e - [mlir][linalg] Fix `FoldTensorCastProducerOp` for generic with memref output
Ivan Butygin
llvmlistbot at llvm.org
Wed Nov 16 14:01:15 PST 2022
Author: Ivan Butygin
Date: 2022-11-16T22:59:54+01:00
New Revision: 06e972ed91e6d173025dc122d202f546d1a5e8ce
URL: https://github.com/llvm/llvm-project/commit/06e972ed91e6d173025dc122d202f546d1a5e8ce
DIFF: https://github.com/llvm/llvm-project/commit/06e972ed91e6d173025dc122d202f546d1a5e8ce.diff
LOG: [mlir][linalg] Fix `FoldTensorCastProducerOp` for generic with memref output
Type should only be added to results if it is tensor.
Differential Revision: https://reviews.llvm.org/D137801
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ec1c60386fb90..18e399e5211fa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1799,7 +1799,8 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
auto tensorCastOp = output->get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get());
- newResultTypes.push_back(newOperands.back().getType());
+ if (!newOperands.back().getType().isa<MemRefType>())
+ newResultTypes.push_back(newOperands.back().getType());
}
// Clone op.
Operation *newOp =
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 55013c48e97dd..c9f1726f7020d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -845,3 +845,28 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
} -> tensor<4xf32>
return
}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
+ %0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32>
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel"]
+ } ins(%0 : tensor<?xf32>)
+ outs(%arg1 : memref<?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ }
+ return
+}
+
+// We need a mixed linalg as a bridge between tensor and memref worlds.
+// CHECK-LABEL: func @cast_producer_mixed
+// CHECK-SAME: (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>)
+// CHECK: linalg.generic {
+// CHECK-SAME: indexing_maps = [#map, #map],
+// CHECK-SAME: iterator_types = ["parallel"]
+// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>)
+// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
More information about the Mlir-commits
mailing list