[llvm] 6d35bd1 - [CodeGenPrepare] Update optimizeGatherScatterInst for scalable vectors.

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 15 03:55:48 PST 2020


Author: Paul Walker
Date: 2020-12-15T10:57:51Z
New Revision: 6d35bd1d48e9fdde38483e6b22a900daa7e3d46a

URL: https://github.com/llvm/llvm-project/commit/6d35bd1d48e9fdde38483e6b22a900daa7e3d46a
DIFF: https://github.com/llvm/llvm-project/commit/6d35bd1d48e9fdde38483e6b22a900daa7e3d46a.diff

LOG: [CodeGenPrepare] Update optimizeGatherScatterInst for scalable vectors.

optimizeGatherScatterInst does nothing specific to fixed length vectors
but uses FixedVectorType to extract the number of elements.  This patch
simply updates the code to use VectorType and getElementCount instead.

For testing I just copied Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
replacing `<4 x ` with `<vscale x 4`.

Differential Revision: https://reviews.llvm.org/D92572

Added: 
    llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll

Modified: 
    llvm/lib/CodeGen/CodeGenPrepare.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 9b44b30e7a1b..8fe5cb9faba4 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -5304,14 +5304,10 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
 ///
 /// If the final index isn't a vector or is a splat, we can emit a scalar GEP
 /// followed by a GEP with an all zeroes vector index. This will enable
