[Mlir-commits] [mlir] 966b720 - [mlir][memref] Fix expanded shape ops memref.cast folding with changed type
Benjamin Kramer
llvmlistbot at llvm.org
Mon Nov 22 13:56:35 PST 2021
Author: Benjamin Kramer
Date: 2021-11-22T22:56:15+01:00
New Revision: 966b72098363d44adf2882b9c34fcdbe344ff913
URL: https://github.com/llvm/llvm-project/commit/966b72098363d44adf2882b9c34fcdbe344ff913
DIFF: https://github.com/llvm/llvm-project/commit/966b72098363d44adf2882b9c34fcdbe344ff913.diff
LOG: [mlir][memref] Fix expanded shape ops memref.cast folding with changed type
`memref.expand_shape` has verification logic to make sure
result dim must be static if all the collapsing src dims are static.
This can be relaxed once expand_shape supports more dynamism.
Differential Revision: https://reviews.llvm.org/D114391
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0837bd2f888ec..11674d51404eb 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1640,8 +1640,6 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
CollapseShapeOpMemRefCastFolder>(context);
}
OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
- if (succeeded(foldMemRefCast(*this)))
- return getResult();
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
}
OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 20d95267e7188..b7569adfdc2ee 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -600,6 +600,18 @@ func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
// -----
+func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<?x?xf32> to memref<8x4xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2]]
+ : memref<8x4xf32> into memref<2x4x4xf32>
+ return %1 : memref<2x4x4xf32>
+}
+
+// CHECK-LABEL: @fold_memref_expand_cast
+// CHECK: memref.expand_shape
+
+// -----
+
// CHECK-LABEL: func @collapse_after_memref_cast_type_change(
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
More information about the Mlir-commits
mailing list