[Mlir-commits] [mlir] 54401b4 - [mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided by user (#91245)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 7 16:19:59 PDT 2024


Author: Prathamesh Tagore
Date: 2024-05-07T16:19:55-07:00
New Revision: 54401b43494a57baae9d3663cd7c694b040ef01c

URL: https://github.com/llvm/llvm-project/commit/54401b43494a57baae9d3663cd7c694b040ef01c
DIFF: https://github.com/llvm/llvm-project/commit/54401b43494a57baae9d3663cd7c694b040ef01c.diff

LOG: [mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided by user (#91245)

The verifier was not checking for the case when the user provided shape
in output_shape is different than the one inferred from output type. Fix
this.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 393f73dc65cd8..78201ae29cd9b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2353,6 +2353,16 @@ LogicalResult ExpandShapeOp::verify() {
            << " dynamic dims while output_shape has " << getOutputShape().size()
            << " values";
 
+  // Verify if provided output shapes are in agreement with output type.
+  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
+  ArrayRef<int64_t> resShape = getResult().getType().getShape();
+  unsigned staticShapeNum = 0;
+
+  for (auto [pos, shape] : llvm::enumerate(resShape))
+    if (!ShapedType::isDynamic(shape) &&
+        shape != staticOutputShapes[staticShapeNum++])
+      emitOpError("invalid output shape provided at pos ") << pos;
+
   return success();
 }
 

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 70c96aad9555e..0f533cb95a0ca 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1103,3 +1103,14 @@ func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>)
       : memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
   return
 }
+
+// -----
+
+func.func @expand_shape_invalid_output_shape(
+    %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
+  // expected-error @+1 {{invalid output shape provided at pos 2}}
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 15, 21] :
+      memref<30x20xf32, strided<[4000, 2], offset: 100>>
+      into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
+  return
+}


        


More information about the Mlir-commits mailing list