[Mlir-commits] [mlir] [mlir] Guard sccp pass from crashing with different source type (PR #120656)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 19 16:10:55 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kai Sasaki (Lewuathe)
<details>
<summary>Changes</summary>
Vector::BroadCastOp expects the identical element type in folding. It causes the crash if the different source type is given to the SCCP pass. We need to guard the pass from crashing if the nonidentical element type is given, but still compatible. (e.g. index vs integer type)
https://github.com/llvm/llvm-project/issues/120193
---
Full diff: https://github.com/llvm/llvm-project/pull/120656.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+10-2)
- (modified) mlir/test/Transforms/sccp.mlir (+13)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad709813c6216a..4db3ba239eeaf3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2523,8 +2523,16 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
- if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
- return DenseElementsAttr::get(vectorType, adaptor.getSource());
+ if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
+ if (vectorType.getElementType() != attr.getType())
+ return {};
+ return DenseElementsAttr::get(vectorType, attr);
+ }
+ if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
+ if (vectorType.getElementType() != attr.getType())
+ return {};
+ return DenseElementsAttr::get(vectorType, attr);
+ }
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
return {};
diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index dcae052c29c248..035fc5c21a38ef 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -246,3 +246,16 @@ func.func @op_with_region() -> (i32) {
^b:
return %1 : i32
}
+
+// CHECK-LABEL: no_crash_with_different_source_type
+func.func @no_crash_with_different_source_type() {
+ // CHECK: llvm.mlir.constant(0 : index) : i64
+ %0 = llvm.mlir.constant(0 : index) : i64
+ llvm.br ^b1(%0 : i64)
+^b1(%1: i64):
+ llvm.br ^b2
+^b2:
+ // CHECK: vector.broadcast %[[CST:.*]] : i64 to vector<128xi64>
+ %2 = vector.broadcast %1 : i64 to vector<128xi64>
+ llvm.return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/120656
More information about the Mlir-commits
mailing list