[Mlir-commits] [mlir] 2b71df5 - [mlir][x86vector] AVX512-BF16 Convert packed F32 to BF16 (#125685)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 18 01:50:15 PST 2025
Author: Adam Siemieniuk
Date: 2025-02-18T10:50:11+01:00
New Revision: 2b71df5a74cb5bd67f3f34277749dc920fd35105
URL: https://github.com/llvm/llvm-project/commit/2b71df5a74cb5bd67f3f34277749dc920fd35105
DIFF: https://github.com/llvm/llvm-project/commit/2b71df5a74cb5bd67f3f34277749dc920fd35105.diff
LOG: [mlir][x86vector] AVX512-BF16 Convert packed F32 to BF16 (#125685)
Adds AVX512 bf16 conversion from packed f32 to bf16 elements.
Tests are slightly refactored to better follow file's convention.
Added:
mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
Modified:
mlir/include/mlir/Dialect/X86Vector/X86Vector.td
mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
mlir/test/Dialect/X86Vector/roundtrip.mlir
mlir/test/Target/LLVMIR/x86vector.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 16181d7e760db..566013e73f4b8 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -341,6 +341,46 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
}
+//----------------------------------------------------------------------------//
+// Convert packed F32 to packed BF16
+//----------------------------------------------------------------------------//
+
+def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
+ AllElementCountsMatch<["a", "dst"]>]> {
+ let summary = "Convert packed F32 to packed BF16 Data.";
+ let description = [{
+ The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
+ to the proper LLVMAVX512BF16 operation `llvm.cvtneps2bf16` depending on
+ the width of MLIR vectors it is applied to.
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed single-precision (32-bit) floating-point elements in `a` to
+ packed BF16 (16-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ ```
+ }];
+ let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict `:` type($a) `->` type($dst)";
+}
+
+def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
+ /*extension=*/"bf16"> {
+ let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
+}
+
+def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
+ /*extension=*/"bf16"> {
+ let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
+ let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
+}
+
//===----------------------------------------------------------------------===//
// AVX op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 260ac9ce589a3..f1fbb39b97fc4 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -131,6 +131,39 @@ struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
}
};
+struct CvtPackedF32ToBF16Conversion
+ : public ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op> {
+ using ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto typeA = dyn_cast<VectorType>(op.getA().getType());
+ unsigned elemBitWidth = typeA.getElementTypeBitWidth();
+ unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
+
+ auto opType = op.getDst().getType();
+ auto opA = op.getA();
+
+ switch (opBitWidth) {
+ case 256: {
+ rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps256IntrOp>(op, opType, opA);
+ break;
+ }
+ case 512: {
+ rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps512IntrOp>(op, opType, opA);
+ break;
+ }
+ default: {
+ return rewriter.notifyMatchFailure(
+ op, "unsupported AVX512-BF16 packed f32 to bf16 variant");
+ }
+ }
+
+ return success();
+ }
+};
+
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
@@ -202,8 +235,10 @@ using Registry = RegistryImpl<
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
- patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
- DotOpConversion>(converter);
+ patterns
+ .add<MaskCompressOpConversion, DotBF16OpConversion,
+ CvtPackedF32ToBF16Conversion, RsqrtOpConversion, DotOpConversion>(
+ converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
@@ -215,6 +250,9 @@ void mlir::configureX86VectorLegalizeForExportTarget(
target.addLegalOp<DotBF16Ps256IntrOp>();
target.addLegalOp<DotBF16Ps512IntrOp>();
target.addIllegalOp<DotBF16Op>();
+ target.addLegalOp<CvtNeF32ToBF16Ps256IntrOp>();
+ target.addLegalOp<CvtNeF32ToBF16Ps512IntrOp>();
+ target.addIllegalOp<CvtPackedF32ToBF16Op>();
target.addLegalOp<RsqrtIntrOp>();
target.addIllegalOp<RsqrtOp>();
target.addLegalOp<DotIntrOp>();
diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
new file mode 100644
index 0000000000000..c97c52f01c3b0
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
@@ -0,0 +1,24 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt %s \
+// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: llc -mcpu=sapphirerapids | \
+// RUN: FileCheck %s
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+ %a: vector<8xf32>) -> vector<8xbf16> {
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
+// CHECK: vcvtneps2bf16{{.*}}%xmm
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+ %a: vector<16xf32>) -> vector<16xbf16> {
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+ return %0 : vector<16xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
+// CHECK: vcvtneps2bf16{{.*}}%ymm
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index ed9177eaec9ce..59be7dd75b3b0 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -70,6 +70,24 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
return %0 : vector<16xf32>
}
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+ %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+ // CHECK: x86vector.avx512.intr.cvtneps2bf16.256
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+ %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+ // CHECK: x86vector.avx512.intr.cvtneps2bf16.512
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+ return %0 : vector<16xbf16>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index cf74a7ee60255..0d00448c63da8 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -74,6 +74,26 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
return %0 : vector<16xf32>
}
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+ %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+ // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+ // CHECK-SAME: vector<8xf32> -> vector<8xbf16>
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+ %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+ // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+ // CHECK-SAME: vector<16xf32> -> vector<16xbf16>
+ %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+ return %0 : vector<16xbf16>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 1df03f10c9321..db1c10cd5cd37 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
// CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
- %arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
+ %src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>
) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
- %0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
+ %0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b)
: (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
llvm.return %0 : vector<4xf32>
}
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
- %arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
+ %src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
- %0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
+ %0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b)
: (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
llvm.return %0 : vector<8xf32>
}
// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
- %arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
+ %src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>
) -> vector<16xf32>
{
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
- %0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
+ %0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b)
: (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
llvm.return %0 : vector<16xf32>
}
+// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
+ %a: vector<8xf32>) -> vector<8xbf16>
+{
+ // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
+ %0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a)
+ : (vector<8xf32>) -> vector<8xbf16>
+ llvm.return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
+ %a: vector<16xf32>) -> vector<16xbf16>
+{
+ // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
+ %0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a)
+ : (vector<16xf32>) -> vector<16xbf16>
+ llvm.return %0 : vector<16xbf16>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
@@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
llvm.func @LLVM_x86_avx_dp_ps_256(
- %arg0: vector<8xf32>, %arg1: vector<8xf32>
+ %a: vector<8xf32>, %b: vector<8xf32>
) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
- %0 = llvm.mlir.constant(-1 : i8) : i8
- %1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
+ %c = llvm.mlir.constant(-1 : i8) : i8
+ %1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
llvm.return %1 : vector<8xf32>
}
More information about the Mlir-commits
mailing list