[Mlir-commits] [mlir] d3846ec - [mlir] Guard sccp pass from crashing with different source type (#120656)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 24 19:19:55 PST 2024
Author: Kai Sasaki
Date: 2024-12-25T12:19:52+09:00
New Revision: d3846eca2061e6e9a8d654551153f7362c27b59a
URL: https://github.com/llvm/llvm-project/commit/d3846eca2061e6e9a8d654551153f7362c27b59a
DIFF: https://github.com/llvm/llvm-project/commit/d3846eca2061e6e9a8d654551153f7362c27b59a.diff
LOG: [mlir] Guard sccp pass from crashing with different source type (#120656)
Vector::BroadCastOp expects the identical element type in folding. It
causes the crash if the different source type is given to the SCCP pass.
We need to guard the pass from crashing if the nonidentical element type
is given, but still compatible. (e.g. index vs integer type)
https://github.com/llvm/llvm-project/issues/120193
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Transforms/sccp.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 491b5f44b722b1..ae1cf95732336a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2523,8 +2523,16 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
- if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
- return DenseElementsAttr::get(vectorType, adaptor.getSource());
+ if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
+ if (vectorType.getElementType() != attr.getType())
+ return {};
+ return DenseElementsAttr::get(vectorType, attr);
+ }
+ if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
+ if (vectorType.getElementType() != attr.getType())
+ return {};
+ return DenseElementsAttr::get(vectorType, attr);
+ }
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
return {};
diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index dcae052c29c248..c78c8594c0ba51 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -246,3 +246,12 @@ func.func @op_with_region() -> (i32) {
^b:
return %1 : i32
}
+
+// CHECK-LABEL: no_crash_with_
diff erent_source_type
+func.func @no_crash_with_
diff erent_source_type() {
+ // CHECK: llvm.mlir.constant(0 : index) : i64
+ %0 = llvm.mlir.constant(0 : index) : i64
+ // CHECK: vector.broadcast %[[CST:.*]] : i64 to vector<128xi64>
+ %1 = vector.broadcast %0 : i64 to vector<128xi64>
+ llvm.return
+}
More information about the Mlir-commits
mailing list