[Mlir-commits] [mlir] [mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided by user (PR #91245)
    Benoit Jacob 
    llvmlistbot at llvm.org
       
    Wed May  8 09:13:48 PDT 2024
    
    
  
================
@@ -2353,6 +2353,16 @@ LogicalResult ExpandShapeOp::verify() {
            << " dynamic dims while output_shape has " << getOutputShape().size()
            << " values";
 
+  // Verify if provided output shapes are in agreement with output type.
+  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
+  ArrayRef<int64_t> resShape = getResult().getType().getShape();
+  unsigned staticShapeNum = 0;
+
+  for (auto [pos, shape] : llvm::enumerate(resShape))
+    if (!ShapedType::isDynamic(shape) &&
+        shape != staticOutputShapes[staticShapeNum++])
+      emitOpError("invalid output shape provided at pos ") << pos;
+
----------------
bjacob wrote:
If my understanding is correct, I would suggest this modification:
```diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78201ae29cd9..4f33137dcff4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2356,11 +2356,8 @@ LogicalResult ExpandShapeOp::verify() {
   // Verify if provided output shapes are in agreement with output type.
   DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
   ArrayRef<int64_t> resShape = getResult().getType().getShape();
-  unsigned staticShapeNum = 0;
-
   for (auto [pos, shape] : llvm::enumerate(resShape))
-    if (!ShapedType::isDynamic(shape) &&
-        shape != staticOutputShapes[staticShapeNum++])
+    if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos])
       emitOpError("invalid output shape provided at pos ") << pos;
 
   return success();
```
https://github.com/llvm/llvm-project/pull/91245
    
    
More information about the Mlir-commits
mailing list