[clang] [llvm] [SPIRV] add pre legalization instruction combine (PR #122839)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 15 10:12:42 PST 2025
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/122839
>From ae1a274f856518d710cffba324c603d9e95adf54 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Thu, 9 Jan 2025 19:19:27 -0500
Subject: [PATCH 1/2] [SPIRV] add pre legalization instruction combine - Add
the boilerplate to support instcombine in SPIRV - instcombine length(X-Y) to
distance(X,Y) - switch HLSL's distance intrinsic to not special case for
SPIRV. - fixes #122766
---
clang/include/clang/Basic/BuiltinsSPIRV.td | 6 +
clang/lib/CodeGen/CGBuiltin.cpp | 10 +
clang/lib/Headers/hlsl/hlsl_detail.h | 8 +-
clang/lib/Sema/SemaSPIRV.cpp | 18 ++
clang/test/CodeGenHLSL/builtins/distance.hlsl | 30 ++-
clang/test/CodeGenHLSL/builtins/length.hlsl | 95 +++++--
clang/test/CodeGenSPIRV/Builtins/length.c | 31 +++
clang/test/SemaSPIRV/BuiltIns/length-errors.c | 25 ++
llvm/lib/Target/SPIRV/CMakeLists.txt | 3 +
llvm/lib/Target/SPIRV/SPIRV.h | 2 +
llvm/lib/Target/SPIRV/SPIRV.td | 1 +
llvm/lib/Target/SPIRV/SPIRVCombine.td | 26 ++
.../SPIRV/SPIRVPreLegalizerCombiner.cpp | 252 ++++++++++++++++++
llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp | 2 +
.../CodeGen/SPIRV/hlsl-intrinsics/distance.ll | 77 +++---
llvm/test/CodeGen/SPIRV/opencl/distance.ll | 11 +
16 files changed, 525 insertions(+), 72 deletions(-)
create mode 100644 clang/test/CodeGenSPIRV/Builtins/length.c
create mode 100644 clang/test/SemaSPIRV/BuiltIns/length-errors.c
create mode 100644 llvm/lib/Target/SPIRV/SPIRVCombine.td
create mode 100644 llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
index 1e66939b822ef8..f72c555921dfe6 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -13,3 +13,9 @@ def SPIRVDistance : Builtin {
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}
+
+def SPIRVLength : Builtin {
+ let Spellings = ["__builtin_spirv_length"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 1b25d365932c30..f541ff0e7ef6f9 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -20487,6 +20487,16 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/X->getType()->getScalarType(), Intrinsic::spv_distance,
ArrayRef<Value *>{X, Y}, nullptr, "spv.distance");
}
+ case SPIRV::BI__builtin_spirv_length: {
+ Value *X = EmitScalarExpr(E->getArg(0));
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ "length operand must have a float representation");
+ assert(E->getArg(0)->getType()->isVectorType() &&
+ "length operand must be a vector");
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/X->getType()->getScalarType(), Intrinsic::spv_length,
+ ArrayRef<Value *>{X}, nullptr, "spv.length");
+ }
}
return nullptr;
}
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 3eb4a3dc861e36..b2c8cc6c5c3dbb 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -61,7 +61,11 @@ length_impl(T X) {
template <typename T, int N>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
length_vec_impl(vector<T, N> X) {
+#if (__has_builtin(__builtin_spirv_length))
+ return __builtin_spirv_length(X);
+#else
return __builtin_elementwise_sqrt(__builtin_hlsl_dot(X, X));
+#endif
}
template <typename T>
@@ -73,11 +77,7 @@ distance_impl(T X, T Y) {
template <typename T, int N>
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
-#if (__has_builtin(__builtin_spirv_distance))
- return __builtin_spirv_distance(X, Y);
-#else
return length_vec_impl(X - Y);
-#endif
}
} // namespace __detail
} // namespace hlsl
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index d2de64826c6eb3..dc49fc79073572 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -51,6 +51,24 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
TheCall->setType(RetTy);
break;
}
+ case SPIRV::BI__builtin_spirv_length: {
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+ auto *VTy = ArgTyA->getAs<VectorType>();
+ if (VTy == nullptr) {
+ SemaRef.Diag(A.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyA
+ << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+ QualType RetTy = VTy->getElementType();
+ TheCall->setType(RetTy);
+ break;
+ }
}
return false;
}
diff --git a/clang/test/CodeGenHLSL/builtins/distance.hlsl b/clang/test/CodeGenHLSL/builtins/distance.hlsl
index 6952700a87f1df..e830903261c8cf 100644
--- a/clang/test/CodeGenHLSL/builtins/distance.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/distance.hlsl
@@ -33,8 +33,9 @@ half test_distance_half(half X, half Y) { return distance(X, Y); }
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z19test_distance_half2Dv2_DhS_(
// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[X:%.*]], <2 x half> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
// SPVCHECK-NEXT: [[ENTRY:.*:]]
-// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.distance.v2f16(<2 x half> [[X]], <2 x half> [[Y]])
-// SPVCHECK-NEXT: ret half [[SPV_DISTANCE_I]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[X]], [[Y]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v2f16(<2 x half> [[SUB_I]])
+// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]]
//
half test_distance_half2(half2 X, half2 Y) { return distance(X, Y); }
@@ -49,8 +50,9 @@ half test_distance_half2(half2 X, half2 Y) { return distance(X, Y); }
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z19test_distance_half3Dv3_DhS_(
// SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[X:%.*]], <3 x half> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
// SPVCHECK-NEXT: [[ENTRY:.*:]]
-// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.distance.v3f16(<3 x half> [[X]], <3 x half> [[Y]])
-// SPVCHECK-NEXT: ret half [[SPV_DISTANCE_I]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[X]], [[Y]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v3f16(<3 x half> [[SUB_I]])
+// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]]
//
half test_distance_half3(half3 X, half3 Y) { return distance(X, Y); }
@@ -65,8 +67,9 @@ half test_distance_half3(half3 X, half3 Y) { return distance(X, Y); }
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z19test_distance_half4Dv4_DhS_(
// SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[X:%.*]], <4 x half> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
// SPVCHECK-NEXT: [[ENTRY:.*:]]
-// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.distance.v4f16(<4 x half> [[X]], <4 x half> [[Y]])
-// SPVCHECK-NEXT: ret half [[SPV_DISTANCE_I]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x half> [[X]], [[Y]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v4f16(<4 x half> [[SUB_I]])
+// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]]
//
half test_distance_half4(half4 X, half4 Y) { return distance(X, Y); }
@@ -97,8 +100,9 @@ float test_distance_float(float X, float Y) { return distance(X, Y); }
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z20test_distance_float2Dv2_fS_(
// SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[X:%.*]], <2 x float> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
// SPVCHECK-NEXT: [[ENTRY:.*:]]
-// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.distance.v2f32(<2 x float> [[X]], <2 x float> [[Y]])
-// SPVCHECK-NEXT: ret float [[SPV_DISTANCE_I]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> [[X]], [[Y]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v2f32(<2 x float> [[SUB_I]])
+// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]]
//
float test_distance_float2(float2 X, float2 Y) { return distance(X, Y); }
@@ -113,8 +117,9 @@ float test_distance_float2(float2 X, float2 Y) { return distance(X, Y); }
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z20test_distance_float3Dv3_fS_(
// SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[X:%.*]], <3 x float> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
// SPVCHECK-NEXT: [[ENTRY:.*:]]
-// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.distance.v3f32(<3 x float> [[X]], <3 x float> [[Y]])
-// SPVCHECK-NEXT: ret float [[SPV_DISTANCE_I]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> [[X]], [[Y]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v3f32(<3 x float> [[SUB_I]])
+// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]]
//
float test_distance_float3(float3 X, float3 Y) { return distance(X, Y); }
@@ -129,7 +134,8 @@ float test_distance_float3(float3 X, float3 Y) { return distance(X, Y); }
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z20test_distance_float4Dv4_fS_(
// SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[X:%.*]], <4 x float> noundef nofpclass(nan inf) [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
// SPVCHECK-NEXT: [[ENTRY:.*:]]
-// SPVCHECK-NEXT: [[SPV_DISTANCE_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.distance.v4f32(<4 x float> [[X]], <4 x float> [[Y]])
-// SPVCHECK-NEXT: ret float [[SPV_DISTANCE_I]]
+// SPVCHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> [[X]], [[Y]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v4f32(<4 x float> [[SUB_I]])
+// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]]
//
float test_distance_float4(float4 X, float4 Y) { return distance(X, Y); }
diff --git a/clang/test/CodeGenHLSL/builtins/length.hlsl b/clang/test/CodeGenHLSL/builtins/length.hlsl
index fcf3ee76ba5bbd..2d4bbd995298f2 100644
--- a/clang/test/CodeGenHLSL/builtins/length.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/length.hlsl
@@ -1,114 +1,163 @@
-// RUN: %clang_cc1 -finclude-default-header -triple \
-// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
-// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK,DXCHECK \
-// RUN: -DTARGET=dx
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -O1 -o - | FileCheck %s
// RUN: %clang_cc1 -finclude-default-header -triple \
// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
-// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefixes=CHECK,SPVCHECK \
-// RUN: -DTARGET=spv
+// RUN: -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z16test_length_halfDh(
// DXCHECK-LABEL: define noundef nofpclass(nan inf) half @_Z16test_length_halfDh(
+//
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z16test_length_halfDh(
// CHECK-SAME: half noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[ELT_ABS_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.fabs.f16(half [[P0]])
// CHECK-NEXT: ret half [[ELT_ABS_I]]
//
-
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z16test_length_halfDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[ELT_ABS_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.fabs.f16(half [[P0]])
+// SPVCHECK-NEXT: ret half [[ELT_ABS_I]]
+//
half test_length_half(half p0)
{
return length(p0);
}
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half2Dv2_Dh(
// DXCHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half2Dv2_Dh(
+//
+
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half2Dv2_Dh(
// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
-// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.[[TARGET]].fdot.v2f16(<2 x half> [[P0]], <2 x half> [[P0]])
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[P0]], <2 x half> [[P0]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.sqrt.f16(half [[HLSL_DOT_I]])
// CHECK-NEXT: ret half [[TMP0]]
//
-
-
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half2Dv2_Dh(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v2f16(<2 x half> [[P0]])
+// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]]
+//
half test_length_half2(half2 p0)
{
return length(p0);
}
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half3Dv3_Dh(
// DXCHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half3Dv3_Dh(
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half3Dv3_Dh(
// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
-// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.[[TARGET]].fdot.v3f16(<3 x half> [[P0]], <3 x half> [[P0]])
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[P0]], <3 x half> [[P0]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.sqrt.f16(half [[HLSL_DOT_I]])
// CHECK-NEXT: ret half [[TMP0]]
//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half3Dv3_Dh(
+// SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v3f16(<3 x half> [[P0]])
+// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]]
+//
half test_length_half3(half3 p0)
{
return length(p0);
}
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half4Dv4_Dh(
// DXCHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half4Dv4_Dh(
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_length_half4Dv4_Dh(
// CHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
-// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.[[TARGET]].fdot.v4f16(<4 x half> [[P0]], <4 x half> [[P0]])
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> [[P0]], <4 x half> [[P0]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.sqrt.f16(half [[HLSL_DOT_I]])
// CHECK-NEXT: ret half [[TMP0]]
//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_length_half4Dv4_Dh(
+// SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.length.v4f16(<4 x half> [[P0]])
+// SPVCHECK-NEXT: ret half [[SPV_LENGTH_I]]
+//
half test_length_half4(half4 p0)
{
return length(p0);
}
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z17test_length_floatf(
// DXCHECK-LABEL: define noundef nofpclass(nan inf) float @_Z17test_length_floatf(
+// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z17test_length_floatf(
// CHECK-SAME: float noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[ELT_ABS_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.fabs.f32(float [[P0]])
// CHECK-NEXT: ret float [[ELT_ABS_I]]
//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z17test_length_floatf(
+// SPVCHECK-SAME: float noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[ELT_ABS_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.fabs.f32(float [[P0]])
+// SPVCHECK-NEXT: ret float [[ELT_ABS_I]]
+//
float test_length_float(float p0)
{
return length(p0);
}
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float2Dv2_f(
// DXCHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float2Dv2_f(
+// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float2Dv2_f(
// CHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
-// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.[[TARGET]].fdot.v2f32(<2 x float> [[P0]], <2 x float> [[P0]])
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> [[P0]], <2 x float> [[P0]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.sqrt.f32(float [[HLSL_DOT_I]])
// CHECK-NEXT: ret float [[TMP0]]
//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float2Dv2_f(
+// SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v2f32(<2 x float> [[P0]])
+// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]]
+//
float test_length_float2(float2 p0)
{
return length(p0);
}
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float3Dv3_f(
// DXCHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float3Dv3_f(
+// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float3Dv3_f(
// CHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
-// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.[[TARGET]].fdot.v3f32(<3 x float> [[P0]], <3 x float> [[P0]])
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> [[P0]], <3 x float> [[P0]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.sqrt.f32(float [[HLSL_DOT_I]])
// CHECK-NEXT: ret float [[TMP0]]
//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float3Dv3_f(
+// SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v3f32(<3 x float> [[P0]])
+// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]]
+//
float test_length_float3(float3 p0)
{
return length(p0);
}
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float4Dv4_f(
// DXCHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float4Dv4_f(
+// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_length_float4Dv4_f(
// CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
-// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.[[TARGET]].fdot.v4f32(<4 x float> [[P0]], <4 x float> [[P0]])
+// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> [[P0]], <4 x float> [[P0]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.sqrt.f32(float [[HLSL_DOT_I]])
// CHECK-NEXT: ret float [[TMP0]]
//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_length_float4Dv4_f(
+// SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[P0:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT: [[ENTRY:.*:]]
+// SPVCHECK-NEXT: [[SPV_LENGTH_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.length.v4f32(<4 x float> [[P0]])
+// SPVCHECK-NEXT: ret float [[SPV_LENGTH_I]]
+//
float test_length_float4(float4 p0)
{
return length(p0);
diff --git a/clang/test/CodeGenSPIRV/Builtins/length.c b/clang/test/CodeGenSPIRV/Builtins/length.c
new file mode 100644
index 00000000000000..59e7c298dd8167
--- /dev/null
+++ b/clang/test/CodeGenSPIRV/Builtins/length.c
@@ -0,0 +1,31 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+
+// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
+
+typedef float float2 __attribute__((ext_vector_type(2)));
+typedef float float3 __attribute__((ext_vector_type(3)));
+typedef float float4 __attribute__((ext_vector_type(4)));
+
+// CHECK-LABEL: define spir_func float @test_length_float2(
+// CHECK-SAME: <2 x float> noundef [[X:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[SPV_LENGTH:%.*]] = tail call float @llvm.spv.length.v2f32(<2 x float> [[X]])
+// CHECK-NEXT: ret float [[SPV_LENGTH]]
+//
+float test_length_float2(float2 X) { return __builtin_spirv_length(X); }
+
+// CHECK-LABEL: define spir_func float @test_length_float3(
+// CHECK-SAME: <3 x float> noundef [[X:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[SPV_LENGTH:%.*]] = tail call float @llvm.spv.length.v3f32(<3 x float> [[X]])
+// CHECK-NEXT: ret float [[SPV_LENGTH]]
+//
+float test_length_float3(float3 X) { return __builtin_spirv_length(X); }
+
+// CHECK-LABEL: define spir_func float @test_length_float4(
+// CHECK-SAME: <4 x float> noundef [[X:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[SPV_LENGTH:%.*]] = tail call float @llvm.spv.length.v4f32(<4 x float> [[X]])
+// CHECK-NEXT: ret float [[SPV_LENGTH]]
+//
+float test_length_float4(float4 X) { return __builtin_spirv_length(X); }
diff --git a/clang/test/SemaSPIRV/BuiltIns/length-errors.c b/clang/test/SemaSPIRV/BuiltIns/length-errors.c
new file mode 100644
index 00000000000000..3244bd6737f116
--- /dev/null
+++ b/clang/test/SemaSPIRV/BuiltIns/length-errors.c
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 %s -triple spirv-pc-vulkan-compute -verify
+
+typedef float float2 __attribute__((ext_vector_type(2)));
+
+void test_too_few_arg()
+{
+ return __builtin_spirv_length();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+ return __builtin_spirv_length(p0, p0);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+float test_double_scalar_inputs(double p0) {
+ return __builtin_spirv_length(p0);
+ // expected-error at -1 {{passing 'double' to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(double)))) double' (vector of 2 'double' values)}}
+}
+
+float test_int_scalar_inputs(int p0) {
+ return __builtin_spirv_length(p0);
+ // expected-error at -1 {{passing 'int' to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(int)))) int' (vector of 2 'int' values)}}
+}
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index a79e19fcd753dc..efdd8c8d24fbd5 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -10,6 +10,8 @@ tablegen(LLVM SPIRVGenRegisterBank.inc -gen-register-bank)
tablegen(LLVM SPIRVGenRegisterInfo.inc -gen-register-info)
tablegen(LLVM SPIRVGenSubtargetInfo.inc -gen-subtarget)
tablegen(LLVM SPIRVGenTables.inc -gen-searchable-tables)
+tablegen(LLVM SPIRVGenPreLegalizeGICombiner.inc -gen-global-isel-combiner
+ -combiners="SPIRVPreLegalizerCombiner")
add_public_tablegen_target(SPIRVCommonTableGen)
@@ -33,6 +35,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVModuleAnalysis.cpp
SPIRVStructurizer.cpp
SPIRVPreLegalizer.cpp
+ SPIRVPreLegalizerCombiner.cpp
SPIRVPostLegalizer.cpp
SPIRVPrepareFunctions.cpp
SPIRVRegisterBankInfo.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 81b57202644256..6d00a046ff7caa 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -24,6 +24,7 @@ FunctionPass *createSPIRVStructurizerPass();
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
FunctionPass *createSPIRVRegularizerPass();
+FunctionPass *createSPIRVPreLegalizerCombiner();
FunctionPass *createSPIRVPreLegalizerPass();
FunctionPass *createSPIRVPostLegalizerPass();
ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
@@ -36,6 +37,7 @@ createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
void initializeSPIRVModuleAnalysisPass(PassRegistry &);
void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
void initializeSPIRVPreLegalizerPass(PassRegistry &);
+void initializeSPIRVPreLegalizerCombinerPass(PassRegistry &);
void initializeSPIRVPostLegalizerPass(PassRegistry &);
void initializeSPIRVStructurizerPass(PassRegistry &);
void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
diff --git a/llvm/lib/Target/SPIRV/SPIRV.td b/llvm/lib/Target/SPIRV/SPIRV.td
index 108c7e6d3861f0..39a4131c7f1bdf 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.td
+++ b/llvm/lib/Target/SPIRV/SPIRV.td
@@ -11,6 +11,7 @@ include "llvm/Target/Target.td"
include "SPIRVRegisterInfo.td"
include "SPIRVRegisterBanks.td"
include "SPIRVInstrInfo.td"
+include "SPIRVCombine.td"
include "SPIRVBuiltins.td"
def SPIRVInstrInfo : InstrInfo;
diff --git a/llvm/lib/Target/SPIRV/SPIRVCombine.td b/llvm/lib/Target/SPIRV/SPIRVCombine.td
new file mode 100644
index 00000000000000..11851894e2f752
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVCombine.td
@@ -0,0 +1,26 @@
+//=- SPIRVCombine.td - Define SPIRV Combine Rules -------------*-tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+include "llvm/Target/GlobalISel/Combine.td"
+
+
+def vector_length_sub_to_distance_lowering : GICombineRule <
+ (defs root:$root),
+ (match (wip_match_opcode G_INTRINSIC):$root,
+ [{ return matchLengthToDistance(*${root}, MRI); }]),
+ (apply [{ applySPIRVDistance(*${root}, MRI, B); }])
+>;
+
+def SPIRVPreLegalizerCombiner
+ : GICombiner<"SPIRVPreLegalizerCombinerImpl",
+ [vector_length_sub_to_distance_lowering]> {
+ let CombineAllMethodName = "tryCombineAllImpl";
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
new file mode 100644
index 00000000000000..54b65e8b04d622
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizerCombiner.cpp
@@ -0,0 +1,252 @@
+
+//===-- SPIRVPreLegalizerCombiner.cpp - combine legalization ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass does combining of machine instructions at the generic MI level,
+// before the legalizer.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRV.h"
+#include "SPIRVTargetMachine.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/CodeGen/GlobalISel/CSEInfo.h"
+#include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h"
+#include "llvm/CodeGen/GlobalISel/Combiner.h"
+#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
+#include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
+#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
+#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
+#include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
+#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
+#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
+#include "llvm/CodeGen/GlobalISel/Utils.h"
+#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
+
+#define GET_GICOMBINER_DEPS
+#include "SPIRVGenPreLegalizeGICombiner.inc"
+#undef GET_GICOMBINER_DEPS
+
+#define DEBUG_TYPE "spirv-prelegalizer-combiner"
+
+using namespace llvm;
+using namespace MIPatternMatch;
+
+namespace {
+
+#define GET_GICOMBINER_TYPES
+#include "SPIRVGenPreLegalizeGICombiner.inc"
+#undef GET_GICOMBINER_TYPES
+
+bool matchLengthToDistance(MachineInstr &MI, MachineRegisterInfo &MRI) {
+ if (MI.getOpcode() != TargetOpcode::G_INTRINSIC ||
+ cast<GIntrinsic>(MI).getIntrinsicID() != Intrinsic::spv_length)
+ return false;
+
+ // First operand of MI is `G_INTRINSIC` so start at operand 2.
+ Register SubAssignTypeReg = MI.getOperand(2).getReg();
+ MachineInstr *Sub1AssignTypeInst = MRI.getVRegDef(SubAssignTypeReg);
+ if (!Sub1AssignTypeInst ||
+ Sub1AssignTypeInst->getDesc().getOpcode() != SPIRV::ASSIGN_TYPE)
+ return false;
+
+ Register SubReg1 = Sub1AssignTypeInst->getOperand(1).getReg();
+ MachineInstr *SubInstr1 = MRI.getVRegDef(SubReg1);
+ if (!SubInstr1 || SubInstr1->getOpcode() != TargetOpcode::G_FSUB)
+ return false;
+
+ return true;
+}
+void applySPIRVDistance(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &B) {
+
+ // Extract the operands for X and Y from the match criteria.
+ Register SubAssignTypeReg = MI.getOperand(2).getReg();
+ MachineInstr *Sub1AssignTypeInst = MRI.getVRegDef(SubAssignTypeReg);
+ Register SubDestReg = Sub1AssignTypeInst->getOperand(1).getReg();
+ MachineInstr *SubInstr = MRI.getVRegDef(SubDestReg);
+ Register SubOperand1 = SubInstr->getOperand(1).getReg();
+ Register SubOperand2 = SubInstr->getOperand(2).getReg();
+
+ // Remove the original `spv_length` instruction.
+
+ Register ResultReg = MI.getOperand(0).getReg();
+ DebugLoc DL = MI.getDebugLoc();
+ MachineBasicBlock &MBB = *MI.getParent();
+ MachineBasicBlock::iterator InsertPt = MI.getIterator();
+
+ // Build the `spv_distance` intrinsic.
+ MachineInstrBuilder NewInstr =
+ BuildMI(MBB, InsertPt, DL, B.getTII().get(TargetOpcode::G_INTRINSIC));
+ NewInstr
+ .addDef(ResultReg) // Result register
+ .addIntrinsicID(Intrinsic::spv_distance) // Intrinsic ID
+ .addUse(SubOperand1) // Operand X
+ .addUse(SubOperand2); // Operand Y
+
+ auto RemoveAllUses = [&](Register Reg) {
+ for (auto &UseMI : MRI.use_instructions(Reg)) {
+ UseMI.eraseFromParent();
+ }
+ };
+
+ RemoveAllUses(
+ SubAssignTypeReg); // remove all uses of FSUB ASSIGN_TYPE register
+ MI.eraseFromParent(); // remove spv_length intrinsic
+ RemoveAllUses(SubDestReg); // remove all uses of FSUB Result
+ SubInstr->eraseFromParent(); // remove FSUB instruction
+}
+
+class SPIRVPreLegalizerCombinerImpl : public Combiner {
+protected:
+ const CombinerHelper Helper;
+ const SPIRVPreLegalizerCombinerImplRuleConfig &RuleConfig;
+ const SPIRVSubtarget &STI;
+
+public:
+ SPIRVPreLegalizerCombinerImpl(
+ MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
+ GISelKnownBits &KB, GISelCSEInfo *CSEInfo,
+ const SPIRVPreLegalizerCombinerImplRuleConfig &RuleConfig,
+ const SPIRVSubtarget &STI, MachineDominatorTree *MDT,
+ const LegalizerInfo *LI);
+
+ static const char *getName() { return "SPIRV00PreLegalizerCombiner"; }
+
+ bool tryCombineAll(MachineInstr &I) const override;
+
+ bool tryCombineAllImpl(MachineInstr &I) const;
+
+private:
+#define GET_GICOMBINER_CLASS_MEMBERS
+#include "SPIRVGenPreLegalizeGICombiner.inc"
+#undef GET_GICOMBINER_CLASS_MEMBERS
+};
+
+#define GET_GICOMBINER_IMPL
+#include "SPIRVGenPreLegalizeGICombiner.inc"
+#undef GET_GICOMBINER_IMPL
+
+SPIRVPreLegalizerCombinerImpl::SPIRVPreLegalizerCombinerImpl(
+ MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
+ GISelKnownBits &KB, GISelCSEInfo *CSEInfo,
+ const SPIRVPreLegalizerCombinerImplRuleConfig &RuleConfig,
+ const SPIRVSubtarget &STI, MachineDominatorTree *MDT,
+ const LegalizerInfo *LI)
+ : Combiner(MF, CInfo, TPC, &KB, CSEInfo),
+ Helper(Observer, B, /*IsPreLegalize*/ true, &KB, MDT, LI),
+ RuleConfig(RuleConfig), STI(STI),
+#define GET_GICOMBINER_CONSTRUCTOR_INITS
+#include "SPIRVGenPreLegalizeGICombiner.inc"
+#undef GET_GICOMBINER_CONSTRUCTOR_INITS
+{
+}
+
+bool SPIRVPreLegalizerCombinerImpl::tryCombineAll(MachineInstr &MI) const {
+ return tryCombineAllImpl(MI);
+}
+
+// Pass boilerplate
+// ================
+
+class SPIRVPreLegalizerCombiner : public MachineFunctionPass {
+public:
+ static char ID;
+
+ SPIRVPreLegalizerCombiner();
+
+ StringRef getPassName() const override { return "SPIRVPreLegalizerCombiner"; }
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+private:
+ SPIRVPreLegalizerCombinerImplRuleConfig RuleConfig;
+};
+
+} // end anonymous namespace
+
+void SPIRVPreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<TargetPassConfig>();
+ AU.setPreservesCFG();
+ getSelectionDAGFallbackAnalysisUsage(AU);
+ AU.addRequired<GISelKnownBitsAnalysis>();
+ AU.addPreserved<GISelKnownBitsAnalysis>();
+ AU.addRequired<MachineDominatorTreeWrapperPass>();
+ AU.addPreserved<MachineDominatorTreeWrapperPass>();
+ AU.addRequired<GISelCSEAnalysisWrapperPass>();
+ AU.addPreserved<GISelCSEAnalysisWrapperPass>();
+ MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+SPIRVPreLegalizerCombiner::SPIRVPreLegalizerCombiner()
+ : MachineFunctionPass(ID) {
+ initializeSPIRVPreLegalizerCombinerPass(*PassRegistry::getPassRegistry());
+
+ if (!RuleConfig.parseCommandLineOption())
+ report_fatal_error("Invalid rule identifier");
+}
+
+bool SPIRVPreLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
+ if (MF.getProperties().hasProperty(
+ MachineFunctionProperties::Property::FailedISel))
+ return false;
+ auto &TPC = getAnalysis<TargetPassConfig>();
+
+ // Enable CSE.
+ GISelCSEAnalysisWrapper &Wrapper =
+ getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
+ auto *CSEInfo = &Wrapper.get(TPC.getCSEConfig());
+
+ const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
+ const auto *LI = ST.getLegalizerInfo();
+
+ const Function &F = MF.getFunction();
+ bool EnableOpt =
+ MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F);
+ GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
+ MachineDominatorTree *MDT =
+ &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
+ CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
+ /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(),
+ F.hasMinSize());
+ // Disable fixed-point iteration to reduce compile-time
+ CInfo.MaxIterations = 1;
+ CInfo.ObserverLvl = CombinerInfo::ObserverLevel::SinglePass;
+ // This is the first Combiner, so the input IR might contain dead
+ // instructions.
+ CInfo.EnableFullDCE = true;
+ SPIRVPreLegalizerCombinerImpl Impl(MF, CInfo, &TPC, *KB, CSEInfo, RuleConfig,
+ ST, MDT, LI);
+ return Impl.combineMachineInstrs();
+}
+
+char SPIRVPreLegalizerCombiner::ID = 0;
+INITIALIZE_PASS_BEGIN(SPIRVPreLegalizerCombiner, DEBUG_TYPE,
+ "Combine SPIRV machine instrs before legalization", false,
+ false)
+INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
+INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis)
+INITIALIZE_PASS_DEPENDENCY(GISelCSEAnalysisWrapperPass)
+INITIALIZE_PASS_END(SPIRVPreLegalizerCombiner, DEBUG_TYPE,
+ "Combine SPIRV machine instrs before legalization", false,
+ false)
+
+namespace llvm {
+FunctionPass *createSPIRVPreLegalizerCombiner() {
+ return new SPIRVPreLegalizerCombiner();
+}
+} // end namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index dca67cb6c632bd..c9cee09cafca3f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -48,6 +48,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() {
initializeSPIRVModuleAnalysisPass(PR);
initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PR);
initializeSPIRVStructurizerPass(PR);
+ initializeSPIRVPreLegalizerCombinerPass(PR);
}
static std::string computeDataLayout(const Triple &TT) {
@@ -218,6 +219,7 @@ bool SPIRVPassConfig::addIRTranslator() {
void SPIRVPassConfig::addPreLegalizeMachineIR() {
addPass(createSPIRVPreLegalizerPass());
+ addPass(createSPIRVPreLegalizerCombiner());
}
// Use the default legalizer.
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/distance.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/distance.ll
index 85a24a0127ae04..fac5d5f9fbd0d2 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/distance.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/distance.ll
@@ -1,33 +1,44 @@
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-
-; Make sure SPIRV operation function calls for distance are lowered correctly.
-
-; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
-; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
-; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
-; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
-; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
-
-define noundef half @distance_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
-entry:
- ; CHECK: %[[#]] = OpFunction %[[#float_16]] None %[[#]]
- ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
- ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
- ; CHECK: %[[#]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]]
- %spv.distance = call half @llvm.spv.distance.f16(<4 x half> %a, <4 x half> %b)
- ret half %spv.distance
-}
-
-define noundef float @distance_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
-entry:
- ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#]]
- ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
- ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
- ; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]]
- %spv.distance = call float @llvm.spv.distance.f32(<4 x float> %a, <4 x float> %b)
- ret float %spv.distance
-}
-
-declare half @llvm.spv.distance.f16(<4 x half>, <4 x half>)
-declare float @llvm.spv.distance.f32(<4 x float>, <4 x float>)
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure SPIRV operation function calls for distance are lowered correctly.
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+
+define noundef half @distance_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#float_16]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
+ ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
+ ; CHECK: %[[#]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]]
+ %spv.distance = call half @llvm.spv.distance.f16(<4 x half> %a, <4 x half> %b)
+ ret half %spv.distance
+}
+
+define noundef float @distance_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]]
+ %spv.distance = call float @llvm.spv.distance.f32(<4 x float> %a, <4 x float> %b)
+ ret float %spv.distance
+}
+
+define noundef float @distance_instcombine_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] Distance %[[#arg0]] %[[#arg1]]
+ %delta = fsub <4 x float> %a, %b
+ %spv.length = call float @llvm.spv.length.f32(<4 x float> %delta)
+ ret float %spv.length
+}
+
+declare half @llvm.spv.distance.f16(<4 x half>, <4 x half>)
+declare float @llvm.spv.distance.f32(<4 x float>, <4 x float>)
diff --git a/llvm/test/CodeGen/SPIRV/opencl/distance.ll b/llvm/test/CodeGen/SPIRV/opencl/distance.ll
index ac18804c00c9ab..ed329175e9c07f 100644
--- a/llvm/test/CodeGen/SPIRV/opencl/distance.ll
+++ b/llvm/test/CodeGen/SPIRV/opencl/distance.ll
@@ -30,5 +30,16 @@ entry:
ret float %spv.distance
}
+define noundef float @distance_instcombine_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#float_32]] None %[[#]]
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+ ; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_cl]] distance %[[#arg0]] %[[#arg1]]
+ %delta = fsub <4 x float> %a, %b
+ %spv.length = call float @llvm.spv.length.f32(<4 x float> %delta)
+ ret float %spv.length
+}
+
declare half @llvm.spv.distance.f16(<4 x half>, <4 x half>)
declare float @llvm.spv.distance.f32(<4 x float>, <4 x float>)
>From 0d1bcdb1270c658000a8b881310fa98edfcec1f5 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Wed, 15 Jan 2025 13:12:25 -0500
Subject: [PATCH 2/2] address pr comments
---
llvm/lib/Target/SPIRV/SPIRVCombine.td | 4 ----
1 file changed, 4 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCombine.td b/llvm/lib/Target/SPIRV/SPIRVCombine.td
index 11851894e2f752..6f726e024de525 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCombine.td
+++ b/llvm/lib/Target/SPIRV/SPIRVCombine.td
@@ -4,10 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-//===----------------------------------------------------------------------===//
-//
-//
-//===----------------------------------------------------------------------===//
include "llvm/Target/GlobalISel/Combine.td"
More information about the llvm-commits
mailing list