[Mlir-commits] [mlir] [mlir] Guard sccp pass from crashing with different source type (PR #120656)
Kai Sasaki
llvmlistbot at llvm.org
Thu Dec 19 16:10:21 PST 2024
https://github.com/Lewuathe created https://github.com/llvm/llvm-project/pull/120656
Vector::BroadCastOp expects the identical element type in folding. It causes the crash if the different source type is given to the SCCP pass. We need to guard the pass from crashing if the nonidentical element type is given, but still compatible. (e.g. index vs integer type)
https://github.com/llvm/llvm-project/issues/120193
>From 93e8e1bd2a13b44eea94163e400a6224f6376c21 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Wed, 18 Dec 2024 15:43:30 +0900
Subject: [PATCH] [mlir] Guard sccp pass from crashing with different source
type
Vector::BroadCastOp expects the idential element type in folding. It
causes the crash if the different source type is given to the SCCP pass.
We need to guard the pass from crashing if the non-idential element type
is given, but still compatible. (e.g. index vs integer type)
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++++++--
mlir/test/Transforms/sccp.mlir | 13 +++++++++++++
2 files changed, 23 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad709813c6216a..4db3ba239eeaf3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2523,8 +2523,16 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
- if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
- return DenseElementsAttr::get(vectorType, adaptor.getSource());
+ if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
+ if (vectorType.getElementType() != attr.getType())
+ return {};
+ return DenseElementsAttr::get(vectorType, attr);
+ }
+ if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
+ if (vectorType.getElementType() != attr.getType())
+ return {};
+ return DenseElementsAttr::get(vectorType, attr);
+ }
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
return {};
diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index dcae052c29c248..035fc5c21a38ef 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -246,3 +246,16 @@ func.func @op_with_region() -> (i32) {
^b:
return %1 : i32
}
+
+// CHECK-LABEL: no_crash_with_different_source_type
+func.func @no_crash_with_different_source_type() {
+ // CHECK: llvm.mlir.constant(0 : index) : i64
+ %0 = llvm.mlir.constant(0 : index) : i64
+ llvm.br ^b1(%0 : i64)
+^b1(%1: i64):
+ llvm.br ^b2
+^b2:
+ // CHECK: vector.broadcast %[[CST:.*]] : i64 to vector<128xi64>
+ %2 = vector.broadcast %1 : i64 to vector<128xi64>
+ llvm.return
+}
More information about the Mlir-commits
mailing list