[Mlir-commits] [mlir] [mlir][affine] Guard invalid dim attribute in the test-reify-bound pass (PR #129013)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 26 22:54:55 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kai Sasaki (Lewuathe)
<details>
<summary>Changes</summary>
Computing the bound of affine op (ValueBoundsConstraintSet::computeBound) crashes due to the invalid dim value given to the op. It is necessary for the pass to check the dim attribute not to be greater than the rank of the input type.
Fixes https://github.com/llvm/llvm-project/issues/128807
---
Full diff: https://github.com/llvm/llvm-project/pull/129013.diff
2 Files Affected:
- (added) mlir/test/Dialect/Affine/invalid-reify-bound-dim.mlir (+13)
- (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+7)
``````````diff
diff --git a/mlir/test/Dialect/Affine/invalid-reify-bound-dim.mlir b/mlir/test/Dialect/Affine/invalid-reify-bound-dim.mlir
new file mode 100644
index 0000000000000..8c878b2664042
--- /dev/null
+++ b/mlir/test/Dialect/Affine/invalid-reify-bound-dim.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-affine-reify-value-bounds))' -verify-diagnostics
+
+// -----
+
+func.func @test_invalid_reify_dim(%size: index) -> (index) {
+ %zero = arith.constant 0 : index
+ %tensor_val = tensor.empty(%size) : tensor<?xf32>
+
+ // expected-error at +1 {{'test.reify_bound' op invalid dim for shaped type}}
+ %dim = "test.reify_bound"(%tensor_val) {dim = 1 : i64} : (tensor<?xf32>) -> index
+
+ return %dim: index
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 891b3bab8629d..fc44da4b53865 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -84,6 +84,13 @@ static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp,
auto boundType = op.getBoundType();
Value value = op.getVar();
std::optional<int64_t> dim = op.getDim();
+ auto shapedType = dyn_cast<ShapedType>(value.getType());
+ if (shapedType && shapedType.hasRank() && dim.has_value() &&
+ dim.value() >= shapedType.getRank()) {
+ op->emitOpError("invalid dim for shaped type");
+ return WalkResult::interrupt();
+ }
+
bool constant = op.getConstant();
bool scalable = op.getScalable();
``````````
</details>
https://github.com/llvm/llvm-project/pull/129013
More information about the Mlir-commits
mailing list