[Mlir-commits] [mlir] [mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided (PR #91245)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 6 10:27:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Prathamesh Tagore (meshtag)
<details>
<summary>Changes</summary>
The verifier was not checking for the case when the user provided shape in output_shape is different than the one inferred from output type. Fix this.
---
Full diff: https://github.com/llvm/llvm-project/pull/91245.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+10)
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 393f73dc65cd8d..e6a93bf42199a4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2353,6 +2353,16 @@ LogicalResult ExpandShapeOp::verify() {
<< " dynamic dims while output_shape has " << getOutputShape().size()
<< " values";
+ // Verify if provided output shapes are in agreement with output type.
+ DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
+ ArrayRef<int64_t> resShape = getResult().getType().getShape();
+ unsigned staticShapeNum = 0;
+
+ for (unsigned i = 0, e = resShape.size(); i < e; ++i)
+ if (!ShapedType::isDynamic(resShape[i]) &&
+ resShape[i] != staticOutputShapes[staticShapeNum++])
+ emitOpError("invalid output shape provided at pos ") << i;
+
return success();
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 70c96aad9555ef..0f533cb95a0ca9 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1103,3 +1103,14 @@ func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>)
: memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
return
}
+
+// -----
+
+func.func @expand_shape_invalid_output_shape(
+ %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
+ // expected-error @+1 {{invalid output shape provided at pos 2}}
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 15, 21] :
+ memref<30x20xf32, strided<[4000, 2], offset: 100>>
+ into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/91245
More information about the Mlir-commits
mailing list