[Mlir-commits] [mlir] [mlir][x86vector] AVX Convert/Broadcast F16 to F32 instructions (PR #137917)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 1 19:21:48 PDT 2025
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/137917
>From fd523e58f9ff41ab997ce836fa47fceb9f4b1acc Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 29 Apr 2025 03:26:41 -0700
Subject: [PATCH 1/6] new avx2 f16 ops in x86vector dialect to handle f16
conversions to f32
---
.../mlir/Dialect/X86Vector/X86Vector.td | 122 ++++++++++++++++++
.../Dialect/X86Vector/IR/X86VectorDialect.cpp | 17 +++
2 files changed, 139 insertions(+)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 126fa0e352656..37bdfc18a17a3 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -527,4 +527,126 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
}
+//----------------------------------------------------------------------------//
+// AVX: Convert packed F16 even-indexed/odd-indexed elements into packed F32
+//----------------------------------------------------------------------------//
+
+def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed F16 even-indexed elements into packed F32 Data.";
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed F16 (16-bit) floating-point even-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneeph2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed F16 odd-indexed elements into packed F32 Data.";
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed F16 (16-bit) floating-point odd-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneoph2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+//----------------------------------------------------------------------------//
+// AVX: Convert F16 to F32 and broadcast into packed F32
+//----------------------------------------------------------------------------//
+
+def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Broadcasts F16 into packed F32 Data.";
+
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert scalar F16 (16-bit) floating-point element stored at memory locations
+ starting at location `__A` to a single-precision (32-bit) floating-point,
+ broadcast it to packed single-precision (32-bit) floating-point elements,
+ and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.bcst.f16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vbcstnesh2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+
+}
+
#endif // X86VECTOR_OPS
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index f5e5070c74f8f..2e01a11921950 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -112,5 +112,22 @@ x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
+SmallVector<Value>
+x86vector::CvtPackedEvenIndexedF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value>
+x86vector::CvtPackedOddIndexedF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value> x86vector::BcstF16ToPackedF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
>From 96e1debc70ac168cd84736951d7bd332cbfd741f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 29 Apr 2025 19:48:47 -0700
Subject: [PATCH 2/6] adding new test-cases
---
.../Dialect/X86Vector/legalize-for-llvm.mlir | 54 +++++++++++++++++
mlir/test/Dialect/X86Vector/roundtrip.mlir | 60 +++++++++++++++++++
mlir/test/Target/LLVMIR/x86vector.mlir | 54 +++++++++++++++++
3 files changed, 168 insertions(+)
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 93b304c44de8e..3888ec05ad866 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -149,6 +149,60 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128
+func.func @avxf16_bsct_f16_to_f32_packed_128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256
+func.func @avxf16_bsct_f16_to_f32_packed_256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index b783cc869b981..a2fdb0cf6d457 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -154,6 +154,66 @@ func.func @avxbf16_bcst_bf16_to_f32_256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_128
+func.func @avxf16_bcst_f16_to_f32_128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_256
+func.func @avxf16_bcst_f16_to_f32_256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index a8bc180d1d0ac..f474ae281ece3 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -163,6 +163,60 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneeph2ps128
+func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneeph2ps256
+func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneoph2ps128
+func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneoph2ps256
+func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vbcstnesh2ps128
+func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vbcstnesh2ps256
+func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
>From 7a2b6dc90724b21790300834e4f95ba873013d4f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 29 Apr 2025 21:54:53 -0700
Subject: [PATCH 3/6] corrected typo in example: llvm.ptr -> memref<*>
---
mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 37bdfc18a17a3..75b07f01e70f1 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -544,7 +544,7 @@ def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32"
Example:
```mlir
- %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
@@ -581,7 +581,7 @@ def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32",
Example:
```mlir
- %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : !llvm.ptr -> vector<8xf32>
+ %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
@@ -624,7 +624,7 @@ def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemR
Example:
```mlir
- %dst = x86vector.avx.bcst.f16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
+ %dst = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
>From d804786f16fad174a3f2ced0e6f6c4904a52b89b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 30 Apr 2025 05:45:32 -0700
Subject: [PATCH 4/6] generalization to cover both bf16/f16
---
.../mlir/Dialect/X86Vector/X86Vector.td | 187 +++++-------------
.../Dialect/X86Vector/IR/X86VectorDialect.cpp | 23 +--
.../Transforms/LegalizeForLLVMExport.cpp | 4 +-
.../Dialect/X86Vector/legalize-for-llvm.mlir | 24 +--
mlir/test/Dialect/X86Vector/roundtrip.mlir | 48 ++---
mlir/test/Target/LLVMIR/x86vector.mlir | 24 +--
6 files changed, 98 insertions(+), 212 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 75b07f01e70f1..4246f9d59d0c6 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -408,101 +408,27 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}];
}
-
//----------------------------------------------------------------------------//
-// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
+// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//
-def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
+def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
- let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
+ let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:
- Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
- memory locations starting at location `__A` to packed single-precision
- (32-bit) floating-point elements, and store the results in `dst`.
-
- Example:
- ```mlir
- %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
- ```
- }];
- let arguments = (ins AnyMemRef:$a);
- let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
- let assemblyFormat =
- "$a attr-dict`:` type($a)`->` type($dst)";
-
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
- std::string intr = "llvm.x86.vcvtneebf162ps";
- VectorType vecType = getDst().getType();
- unsigned elemBitWidth = vecType.getElementTypeBitWidth();
- unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
- intr += std::to_string(opBitWidth);
- return intr;
- }
- }];
-
- let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
- }];
-}
-
-def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
- let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
- let description = [{
- #### From the Intel Intrinsics Guide:
-
- Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
- memory locations starting at location `__A` to packed single-precision
- (32-bit) floating-point elements, and store the results in `dst`.
-
- Example:
- ```mlir
- %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
- ```
- }];
- let arguments = (ins AnyMemRef:$a);
- let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
- let assemblyFormat =
- "$a attr-dict`:` type($a)`->` type($dst)";
-
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
- std::string intr = "llvm.x86.vcvtneobf162ps";
- VectorType vecType = getDst().getType();
- unsigned elemBitWidth = vecType.getElementTypeBitWidth();
- unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
- intr += std::to_string(opBitWidth);
- return intr;
- }
- }];
-
- let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
- }];
-}
-
-//----------------------------------------------------------------------------//
-// AVX: Convert BF16 to F32 and broadcast into packed F32
-//----------------------------------------------------------------------------//
-
-def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
- let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
- let description = [{
- #### From the Intel Intrinsics Guide:
-
- Convert scalar BF16 (16-bit) floating-point element stored at memory locations
+ Convert scalar BF16 or F16 (16-bit) floating-point element stored at memory locations
starting at location `__A` to a single-precision (32-bit) floating-point,
broadcast it to packed single-precision (32-bit) floating-point elements,
and store the results in `dst`.
Example:
```mlir
- %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ ```
+ ```mlir
+ %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
@@ -512,7 +438,13 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
- std::string intr = "llvm.x86.vbcstnebf162ps";
+ auto elementType =
+ (cast<MemRefType>(getA().getType())).getElementType();
+ std::string intr = "llvm.x86.";
+ if (elementType.isBF16())
+ intr += "vbcstnebf162ps";
+ if (elementType.isF16())
+ intr += "vbcstnesh2ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -527,24 +459,26 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
}
-//----------------------------------------------------------------------------//
-// AVX: Convert packed F16 even-indexed/odd-indexed elements into packed F32
-//----------------------------------------------------------------------------//
+//------------------------------------------------------------------------------//
+// AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32
+//------------------------------------------------------------------------------//
-def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
- let summary = "AVX: Convert packed F16 even-indexed elements into packed F32 Data.";
+ let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data.";
let description = [{
-
#### From the Intel Intrinsics Guide:
- Convert packed F16 (16-bit) floating-point even-indexed elements stored at
+ Convert packed BF16 or F16 (16-bit) floating-point even-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.
Example:
```mlir
- %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ ```
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
@@ -554,7 +488,13 @@ def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32"
let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
- std::string intr = "llvm.x86.vcvtneeph2ps";
+ auto elementType =
+ (cast<MemRefType>(getA().getType())).getElementType();
+ std::string intr = "llvm.x86.";
+ if (elementType.isBF16())
+ intr += "vcvtneebf162ps";
+ if (elementType.isF16())
+ intr += "vcvtneeph2ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -568,63 +508,22 @@ def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32"
}];
}
-def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
- let summary = "AVX: Convert packed F16 odd-indexed elements into packed F32 Data.";
+ let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
let description = [{
-
#### From the Intel Intrinsics Guide:
- Convert packed F16 (16-bit) floating-point odd-indexed elements stored at
+ Convert packed BF16 or F16 (16-bit) floating-point odd-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.
Example:
```mlir
- %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
```
- }];
- let arguments = (ins AnyMemRef:$a);
- let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
- let assemblyFormat =
- "$a attr-dict`:` type($a)`->` type($dst)";
-
- let extraClassDefinition = [{
- std::string $cppClass::getIntrinsicName() {
- std::string intr = "llvm.x86.vcvtneoph2ps";
- VectorType vecType = getDst().getType();
- unsigned elemBitWidth = vecType.getElementTypeBitWidth();
- unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
- intr += std::to_string(opBitWidth);
- return intr;
- }
- }];
-
- let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
- }];
-}
-
-//----------------------------------------------------------------------------//
-// AVX: Convert F16 to F32 and broadcast into packed F32
-//----------------------------------------------------------------------------//
-
-def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemRead]>,
- DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
- let summary = "AVX: Broadcasts F16 into packed F32 Data.";
-
- let description = [{
-
- #### From the Intel Intrinsics Guide:
-
- Convert scalar F16 (16-bit) floating-point element stored at memory locations
- starting at location `__A` to a single-precision (32-bit) floating-point,
- broadcast it to packed single-precision (32-bit) floating-point elements,
- and store the results in `dst`.
-
- Example:
```mlir
- %dst = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
@@ -634,7 +533,13 @@ def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemR
let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
- std::string intr = "llvm.x86.vbcstnesh2ps";
+ auto elementType =
+ (cast<MemRefType>(getA().getType())).getElementType();
+ std::string intr = "llvm.x86.";
+ if (elementType.isBF16())
+ intr += "vcvtneobf162ps";
+ if (elementType.isF16())
+ intr += "vcvtneoph2ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -643,10 +548,8 @@ def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemR
}
}];
- let extraClassDeclaration = [{
+ let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];
-
}
-
#endif // X86VECTOR_OPS
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 2e01a11921950..03430558dba7e 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -95,36 +95,19 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
return operands;
}
-SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
+SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
SmallVector<Value>
-x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
+x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
SmallVector<Value>
-x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
- return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
-}
-
-SmallVector<Value>
-x86vector::CvtPackedEvenIndexedF16ToF32Op::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
- return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
-}
-
-SmallVector<Value>
-x86vector::CvtPackedOddIndexedF16ToF32Op::getIntrinsicOperands(
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
- return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
-}
-
-SmallVector<Value> x86vector::BcstF16ToPackedF32Op::getIntrinsicOperands(
+x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index d2297554a1012..7e2f4c6c879da 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -116,6 +116,6 @@ void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addIllegalOp<
MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
- CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
- CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
+ CvtPackedF32ToBF16Op, CvtPackedEvenIndexedToF32Op,
+ CvtPackedOddIndexedToF32Op, BcstToPackedF32Op, RsqrtOp, DotOp>();
}
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 3888ec05ad866..63f06624ef897 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -100,7 +100,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -109,7 +109,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -118,7 +118,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -127,7 +127,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -136,7 +136,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_128(
%a: memref<1xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -145,7 +145,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
%a: memref<1xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -154,7 +154,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
- %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -163,7 +163,7 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
- %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -172,7 +172,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
- %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -181,7 +181,7 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
- %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -190,7 +190,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_128(
%a: memref<1xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
- %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -199,7 +199,7 @@ func.func @avxf16_bsct_f16_to_f32_packed_256(
%a: memref<1xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
- %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index a2fdb0cf6d457..7dcab3eb4dcb8 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -98,9 +98,9 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
- // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
+ // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -108,9 +108,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
- // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
+ // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -118,9 +118,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
- // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -128,9 +128,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
- // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -138,9 +138,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
func.func @avxbf16_bcst_bf16_to_f32_128(
%a: memref<1xbf16>) -> vector<4xf32>
{
- // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
+ // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xbf16> -> vector<4xf32>
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -148,9 +148,9 @@ func.func @avxbf16_bcst_bf16_to_f32_128(
func.func @avxbf16_bcst_bf16_to_f32_256(
%a: memref<1xbf16>) -> vector<8xf32>
{
- // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
+ // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xbf16> -> vector<8xf32>
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -158,9 +158,9 @@ func.func @avxbf16_bcst_bf16_to_f32_256(
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
- // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
- %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -168,9 +168,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
- // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
- %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -178,9 +178,9 @@ func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
- // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<8xf16> -> vector<4xf32>
- %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -188,9 +188,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
- // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} :
// CHECK-SAME: memref<16xf16> -> vector<8xf32>
- %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -198,9 +198,9 @@ func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
func.func @avxf16_bcst_f16_to_f32_128(
%a: memref<1xf16>) -> vector<4xf32>
{
- // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xf16> -> vector<4xf32>
- %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -208,9 +208,9 @@ func.func @avxf16_bcst_f16_to_f32_128(
func.func @avxf16_bcst_f16_to_f32_256(
%a: memref<1xf16>) -> vector<8xf32>
{
- // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} :
// CHECK-SAME: memref<1xf16> -> vector<8xf32>
- %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index f474ae281ece3..d11dc89bdc7c9 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -114,7 +114,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -123,7 +123,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
- %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -132,7 +132,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -141,7 +141,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
- %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -150,7 +150,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
%a: memref<1xbf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -159,7 +159,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
%a: memref<1xbf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
- %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -168,7 +168,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
- %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -177,7 +177,7 @@ func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
- %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -186,7 +186,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
- %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -195,7 +195,7 @@ func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
- %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
@@ -204,7 +204,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
%a: memref<1xf16>) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
- %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}
@@ -213,7 +213,7 @@ func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
%a: memref<1xf16>) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
- %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}
>From d8205f95ba46a0177eb3964a6c807a96167485df Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 30 Apr 2025 07:46:25 -0700
Subject: [PATCH 5/6] fixed clang format errors
---
mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp | 6 ++----
.../X86Vector/Transforms/LegalizeForLLVMExport.cpp | 8 ++++----
2 files changed, 6 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 03430558dba7e..8d383b1f8103b 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -100,14 +100,12 @@ SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
-SmallVector<Value>
-x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
-SmallVector<Value>
-x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 7e2f4c6c879da..9ee44a63ba2e4 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -114,8 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
- target.addIllegalOp<
- MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
- CvtPackedF32ToBF16Op, CvtPackedEvenIndexedToF32Op,
- CvtPackedOddIndexedToF32Op, BcstToPackedF32Op, RsqrtOp, DotOp>();
+ target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
+ Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op,
+ CvtPackedEvenIndexedToF32Op, CvtPackedOddIndexedToF32Op,
+ BcstToPackedF32Op, RsqrtOp, DotOp>();
}
>From 8aade1e935242729ec4181e37d2b7d65ec813030 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 1 May 2025 19:21:33 -0700
Subject: [PATCH 6/6] updated from AnyMemRef to MemRefOf[BF16, F16] and few
clean-ups
---
.../mlir/Dialect/X86Vector/X86Vector.td | 20 +++++++------------
1 file changed, 7 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 4246f9d59d0c6..4f8301f9380b8 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -426,12 +426,10 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
Example:
```mlir
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
- ```
- ```mlir
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
```
}];
- let arguments = (ins AnyMemRef:$a);
+ let arguments = (ins MemRefOf<[BF16, F16]>:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
@@ -439,7 +437,7 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
auto elementType =
- (cast<MemRefType>(getA().getType())).getElementType();
+ getA().getType().getElementType();
std::string intr = "llvm.x86.";
if (elementType.isBF16())
intr += "vbcstnebf162ps";
@@ -453,7 +451,7 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
}
}];
- let extraClassDeclaration = [{
+ let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];
@@ -476,12 +474,10 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
Example:
```mlir
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
- ```
- ```mlir
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
```
}];
- let arguments = (ins AnyMemRef:$a);
+ let arguments = (ins MemRefOf<[BF16, F16]>:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
@@ -489,7 +485,7 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
auto elementType =
- (cast<MemRefType>(getA().getType())).getElementType();
+ getA().getType().getElementType();
std::string intr = "llvm.x86.";
if (elementType.isBF16())
intr += "vcvtneebf162ps";
@@ -521,12 +517,10 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
Example:
```mlir
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
- ```
- ```mlir
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
```
}];
- let arguments = (ins AnyMemRef:$a);
+ let arguments = (ins MemRefOf<[BF16, F16]>:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";
@@ -534,7 +528,7 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
auto elementType =
- (cast<MemRefType>(getA().getType())).getElementType();
+ getA().getType().getElementType();
std::string intr = "llvm.x86.";
if (elementType.isBF16())
intr += "vcvtneobf162ps";
More information about the Mlir-commits
mailing list