[Mlir-commits] [mlir] Fix `memref.expand_shape` verifier (PR #91501)

Benoit Jacob llvmlistbot at llvm.org
Wed May 8 09:33:23 PDT 2024


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/91501

>From bae8252983a73369a3acb346244ff6e5d9b186fe Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 8 May 2024 12:27:35 -0400
Subject: [PATCH 1/2] fix-expand-verifier

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78201ae29cd9b..dc6fd770d9bd7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2356,12 +2356,11 @@ LogicalResult ExpandShapeOp::verify() {
   // Verify if provided output shapes are in agreement with output type.
   DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
   ArrayRef<int64_t> resShape = getResult().getType().getShape();
-  unsigned staticShapeNum = 0;
-
-  for (auto [pos, shape] : llvm::enumerate(resShape))
-    if (!ShapedType::isDynamic(shape) &&
-        shape != staticOutputShapes[staticShapeNum++])
+  for (auto [pos, shape] : llvm::enumerate(resShape)) {
+    if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
       emitOpError("invalid output shape provided at pos ") << pos;
+    }
+  }
 
   return success();
 }

>From e8746918a4446f23e360a2ec1681a1463b5c7dae Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 8 May 2024 12:33:11 -0400
Subject: [PATCH 2/2] review-comment

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dc6fd770d9bd7..c9a85919ec799 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2358,7 +2358,7 @@ LogicalResult ExpandShapeOp::verify() {
   ArrayRef<int64_t> resShape = getResult().getType().getShape();
   for (auto [pos, shape] : llvm::enumerate(resShape)) {
     if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
-      emitOpError("invalid output shape provided at pos ") << pos;
+      return emitOpError("invalid output shape provided at pos ") << pos;
     }
   }
 



More information about the Mlir-commits mailing list