-/// SelectionDAGBuilder to use a the scalar GEP as the uniform base and have a
+/// SelectionDAGBuilder to use the scalar GEP as the uniform base and have a
 /// zero index.
 bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
                                                Value *Ptr) {
-  // FIXME: Support scalable vectors.
-  if (isa<ScalableVectorType>(Ptr->getType()))
-    return false;
-
   Value *NewAddr;
 
   if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
@@ -5370,7 +5366,7 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
     if (!RewriteGEP && Ops.size() == 2)
       return false;
 
-    unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
+    auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
 
     IRBuilder<> Builder(MemoryInst);
 
@@ -5380,7 +5376,7 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
     // and a vector GEP with all zeroes final index.
     if (!Ops[FinalIndex]->getType()->isVectorTy()) {
       NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front());
-      auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
+      auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
       NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy));
     } else {
       Value *Base = Ops[0];
@@ -5403,13 +5399,13 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
     if (!V)
       return false;
 
-    unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
+    auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
 
     IRBuilder<> Builder(MemoryInst);
 
     // Emit a vector GEP with a scalar pointer and all 0s vector index.
     Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
-    auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
+    auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
     NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy));
   } else {
     // Constant, SelectionDAGBuilder knows to check if its a splat.

diff  --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
new file mode 100644
index 000000000000..08011b6b5b6a
--- /dev/null
+++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
@@ -0,0 +1,113 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -codegenprepare < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+%struct.a = type { i32, i32 }
+ at c = external dso_local global %struct.a, align 4
+ at glob_array = internal unnamed_addr constant [16 x i32] [i32 1, i32 1, i32 2, i32 3, i32 5, i32 8, i32 13, i32 21, i32 34, i32 55, i32 89, i32 144, i32 233, i32 377, i32 610, i32 987], align 16
+
+define <vscale x 4 x i32> @splat_base(i32* %base, <vscale x 4 x i64> %index, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: @splat_base(
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, i32* [[BASE:%.*]], <vscale x 4 x i64> [[INDEX:%.*]]
+; CHECK-NEXT:    [[RES:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> undef)
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[RES]]
+;
+  %broadcast.splatinsert = insertelement <vscale x 4 x i32*> undef, i32* %base, i32 0
+  %broadcast.splat = shufflevector <vscale x 4 x i32*> %broadcast.splatinsert, <vscale x 4 x i32*> undef, <vscale x 4 x i32> zeroinitializer
+  %gep = getelementptr i32, <vscale x 4 x i32*> %broadcast.splat, <vscale x 4 x i64> %index
+  %res = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> %gep, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+  ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 4 x i32> @splat_struct(%struct.a* %base, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: @splat_struct(
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [[STRUCT_A:%.*]], %struct.a* [[BASE:%.*]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, i32* [[TMP1]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[RES:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> [[TMP2]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> undef)
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[RES]]
+;
+  %gep = getelementptr %struct.a, %struct.a* %base, <vscale x 4 x i64> zeroinitializer, i32 1
+  %res = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> %gep, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+  ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 4 x i32> @scalar_index(i32* %base, i64 %index, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: @scalar_index(
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, i32* [[BASE:%.*]], i64 [[INDEX:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, i32* [[TMP1]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[RES:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> [[TMP2]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> undef)
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[RES]]
+;
+  %broadcast.splatinsert = insertelement <vscale x 4 x i32*> undef, i32* %base, i32 0
+  %broadcast.splat = shufflevector <vscale x 4 x i32*> %broadcast.splatinsert, <vscale x 4 x i32*> undef, <vscale x 4 x i32> zeroinitializer
+  %gep = getelementptr i32, <vscale x 4 x i32*> %broadcast.splat, i64 %index
+  %res = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> %gep, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+  ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 4 x i32> @splat_index(i32* %base, i64 %index, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: @splat_index(
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, i32* [[BASE:%.*]], i64 [[INDEX:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, i32* [[TMP1]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[RES:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> [[TMP2]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> undef)
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[RES]]
+;
+  %broadcast.splatinsert = insertelement <vscale x 4 x i64> undef, i64 %index, i32 0
+  %broadcast.splat = shufflevector <vscale x 4 x i64> %broadcast.splatinsert, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %gep = getelementptr i32, i32* %base, <vscale x 4 x i64> %broadcast.splat
+  %res = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> %gep, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+  ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 4 x i32> @test_global_array(<vscale x 4 x i64> %indxs, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: @test_global_array(
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, i32* getelementptr inbounds ([16 x i32], [16 x i32]* @glob_array, i64 0, i64 0), <vscale x 4 x i64> [[INDXS:%.*]]
+; CHECK-NEXT:    [[G:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> undef)
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[G]]
+;
+  %p = getelementptr inbounds [16 x i32], [16 x i32]* @glob_array, i64 0, <vscale x 4 x i64> %indxs
+  %g = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> %p, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+  ret <vscale x 4 x i32> %g
+}
+
+define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: @global_struct_splat(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> shufflevector (<vscale x 4 x i32*> insertelement (<vscale x 4 x i32*> undef, i32* getelementptr inbounds (%struct.a, %struct.a* @c, i64 0, i32 1), i32 0), <vscale x 4 x i32*> undef, <vscale x 4 x i32> zeroinitializer), i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> undef)
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[TMP1]]
+;
+  %1 = insertelement <vscale x 4 x %struct.a*> undef, %struct.a* @c, i32 0
+  %2 = shufflevector <vscale x 4 x %struct.a*> %1, <vscale x 4 x %struct.a*> undef, <vscale x 4 x i32> zeroinitializer
+  %3 = getelementptr %struct.a, <vscale x 4 x %struct.a*> %2, <vscale x 4 x i64> zeroinitializer, i32 1
+  %4 = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> %3, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> undef)
+  ret <vscale x 4 x i32> %4
+}
+
+define <vscale x 4 x i32> @splat_ptr_gather(i32* %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
+; CHECK-LABEL: @splat_ptr_gather(
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[TMP2]]
+;
+  %1 = insertelement <vscale x 4 x i32*> undef, i32* %ptr, i32 0
+  %2 = shufflevector <vscale x 4 x i32*> %1, <vscale x 4 x i32*> undef, <vscale x 4 x i32> zeroinitializer
+  %3 = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> %2, i32 4, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru)
+  ret <vscale x 4 x i32> %3
+}
+
+define void @splat_ptr_scatter(i32* %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
+; CHECK-LABEL: @splat_ptr_scatter(
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i32.nxv4p0i32(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x i32*> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
+; CHECK-NEXT:    ret void
+;
+  %1 = insertelement <vscale x 4 x i32*> undef, i32* %ptr, i32 0
+  %2 = shufflevector <vscale x 4 x i32*> %1, <vscale x 4 x i32*> undef, <vscale x 4 x i32> zeroinitializer
+  call void @llvm.masked.scatter.nxv4i32.nxv4p0i32(<vscale x 4 x i32> %val, <vscale x 4 x i32*> %2, i32 4, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+declare <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*>, i32, <vscale x 4 x i1>, <vscale x 4 x i32>)
+declare void @llvm.masked.scatter.nxv4i32.nxv4p0i32(<vscale x 4 x i32>, <vscale x 4 x i32*>, i32, <vscale x 4 x i1>)
+
+attributes #0 = { "target-features"="+sve" }


        


More information about the llvm-commits mailing list