[Mlir-commits] [mlir] [mlir][x86vector] AVX Convert/Broadcast BF16 to F32 instructions (PR #137917)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 29 19:53:46 PDT 2025
https://github.com/arun-thmn created https://github.com/llvm/llvm-project/pull/137917
Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: https://github.com/llvm/llvm-project/pull/136830). The instructions that are added:
- VBCSTNESH2PS
- VCVTNEEPH2PS
- VCVTNEOPH2PS
>From fd523e58f9ff41ab997ce836fa47fceb9f4b1acc Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 29 Apr 2025 03:26:41 -0700
Subject: [PATCH 1/2] new avx2 f16 ops in x86vector dialect to handle f16
conversions to f32
---
.../mlir/Dialect/X86Vector/X86Vector.td | 122 ++++++++++++++++++
.../Dialect/X86Vector/IR/X86VectorDialect.cpp | 17 +++
2 files changed, 139 insertions(+)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 126fa0e352656..37bdfc18a17a3 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -527,4 +527,126 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
}
+//----------------------------------------------------------------------------//
+// AVX: Convert packed F16 even-indexed/odd-indexed elements into packed F32
+//----------------------------------------------------------------------------//
+
+def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed F16 even-indexed elements into packed F32 Data.";
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed F16 (16-bit) floating-point even-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneeph2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed F16 odd-indexed elements into packed F32 Data.";
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed F16 (16-bit) floating-point odd-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneoph2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+//----------------------------------------------------------------------------//
+// AVX: Convert F16 to F32 and broadcast into packed F32
+//----------------------------------------------------------------------------//
+
+def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Broadcasts F16 into packed F32 Data.";
+
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert scalar F16 (16-bit) floating-point element stored at memory locations
+ starting at location `__A` to a single-precision (32-bit) floating-point,
+ broadcast it to packed single-precision (32-bit) floating-point elements,
+ and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.bcst.f16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vbcstnesh2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+
+}
+
#endif // X86VECTOR_OPS
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index f5e5070c74f8f..2e01a11921950 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -112,5 +112,22 @@ x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
+SmallVector<Value>
+x86vector::CvtPackedEvenIndexedF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value>
+x86vector::CvtPackedOddIndexedF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value> x86vector::BcstF16ToPackedF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
>From 96e1debc70ac168cd84736951d7bd332cbfd741f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 29 Apr 2025 19:48:47 -0700
Subject: [PATCH 2/2] adding new test-cases
---
.../Dialect/X86Vector/legalize-for-llvm.mlir | 54 +++++++++++++++++
mlir/test/Dialect/X86Vector/roundtrip.mlir | 60 +++++++++++++++++++
mlir/test/Target/LLVMIR/x86vector.mlir | 54 +++++++++++++++++
3 files changed, 168 insertions(+)
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 93b304c44de8e..3888ec05ad866 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -149,6 +149,60 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128
+func.func @avxf16_bsct_f16_to_f32_packed_128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256
+func.func @avxf16_bsct_f16_to_f32_packed_256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index b783cc869b981..a2fdb0cf6d457 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -154,6 +154,66 @@ func.func @avxbf16_bcst_bf16_to_f32_256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_128
+func.func @avxf16_bcst_f16_to_f32_128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_256
+func.func @avxf16_bcst_f16_to_f32_256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index a8bc180d1d0ac..f474ae281ece3 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -163,6 +163,60 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneeph2ps128
+func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneeph2ps256
+func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneoph2ps128
+func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneoph2ps256
+func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vbcstnesh2ps128
+func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vbcstnesh2ps256
+func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
More information about the Mlir-commits
mailing list