[Mlir-commits] [mlir] 5c3d679 - [mlir][x86vector] AVX Convert/Broadcast F16 to F32 instructions (#137917)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 5 00:34:33 PDT 2025


Author: arun-thmn
Date: 2025-05-05T09:34:30+02:00
New Revision: 5c3d679516ce054d307abbfd0ad494d4924d70a6

URL: https://github.com/llvm/llvm-project/commit/5c3d679516ce054d307abbfd0ad494d4924d70a6
DIFF: https://github.com/llvm/llvm-project/commit/5c3d679516ce054d307abbfd0ad494d4924d70a6.diff

LOG: [mlir][x86vector] AVX Convert/Broadcast F16 to F32 instructions (#137917)

Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR:
https://github.com/llvm/llvm-project/pull/136830). The instructions that
are added:

- VBCSTNESH2PS
- VCVTNEEPH2PS
- VCVTNEOPH2PS

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/X86Vector/X86Vector.td
    mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
    mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
    mlir/test/Dialect/X86Vector/roundtrip.mlir
    mlir/test/Target/LLVMIR/x86vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 126fa0e352656..4f8301f9380b8 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -408,34 +408,41 @@ def DotOp : AVX_LowOp<"dot", [Pure,
   }];
 }
 
-
 //----------------------------------------------------------------------------//
-// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
+// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
 //----------------------------------------------------------------------------//
 
-def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>, 
+def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
   DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
-  let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
+  let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
   let description = [{
     #### From the Intel Intrinsics Guide:
 
-    Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
-    memory locations starting at location `__A` to packed single-precision
-    (32-bit) floating-point elements, and store the results in `dst`.
+    Convert scalar BF16 or F16 (16-bit) floating-point element stored at memory locations
+    starting at location `__A` to a single-precision (32-bit) floating-point,
+    broadcast it to packed single-precision (32-bit) floating-point elements,
+    and store the results in `dst`.
 
     Example:
     ```mlir
-    %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+    %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+    %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
     ```
   }];
-  let arguments = (ins AnyMemRef:$a);
+  let arguments = (ins MemRefOf<[BF16, F16]>:$a);
   let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
   let assemblyFormat =
     "$a  attr-dict`:` type($a)`->` type($dst)";
 
   let extraClassDefinition = [{
     std::string $cppClass::getIntrinsicName() {
-      std::string intr = "llvm.x86.vcvtneebf162ps";
+      auto elementType =
+        getA().getType().getElementType();
+      std::string intr = "llvm.x86.";
+      if (elementType.isBF16())
+        intr += "vbcstnebf162ps";
+      if (elementType.isF16())
+        intr += "vbcstnesh2ps";
       VectorType vecType = getDst().getType();
       unsigned elemBitWidth = vecType.getElementTypeBitWidth();
       unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -447,31 +454,43 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
   let extraClassDeclaration = [{
         SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
   }];
+
 }
 
-def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>, 
+//------------------------------------------------------------------------------//
+// AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32
+//------------------------------------------------------------------------------//
+
+def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>, 
   DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
-  let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
+  let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data.";
   let description = [{
     #### From the Intel Intrinsics Guide:
 
-    Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
+    Convert packed BF16 or F16 (16-bit) floating-point even-indexed elements stored at
     memory locations starting at location `__A` to packed single-precision
     (32-bit) floating-point elements, and store the results in `dst`.
 
     Example:
     ```mlir
-    %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+    %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+    %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
     ```
   }];
-  let arguments = (ins AnyMemRef:$a);
+  let arguments = (ins MemRefOf<[BF16, F16]>:$a);
   let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
   let assemblyFormat =
     "$a  attr-dict`:` type($a)`->` type($dst)";
 
   let extraClassDefinition = [{
     std::string $cppClass::getIntrinsicName() {
-      std::string intr = "llvm.x86.vcvtneobf162ps";
+      auto elementType =
+        getA().getType().getElementType();
+      std::string intr = "llvm.x86.";
+      if (elementType.isBF16())
+        intr += "vcvtneebf162ps";
+      if (elementType.isF16())
+        intr += "vcvtneeph2ps";
       VectorType vecType = getDst().getType();
       unsigned elemBitWidth = vecType.getElementTypeBitWidth();
       unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -485,34 +504,36 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
   }];
 }
 
-//----------------------------------------------------------------------------//
-// AVX: Convert BF16 to F32 and broadcast into packed F32
-//----------------------------------------------------------------------------//
-
-def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
+def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>, 
   DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
-  let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
+  let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
   let description = [{
     #### From the Intel Intrinsics Guide:
 
-    Convert scalar BF16 (16-bit) floating-point element stored at memory locations
-    starting at location `__A` to a single-precision (32-bit) floating-point,
-    broadcast it to packed single-precision (32-bit) floating-point elements,
-    and store the results in `dst`.
+    Convert packed BF16 or F16 (16-bit) floating-point odd-indexed elements stored at
+    memory locations starting at location `__A` to packed single-precision
+    (32-bit) floating-point elements, and store the results in `dst`.
 
     Example:
     ```mlir
-    %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+    %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+    %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
     ```
   }];
-  let arguments = (ins AnyMemRef:$a);
+  let arguments = (ins MemRefOf<[BF16, F16]>:$a);
   let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
   let assemblyFormat =
     "$a  attr-dict`:` type($a)`->` type($dst)";
 
   let extraClassDefinition = [{
     std::string $cppClass::getIntrinsicName() {
-      std::string intr = "llvm.x86.vbcstnebf162ps";
+      auto elementType =
+        getA().getType().getElementType();
+      std::string intr = "llvm.x86.";
+      if (elementType.isBF16())
+        intr += "vcvtneobf162ps";
+      if (elementType.isF16())
+        intr += "vcvtneoph2ps";
       VectorType vecType = getDst().getType();
       unsigned elemBitWidth = vecType.getElementTypeBitWidth();
       unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -521,10 +542,8 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
     }
   }];
 
-    let extraClassDeclaration = [{
+  let extraClassDeclaration = [{
         SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
   }];
-
 }
-
 #endif // X86VECTOR_OPS

diff  --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index f5e5070c74f8f..8d383b1f8103b 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -95,19 +95,17 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
   return operands;
 }
 
-SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
+SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
     RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
   return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
 }
 
-SmallVector<Value>
-x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
     RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
   return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
 }
 
