[Mlir-commits] [mlir] [mlir][x86vector] AVX512-B16 Dot op (PR #124800)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 28 09:39:44 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Adds AVX512-BF16 operation definitions and a bf16 dot-product operation.
Defines lowering to LLVM intrinsics.

---
Full diff: https://github.com/llvm/llvm-project/pull/124800.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/X86Vector/X86Vector.td (+88) 
- (modified) mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp (+47-2) 
- (modified) mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir (+27) 
- (modified) mlir/test/Dialect/X86Vector/roundtrip.mlir (+27) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir (+25) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index fa3f0ee0460b1d..409ef9ce16054e 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -271,6 +271,94 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
                    VectorOfLengthAndType<[8], [I64]>:$b);
 }
 
+//===----------------------------------------------------------------------===//
+// AVX512-BF16 op definitions
+//===----------------------------------------------------------------------===//
+
+// Operation that is part of the input dialect.
+class AVX512BF16_Op<string mnemonic, list<Trait> traits = []> :
+  Op<X86Vector_Dialect, "avx512bf16." # mnemonic, traits> {}
+
+// Intrinsic operation used during lowering to LLVM IR.
+class AVX512BF16_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
+  LLVM_IntrOpBase<X86Vector_Dialect, "avx512bf16.intr." # mnemonic,
+                  "x86_avx512bf16_" # !subst(".", "_", mnemonic),
+                  [], [], traits, numResults>;
+
+// Defined by first result overload. May have to be extended for other
+// instructions in the future.
+class AVX512BF16_IntrOverloadedOp<string mnemonic,
+                              list<Trait> traits = []> :
+  LLVM_IntrOpBase<X86Vector_Dialect, "avx512bf16.intr." # mnemonic,
+                  "x86_avx512bf16_" # !subst(".", "_", mnemonic),
+                  /*list<int> overloadedResults=*/[0],
+                  /*list<int> overloadedOperands=*/[],
+                  traits, /*numResults=*/1>;
+
+//----------------------------------------------------------------------------//
+// AVX512-BF16 Dot
+//----------------------------------------------------------------------------//
+
+def DotBF16Op : AVX512BF16_Op<"dot", [Pure,
+  AllTypesMatch<["a", "b"]>,
+  AllTypesMatch<["src", "dst"]>,
+  TypesMatchWith<"`a` has twice an many elements as `src`",
+                 "src", "a",
+                 "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
+                 "BFloat16Type::get($_self.getContext()))">]> {
+  let summary = "Dot BF16 op";
+  let description = [{
+    The `dot` op is an AVX512-BF16 specific op that can lower to the proper
+    LLVMAVX512BF16 operation `llvm.dpbf16ps` depending on the width of MLIR
+    vectors it is applied to.
+
+    #### From the Intel Intrinsics Guide:
+
+    Compute dot-product of BF16 (16-bit) floating-point pairs in `a` and `b`,
+    accumulating the intermediate single-precision (32-bit) floating-point
+    elements with elements in `src`, and store the results in `dst`.
+
+    Example:
+    ```mlir
+    %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
+                   VectorOfLengthAndType<[8, 16, 32], [BF16]>:$a,
+                   VectorOfLengthAndType<[8, 16, 32], [BF16]>:$b
+                   );
+  let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
+  let assemblyFormat =
+    "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
+}
+
+def DotBF16Ps128IntrOp : AVX512BF16_IntrOp<"dpbf16ps.128", 1, [Pure,
+    AllTypesMatch<["a", "b"]>,
+    AllTypesMatch<["src", "res"]>]> {
+  let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
+                       VectorOfLengthAndType<[8], [BF16]>:$a,
+                       VectorOfLengthAndType<[8], [BF16]>:$b);
+  let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
+}
+
+def DotBF16Ps256IntrOp : AVX512BF16_IntrOp<"dpbf16ps.256", 1, [Pure,
+    AllTypesMatch<["a", "b"]>,
+    AllTypesMatch<["src", "res"]>]> {
+  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
+                       VectorOfLengthAndType<[16], [BF16]>:$a,
+                       VectorOfLengthAndType<[16], [BF16]>:$b);
+  let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
+}
+
+def DotBF16Ps512IntrOp : AVX512BF16_IntrOp<"dpbf16ps.512", 1, [Pure,
+    AllTypesMatch<["a", "b"]>,
+    AllTypesMatch<["src", "res"]>]> {
+  let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
+                       VectorOfLengthAndType<[32], [BF16]>:$a,
+                       VectorOfLengthAndType<[32], [BF16]>:$b);
+  let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
+}
+
 //===----------------------------------------------------------------------===//
 // AVX op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index e918473cae9e3a..260ac9ce589a38 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -90,6 +90,47 @@ struct MaskCompressOpConversion
   }
 };
 
+struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
+  using ConvertOpToLLVMPattern<DotBF16Op>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(DotBF16Op op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto typeA = dyn_cast<VectorType>(op.getA().getType());
+    unsigned elemBitWidth = typeA.getElementTypeBitWidth();
+    unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
+
+    auto opType = adaptor.getSrc().getType();
+    auto opSrc = adaptor.getSrc();
+    auto opA = adaptor.getA();
+    auto opB = adaptor.getB();
+
+    switch (opBitWidth) {
+    case 128: {
+      rewriter.replaceOpWithNewOp<DotBF16Ps128IntrOp>(op, opType, opSrc, opA,
+                                                      opB);
+      break;
+    }
+    case 256: {
+      rewriter.replaceOpWithNewOp<DotBF16Ps256IntrOp>(op, opType, opSrc, opA,
+                                                      opB);
+      break;
+    }
+    case 512: {
+      rewriter.replaceOpWithNewOp<DotBF16Ps512IntrOp>(op, opType, opSrc, opA,
+                                                      opB);
+      break;
+    }
+    default: {
+      return rewriter.notifyMatchFailure(op,
+                                         "unsupported AVX512-BF16 dot variant");
+    }
+    }
+
+    return success();
+  }
+};
+
 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
 
@@ -161,8 +202,8 @@ using Registry = RegistryImpl<
 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   Registry::registerPatterns(converter, patterns);
-  patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
-      converter);
+  patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
+               DotOpConversion>(converter);
 }
 
 void mlir::configureX86VectorLegalizeForExportTarget(
@@ -170,6 +211,10 @@ void mlir::configureX86VectorLegalizeForExportTarget(
   Registry::configureTarget(target);
   target.addLegalOp<MaskCompressIntrOp>();
   target.addIllegalOp<MaskCompressOp>();
+  target.addLegalOp<DotBF16Ps128IntrOp>();
+  target.addLegalOp<DotBF16Ps256IntrOp>();
+  target.addLegalOp<DotBF16Ps512IntrOp>();
+  target.addIllegalOp<DotBF16Op>();
   target.addLegalOp<RsqrtIntrOp>();
   target.addIllegalOp<RsqrtOp>();
   target.addLegalOp<DotIntrOp>();
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 8b9006395fdfe4..cbc8c3051c6ab1 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -43,6 +43,33 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
   return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
 }
 
+// CHECK-LABEL: func @avx512bf16_dot_128
+func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
+  %b: vector<8xbf16>) -> (vector<4xf32>)
+{
+  // CHECK: x86vector.avx512bf16.intr.dpbf16ps.128
+  %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_256
+func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
+  %b: vector<16xbf16>) -> (vector<8xf32>)
+{
+  // CHECK: x86vector.avx512bf16.intr.dpbf16ps.256
+  %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_512
+func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
+  %b: vector<32xbf16>) -> (vector<16xf32>)
+{
+  // CHECK: x86vector.avx512bf16.intr.dpbf16ps.512
+  %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
 // CHECK-LABEL: func @avx_rsqrt
 func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
 {
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 557978b51c5123..f7111f75db6180 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -47,6 +47,33 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
   return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
 }
 
+// CHECK-LABEL: func @avx512bf16_dot_128
+func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
+  %b: vector<8xbf16>) -> (vector<4xf32>)
+{
+  // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
+  %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_256
+func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
+  %b: vector<16xbf16>) -> (vector<8xf32>)
+{
+  // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
+  %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_512
+func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
+  %b: vector<32xbf16>) -> (vector<16xf32>)
+{
+  // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
+  %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
 // CHECK-LABEL: func @avx_rsqrt
 func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
 {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
new file mode 100644
index 00000000000000..fe333f49fc8e14
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="avx512bf16" --dlopen=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() -> i32 {
+  %i0 = arith.constant 0 : i32
+  %i3 = arith.constant 3 : i32
+
+  %src = arith.constant dense<1.0> : vector<4xf32>
+  %a = arith.constant dense<[1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0]> : vector<8xbf16>
+  %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xbf16>
+  %dst = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+
+  %1 = vector.extractelement %dst[%i0 : i32] : vector<4xf32>
+  %2 = vector.extractelement %dst[%i3 : i32] : vector<4xf32>
+  %d = arith.addf %1, %2 : f32
+
+  // CHECK: ( 30, 82, 150, 234 )
+  // CHECK: 264
+  vector.print %dst : vector<4xf32>
+  vector.print %d : f32
+
+  return %i0 : i32
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/124800


More information about the Mlir-commits mailing list