[Mlir-commits] [mlir] [mlir][x86vector] AVX512-BF16 Dot op (PR #124800)
Adam Siemieniuk
llvmlistbot at llvm.org
Wed Jan 29 03:11:06 PST 2025
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/124800
>From 62280ae45e5de52d927aabc2eea1365c33cd134e Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 28 Jan 2025 12:44:36 +0100
Subject: [PATCH 01/10] AVX512 dot td
---
.../mlir/Dialect/X86Vector/X86Vector.td | 27 +++++++++++++++++++
1 file changed, 27 insertions(+)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index fa3f0ee0460b1d..bab65a90dca45f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -271,6 +271,33 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
VectorOfLengthAndType<[8], [I64]>:$b);
}
+//----------------------------------------------------------------------------//
+// AVX512 Dot
+//----------------------------------------------------------------------------//
+
+def AVX512DotOp : AVX512_Op<"dot", [Pure,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "dst"]>]> {
+ let summary = "Dot op";
+ let description = [{
+ The `dot` op is an AVX512 specific op that can lower to the
+ `llvm.dpbf16ps.512` instruction.
+
+ #### From the Intel Intrinsics Guide:
+
+ Compute dot-product of BF16 (16-bit) floating-point pairs in `a` and `b`,
+ accumulating the intermediate single-precision (32-bit) floating-point
+ elements with elements in `src`, and store the results in `dst`.
+ }];
+ let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
+ VectorOfLengthAndType<[32], [BF16]>:$a,
+ VectorOfLengthAndType<[32], [BF16]>:$b
+ );
+ let results = (outs VectorOfLengthAndType<[16], [F32]>:$dst);
+ let assemblyFormat =
+ "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
+}
+
//===----------------------------------------------------------------------===//
// AVX op definitions
//===----------------------------------------------------------------------===//
>From 3af7812031b604f0ac90047d9421d64aaac41b6f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 28 Jan 2025 15:30:17 +0100
Subject: [PATCH 02/10] AVX512-BF16 td
---
.../mlir/Dialect/X86Vector/X86Vector.td | 76 ++++++++++++++++---
1 file changed, 66 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index bab65a90dca45f..71c7f85159e0d6 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -271,17 +271,46 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
VectorOfLengthAndType<[8], [I64]>:$b);
}
+//===----------------------------------------------------------------------===//
+// AVX512-BF16 op definitions
+//===----------------------------------------------------------------------===//
+
+// Operation that is part of the input dialect.
+class AVX512BF16_Op<string mnemonic, list<Trait> traits = []> :
+ Op<X86Vector_Dialect, "avx512bf16." # mnemonic, traits> {}
+
+// Intrinsic operation used during lowering to LLVM IR.
+class AVX512BF16_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
+ LLVM_IntrOpBase<X86Vector_Dialect, "avx512bf16.intr." # mnemonic,
+ "x86_avx512bf16_" # !subst(".", "_", mnemonic),
+ [], [], traits, numResults>;
+
+// Defined by first result overload. May have to be extended for other
+// instructions in the future.
+class AVX512BF16_IntrOverloadedOp<string mnemonic,
+ list<Trait> traits = []> :
+ LLVM_IntrOpBase<X86Vector_Dialect, "avx512bf16.intr." # mnemonic,
+ "x86_avx512bf16_" # !subst(".", "_", mnemonic),
+ /*list<int> overloadedResults=*/[0],
+ /*list<int> overloadedOperands=*/[],
+ traits, /*numResults=*/1>;
+
//----------------------------------------------------------------------------//
-// AVX512 Dot
+// AVX512-BF16 Dot
//----------------------------------------------------------------------------//
-def AVX512DotOp : AVX512_Op<"dot", [Pure,
+def DotBF16Op : AVX512BF16_Op<"dot", [Pure,
AllTypesMatch<["a", "b"]>,
- AllTypesMatch<["src", "dst"]>]> {
- let summary = "Dot op";
+ AllTypesMatch<["src", "dst"]>,
+ TypesMatchWith<"`a` has twice an many elements as `src`",
+ "src", "a",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
+ "BFloat16Type::get($_self.getContext()))">]> {
+ let summary = "Dot BF16 op";
let description = [{
- The `dot` op is an AVX512 specific op that can lower to the
- `llvm.dpbf16ps.512` instruction.
+ The `dot` op is an AVX512-BF16 specific op that can lower to the proper
+ LLVMAVX512BF16 operation `llvm.dpbf16ps` depending on the width of MLIR
+ vectors it is applied to.
#### From the Intel Intrinsics Guide:
@@ -289,15 +318,42 @@ def AVX512DotOp : AVX512_Op<"dot", [Pure,
accumulating the intermediate single-precision (32-bit) floating-point
elements with elements in `src`, and store the results in `dst`.
}];
- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
- VectorOfLengthAndType<[32], [BF16]>:$a,
- VectorOfLengthAndType<[32], [BF16]>:$b
+ let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
+ VectorOfLengthAndType<[8, 16, 32], [BF16]>:$a,
+ VectorOfLengthAndType<[8, 16, 32], [BF16]>:$b
);
- let results = (outs VectorOfLengthAndType<[16], [F32]>:$dst);
+ let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
let assemblyFormat =
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
}
+def DotBF16Ps128IntrOp : AVX512BF16_IntrOp<"dpbf16ps.128", 1, [Pure,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "res"]>]> {
+ let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
+ VectorOfLengthAndType<[8], [BF16]>:$a,
+ VectorOfLengthAndType<[8], [BF16]>:$b);
+ let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
+}
+
+def DotBF16Ps256IntrOp : AVX512BF16_IntrOp<"dpbf16ps.256", 1, [Pure,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "res"]>]> {
+ let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
+ VectorOfLengthAndType<[16], [BF16]>:$a,
+ VectorOfLengthAndType<[16], [BF16]>:$b);
+ let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
+}
+
+def DotBF16Ps512IntrOp : AVX512BF16_IntrOp<"dpbf16ps.512", 1, [Pure,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "res"]>]> {
+ let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
+ VectorOfLengthAndType<[32], [BF16]>:$a,
+ VectorOfLengthAndType<[32], [BF16]>:$b);
+ let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
+}
+
//===----------------------------------------------------------------------===//
// AVX op definitions
//===----------------------------------------------------------------------===//
>From 00ff34e2412838c3e5d91f67c931095457d1a82d Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 28 Jan 2025 16:20:36 +0100
Subject: [PATCH 03/10] dotbf16 conversion
---
.../Transforms/LegalizeForLLVMExport.cpp | 49 ++++++++++++++++++-
1 file changed, 47 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index e918473cae9e3a..260ac9ce589a38 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -90,6 +90,47 @@ struct MaskCompressOpConversion
}
};
+struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
+ using ConvertOpToLLVMPattern<DotBF16Op>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(DotBF16Op op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto typeA = dyn_cast<VectorType>(op.getA().getType());
+ unsigned elemBitWidth = typeA.getElementTypeBitWidth();
+ unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
+
+ auto opType = adaptor.getSrc().getType();
+ auto opSrc = adaptor.getSrc();
+ auto opA = adaptor.getA();
+ auto opB = adaptor.getB();
+
+ switch (opBitWidth) {
+ case 128: {
+ rewriter.replaceOpWithNewOp<DotBF16Ps128IntrOp>(op, opType, opSrc, opA,
+ opB);
+ break;
+ }
+ case 256: {
+ rewriter.replaceOpWithNewOp<DotBF16Ps256IntrOp>(op, opType, opSrc, opA,
+ opB);
+ break;
+ }
+ case 512: {
+ rewriter.replaceOpWithNewOp<DotBF16Ps512IntrOp>(op, opType, opSrc, opA,
+ opB);
+ break;
+ }
+ default: {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported AVX512-BF16 dot variant");
+ }
+ }
+
+ return success();
+ }
+};
+
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
@@ -161,8 +202,8 @@ using Registry = RegistryImpl<
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
- patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
- converter);
+ patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
+ DotOpConversion>(converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
@@ -170,6 +211,10 @@ void mlir::configureX86VectorLegalizeForExportTarget(
Registry::configureTarget(target);
target.addLegalOp<MaskCompressIntrOp>();
target.addIllegalOp<MaskCompressOp>();
+ target.addLegalOp<DotBF16Ps128IntrOp>();
+ target.addLegalOp<DotBF16Ps256IntrOp>();
+ target.addLegalOp<DotBF16Ps512IntrOp>();
+ target.addIllegalOp<DotBF16Op>();
target.addLegalOp<RsqrtIntrOp>();
target.addIllegalOp<RsqrtOp>();
target.addLegalOp<DotIntrOp>();
>From f5e32c5c2e851ab72ed94d43666256fc4b9fc3f4 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 28 Jan 2025 17:56:49 +0100
Subject: [PATCH 04/10] avx dot-b16 test
---
.../Vector/CPU/X86Vector/dot-bf16.mlir | 25 +++++++++++++++++++
1 file changed, 25 insertions(+)
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
new file mode 100644
index 00000000000000..fe333f49fc8e14
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="avx512bf16" --dlopen=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() -> i32 {
+ %i0 = arith.constant 0 : i32
+ %i3 = arith.constant 3 : i32
+
+ %src = arith.constant dense<1.0> : vector<4xf32>
+ %a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xbf16>
+ %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xbf16>
+ %dst = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+
+ %1 = vector.extractelement %dst[%i0 : i32] : vector<4xf32>
+ %2 = vector.extractelement %dst[%i3 : i32] : vector<4xf32>
+ %d = arith.addf %1, %2 : f32
+
+ // CHECK: ( 30, 82, 150, 234 )
+ // CHECK: 264
+ vector.print %dst : vector<4xf32>
+ vector.print %d : f32
+
+ return %i0 : i32
+}
>From 34aae4170dd2a1ca200036718d85391466c195da Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 28 Jan 2025 18:08:01 +0100
Subject: [PATCH 05/10] Op example
---
mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 71c7f85159e0d6..409ef9ce16054e 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -317,6 +317,11 @@ def DotBF16Op : AVX512BF16_Op<"dot", [Pure,
Compute dot-product of BF16 (16-bit) floating-point pairs in `a` and `b`,
accumulating the intermediate single-precision (32-bit) floating-point
elements with elements in `src`, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ ```
}];
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
VectorOfLengthAndType<[8, 16, 32], [BF16]>:$a,
>From 1d65c11114c26e44437fe5cc424ccfa899860654 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 28 Jan 2025 18:08:11 +0100
Subject: [PATCH 06/10] Test cases
---
.../Dialect/X86Vector/legalize-for-llvm.mlir | 27 +++++++++++++++++++
mlir/test/Dialect/X86Vector/roundtrip.mlir | 27 +++++++++++++++++++
2 files changed, 54 insertions(+)
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 8b9006395fdfe4..cbc8c3051c6ab1 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -43,6 +43,33 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
}
+// CHECK-LABEL: func @avx512bf16_dot_128
+func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
+ %b: vector<8xbf16>) -> (vector<4xf32>)
+{
+ // CHECK: x86vector.avx512bf16.intr.dpbf16ps.128
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_256
+func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
+ %b: vector<16xbf16>) -> (vector<8xf32>)
+{
+ // CHECK: x86vector.avx512bf16.intr.dpbf16ps.256
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_512
+func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
+ %b: vector<32xbf16>) -> (vector<16xf32>)
+{
+ // CHECK: x86vector.avx512bf16.intr.dpbf16ps.512
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 557978b51c5123..f7111f75db6180 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -47,6 +47,33 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
}
+// CHECK-LABEL: func @avx512bf16_dot_128
+func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
+ %b: vector<8xbf16>) -> (vector<4xf32>)
+{
+ // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_256
+func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
+ %b: vector<16xbf16>) -> (vector<8xf32>)
+{
+ // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_512
+func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
+ %b: vector<32xbf16>) -> (vector<16xf32>)
+{
+ // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
>From 6bfbc9bd67b62ca887332e5238a33a1e67bf8729 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 29 Jan 2025 11:36:49 +0100
Subject: [PATCH 07/10] Remove integration test
---
.../Vector/CPU/X86Vector/dot-bf16.mlir | 25 -------------------
1 file changed, 25 deletions(-)
delete mode 100644 mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
deleted file mode 100644
index fe333f49fc8e14..00000000000000
--- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
+++ /dev/null
@@ -1,25 +0,0 @@
-// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm -reconcile-unrealized-casts | \
-// RUN: mlir-translate --mlir-to-llvmir | \
-// RUN: %lli --entry-function=entry --mattr="avx512bf16" --dlopen=%mlir_c_runner_utils | \
-// RUN: FileCheck %s
-
-func.func @entry() -> i32 {
- %i0 = arith.constant 0 : i32
- %i3 = arith.constant 3 : i32
-
- %src = arith.constant dense<1.0> : vector<4xf32>
- %a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xbf16>
- %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xbf16>
- %dst = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
-
- %1 = vector.extractelement %dst[%i0 : i32] : vector<4xf32>
- %2 = vector.extractelement %dst[%i3 : i32] : vector<4xf32>
- %d = arith.addf %1, %2 : f32
-
- // CHECK: ( 30, 82, 150, 234 )
- // CHECK: 264
- vector.print %dst : vector<4xf32>
- vector.print %d : f32
-
- return %i0 : i32
-}
>From 4f0fcaf2537a0822002708bb822da04c8abbff2d Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 29 Jan 2025 11:37:33 +0100
Subject: [PATCH 08/10] Add llc test to verify correct lowering
---
mlir/test/Dialect/X86Vector/dot-bf16.mlir | 30 +++++++++++++++++++++++
1 file changed, 30 insertions(+)
create mode 100644 mlir/test/Dialect/X86Vector/dot-bf16.mlir
diff --git a/mlir/test/Dialect/X86Vector/dot-bf16.mlir b/mlir/test/Dialect/X86Vector/dot-bf16.mlir
new file mode 100644
index 00000000000000..0ed13d5824e6e8
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/dot-bf16.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s \
+// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: llc -mcpu=sapphirerapids | \
+// RUN: FileCheck %s
+
+func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
+ %b: vector<8xbf16>) -> vector<4xf32> {
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+// CHECK-LABEL: avx512bf16_dot_128:
+// CHECK: vdpbf16ps %xmm2, %xmm1, %xmm0
+
+func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
+ %b: vector<16xbf16>) -> vector<8xf32> {
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+// CHECK-LABEL: avx512bf16_dot_256:
+// CHECK: vdpbf16ps %ymm2, %ymm1, %ymm0
+
+func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
+ %b: vector<32xbf16>) -> vector<16xf32> {
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ return %0 : vector<16xf32>
+}
+// CHECK-LABEL: avx512bf16_dot_512:
+// CHECK: vdpbf16ps %zmm2, %zmm1, %zmm0
>From 169086d0c098998f1339ccdfae52774f4b135c46 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 29 Jan 2025 12:10:16 +0100
Subject: [PATCH 09/10] Collapse extensions into avx512 namespace
---
.../mlir/Dialect/X86Vector/X86Vector.td | 56 +++++++------------
mlir/test/Dialect/X86Vector/dot-bf16.mlir | 6 +-
.../Dialect/X86Vector/legalize-for-llvm.mlir | 12 ++--
mlir/test/Dialect/X86Vector/roundtrip.mlir | 12 ++--
4 files changed, 34 insertions(+), 52 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 409ef9ce16054e..16181d7e760db5 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -35,17 +35,20 @@ class AVX512_Op<string mnemonic, list<Trait> traits = []> :
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
// Intrinsic operation used during lowering to LLVM IR.
-class AVX512_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
+class AVX512_IntrOp<string mnemonic, int numResults,
+ list<Trait> traits = [],
+ string extension = ""> :
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
- "x86_avx512_" # !subst(".", "_", mnemonic),
+ !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
[], [], traits, numResults>;
// Defined by first result overload. May have to be extended for other
// instructions in the future.
class AVX512_IntrOverloadedOp<string mnemonic,
- list<Trait> traits = []> :
+ list<Trait> traits = [],
+ string extension = ""> :
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
- "x86_avx512_" # !subst(".", "_", mnemonic),
+ !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/[0],
/*list<int> overloadedOperands=*/[],
traits, /*numResults=*/1>;
@@ -271,35 +274,11 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
VectorOfLengthAndType<[8], [I64]>:$b);
}
-//===----------------------------------------------------------------------===//
-// AVX512-BF16 op definitions
-//===----------------------------------------------------------------------===//
-
-// Operation that is part of the input dialect.
-class AVX512BF16_Op<string mnemonic, list<Trait> traits = []> :
- Op<X86Vector_Dialect, "avx512bf16." # mnemonic, traits> {}
-
-// Intrinsic operation used during lowering to LLVM IR.
-class AVX512BF16_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
- LLVM_IntrOpBase<X86Vector_Dialect, "avx512bf16.intr." # mnemonic,
- "x86_avx512bf16_" # !subst(".", "_", mnemonic),
- [], [], traits, numResults>;
-
-// Defined by first result overload. May have to be extended for other
-// instructions in the future.
-class AVX512BF16_IntrOverloadedOp<string mnemonic,
- list<Trait> traits = []> :
- LLVM_IntrOpBase<X86Vector_Dialect, "avx512bf16.intr." # mnemonic,
- "x86_avx512bf16_" # !subst(".", "_", mnemonic),
- /*list<int> overloadedResults=*/[0],
- /*list<int> overloadedOperands=*/[],
- traits, /*numResults=*/1>;
-
//----------------------------------------------------------------------------//
-// AVX512-BF16 Dot
+// Dot BF16
//----------------------------------------------------------------------------//
-def DotBF16Op : AVX512BF16_Op<"dot", [Pure,
+def DotBF16Op : AVX512_Op<"dot", [Pure,
AllTypesMatch<["a", "b"]>,
AllTypesMatch<["src", "dst"]>,
TypesMatchWith<"`a` has twice an many elements as `src`",
@@ -320,7 +299,7 @@ def DotBF16Op : AVX512BF16_Op<"dot", [Pure,
Example:
```mlir
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
```
}];
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
@@ -332,27 +311,30 @@ def DotBF16Op : AVX512BF16_Op<"dot", [Pure,
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
}
-def DotBF16Ps128IntrOp : AVX512BF16_IntrOp<"dpbf16ps.128", 1, [Pure,
+def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
AllTypesMatch<["a", "b"]>,
- AllTypesMatch<["src", "res"]>]> {
+ AllTypesMatch<["src", "res"]>],
+ /*extension=*/"bf16"> {
let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
VectorOfLengthAndType<[8], [BF16]>:$a,
VectorOfLengthAndType<[8], [BF16]>:$b);
let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
}
-def DotBF16Ps256IntrOp : AVX512BF16_IntrOp<"dpbf16ps.256", 1, [Pure,
+def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
AllTypesMatch<["a", "b"]>,
- AllTypesMatch<["src", "res"]>]> {
+ AllTypesMatch<["src", "res"]>],
+ /*extension=*/"bf16"> {
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
VectorOfLengthAndType<[16], [BF16]>:$a,
VectorOfLengthAndType<[16], [BF16]>:$b);
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
}
-def DotBF16Ps512IntrOp : AVX512BF16_IntrOp<"dpbf16ps.512", 1, [Pure,
+def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
AllTypesMatch<["a", "b"]>,
- AllTypesMatch<["src", "res"]>]> {
+ AllTypesMatch<["src", "res"]>],
+ /*extension=*/"bf16"> {
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
VectorOfLengthAndType<[32], [BF16]>:$a,
VectorOfLengthAndType<[32], [BF16]>:$b);
diff --git a/mlir/test/Dialect/X86Vector/dot-bf16.mlir b/mlir/test/Dialect/X86Vector/dot-bf16.mlir
index 0ed13d5824e6e8..9e5417da461044 100644
--- a/mlir/test/Dialect/X86Vector/dot-bf16.mlir
+++ b/mlir/test/Dialect/X86Vector/dot-bf16.mlir
@@ -7,7 +7,7 @@
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
%b: vector<8xbf16>) -> vector<4xf32> {
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: avx512bf16_dot_128:
@@ -15,7 +15,7 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
%b: vector<16xbf16>) -> vector<8xf32> {
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: avx512bf16_dot_256:
@@ -23,7 +23,7 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
%b: vector<32xbf16>) -> vector<16xf32> {
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: avx512bf16_dot_512:
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index cbc8c3051c6ab1..ed9177eaec9ce4 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -47,8 +47,8 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
%b: vector<8xbf16>) -> (vector<4xf32>)
{
- // CHECK: x86vector.avx512bf16.intr.dpbf16ps.128
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+ // CHECK: x86vector.avx512.intr.dpbf16ps.128
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -56,8 +56,8 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
%b: vector<16xbf16>) -> (vector<8xf32>)
{
- // CHECK: x86vector.avx512bf16.intr.dpbf16ps.256
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+ // CHECK: x86vector.avx512.intr.dpbf16ps.256
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -65,8 +65,8 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
%b: vector<32xbf16>) -> (vector<16xf32>)
{
- // CHECK: x86vector.avx512bf16.intr.dpbf16ps.512
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ // CHECK: x86vector.avx512.intr.dpbf16ps.512
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32>
}
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index f7111f75db6180..cf74a7ee602558 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -51,8 +51,8 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
%b: vector<8xbf16>) -> (vector<4xf32>)
{
- // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+ // CHECK: x86vector.avx512.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -60,8 +60,8 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
%b: vector<16xbf16>) -> (vector<8xf32>)
{
- // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+ // CHECK: x86vector.avx512.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -69,8 +69,8 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
%b: vector<32xbf16>) -> (vector<16xf32>)
{
- // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
- %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ // CHECK: x86vector.avx512.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
+ %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
return %0 : vector<16xf32>
}
>From eaa9333921f6db077b08897f9a10476523368021 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 29 Jan 2025 12:10:44 +0100
Subject: [PATCH 10/10] Add translation tests
---
mlir/test/Target/LLVMIR/x86vector.mlir | 44 ++++++++++++++++++++++++++
1 file changed, 44 insertions(+)
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 190732868cb7ad..1df03f10c93214 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -60,6 +60,39 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
}
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
+llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
+ %arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
+ ) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
+ %0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
+ : (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
+ llvm.return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
+llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
+ %arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
+ ) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
+ %0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
+ : (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
+ llvm.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
+llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
+ %arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
+ ) -> vector<16xf32>
+{
+ // CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
+ %0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
+ : (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
+ llvm.return %0 : vector<16xf32>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
@@ -67,3 +100,14 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
%0 = "x86vector.avx.intr.rsqrt.ps.256"(%a) : (vector<8xf32>) -> (vector<8xf32>)
llvm.return %0 : vector<8xf32>
}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
+llvm.func @LLVM_x86_avx_dp_ps_256(
+ %arg0: vector<8xf32>, %arg1: vector<8xf32>
+ ) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
+ %0 = llvm.mlir.constant(-1 : i8) : i8
+ %1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
+ llvm.return %1 : vector<8xf32>
+}
More information about the Mlir-commits
mailing list