[Mlir-commits] [mlir] Fix `memref.expand_shape` verifier (PR #91501)

Benoit Jacob llvmlistbot at llvm.org
Wed May 8 09:57:40 PDT 2024


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/91501

>From 420a815c7c9ccf36a839dd0e93878ff0e2594828 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 8 May 2024 12:27:35 -0400
Subject: [PATCH] fix-expand-verifier

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp            | 11 +++++------
 mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir |  2 +-
 mlir/test/Dialect/MemRef/ops.mlir                   |  7 ++++++-
 3 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78201ae29cd9b..c9a85919ec799 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2356,12 +2356,11 @@ LogicalResult ExpandShapeOp::verify() {
   // Verify if provided output shapes are in agreement with output type.
   DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
   ArrayRef<int64_t> resShape = getResult().getType().getShape();
-  unsigned staticShapeNum = 0;
-
-  for (auto [pos, shape] : llvm::enumerate(resShape))
-    if (!ShapedType::isDynamic(shape) &&
-        shape != staticOutputShapes[staticShapeNum++])
-      emitOpError("invalid output shape provided at pos ") << pos;
+  for (auto [pos, shape] : llvm::enumerate(resShape)) {
+    if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
+      return emitOpError("invalid output shape provided at pos ") << pos;
+    }
+  }
 
   return success();
 }
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 99b5f78b03fba..e49dff44ae0d6 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -502,7 +502,7 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
 func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
   %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
-  %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [1, 16, %sz0, 1] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+  %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [%sz0, 1, 8, 2] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
   %dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
 
   affine.for %arg6 = 0 to %dim step 64 {
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 60fb0ffeee240..b60894377f22f 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -203,7 +203,8 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
          %arg3: memref<?x42xf32, strided<[42, 1], offset: 0>>,
          %arg4: index,
          %arg5: index,
-         %arg6: index) {
+         %arg6: index,
+         %arg7: memref<4x?x4xf32>) {
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -248,6 +249,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
 //  CHECK-SAME:     memref<?xf32, strided<[1]>> into memref<?x42xf32>
   %r3 = memref.expand_shape %3 [[0, 1]] output_shape [%arg6, 42] :
     memref<?xf32, strided<[1]>> into memref<?x42xf32>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+  %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
+        : memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
   return
 }
 



More information about the Mlir-commits mailing list