[llvm] [LLVM][IR] Add textual shorthand for specifying constant vector splats. (PR #74620)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 6 10:20:37 PST 2023


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/74620

>From fa97697fc152dfc99bdb11b1491f11bce05d0275 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Wed, 29 Nov 2023 13:54:33 +0000
Subject: [PATCH] [LLVM][IR] Add textual shorthand for specifying constant
 vector splats.

Add LL parsing for `<N x ty> splat(ty <imm>)` that lowers onto
ConstantInt::get() for integer types and ConstantFP::get() for
floating-point types.

The intent is to extend ConstantInt/FP classes to support vector
types rather than redirecting to other constant classes as the
get() methods do today.

This patch gives IR writers the convenience of using the shorthand
today, thus allowing existing tests to be ported.
---
 llvm/docs/LangRef.rst                 |  5 ++
 llvm/include/llvm/AsmParser/LLToken.h |  1 +
 llvm/lib/AsmParser/LLLexer.cpp        |  1 +
 llvm/lib/AsmParser/LLParser.cpp       | 62 +++++++++++++++++++------
 llvm/test/Assembler/constant-splat.ll | 67 +++++++++++++++++++++++++++
 5 files changed, 122 insertions(+), 14 deletions(-)
 create mode 100644 llvm/test/Assembler/constant-splat.ll

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index f5e8065ca1dc6..f7977488a54fd 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4299,6 +4299,11 @@ constants and smaller complex constants.
     "``< i32 42, i32 11, i32 74, i32 100 >``". Vector constants
     must have :ref:`vector type <t_vector>`, and the number and types of
     elements must match those specified by the type.
+
+    When creating a vector whose elements have the same constant value, the
+    prefered syntax is ``splat (<Ty> Val)``. For example: "``splat (i32 11)``".
+    These vector constants must have ::ref:`vector type <t_vector>` with an
+    element type that matches the ``splat`` operand.
 **Zero initialization**
     The string '``zeroinitializer``' can be used to zero initialize a
     value to zero of *any* type, including scalar and
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 0aa0093e8efbd..147cf56c821aa 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -336,6 +336,7 @@ enum Kind {
   kw_extractelement,
   kw_insertelement,
   kw_shufflevector,
+  kw_splat,
   kw_extractvalue,
   kw_insertvalue,
   kw_blockaddress,
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index bf01b39e6f971..919c69fe2783e 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -698,6 +698,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(uinc_wrap);
   KEYWORD(udec_wrap);
 
+  KEYWORD(splat);
   KEYWORD(vscale);
   KEYWORD(x);
   KEYWORD(blockaddress);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 5c7d551b2b31a..d275b2bab06ef 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -3981,6 +3981,33 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     return false;
   }
 
+  case lltok::kw_splat: {
+    Lex.Lex();
+    if (parseToken(lltok::lparen, "expected '(' after vector splat"))
+      return true;
+    Constant *C;
+    if (parseGlobalTypeAndValue(C))
+      return true;
+    if (parseToken(lltok::rparen, "expected ')' at end of vector splat"))
+      return true;
+
+    if (auto *CI = dyn_cast<ConstantInt>(C)) {
+      ID.APSIntVal = CI->getValue();
+      ID.Kind = ValID::t_APSInt;
+      return false;
+    }
+
+    if (auto *CFP = dyn_cast<ConstantFP>(C)) {
+      ID.APFloatVal = CFP->getValue();
+      ID.Kind = ValID::t_APFloat;
+      return false;
+    }
+
+    ID.ConstantVal = C;
+    ID.Kind = ValID::t_Constant;
+    return false;
+  }
+
   case lltok::kw_getelementptr:
   case lltok::kw_shufflevector:
   case lltok::kw_insertelement:
@@ -5740,14 +5767,15 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
       V = NoCFIValue::get(cast<GlobalValue>(V));
     return V == nullptr;
   case ValID::t_APSInt:
-    if (!Ty->isIntegerTy())
+    if (!Ty->isIntOrIntVectorTy())
       return error(ID.Loc, "integer constant must have integer type");
-    ID.APSIntVal = ID.APSIntVal.extOrTrunc(Ty->getPrimitiveSizeInBits());
-    V = ConstantInt::get(Context, ID.APSIntVal);
+    ID.APSIntVal = ID.APSIntVal.extOrTrunc(Ty->getScalarSizeInBits());
+    V = ConstantInt::get(Ty, ID.APSIntVal);
     return false;
-  case ValID::t_APFloat:
-    if (!Ty->isFloatingPointTy() ||
-        !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
+  case ValID::t_APFloat: {
+    Type *ScalarTy = Ty->getScalarType();
+    if (!ScalarTy->isFloatingPointTy() ||
+        !ConstantFP::isValueValidForType(ScalarTy, ID.APFloatVal))
       return error(ID.Loc, "floating point constant invalid for type");
 
     // The lexer has no type info, so builds all half, bfloat, float, and double
@@ -5756,13 +5784,13 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
       // Check for signaling before potentially converting and losing that info.
       bool IsSNAN = ID.APFloatVal.isSignaling();
       bool Ignored;
-      if (Ty->isHalfTy())
+      if (ScalarTy->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isBFloatTy())
+      else if (ScalarTy->isBFloatTy())
         ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isFloatTy())
+      else if (ScalarTy->isFloatTy())
         ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
                               &Ignored);
       if (IsSNAN) {
@@ -5774,13 +5802,15 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
                                          ID.APFloatVal.isNegative(), &Payload);
       }
     }
