[Mlir-commits] [mlir] [mlir][x86vector] AVX Convert/Broadcast BF16 to F32 instructions (PR #135143)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 22 07:04:05 PDT 2025
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/135143
>From 860ccf78f14934ca6aefeb5af7de5a705ca8245c Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 10 Apr 2025 01:30:12 -0700
Subject: [PATCH 1/7] new x86 avx instructions: vbcstnebf162ps, vcvtneebf162ps,
vcvtneobf162ps
---
.../mlir/Dialect/X86Vector/X86Vector.td | 106 ++++++++++++++++++
.../mlir/Dialect/X86Vector/X86VectorDialect.h | 1 +
.../Transforms/LegalizeForLLVMExport.cpp | 3 +-
.../bcst-avx-bf16-to-f32-packed.mlir | 22 ++++
.../X86Vector/cvt-packed-avx-bf16-to-f32.mlir | 48 ++++++++
.../Dialect/X86Vector/legalize-for-llvm.mlir | 54 +++++++++
mlir/test/Dialect/X86Vector/roundtrip.mlir | 60 ++++++++++
mlir/test/Target/LLVMIR/x86vector.mlir | 54 +++++++++
8 files changed, 347 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir
create mode 100644 mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 5be0d92db4630..a235685f773f8 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -408,4 +408,110 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}];
}
+
+//----------------------------------------------------------------------------//
+// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
+//----------------------------------------------------------------------------//
+
+def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [Pure,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
+ let description = [{
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xbf16>
+ ```
+ }];
+ let arguments = (ins LLVM_AnyPointer:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneebf162ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+}
+
+def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [Pure,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
+ let description = [{
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xbf16>
+ ```
+ }];
+ let arguments = (ins LLVM_AnyPointer:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneobf162ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+}
+
+//----------------------------------------------------------------------------//
+// AVX: Convert BF16 to F32 and broadcast into packed F32
+//----------------------------------------------------------------------------//
+
+def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
+ let description = [{
+ #### From the Intel Intrinsics Guide:
+
+ Convert scalar BF16 (16-bit) floating-point element stored at memory locations
+ starting at location `__A` to a single-precision (32-bit) floating-point,
+ broadcast it to packed single-precision (32-bit) floating-point elements,
+ and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xbf16>
+ ```
+ }];
+ let arguments = (ins LLVM_AnyPointer:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vbcstnebf162ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+}
+
#endif // X86VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
index 7bcf4c69b0a6c..f2f8d36fdfd01 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
@@ -21,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
/// Include the generated interface declarations.
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc"
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index c0c7f61f55f88..668888eab1c2a 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -115,6 +115,7 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
- Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op, RsqrtOp,
+ Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
+ CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp,
DotOp>();
}
diff --git a/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir b/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir
new file mode 100644
index 0000000000000..8243e628f7e2b
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir
@@ -0,0 +1,22 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt %s \
+// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: llc -mcpu=sierraforest | \
+// RUN: FileCheck %s
+
+func.func @avxbf16_bcst_bf16_to_f32_packed_128(%arg0: !llvm.ptr) -> vector<4xf32> {
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_128:
+// CHECK: vbcstnebf162ps{{.*}}%xmm
+
+func.func @avxbf16_bcst_bf16_to_f32_packed_256(%arg0: !llvm.ptr) -> vector<8xf32> {
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_256:
+// CHECK: vbcstnebf162ps{{.*}}%ymm
diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir
new file mode 100644
index 0000000000000..08ad9c1c4a8d0
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir
@@ -0,0 +1,48 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt %s \
+// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: llc -mcpu=sierraforest | \
+// RUN: FileCheck %s
+
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> {
+ %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index
+ %0 = arith.index_cast %intptr : index to i32
+ %1 = llvm.inttoptr %0 : i32 to !llvm.ptr
+ %2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32>
+ return %2 : vector<4xf32>
+}
+// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_128:
+// CHECK: vcvtneebf162ps{{.*}}%xmm
+
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> {
+ %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index
+ %0 = arith.index_cast %intptr : index to i32
+ %1 = llvm.inttoptr %0 : i32 to !llvm.ptr
+ %2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32>
+ return %2 : vector<8xf32>
+}
+// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_256:
+// CHECK: vcvtneebf162ps{{.*}}%ymm
+
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> {
+ %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index
+ %0 = arith.index_cast %intptr : index to i32
+ %1 = llvm.inttoptr %0 : i32 to !llvm.ptr
+ %2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32>
+ return %2 : vector<4xf32>
+}
+// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128:
+// CHECK: vcvtneobf162ps{{.*}}%xmm
+
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> {
+ %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index
+ %0 = arith.index_cast %intptr : index to i32
+ %1 = llvm.inttoptr %0 : i32 to !llvm.ptr
+ %2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32>
+ return %2 : vector<8xf32>
+}
+// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256:
+// CHECK: vcvtneobf162ps{{.*}}%ymm
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index df0be7bce83be..e1969481c845c 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -95,6 +95,60 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
+ %a: !llvm.ptr) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
+ %a: !llvm.ptr) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
+ %a: !llvm.ptr) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
+ %a: !llvm.ptr) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128
+func.func @avxbf16_bsct_bf16_to_f32_packed_128(
+ %a: !llvm.ptr) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256
+func.func @avxbf16_bsct_bf16_to_f32_packed_256(
+ %a: !llvm.ptr) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 0d00448c63da8..d36628588190e 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -94,6 +94,66 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
+ %a: !llvm.ptr) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
+ // CHECK-SAME: !llvm.ptr -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
+ %a: !llvm.ptr) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
+ // CHECK-SAME: !llvm.ptr -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
+ %a: !llvm.ptr) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
+ // CHECK-SAME: !llvm.ptr -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
+ %a: !llvm.ptr) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
+ // CHECK-SAME: !llvm.ptr -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128
+func.func @avxbf16_bcst_bf16_to_f32_128(
+ %a: !llvm.ptr) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
+ // CHECK-SAME: !llvm.ptr -> vector<4xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256
+func.func @avxbf16_bcst_bf16_to_f32_256(
+ %a: !llvm.ptr) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
+ // CHECK-SAME: !llvm.ptr -> vector<8xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 85dad36334b1d..095375839d282 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -109,6 +109,60 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128
+func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
+ %a: !llvm.ptr) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneebf162ps256
+func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
+ %a: !llvm.ptr) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneobf162ps128
+func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
+ %a: !llvm.ptr) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneobf162ps256
+func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
+ %a: !llvm.ptr) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vbcstnebf162ps128
+func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
+ %a: !llvm.ptr) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vbcstnebf162ps256
+func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
+ %a: !llvm.ptr) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
>From cc4553879bd452dd983b2d7c262342be4748ac5d Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 10 Apr 2025 19:56:58 -0700
Subject: [PATCH 2/7] fixed couple of clang format
---
mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h | 2 +-
.../X86Vector/Transforms/LegalizeForLLVMExport.cpp | 8 ++++----
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
index f2f8d36fdfd01..5f487c8e6a9af 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
@@ -21,7 +22,6 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
/// Include the generated interface declarations.
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc"
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 668888eab1c2a..598c30810a38d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -114,8 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
- target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
- Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
- CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp,
- DotOp>();
+ target.addIllegalOp<
+ MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
+ CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
+ CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
}
>From 486ec2d363f12ac77e77ab6c19fbd9c0bf4deef9 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 11 Apr 2025 01:49:48 -0700
Subject: [PATCH 3/7] fixed a typo in description
---
mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index a235685f773f8..c05bd1c3640b4 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -425,7 +425,7 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
Example:
```mlir
- %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xbf16>
+ %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
```
}];
let arguments = (ins LLVM_AnyPointer:$a);
@@ -457,7 +457,7 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
Example:
```mlir
- %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xbf16>
+ %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
```
}];
let arguments = (ins LLVM_AnyPointer:$a);
@@ -494,7 +494,7 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
Example:
```mlir
- %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xbf16>
+ %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
```
}];
let arguments = (ins LLVM_AnyPointer:$a);
>From 0a80bbc31f3264c1bd8c93b47e7b8104c2f2ed81 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 13 Apr 2025 19:01:47 -0700
Subject: [PATCH 4/7] removing tests related to assembly check
---
.../bcst-avx-bf16-to-f32-packed.mlir | 22 ---------
.../X86Vector/cvt-packed-avx-bf16-to-f32.mlir | 48 -------------------
2 files changed, 70 deletions(-)
delete mode 100644 mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir
delete mode 100644 mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir
diff --git a/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir b/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir
deleted file mode 100644
index 8243e628f7e2b..0000000000000
--- a/mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir
+++ /dev/null
@@ -1,22 +0,0 @@
-// REQUIRES: target=x86{{.*}}
-
-// RUN: mlir-opt %s \
-// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
-// RUN: mlir-translate --mlir-to-llvmir | \
-// RUN: llc -mcpu=sierraforest | \
-// RUN: FileCheck %s
-
-func.func @avxbf16_bcst_bf16_to_f32_packed_128(%arg0: !llvm.ptr) -> vector<4xf32> {
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<4xf32>
- return %0 : vector<4xf32>
-}
-// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_128:
-// CHECK: vbcstnebf162ps{{.*}}%xmm
-
-func.func @avxbf16_bcst_bf16_to_f32_packed_256(%arg0: !llvm.ptr) -> vector<8xf32> {
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<8xf32>
- return %0 : vector<8xf32>
-}
-// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_256:
-// CHECK: vbcstnebf162ps{{.*}}%ymm
diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir
deleted file mode 100644
index 08ad9c1c4a8d0..0000000000000
--- a/mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir
+++ /dev/null
@@ -1,48 +0,0 @@
-// REQUIRES: target=x86{{.*}}
-
-// RUN: mlir-opt %s \
-// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
-// RUN: mlir-translate --mlir-to-llvmir | \
-// RUN: llc -mcpu=sierraforest | \
-// RUN: FileCheck %s
-
-func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> {
- %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index
- %0 = arith.index_cast %intptr : index to i32
- %1 = llvm.inttoptr %0 : i32 to !llvm.ptr
- %2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32>
- return %2 : vector<4xf32>
-}
-// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_128:
-// CHECK: vcvtneebf162ps{{.*}}%xmm
-
-func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> {
- %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index
- %0 = arith.index_cast %intptr : index to i32
- %1 = llvm.inttoptr %0 : i32 to !llvm.ptr
- %2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32>
- return %2 : vector<8xf32>
-}
-// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_256:
-// CHECK: vcvtneebf162ps{{.*}}%ymm
-
-func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> {
- %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index
- %0 = arith.index_cast %intptr : index to i32
- %1 = llvm.inttoptr %0 : i32 to !llvm.ptr
- %2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32>
- return %2 : vector<4xf32>
-}
-// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128:
-// CHECK: vcvtneobf162ps{{.*}}%xmm
-
-func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> {
- %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index
- %0 = arith.index_cast %intptr : index to i32
- %1 = llvm.inttoptr %0 : i32 to !llvm.ptr
- %2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32>
- return %2 : vector<8xf32>
-}
-// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256:
-// CHECK: vcvtneobf162ps{{.*}}%ymm
>From a9df22e3f8a29ba143491c02c33fafac68cdc5c3 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 22 Apr 2025 05:52:22 -0700
Subject: [PATCH 5/7] The input type changed to accept memref
---
.../mlir/Dialect/X86Vector/X86Vector.td | 29 ++++++++----
.../mlir/Dialect/X86Vector/X86VectorDialect.h | 1 +
.../Dialect/X86Vector/X86VectorInterfaces.td | 2 +-
.../Dialect/X86Vector/IR/X86VectorDialect.cpp | 46 +++++++++++++++++--
.../Transforms/LegalizeForLLVMExport.cpp | 4 +-
.../Dialect/X86Vector/legalize-for-llvm.mlir | 24 +++++-----
mlir/test/Dialect/X86Vector/roundtrip.mlir | 36 +++++++--------
mlir/test/Target/LLVMIR/x86vector.mlir | 26 +++++------
8 files changed, 111 insertions(+), 57 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index c05bd1c3640b4..31971a46e7475 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -83,7 +83,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
}
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];
}
@@ -404,7 +404,7 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];
}
@@ -413,7 +413,7 @@ def DotOp : AVX_LowOp<"dot", [Pure,
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
//----------------------------------------------------------------------------//
-def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [Pure,
+def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
let description = [{
@@ -428,7 +428,7 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
%dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
```
}];
- let arguments = (ins LLVM_AnyPointer:$a);
+ let arguments = (ins AnyMemRef:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
@@ -443,9 +443,13 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
return intr;
}
}];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
}
-def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [Pure,
+def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
let description = [{
@@ -460,7 +464,7 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
%dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
```
}];
- let arguments = (ins LLVM_AnyPointer:$a);
+ let arguments = (ins AnyMemRef:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
@@ -475,13 +479,17 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
return intr;
}
}];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
}
//----------------------------------------------------------------------------//
// AVX: Convert BF16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//
-def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
+def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
let description = [{
@@ -497,7 +505,7 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
%dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
```
}];
- let arguments = (ins LLVM_AnyPointer:$a);
+ let arguments = (ins AnyMemRef:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
@@ -512,6 +520,11 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
return intr;
}
}];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+
}
#endif // X86VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
index 5f487c8e6a9af..308adfa5b9021 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index 98d5ca70b4a7d..5176f4a447b6e 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -58,7 +58,7 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
}],
/*retType=*/"SmallVector<Value>",
/*methodName=*/"getIntrinsicOperands",
- /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
/*methodBody=*/"",
/*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
>,
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 5bb4dcfd60d83..555603e99f4a8 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -11,6 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
@@ -31,6 +33,26 @@ void x86vector::X86VectorDialect::initialize() {
>();
}
+static SmallVector<Value>
+getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
+ RewriterBase &rewriter,
+ const LLVMTypeConverter &typeConverter) {
+ SmallVector<Value> operands;
+ auto opType = memrefVal.getType();
+
+ Type llvmStructType = typeConverter.convertType(opType);
+ Value llvmStruct =
+ rewriter
+ .create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
+ .getResult(0);
+ MemRefDescriptor memRefDescriptor(llvmStruct);
+
+ Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
+ operands.push_back(ptr);
+
+ return operands;
+}
+
LogicalResult x86vector::MaskCompressOp::verify() {
if (getSrc() && getConstantSrc())
return emitError("cannot use both src and constant_src");
@@ -45,8 +67,8 @@ LogicalResult x86vector::MaskCompressOp::verify() {
return success();
}
-SmallVector<Value>
-x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
+SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
auto loc = getLoc();
auto opType = getA().getType();
@@ -64,7 +86,8 @@ x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
}
SmallVector<Value>
-x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
+x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
+ const LLVMTypeConverter &typeConverter) {
SmallVector<Value> operands(getOperands());
// Dot product of all elements, broadcasted to all elements.
Value scale =
@@ -74,5 +97,22 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
return operands;
}
+SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value>
+x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value>
+x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 598c30810a38d..d2297554a1012 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -96,8 +96,8 @@ struct OneToOneIntrinsicOpConversion
LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
PatternRewriter &rewriter) const override {
return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
- op.getIntrinsicOperands(rewriter), typeConverter,
- rewriter);
+ op.getIntrinsicOperands(rewriter, typeConverter),
+ typeConverter, rewriter);
}
private:
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index e1969481c845c..93b304c44de8e 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -97,55 +97,55 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
- %a: !llvm.ptr) -> vector<4xf32>
+ %a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
- %a: !llvm.ptr) -> vector<8xf32>
+ %a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
- %a: !llvm.ptr) -> vector<4xf32>
+ %a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
- %a: !llvm.ptr) -> vector<8xf32>
+ %a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128
func.func @avxbf16_bsct_bf16_to_f32_packed_128(
- %a: !llvm.ptr) -> vector<4xf32>
+ %a: memref<1xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256
func.func @avxbf16_bsct_bf16_to_f32_packed_256(
- %a: !llvm.ptr) -> vector<8xf32>
+ %a: memref<1xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index d36628588190e..b783cc869b981 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -96,61 +96,61 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
- %a: !llvm.ptr) -> vector<4xf32>
+ %a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
- // CHECK-SAME: !llvm.ptr -> vector<4xf32>
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ // CHECK-SAME: memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
- %a: !llvm.ptr) -> vector<8xf32>
+ %a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
- // CHECK-SAME: !llvm.ptr -> vector<8xf32>
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ // CHECK-SAME: memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
- %a: !llvm.ptr) -> vector<4xf32>
+ %a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
- // CHECK-SAME: !llvm.ptr -> vector<4xf32>
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ // CHECK-SAME: memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
- %a: !llvm.ptr) -> vector<8xf32>
+ %a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
- // CHECK-SAME: !llvm.ptr -> vector<8xf32>
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ // CHECK-SAME: memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128
func.func @avxbf16_bcst_bf16_to_f32_128(
- %a: !llvm.ptr) -> vector<4xf32>
+ %a: memref<1xbf16>) -> vector<4xf32>
{
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
- // CHECK-SAME: !llvm.ptr -> vector<4xf32>
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
+ // CHECK-SAME: memref<1xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256
func.func @avxbf16_bcst_bf16_to_f32_256(
- %a: !llvm.ptr) -> vector<8xf32>
+ %a: memref<1xbf16>) -> vector<8xf32>
{
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
- // CHECK-SAME: !llvm.ptr -> vector<8xf32>
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
+ // CHECK-SAME: memref<1xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 095375839d282..a8bc180d1d0ac 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm \
+// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm -reconcile-unrealized-casts \
// RUN: | mlir-translate --mlir-to-llvmir \
// RUN: | FileCheck %s
@@ -111,55 +111,55 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128
func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
- %a: !llvm.ptr) -> vector<4xf32>
+ %a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneebf162ps256
func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
- %a: !llvm.ptr) -> vector<8xf32>
+ %a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneobf162ps128
func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
- %a: !llvm.ptr) -> vector<4xf32>
+ %a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneobf162ps256
func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
- %a: !llvm.ptr) -> vector<8xf32>
+ %a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vbcstnebf162ps128
func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
- %a: !llvm.ptr) -> vector<4xf32>
+ %a: memref<1xbf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vbcstnebf162ps256
func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
- %a: !llvm.ptr) -> vector<8xf32>
+ %a: memref<1xbf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
>From 5dfcee7dc33be0dabe3a0b20e2a944de0f0e4f95 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 22 Apr 2025 05:57:52 -0700
Subject: [PATCH 6/7] Removed header include
---
mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 555603e99f4a8..f5e5070c74f8f 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -11,8 +11,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
-#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
-#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
>From 0ae2dc537b0516a74c2782a4a8ffab90ad64a1d0 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 22 Apr 2025 07:03:49 -0700
Subject: [PATCH 7/7] added MemoryEffect<MemRead> instead of Pure in td
---
mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 31971a46e7475..5ae72e63c6b93 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -413,7 +413,7 @@ def DotOp : AVX_LowOp<"dot", [Pure,
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
//----------------------------------------------------------------------------//
-def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [
+def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
let description = [{
@@ -449,7 +449,7 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
}];
}
-def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [
+def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
let description = [{
@@ -489,7 +489,7 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
// AVX: Convert BF16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//
-def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [
+def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
let description = [{
More information about the Mlir-commits
mailing list