-SmallVector<Value>
-x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
     RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
   return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
 }

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index d2297554a1012..9ee44a63ba2e4 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -114,8 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
 
 void mlir::configureX86VectorLegalizeForExportTarget(
     LLVMConversionTarget &target) {
-  target.addIllegalOp<
-      MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
-      CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
-      CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
+  target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
+                      Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op,
+                      CvtPackedEvenIndexedToF32Op, CvtPackedOddIndexedToF32Op,
+                      BcstToPackedF32Op, RsqrtOp, DotOp>();
 }

diff  --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 93b304c44de8e..63f06624ef897 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -100,7 +100,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
   %a: memref<8xbf16>) -> vector<4xf32>
 {
   // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
-  %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
@@ -109,7 +109,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
   %a: memref<16xbf16>) -> vector<8xf32>
 {
   // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
-  %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 
@@ -118,7 +118,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
   %a: memref<8xbf16>) -> vector<4xf32>
 {
   // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
-  %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
@@ -127,7 +127,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
   %a: memref<16xbf16>) -> vector<8xf32>
 {
   // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
-  %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 
@@ -136,7 +136,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_128(
   %a: memref<1xbf16>) -> vector<4xf32>
 {
   // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
-  %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
@@ -145,7 +145,61 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
   %a: memref<1xbf16>) -> vector<8xf32>
 {
   // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
-  %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+  %a: memref<8xf16>) -> vector<4xf32>
+{
+  // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+  %a: memref<16xf16>) -> vector<8xf32>
+{
+  // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+  %a: memref<8xf16>) -> vector<4xf32>
+{
+  // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+  %a: memref<16xf16>) -> vector<8xf32>
+{
+  // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128
+func.func @avxf16_bsct_f16_to_f32_packed_128(
+  %a: memref<1xf16>) -> vector<4xf32>
+{
+  // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256
+func.func @avxf16_bsct_f16_to_f32_packed_256(
+  %a: memref<1xf16>) -> vector<8xf32>
+{
+  // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 

diff  --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index b783cc869b981..7dcab3eb4dcb8 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -98,9 +98,9 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
   %a: memref<8xbf16>) -> vector<4xf32>
 {
-  // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
+  // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
   // CHECK-SAME: memref<8xbf16> -> vector<4xf32>
-  %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
@@ -108,9 +108,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
   %a: memref<16xbf16>) -> vector<8xf32>
 {
-  // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
+  // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
   // CHECK-SAME: memref<16xbf16> -> vector<8xf32>
-  %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 
@@ -118,9 +118,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
 func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
   %a: memref<8xbf16>) -> vector<4xf32>
 {
-  // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
+  // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
   // CHECK-SAME: memref<8xbf16> -> vector<4xf32>
-  %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
@@ -128,9 +128,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
 func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
   %a: memref<16xbf16>) -> vector<8xf32>
 {
-  // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
+  // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
   // CHECK-SAME: memref<16xbf16> -> vector<8xf32>
-  %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 
@@ -138,9 +138,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
 func.func @avxbf16_bcst_bf16_to_f32_128(
   %a: memref<1xbf16>) -> vector<4xf32>
 {
-  // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
+  // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
   // CHECK-SAME: memref<1xbf16> -> vector<4xf32>
-  %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
@@ -148,9 +148,69 @@ func.func @avxbf16_bcst_bf16_to_f32_128(
 func.func @avxbf16_bcst_bf16_to_f32_256(
   %a: memref<1xbf16>) -> vector<8xf32>
 {
-  // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
+  // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
   // CHECK-SAME: memref<1xbf16> -> vector<8xf32>
-  %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+  %a: memref<8xf16>) -> vector<4xf32>
+{
+  // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
+  // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+  %a: memref<16xf16>) -> vector<8xf32>
+{
+  // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
+  // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+  %a: memref<8xf16>) -> vector<4xf32>
+{
+  // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
+  // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+  %a: memref<16xf16>) -> vector<8xf32>
+{
+  // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
+  // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_128
+func.func @avxf16_bcst_f16_to_f32_128(
+  %a: memref<1xf16>) -> vector<4xf32>
+{
+  // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
+  // CHECK-SAME: memref<1xf16> -> vector<4xf32>
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_256
+func.func @avxf16_bcst_f16_to_f32_256(
+  %a: memref<1xf16>) -> vector<8xf32>
+{
+  // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
+  // CHECK-SAME: memref<1xf16> -> vector<8xf32>
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 

diff  --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index a8bc180d1d0ac..d11dc89bdc7c9 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -114,7 +114,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
   %a: memref<8xbf16>) -> vector<4xf32>
 {
   // CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
-  %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
@@ -123,7 +123,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
   %a: memref<16xbf16>) -> vector<8xf32>
 {
   // CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
-  %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 
@@ -132,7 +132,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
   %a: memref<8xbf16>) -> vector<4xf32>
 {
   // CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
-  %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
@@ -141,7 +141,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
   %a: memref<16xbf16>) -> vector<8xf32>
 {
   // CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
-  %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 
@@ -150,7 +150,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
   %a: memref<1xbf16>) -> vector<4xf32>
 {
   // CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
-  %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
   return %0 : vector<4xf32>
 }
 
@@ -159,7 +159,61 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
   %a: memref<1xbf16>) -> vector<8xf32>
 {
   // CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
-  %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneeph2ps128
+func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
+  %a: memref<8xf16>) -> vector<4xf32>
+{
+  // CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneeph2ps256
+func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
+  %a: memref<16xf16>) -> vector<8xf32>
+{
+  // CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
+  %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneoph2ps128
+func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
+  %a: memref<8xf16>) -> vector<4xf32>
+{
+  // CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneoph2ps256
+func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
+  %a: memref<16xf16>) -> vector<8xf32>
+{
+  // CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
+  %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vbcstnesh2ps128
+func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
+  %a: memref<1xf16>) -> vector<4xf32>
+{
+  // CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vbcstnesh2ps256
+func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
+  %a: memref<1xf16>) -> vector<8xf32>
+{
+  // CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
+  %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
   return %0 : vector<8xf32>
 }
 


        


More information about the Mlir-commits mailing list