[Mlir-commits] [mlir] 2b71df5 - [mlir][x86vector] AVX512-BF16 Convert packed F32 to BF16 (#125685)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 18 01:50:15 PST 2025


Author: Adam Siemieniuk
Date: 2025-02-18T10:50:11+01:00
New Revision: 2b71df5a74cb5bd67f3f34277749dc920fd35105

URL: https://github.com/llvm/llvm-project/commit/2b71df5a74cb5bd67f3f34277749dc920fd35105
DIFF: https://github.com/llvm/llvm-project/commit/2b71df5a74cb5bd67f3f34277749dc920fd35105.diff

LOG: [mlir][x86vector] AVX512-BF16 Convert packed F32 to BF16 (#125685)

Adds AVX512 bf16 conversion from packed f32 to bf16 elements.

Tests are slightly refactored to better follow file's convention.

Added: 
    mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir

Modified: 
    mlir/include/mlir/Dialect/X86Vector/X86Vector.td
    mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
    mlir/test/Dialect/X86Vector/roundtrip.mlir
    mlir/test/Target/LLVMIR/x86vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 16181d7e760db..566013e73f4b8 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -341,6 +341,46 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
   let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
 }
 
+//----------------------------------------------------------------------------//
+// Convert packed F32 to packed BF16
+//----------------------------------------------------------------------------//
+
+def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
+  AllElementCountsMatch<["a", "dst"]>]> {
+  let summary = "Convert packed F32 to packed BF16 Data.";
+  let description = [{
+    The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
+    to the proper LLVMAVX512BF16 operation `llvm.cvtneps2bf16` depending on
+    the width of MLIR vectors it is applied to.
+
+    #### From the Intel Intrinsics Guide:
+
+    Convert packed single-precision (32-bit) floating-point elements in `a` to
+    packed BF16 (16-bit) floating-point elements, and store the results in `dst`.
+
+    Example:
+    ```mlir
+    %dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
+  let assemblyFormat =
+    "$a attr-dict `:` type($a) `->` type($dst)";
+}
+
+def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
+    /*extension=*/"bf16"> {
+  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
+}
+
+def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
+    /*extension=*/"bf16"> {
+  let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
+}
+
 //===----------------------------------------------------------------------===//
 // AVX op definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 260ac9ce589a3..f1fbb39b97fc4 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -131,6 +131,39 @@ struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
   }
 };
 
+struct CvtPackedF32ToBF16Conversion
+    : public ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op> {
+  using ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto typeA = dyn_cast<VectorType>(op.getA().getType());
+    unsigned elemBitWidth = typeA.getElementTypeBitWidth();
+    unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
+
+    auto opType = op.getDst().getType();
+    auto opA = op.getA();
+
+    switch (opBitWidth) {
+    case 256: {
+      rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps256IntrOp>(op, opType, opA);
+      break;
+    }
+    case 512: {
+      rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps512IntrOp>(op, opType, opA);
+      break;
+    }
+    default: {
+      return rewriter.notifyMatchFailure(
+          op, "unsupported AVX512-BF16 packed f32 to bf16 variant");
+    }
+    }
+
+    return success();
+  }
+};
+
 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
 
@@ -202,8 +235,10 @@ using Registry = RegistryImpl<
 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   Registry::registerPatterns(converter, patterns);
-  patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
-               DotOpConversion>(converter);
+  patterns
+      .add<MaskCompressOpConversion, DotBF16OpConversion,
+           CvtPackedF32ToBF16Conversion, RsqrtOpConversion, DotOpConversion>(
+          converter);
 }
 
 void mlir::configureX86VectorLegalizeForExportTarget(
@@ -215,6 +250,9 @@ void mlir::configureX86VectorLegalizeForExportTarget(
   target.addLegalOp<DotBF16Ps256IntrOp>();
   target.addLegalOp<DotBF16Ps512IntrOp>();
   target.addIllegalOp<DotBF16Op>();
+  target.addLegalOp<CvtNeF32ToBF16Ps256IntrOp>();
+  target.addLegalOp<CvtNeF32ToBF16Ps512IntrOp>();
+  target.addIllegalOp<CvtPackedF32ToBF16Op>();
   target.addLegalOp<RsqrtIntrOp>();
   target.addIllegalOp<RsqrtOp>();
   target.addLegalOp<DotIntrOp>();

diff  --git a/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
new file mode 100644
index 0000000000000..c97c52f01c3b0
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
@@ -0,0 +1,24 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt %s \
+// RUN:   -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
+// RUN:   -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: llc -mcpu=sapphirerapids | \
+// RUN: FileCheck %s
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+    %a: vector<8xf32>) -> vector<8xbf16> {
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+  return %0 : vector<8xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
+// CHECK: vcvtneps2bf16{{.*}}%xmm
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+    %a: vector<16xf32>) -> vector<16xbf16> {
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+  return %0 : vector<16xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
+// CHECK: vcvtneps2bf16{{.*}}%ymm

diff  --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index ed9177eaec9ce..59be7dd75b3b0 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -70,6 +70,24 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+  %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+  // CHECK: x86vector.avx512.intr.cvtneps2bf16.256
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+  return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+  %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+  // CHECK: x86vector.avx512.intr.cvtneps2bf16.512
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+  return %0 : vector<16xbf16>
+}
+
 // CHECK-LABEL: func @avx_rsqrt
 func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
 {

diff  --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index cf74a7ee60255..0d00448c63da8 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -74,6 +74,26 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+  %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+  // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+  // CHECK-SAME: vector<8xf32> -> vector<8xbf16>
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+  return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+  %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+  // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+  // CHECK-SAME: vector<16xf32> -> vector<16xbf16>
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+  return %0 : vector<16xbf16>
+}
+
 // CHECK-LABEL: func @avx_rsqrt
 func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
 {

diff  --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 1df03f10c9321..db1c10cd5cd37 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
 
 // CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
-    %arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
+    %src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>
   ) -> vector<4xf32>
 {
   // CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
-  %0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
+  %0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b)
     : (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
   llvm.return %0 : vector<4xf32>
 }
 
 // CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
-    %arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
+    %src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>
   ) -> vector<8xf32>
 {
   // CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
-  %0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
+  %0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b)
     : (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
   llvm.return %0 : vector<8xf32>
 }
 
 // CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
-    %arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
+    %src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>
   ) -> vector<16xf32>
 {
   // CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
-  %0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
+  %0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b)
     : (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
   llvm.return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
+  %a: vector<8xf32>) -> vector<8xbf16>
+{
+  // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
+  %0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a)
+    : (vector<8xf32>) -> vector<8xbf16>
+  llvm.return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
+  %a: vector<16xf32>) -> vector<16xbf16>
+{
+  // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
+  %0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a)
+    : (vector<16xf32>) -> vector<16xbf16>
+  llvm.return %0 : vector<16xbf16>
+}
+
 // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
 llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
 {
@@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
 
 // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
 llvm.func @LLVM_x86_avx_dp_ps_256(
-    %arg0: vector<8xf32>, %arg1: vector<8xf32>
+    %a: vector<8xf32>, %b: vector<8xf32>
   ) -> vector<8xf32>
 {
   // CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
-  %0 = llvm.mlir.constant(-1 : i8) : i8
-  %1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
+  %c = llvm.mlir.constant(-1 : i8) : i8
+  %1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
   llvm.return %1 : vector<8xf32>
 }


        


More information about the Mlir-commits mailing list