[Mlir-commits] [mlir] 06e972e - [mlir][linalg] Fix `FoldTensorCastProducerOp` for generic with memref output

Ivan Butygin llvmlistbot at llvm.org
Wed Nov 16 14:01:15 PST 2022


Author: Ivan Butygin
Date: 2022-11-16T22:59:54+01:00
New Revision: 06e972ed91e6d173025dc122d202f546d1a5e8ce

URL: https://github.com/llvm/llvm-project/commit/06e972ed91e6d173025dc122d202f546d1a5e8ce
DIFF: https://github.com/llvm/llvm-project/commit/06e972ed91e6d173025dc122d202f546d1a5e8ce.diff

LOG: [mlir][linalg] Fix `FoldTensorCastProducerOp` for generic with memref output

Type should only be added to results if it is tensor.

Differential Revision: https://reviews.llvm.org/D137801

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ec1c60386fb90..18e399e5211fa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1799,7 +1799,8 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
       auto tensorCastOp = output->get().getDefiningOp<tensor::CastOp>();
       bool fold = canFoldIntoConsumerOp(tensorCastOp);
       newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get());
-      newResultTypes.push_back(newOperands.back().getType());
+      if (!newOperands.back().getType().isa<MemRefType>())
+        newResultTypes.push_back(newOperands.back().getType());
     }
     // Clone op.
     Operation *newOp =

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 55013c48e97dd..c9f1726f7020d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -845,3 +845,28 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
   } -> tensor<4xf32>
   return
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
+  %0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32>
+  linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel"]
+  } ins(%0 : tensor<?xf32>)
+    outs(%arg1 : memref<?xf32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32):
+    linalg.yield %arg2 : f32
+  }
+  return
+}
+
+// We need a mixed linalg as a bridge between tensor and memref worlds.
+// CHECK-LABEL: func @cast_producer_mixed
+//  CHECK-SAME:     (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>)
+//       CHECK:     linalg.generic {
+//  CHECK-SAME:    indexing_maps = [#map, #map],
+//  CHECK-SAME:    iterator_types = ["parallel"]
+//  CHECK-SAME:  } ins(%[[ARG1]] : tensor<5xf32>)
+//  CHECK-SAME:    outs(%[[ARG2]] : memref<?xf32>) {


        


More information about the Mlir-commits mailing list