[llvm] [SROA] Canonicalize homogeneous structs into fixed vectors to  elimina… (PR #165159)
    via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Sun Oct 26 08:54:16 PDT 2025
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Yaxun (Sam) Liu (yxsamliu)
<details>
<summary>Changes</summary>
…te allocas
Motivation: SROA would keep temporary allocas (e.g. copies and zero-inits) for homogeneous, 16-byte structs. On targets like AMDGPU these map to scratch memory and can severely hurt performance.
The following example could not eliminate the allocas before this change:
```
struct alignas(16) myint4 {
  int x, y, z, w;
};
void foo(myint4* x, myint4 y, int cond) {
  myint4 temp = y;
  myint4 zero{0,0,0,0};
  myint4 data = cond ? temp : zero;
  *x = data;
}
```
Method: During rewritePartition, when the slice type is a struct of 2 or 4 identical element types, and DataLayout proves it is tightly packed (no padding; element offsets are i*EltSize; StructSize == N*EltSize), and the element type is a valid fixed-size vector element, and the total size is at or below a configurable threshold, rewrite the slice type to a fixed vector <N x EltTy>. This runs before the alloca-reuse fast path.
Why it works: For tightly packed homogeneous structs, the in-memory representation is bitwise-identical to the corresponding fixed vector, so the transformation is semantics-preserving. The vector form enables SROA/ InstCombine/GVN to replace memcpy/memset and conditional copies with vector selects and a single vector store, allowing the allocas to be eliminated. Tests (flat and nested struct) show allocas/mem* disappear and a <4 x i32> store remains.
Control: Introduces -sroa-max-struct-to-vector-bytes=N (default 0 = disabled) to guard the transform by struct size. Enable via:
  - opt: -passes='sroa,gvn,instcombine,simplifycfg' \ -sroa-max-struct-to-vector-bytes=16
  - clang/llc: -mllvm -sroa-max-struct-to-vector-bytes=16 Set to 0 to turn the optimization off if regressions are observed.
---
Full diff: https://github.com/llvm/llvm-project/pull/165159.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Scalar/SROA.cpp (+58) 
- (added) llvm/test/Transforms/SROA/struct-to-vector.ll (+311) 
``````````diff
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 5c60fad6f91aa..d31aca0338c91 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -122,6 +122,12 @@ namespace llvm {
 /// Disable running mem2reg during SROA in order to test or debug SROA.
 static cl::opt<bool> SROASkipMem2Reg("sroa-skip-mem2reg", cl::init(false),
                                      cl::Hidden);
+/// Maximum struct size in bytes to canonicalize homogeneous structs to vectors.
+/// 0 disables the transformation to avoid regressions by default.
+static cl::opt<unsigned> SROAMaxStructToVectorBytes(
+    "sroa-max-struct-to-vector-bytes", cl::init(0), cl::Hidden,
+    cl::desc("Max struct size in bytes to canonicalize homogeneous structs to "
+             "fixed vectors (0=disable)"));
 extern cl::opt<bool> ProfcheckDisableMetadataFixes;
 } // namespace llvm
 
@@ -5267,6 +5273,58 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
   if (VecTy)
     SliceTy = VecTy;
 
+  // Canonicalize homogeneous, tightly-packed 2- or 4-field structs to
+  // a fixed-width vector type when the DataLayout proves bitwise identity.
+  // Do this BEFORE the alloca reuse fast-path so that we don't miss
+  // opportunities to vectorize memcpy on allocas whose SliceTy initially
+  // equals the allocated type.
+  if (SROAMaxStructToVectorBytes) {
+    if (auto *STy = dyn_cast<StructType>(SliceTy)) {
+      unsigned NumElts = STy->getNumElements();
+      if (NumElts == 2 || NumElts == 4) {
+        Type *EltTy =
+            STy->getNumElements() > 0 ? STy->getElementType(0) : nullptr;
+        bool IsAllowedElt = false;
+        if (EltTy && VectorType::isValidElementType(EltTy)) {
+          if (auto *IT = dyn_cast<IntegerType>(EltTy))
+            IsAllowedElt = IT->getBitWidth() >= 8;
+          else if (EltTy->isFloatingPointTy())
+            IsAllowedElt = true;
+        }
+        bool AllSame = IsAllowedElt;
+        for (unsigned I = 1; AllSame && I < NumElts; ++I)
+          if (STy->getElementType(I) != EltTy)
+            AllSame = false;
+        if (AllSame) {
+          const StructLayout *SL = DL.getStructLayout(STy);
+          TypeSize EltTS = DL.getTypeAllocSize(EltTy);
+          if (EltTS.isFixed()) {
+            const uint64_t EltSize = EltTS.getFixedValue();
+            if (EltSize >= 1) {
+              const uint64_t StructSize = SL->getSizeInBytes();
+              if (StructSize != 0 &&
+                  StructSize <= SROAMaxStructToVectorBytes) {
+                bool TightlyPacked = (StructSize == NumElts * EltSize);
+                if (TightlyPacked) {
+                  for (unsigned I = 0; I < NumElts; ++I) {
+                    if (SL->getElementOffset(I) != I * EltSize) {
+                      TightlyPacked = false;
+                      break;
+                    }
+                  }
+                }
+                if (TightlyPacked) {
+                  Type *NewSliceTy = FixedVectorType::get(EltTy, NumElts);
+                  SliceTy = NewSliceTy;
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
   // Check for the case where we're going to rewrite to a new alloca of the
   // exact same type as the original, and with the same access offsets. In that
   // case, re-use the existing alloca, but still run through the rewriter to
diff --git a/llvm/test/Transforms/SROA/struct-to-vector.ll b/llvm/test/Transforms/SROA/struct-to-vector.ll
new file mode 100644
index 0000000000000..ceaf8ea435abb
--- /dev/null
+++ b/llvm/test/Transforms/SROA/struct-to-vector.ll
@@ -0,0 +1,311 @@
+; RUN: opt -passes='sroa,gvn,instcombine,simplifycfg' -S \
+; RUN:   -sroa-max-struct-to-vector-bytes=16 %s \
+; RUN:   | FileCheck %s \
+; RUN:       --check-prefixes=FLAT,NESTED,PADDED,NONHOMO,I1,PTR
+%struct.myint4 = type { i32, i32, i32, i32 }
+
+; FLAT-LABEL: define dso_local void @foo_flat(
+; FLAT-NOT: alloca
+; FLAT-NOT: llvm.memcpy
+; FLAT-NOT: llvm.memset
+; FLAT: insertelement <2 x i64>
+; FLAT: bitcast <2 x i64> %{{[^ ]+}} to <4 x i32>
+; FLAT: select i1 %{{[^,]+}}, <4 x i32> zeroinitializer, <4 x i32> %{{[^)]+}}
+; FLAT: store <4 x i32> %{{[^,]+}}, ptr %x, align 16
+; FLAT: ret void
+define dso_local void @foo_flat(ptr noundef %x, i64 %y.coerce0, i64 %y.coerce1, i32 noundef %cond) {
+entry:
+  %y = alloca %struct.myint4, align 16
+  %x.addr = alloca ptr, align 8
+  %cond.addr = alloca i32, align 4
+  %temp = alloca %struct.myint4, align 16
+  %zero = alloca %struct.myint4, align 16
+  %data = alloca %struct.myint4, align 16
+  %0 = getelementptr inbounds nuw { i64, i64 }, ptr %y, i32 0, i32 0
+  store i64 %y.coerce0, ptr %0, align 16
+  %1 = getelementptr inbounds nuw { i64, i64 }, ptr %y, i32 0, i32 1
+  store i64 %y.coerce1, ptr %1, align 8
+  store ptr %x, ptr %x.addr, align 8
+  store i32 %cond, ptr %cond.addr, align 4
+  call void @llvm.lifetime.start.p0(ptr %temp)
+  call void @llvm.memcpy.p0.p0.i64(ptr align 16 %temp, ptr align 16 %y, i64 16, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %zero)
+  call void @llvm.memset.p0.i64(ptr align 16 %zero, i8 0, i64 16, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %data)
+  %2 = load i32, ptr %cond.addr, align 4
+  %tobool = icmp ne i32 %2, 0
+  br i1 %tobool, label %cond.true, label %cond.false
+
+cond.true:
+  br label %cond.end
+
+cond.false:
+  br label %cond.end
+
+cond.end:
+  %cond1 = phi ptr [ %temp, %cond.true ], [ %zero, %cond.false ]
+  call void @llvm.memcpy.p0.p0.i64(ptr align 16 %data, ptr align 16 %cond1, i64 16, i1 false)
+  %3 = load ptr, ptr %x.addr, align 8
+  call void @llvm.memcpy.p0.p0.i64(ptr align 16 %3, ptr align 16 %data, i64 16, i1 false)
+  call void @llvm.lifetime.end.p0(ptr %data)
+  call void @llvm.lifetime.end.p0(ptr %zero)
+  call void @llvm.lifetime.end.p0(ptr %temp)
+  ret void
+}
+%struct.myint4_base_n = type { i32, i32, i32, i32 }
+%struct.myint4_nested = type { %struct.myint4_base_n }
+
+; NESTED-LABEL: define dso_local void @foo_nested(
+; NESTED-NOT: alloca
+; NESTED-NOT: llvm.memcpy
+; NESTED-NOT: llvm.memset
+; NESTED: insertelement <2 x i64>
+; NESTED: bitcast <2 x i64> %{{[^ ]+}} to <4 x i32>
+; NESTED: select i1 %{{[^,]+}}, <4 x i32> zeroinitializer, <4 x i32> %{{[^)]+}}
+; NESTED: store <4 x i32> %{{[^,]+}}, ptr %x, align 16
+; NESTED: ret void
+define dso_local void @foo_nested(ptr noundef %x, i64 %y.coerce0, i64 %y.coerce1, i32 noundef %cond) {
+entry:
+  %y = alloca %struct.myint4_nested, align 16
+  %x.addr = alloca ptr, align 8
+  %cond.addr = alloca i32, align 4
+  %temp = alloca %struct.myint4_nested, align 16
+  %zero = alloca %struct.myint4_nested, align 16
+  %data = alloca %struct.myint4_nested, align 16
+  %0 = getelementptr inbounds nuw { i64, i64 }, ptr %y, i32 0, i32 0
+  store i64 %y.coerce0, ptr %0, align 16
+  %1 = getelementptr inbounds nuw { i64, i64 }, ptr %y, i32 0, i32 1
+  store i64 %y.coerce1, ptr %1, align 8
+  store ptr %x, ptr %x.addr, align 8
+  store i32 %cond, ptr %cond.addr, align 4
+  call void @llvm.lifetime.start.p0(ptr %temp)
+  call void @llvm.memcpy.p0.p0.i64(ptr align 16 %temp, ptr align 16 %y, i64 16, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %zero)
+  call void @llvm.memset.p0.i64(ptr align 16 %zero, i8 0, i64 16, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %data)
+  %2 = load i32, ptr %cond.addr, align 4
+  %tobool = icmp ne i32 %2, 0
+  br i1 %tobool, label %cond.true, label %cond.false
+
+cond.true:
+  br label %cond.end
+
+cond.false:
+  br label %cond.end
+
+cond.end:
+  %cond1 = phi ptr [ %temp, %cond.true ], [ %zero, %cond.false ]
+  call void @llvm.memcpy.p0.p0.i64(ptr align 16 %data, ptr align 16 %cond1, i64 16, i1 false)
+  %3 = load ptr, ptr %x.addr, align 8
+  call void @llvm.memcpy.p0.p0.i64(ptr align 16 %3, ptr align 16 %data, i64 16, i1 false)
+  call void @llvm.lifetime.end.p0(ptr %data)
+  call void @llvm.lifetime.end.p0(ptr %zero)
+  call void @llvm.lifetime.end.p0(ptr %temp)
+  ret void
+}
+
+; PADDED-LABEL: define dso_local void @foo_padded(
+; PADDED: llvm.memcpy
+; PADDED-NOT: store <
+; PADDED: ret void
+%struct.padded = type { i32, i8, i32, i8 }
+define dso_local void @foo_padded(ptr noundef %x, i32 %a0, i8 %a1,
+                                  i32 %a2, i8 %a3,
+                                  i32 noundef %cond) {
+entry:
+  %y = alloca %struct.padded, align 4
+  %x.addr = alloca ptr, align 8
+  %cond.addr = alloca i32, align 4
+  %temp = alloca %struct.padded, align 4
+  %zero = alloca %struct.padded, align 4
+  %data = alloca %struct.padded, align 4
+  %y_i32_0 = getelementptr inbounds %struct.padded, ptr %y, i32 0, i32 0
+  store i32 %a0, ptr %y_i32_0, align 4
+  %y_i8_1 = getelementptr inbounds %struct.padded, ptr %y, i32 0, i32 1
+  store i8 %a1, ptr %y_i8_1, align 1
+  %y_i32_2 = getelementptr inbounds %struct.padded, ptr %y, i32 0, i32 2
+  store i32 %a2, ptr %y_i32_2, align 4
+  %y_i8_3 = getelementptr inbounds %struct.padded, ptr %y, i32 0, i32 3
+  store i8 %a3, ptr %y_i8_3, align 1
+  store ptr %x, ptr %x.addr, align 8
+  store i32 %cond, ptr %cond.addr, align 4
+  call void @llvm.lifetime.start.p0(ptr %temp)
+  call void @llvm.memcpy.p0.p0.i64(ptr align 4 %temp, ptr align 4 %y,
+                                   i64 16, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %zero)
+  call void @llvm.memset.p0.i64(ptr align 4 %zero, i8 0, i64 16, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %data)
+  %c.pad = load i32, ptr %cond.addr, align 4
+  %tobool.pad = icmp ne i32 %c.pad, 0
+  br i1 %tobool.pad, label %cond.true.pad, label %cond.false.pad
+
+cond.true.pad:
+  br label %cond.end.pad
+
+cond.false.pad:
+  br label %cond.end.pad
+
+cond.end.pad:
+  %cond1.pad = phi ptr [ %temp, %cond.true.pad ], [ %zero, %cond.false.pad ]
+  call void @llvm.memcpy.p0.p0.i64(ptr align 4 %data, ptr align 4 %cond1.pad,
+                                   i64 16, i1 false)
+  %xv.pad = load ptr, ptr %x.addr, align 8
+  call void @llvm.memcpy.p0.p0.i64(ptr align 4 %xv.pad, ptr align 4 %data,
+                                   i64 16, i1 false)
+  call void @llvm.lifetime.end.p0(ptr %data)
+  call void @llvm.lifetime.end.p0(ptr %zero)
+  call void @llvm.lifetime.end.p0(ptr %temp)
+  ret void
+}
+
+; NONHOMO-LABEL: define dso_local void @foo_nonhomo(
+; NONHOMO: llvm.memcpy
+; NONHOMO-NOT: store <
+; NONHOMO: ret void
+%struct.nonhomo = type { i32, i64, i32, i64 }
+define dso_local void @foo_nonhomo(ptr noundef %x, i32 %a0, i64 %a1,
+                                   i32 %a2, i64 %a3,
+                                   i32 noundef %cond) {
+entry:
+  %y = alloca %struct.nonhomo, align 8
+  %x.addr = alloca ptr, align 8
+  %cond.addr = alloca i32, align 4
+  %temp = alloca %struct.nonhomo, align 8
+  %zero = alloca %struct.nonhomo, align 8
+  %data = alloca %struct.nonhomo, align 8
+  %y_i32_0n = getelementptr inbounds %struct.nonhomo, ptr %y, i32 0, i32 0
+  store i32 %a0, ptr %y_i32_0n, align 4
+  %y_i64_1n = getelementptr inbounds %struct.nonhomo, ptr %y, i32 0, i32 1
+  store i64 %a1, ptr %y_i64_1n, align 8
+  %y_i32_2n = getelementptr inbounds %struct.nonhomo, ptr %y, i32 0, i32 2
+  store i32 %a2, ptr %y_i32_2n, align 4
+  %y_i64_3n = getelementptr inbounds %struct.nonhomo, ptr %y, i32 0, i32 3
+  store i64 %a3, ptr %y_i64_3n, align 8
+  store ptr %x, ptr %x.addr, align 8
+  store i32 %cond, ptr %cond.addr, align 4
+  call void @llvm.lifetime.start.p0(ptr %temp)
+  call void @llvm.memcpy.p0.p0.i64(ptr align 8 %temp, ptr align 8 %y,
+                                   i64 32, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %zero)
+  call void @llvm.memset.p0.i64(ptr align 8 %zero, i8 0, i64 32, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %data)
+  %c.nh = load i32, ptr %cond.addr, align 4
+  %tobool.nh = icmp ne i32 %c.nh, 0
+  br i1 %tobool.nh, label %cond.true.nh, label %cond.false.nh
+
+cond.true.nh:
+  br label %cond.end.nh
+
+cond.false.nh:
+  br label %cond.end.nh
+
+cond.end.nh:
+  %cond1.nh = phi ptr [ %temp, %cond.true.nh ], [ %zero, %cond.false.nh ]
+  call void @llvm.memcpy.p0.p0.i64(ptr align 8 %data, ptr align 8 %cond1.nh,
+                                   i64 32, i1 false)
+  %xv.nh = load ptr, ptr %x.addr, align 8
+  call void @llvm.memcpy.p0.p0.i64(ptr align 8 %xv.nh, ptr align 8 %data,
+                                   i64 32, i1 false)
+  call void @llvm.lifetime.end.p0(ptr %data)
+  call void @llvm.lifetime.end.p0(ptr %zero)
+  call void @llvm.lifetime.end.p0(ptr %temp)
+  ret void
+}
+
+; I1-LABEL: define dso_local void @foo_i1(
+; I1-NOT: <4 x i1>
+; I1: ret void
+%struct.i1x4 = type { i1, i1, i1, i1 }
+define dso_local void @foo_i1(ptr noundef %x, i64 %dummy0, i64 %dummy1,
+                              i32 noundef %cond) {
+entry:
+  %y = alloca %struct.i1x4, align 1
+  %x.addr = alloca ptr, align 8
+  %cond.addr = alloca i32, align 4
+  %temp = alloca %struct.i1x4, align 1
+  %zero = alloca %struct.i1x4, align 1
+  %data = alloca %struct.i1x4, align 1
+  store ptr %x, ptr %x.addr, align 8
+  store i32 %cond, ptr %cond.addr, align 4
+  call void @llvm.lifetime.start.p0(ptr %temp)
+  call void @llvm.memcpy.p0.p0.i64(ptr align 1 %temp, ptr align 1 %y,
+                                   i64 4, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %zero)
+  call void @llvm.memset.p0.i64(ptr align 1 %zero, i8 0, i64 4, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %data)
+  %c.i1 = load i32, ptr %cond.addr, align 4
+  %tobool.i1 = icmp ne i32 %c.i1, 0
+  br i1 %tobool.i1, label %cond.true.i1, label %cond.false.i1
+
+cond.true.i1:
+  br label %cond.end.i1
+
+cond.false.i1:
+  br label %cond.end.i1
+
+cond.end.i1:
+  %cond1.i1 = phi ptr [ %temp, %cond.true.i1 ], [ %zero, %cond.false.i1 ]
+  call void @llvm.memcpy.p0.p0.i64(ptr align 1 %data, ptr align 1 %cond1.i1,
+                                   i64 4, i1 false)
+  %xv.i1 = load ptr, ptr %x.addr, align 8
+  call void @llvm.memcpy.p0.p0.i64(ptr align 1 %xv.i1, ptr align 1 %data,
+                                   i64 4, i1 false)
+  call void @llvm.lifetime.end.p0(ptr %data)
+  call void @llvm.lifetime.end.p0(ptr %zero)
+  call void @llvm.lifetime.end.p0(ptr %temp)
+  ret void
+}
+
+; PTR-LABEL: define dso_local void @foo_ptr(
+; PTR: llvm.memcpy
+; PTR-NOT: <4 x ptr>
+; PTR: ret void
+%struct.ptr4 = type { ptr, ptr, ptr, ptr }
+define dso_local void @foo_ptr(ptr noundef %x, ptr %p0, ptr %p1,
+                               ptr %p2, ptr %p3,
+                               i32 noundef %cond) {
+entry:
+  %y = alloca %struct.ptr4, align 8
+  %x.addr = alloca ptr, align 8
+  %cond.addr = alloca i32, align 4
+  %temp = alloca %struct.ptr4, align 8
+  %zero = alloca %struct.ptr4, align 8
+  %data = alloca %struct.ptr4, align 8
+  %y_p0 = getelementptr inbounds %struct.ptr4, ptr %y, i32 0, i32 0
+  store ptr %p0, ptr %y_p0, align 8
+  %y_p1 = getelementptr inbounds %struct.ptr4, ptr %y, i32 0, i32 1
+  store ptr %p1, ptr %y_p1, align 8
+  %y_p2 = getelementptr inbounds %struct.ptr4, ptr %y, i32 0, i32 2
+  store ptr %p2, ptr %y_p2, align 8
+  %y_p3 = getelementptr inbounds %struct.ptr4, ptr %y, i32 0, i32 3
+  store ptr %p3, ptr %y_p3, align 8
+  store ptr %x, ptr %x.addr, align 8
+  store i32 %cond, ptr %cond.addr, align 4
+  call void @llvm.lifetime.start.p0(ptr %temp)
+  call void @llvm.memcpy.p0.p0.i64(ptr align 8 %temp, ptr align 8 %y,
+                                   i64 32, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %zero)
+  call void @llvm.memset.p0.i64(ptr align 8 %zero, i8 0, i64 32, i1 false)
+  call void @llvm.lifetime.start.p0(ptr %data)
+  %c.ptr = load i32, ptr %cond.addr, align 4
+  %tobool.ptr = icmp ne i32 %c.ptr, 0
+  br i1 %tobool.ptr, label %cond.true.ptr, label %cond.false.ptr
+
+cond.true.ptr:
+  br label %cond.end.ptr
+
+cond.false.ptr:
+  br label %cond.end.ptr
+
+cond.end.ptr:
+  %cond1.ptr = phi ptr [ %temp, %cond.true.ptr ], [ %zero, %cond.false.ptr ]
+  call void @llvm.memcpy.p0.p0.i64(ptr align 8 %data, ptr align 8 %cond1.ptr,
+                                   i64 32, i1 false)
+  %xv.ptr = load ptr, ptr %x.addr, align 8
+  call void @llvm.memcpy.p0.p0.i64(ptr align 8 %xv.ptr, ptr align 8 %data,
+                                   i64 32, i1 false)
+  call void @llvm.lifetime.end.p0(ptr %data)
+  call void @llvm.lifetime.end.p0(ptr %zero)
+  call void @llvm.lifetime.end.p0(ptr %temp)
+  ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/165159
    
    
More information about the llvm-commits
mailing list