[Mlir-commits] [mlir] 966b720 - [mlir][memref] Fix expanded shape ops memref.cast folding with changed type

Benjamin Kramer llvmlistbot at llvm.org
Mon Nov 22 13:56:35 PST 2021


Author: Benjamin Kramer
Date: 2021-11-22T22:56:15+01:00
New Revision: 966b72098363d44adf2882b9c34fcdbe344ff913

URL: https://github.com/llvm/llvm-project/commit/966b72098363d44adf2882b9c34fcdbe344ff913
DIFF: https://github.com/llvm/llvm-project/commit/966b72098363d44adf2882b9c34fcdbe344ff913.diff

LOG: [mlir][memref] Fix expanded shape ops memref.cast folding with changed type

`memref.expand_shape` has verification logic to make sure
result dim must be static if all the collapsing src dims are static.

This can be relaxed once expand_shape supports more dynamism.

Differential Revision: https://reviews.llvm.org/D114391

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0837bd2f888ec..11674d51404eb 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1640,8 +1640,6 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
               CollapseShapeOpMemRefCastFolder>(context);
 }
 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
-  if (succeeded(foldMemRefCast(*this)))
-    return getResult();
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
 }
 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 20d95267e7188..b7569adfdc2ee 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -600,6 +600,18 @@ func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
 
 // -----
 
+func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<?x?xf32> to memref<8x4xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2]]
+      : memref<8x4xf32> into memref<2x4x4xf32>
+  return %1 : memref<2x4x4xf32>
+}
+
+// CHECK-LABEL: @fold_memref_expand_cast
+// CHECK: memref.expand_shape
+
+// -----
+
 // CHECK-LABEL:   func @collapse_after_memref_cast_type_change(
 // CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
 // CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]


        


More information about the Mlir-commits mailing list