[llvm] [instCombine][bugfix] Fix crash caused by using of cast in instCombin… (PR #102472)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 8 07:11:40 PDT 2024
https://github.com/cceerczw created https://github.com/llvm/llvm-project/pull/102472
…eSVECmpNE.
Func instCombineSVECmpNE is used to identify specific pattern of instruction 'svecmene',
and then predict its result, use the result to replace instruction 'svecmene'. The specific pattern can be descriped below:
1.The svecmpne must compare all elements of vec.
2.The svecmpne inst compare its ves with zero.
3.The vec in svecmpne inst is generated by inst dupqlane, and the copy value of this dupqlane must be zero.
In NO.3 above, func instCombineSVECmpNE uses 'cast' to transform op1 of dupqlane without checking if the cast is success, then generate a crash in some situation.
>From 0b7ff037a078a72f19e85438f827ef6292449b60 Mon Sep 17 00:00:00 2001
From: chengzhiwei <chengzhiwei6 at huawei.com>
Date: Thu, 8 Aug 2024 15:02:45 +0800
Subject: [PATCH] [instCombine][bugfix] Fix crash caused by using of cast in
instCombineSVECmpNE.
Func instCombineSVECmpNE is used to identify specific pattern of instruction
'svecmene',
and then predict its result, use the result to replace instruction 'svecmene'.
The specific pattern can be descriped below:
1.The svecmpne must compare all elements of vec.
2.The svecmpne inst compare its ves with zero.
3.The vec in svecmpne inst is generated by inst dupqlane, and the copy
value of this dupqlane must be zero.
In NO.3 above, func instCombineSVECmpNE uses 'cast' to transform op1 of
dupqlane without checking if the cast is success, then generate a crash
in some situation.
---
.../AArch64/AArch64TargetTransformInfo.cpp | 3 +-
.../AArch64/sve-inst-combine-cmpne.ll | 58 +++++++++++++++++++
2 files changed, 60 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/Transforms/InstCombine/AArch64/sve-inst-combine-cmpne.ll
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 79c0e45e3aa5b5..c27c56d4557b14 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1174,7 +1174,8 @@ static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
return std::nullopt;
// Where the dupq is a lane 0 replicate of a vector insert
- if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
+ auto *DupQLaneOp1 = dyn_cast<ConstantInt>(DupQLane->getArgOperand(1));
+ if (!DupQLaneOp1 || !DupQLaneOp1->isZero())
return std::nullopt;
auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-inst-combine-cmpne.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-inst-combine-cmpne.ll
new file mode 100644
index 00000000000000..b21f2f538d3cab
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/sve-inst-combine-cmpne.ll
@@ -0,0 +1,58 @@
+; RUN: opt -S -mtriple=aarch64-unknown-linux-gnu -O2 < %s | FileCheck %s
+
+; Function Attrs: nofree nosync nounwind readnone uwtable vscale_range(1,16)
+define dso_local i32 @testInstCombineSVECmpNE() local_unnamed_addr #0 {
+entry:
+ %0 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.index.nxv16i8(i8 42, i8 1)
+ %1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 31)
+ br label %for.body
+
+for.cond.cleanup: ; preds = %for.inc
+ %2 = tail call i1 @llvm.aarch64.sve.ptest.any.nxv16i1(<vscale x 16 x i1> %1, <vscale x 16 x i1> %cmp_rslt.1)
+ %not. = xor i1 %2, true
+ %. = zext i1 %not. to i32
+ ret i32 %.
+
+for.body: ; preds = %entry, %for.inc
+ %i.010 = phi i64 [ 0, %entry ], [ %inc, %for.inc ]
+ %cmp1 = icmp ugt i64 %i.010, 32
+ %3 = tail call <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8> %0, i64 %i.010)
+ br i1 %cmp1, label %if.then, label %if.else
+
+if.then: ; preds = %for.body
+ %4 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1> %1, <vscale x 16 x i8> %3, <vscale x 16 x i8> zeroinitializer)
+ br label %for.inc
+ ; CHECK: %4 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1> %1, <vscale x 16 x i8> %3, <vscale x 16 x i8> zeroinitializer)
+ ; CHECK-NEXT: br label %for.inc
+
+if.else: ; preds = %for.body
+ %5 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1> %1, <vscale x 16 x i8> %3, <vscale x 16 x i8> shufflevector (<vscale x 16 x i8> insertelement (<vscale x 16 x i8> poison, i8 1, i32 0), <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer))
+ br label %for.inc
+
+for.inc: ; preds = %if.then, %if.else
+ %cmp_rslt.1 = phi <vscale x 16 x i1> [ %4, %if.then ], [ %5, %if.else ]
+ %inc = add nuw nsw i64 %i.010, 1
+ %exitcond.not = icmp eq i64 %inc, 63
+ br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !llvm.loop !6
+}
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone willreturn
+declare <vscale x 16 x i8> @llvm.aarch64.sve.index.nxv16i8(i8, i8) #1
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone willreturn
+declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 immarg) #1
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone willreturn
+declare <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8>, i64) #1
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone willreturn
+declare <vscale x 16 x i1> @llvm.aarch64.sve.cmpne.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>) #1
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone willreturn
+declare i1 @llvm.aarch64.sve.ptest.any.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>) #1
+
+attributes #0 = { nofree nosync nounwind readnone uwtable vscale_range(1,16) "frame-pointer"="non-leaf" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="generic" "target-features"="+neon,+sve,+v8.2a" }
+attributes #1 = { mustprogress nocallback nofree nosync nounwind readnone willreturn }
+
+!6 = distinct !{!6, !7}
+!7 = !{!"llvm.loop.mustprogress"}
More information about the llvm-commits
mailing list