[llvm] [llvm] Fix crash when complex deinterleaving operates on an unrolled loop (PR #129735)
Igor Kirillov via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 14 10:39:20 PDT 2025
================
@@ -2253,8 +2261,31 @@ void ComplexDeinterleavingGraph::processReductionSingle(
auto *FinalReduction = ReductionInfo[Real].second;
Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
- auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
+ Value *Other;
+ bool EraseFinalReductionHere = false;
+ if (match(FinalReduction, m_c_Add(m_Specific(Real), m_Value(Other)))) {
----------------
igogo-x86 wrote:
For regular reductions (without cdot), we needed to analyse and rewrite use outside of the loop due to Real and Imaginary part extraction. See cases in `complex-deinterleaving-reductions.ll`. But for `cdot`, we don't need to do any of that. Here's a test from `complex-deinterleaving-cdot.ll`:
```
define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) {
entry:
br label %vector.body
vector.body: ; preds = %vector.body, %entry
%vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce.sub, %vector.body ]
%a.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %a)
%b.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %b)
%a.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 0
%a.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 1
%b.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 0
%b.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 1
%a.real.ext = sext <vscale x 16 x i8> %a.real to <vscale x 16 x i32>
%a.imag.ext = sext <vscale x 16 x i8> %a.imag to <vscale x 16 x i32>
%b.real.ext = sext <vscale x 16 x i8> %b.real to <vscale x 16 x i32>
%b.imag.ext = sext <vscale x 16 x i8> %b.imag to <vscale x 16 x i32>
%real.mul = mul <vscale x 16 x i32> %b.real.ext, %a.real.ext
%real.mul.reduced = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi, <vscale x 16 x i32> %real.mul)
%imag.mul = mul <vscale x 16 x i32> %b.imag.ext, %a.imag.ext
%imag.mul.neg = sub <vscale x 16 x i32> zeroinitializer, %imag.mul
%partial.reduce.sub = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %real.mul.reduced, <vscale x 16 x i32> %imag.mul.neg)
br i1 true, label %middle.block, label %vector.body
middle.block: ; preds = %vector.body
%0 = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %partial.reduce.sub)
ret i32 %0
}
```
It is currently transformed into:
```
efine i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) #0 {
entry:
br label %vector.body
vector.body: ; preds = %vector.body, %entry
%0 = phi <vscale x 8 x i32> [ zeroinitializer, %entry ], [ %10, %vector.body ]
%1 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 0)
%2 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 0)
%3 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 16)
%4 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 16)
%5 = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> %0, i64 0)
%6 = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> %0, i64 4)
%7 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %5, <vscale x 16 x i8> %1, <vscale x 16 x i8> %2, i32 0)
%8 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %6, <vscale x 16 x i8> %3, <vscale x 16 x i8> %4, i32 0)
%9 = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> %7, i64 0)
%10 = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> %9, <vscale x 4 x i32> %8, i64 4)
br i1 true, label %middle.block, label %vector.body
middle.block: ; preds = %vector.body
%11 = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> %10)
ret i32 %11
}
```
But instead, we could ignore everything happening after the final *llvm.experimental.vector.partial.reduce.add* and just put one cdot on another:
```
define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) #0 {
entry:
br label %vector.body
vector.body: ; preds = %vector.body, %entry
%0 = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %8, %vector.body ]
%1 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 0)
%2 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 0)
%3 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 16)
%4 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 16)
%7 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %0, <vscale x 16 x i8> %1, <vscale x 16 x i8> %2, i32 0)
%8 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %7, <vscale x 16 x i8> %3, <vscale x 16 x i8> %4, i32 0)
br i1 true, label %middle.block, label %vector.body
middle.block: ; preds = %vector.body
%result = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %8)
ret i32 %result
}
```
https://github.com/llvm/llvm-project/pull/129735
More information about the llvm-commits
mailing list