-    V = ConstantFP::get(Context, ID.APFloatVal);
 
-    if (V->getType() != Ty)
+    if (Type::getFloatingPointTy(Context, ID.APFloatVal.getSemantics()) !=
+        ScalarTy)
       return error(ID.Loc, "floating point constant does not have type '" +
                                getTypeString(Ty) + "'");
 
+    V = ConstantFP::get(Ty, ID.APFloatVal);
     return false;
+  }
   case ValID::t_Null:
     if (!Ty->isPointerTy())
       return error(ID.Loc, "null must be a pointer type");
@@ -5818,11 +5848,15 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
     V = PoisonValue::get(Ty);
     return false;
   case ValID::t_Constant:
-    if (ID.ConstantVal->getType() != Ty)
+    if (ID.ConstantVal->getType() != Ty->getScalarType())
       return error(ID.Loc, "constant expression type mismatch: got type '" +
                                getTypeString(ID.ConstantVal->getType()) +
-                               "' but expected '" + getTypeString(Ty) + "'");
-    V = ID.ConstantVal;
+                               "' but expected '" +
+                               getTypeString(Ty->getScalarType()) + "'");
+    if (auto *VTy = dyn_cast<VectorType>(Ty))
+      V = ConstantVector::getSplat(VTy->getElementCount(), ID.ConstantVal);
+    else
+      V = ID.ConstantVal;
     return false;
   case ValID::t_ConstantStruct:
   case ValID::t_PackedConstantStruct:
diff --git a/llvm/test/Assembler/constant-splat.ll b/llvm/test/Assembler/constant-splat.ll
new file mode 100644
index 0000000000000..565f896b098c6
--- /dev/null
+++ b/llvm/test/Assembler/constant-splat.ll
@@ -0,0 +1,67 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s
+
+; NOTE: Tests the expansion of the "splat" shorthand method to create vector
+; constants.  Future work will change how "splat" is expanded, ultimately
+; leading to a point where "splat" is emitted as the disassembly.
+
+ at my_global = external global i32
+
+; CHECK: @constant.splat.i1 = constant <1 x i1> <i1 true>
+ at constant.splat.i1 = constant <1 x i1> splat (i1 true)
+
+; CHECK: @constant.splat.i32 = constant <5 x i32> <i32 7, i32 7, i32 7, i32 7, i32 7>
+ at constant.splat.i32 = constant <5 x i32> splat (i32 7)
+
+; CHECK: @constant.splat.i128 = constant <2 x i128> <i128 85070591730234615870450834276742070272, i128 85070591730234615870450834276742070272>
+ at constant.splat.i128 = constant <2 x i128> splat (i128 85070591730234615870450834276742070272)
+
+; CHECK: @constant.splat.f16 = constant <4 x half> <half 0xHBC00, half 0xHBC00, half 0xHBC00, half 0xHBC00>
+ at constant.splat.f16 = constant <4 x half> splat (half 0xHBC00)
+
+; CHECK: @constant.splat.f32 = constant <5 x float> <float -2.000000e+00, float -2.000000e+00, float -2.000000e+00, float -2.000000e+00, float -2.000000e+00>
+ at constant.splat.f32 = constant <5 x float> splat (float -2.000000e+00)
+
+; CHECK: @constant.splat.f64 = constant <3 x double> <double -3.000000e+00, double -3.000000e+00, double -3.000000e+00>
+ at constant.splat.f64 = constant <3 x double> splat (double -3.000000e+00)
+
+; CHECK: @constant.splat.128 = constant <2 x fp128> <fp128 0xL00000000000000018000000000000000, fp128 0xL00000000000000018000000000000000>
+ at constant.splat.128 = constant <2 x fp128> splat (fp128 0xL00000000000000018000000000000000)
+
+; CHECK: @constant.splat.bf16 = constant <4 x bfloat> <bfloat 0xRC0A0, bfloat 0xRC0A0, bfloat 0xRC0A0, bfloat 0xRC0A0>
+ at constant.splat.bf16 = constant <4 x bfloat> splat (bfloat 0xRC0A0)
+
+; CHECK: @constant.splat.x86_fp80 = constant <3 x x86_fp80> <x86_fp80 0xK4000C8F5C28F5C28F800, x86_fp80 0xK4000C8F5C28F5C28F800, x86_fp80 0xK4000C8F5C28F5C28F800>
+ at constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800)
+
+; CHECK: @constant.splat.ppc_fp128 = constant <1 x ppc_fp128> <ppc_fp128 0xM80000000000000000000000000000000>
+ at constant.splat.ppc_fp128 = constant <1 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000)
+
+; CHECK: @constant.splat.global.ptr = constant <4 x ptr> <ptr @my_global, ptr @my_global, ptr @my_global, ptr @my_global>
+ at constant.splat.global.ptr = constant <4 x ptr> splat (ptr @my_global)
+
+define void @add_fixed_lenth_vector_splat_i32(<4 x i32> %a) {
+; CHECK: %add = add <4 x i32> %a, <i32 137, i32 137, i32 137, i32 137>
+  %add = add <4 x i32> %a, splat (i32 137)
+  ret void
+}
+
+define <4 x i32> @ret_fixed_lenth_vector_splat_i32() {
+; CHECK: ret <4 x i32> <i32 56, i32 56, i32 56, i32 56>
+  ret <4 x i32> splat (i32 56)
+}
+
+define void @add_fixed_lenth_vector_splat_double(<vscale x 2 x double> %a) {
+; CHECK: %add = fadd <vscale x 2 x double> %a, shufflevector (<vscale x 2 x double> insertelement (<vscale x 2 x double> poison, double 5.700000e+00, i64 0), <vscale x 2 x double> poison, <vscale x 2 x i32> zeroinitializer)
+  %add = fadd <vscale x 2 x double> %a, splat (double 5.700000e+00)
+  ret void
+}
+
+define <vscale x 4 x i32> @ret_scalable_vector_splat_i32() {
+; CHECK: ret <vscale x 4 x i32> shufflevector (<vscale x 4 x i32> insertelement (<vscale x 4 x i32> poison, i32 78, i64 0), <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer)
+  ret <vscale x 4 x i32> splat (i32 78)
+}
+
+define <vscale x 4 x ptr> @ret_scalable_vector_ptr() {
+; CHECK: ret <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @my_global, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer)
+  ret <vscale x 4 x ptr> splat (ptr @my_global)
+}



More information about the llvm-commits mailing list