[Mlir-commits] [mlir] [mlir][llvm] add experimental.vector.interleave2 intrinsic (PR #79270)

Cullen Rhodes llvmlistbot at llvm.org
Fri Jan 26 08:12:37 PST 2024


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/79270

>From f1fe6a22df1ef9ac6d54e4d26afcdb03c2faa173 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 24 Jan 2024 10:15:14 +0000
Subject: [PATCH 1/3] [mlir][ArmSVE] add zip1 intrinsic

---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 4 ++++
 mlir/test/Target/LLVMIR/arm-sve.mlir          | 7 +++++++
 2 files changed, 11 insertions(+)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index e3f3d9e62e8fb3..754413a1ad491e 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -410,4 +410,8 @@ def ConvertToSvboolIntrOp :
     /*overloadedResults=*/[]>,
     Arguments<(ins SVEPredicate:$mask)>;
 
+def Zip1IntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"zip1">,
+  Arguments<(ins AnyScalableVector, AnyScalableVector)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index b63d3f06515690..002b1f9d804a7c 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -314,3 +314,10 @@ llvm.func @arm_sve_convert_to_svbool(
     : (vector<[1]xi1>) -> vector<[16]xi1>
   llvm.return
 }
+
+// CHECK-LABEL: @arm_sve_zip1
+// CHECK-NEXT: call <vscale x 8 x half> @llvm.aarch64.sve.zip1.nxv8f16(<vscale x 8 x half> %{{.*}}, <vscale x 8 x half> {{.*}})
+llvm.func @arm_sve_zip1(%arg0 : vector<[8]xf16>) -> vector<[8]xf16> {
+  %0 = "arm_sve.intr.zip1"(%arg0, %arg0) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+  llvm.return %0 : vector<[8]xf16>
+}

>From 0686a46ea03ae0683d42204cce05ec2d4a10ae29 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 26 Jan 2024 13:51:13 +0000
Subject: [PATCH 2/3] replace with interleave2 intrinsic

---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td   |  4 ----
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td     | 14 ++++++++++++++
 mlir/test/Dialect/LLVMIR/invalid.mlir           | 17 +++++++++++++++++
 mlir/test/Dialect/LLVMIR/roundtrip.mlir         |  7 +++++++
 mlir/test/Target/LLVMIR/arm-sve.mlir            |  7 -------
 5 files changed, 38 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 754413a1ad491e..e3f3d9e62e8fb3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -410,8 +410,4 @@ def ConvertToSvboolIntrOp :
     /*overloadedResults=*/[]>,
     Arguments<(ins SVEPredicate:$mask)>;
 
-def Zip1IntrOp :
-  ArmSVE_IntrBinaryOverloadedOp<"zip1">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector)>;
-
 #endif // ARMSVE_OPS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index a4f08fb92da903..b0c9f0d2d03009 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -969,6 +969,20 @@ def LLVM_vector_extract
   }];
 }
 
+class HalfElementsVectorTypeConstraint<string lhs, string rhs>
+    : TypesMatchWith<
+        "'" # rhs # "'" # " has half elements of '" # lhs # "'",
+      lhs, rhs,
+      "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self))"
+        ".setDim(0, ::llvm::cast<VectorType>($_self).getShape()[0] / 2))">;
+
+def LLVM_experimental_vector_interleave2
+    : LLVM_OneResultIntrOp<"experimental.vector.interleave2",
+                           /*overloadedResults=*/[0], /*overloadedOperands=*/[],
+                           /*traits=*/[Pure, AllTypesMatch<["vec1", "vec2"]>,
+                                       HalfElementsVectorTypeConstraint<"res", "vec1">]>,
+      Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
+
 //
 // LLVM Vector Predication operations.
 //
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index d72ff8ca3c3aa7..c8fad3c7059472 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1218,6 +1218,23 @@ func.func @extract_scalable_from_fixed_length_vector(%arg0 : vector<16xf32>) {
   %0 = llvm.intr.vector.extract %arg0[0] : vector<[8]xf32> from vector<16xf32>
 }
 
+
+// -----
+
+func.func @experimental_vector_interleave2_bad_type0(%vec1: vector<[2]xf16>, %vec2 : vector<[4]xf16>) {
+  // expected-error at +1 {{op failed to verify that all of {vec1, vec2} have same type}}
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[2]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+  return
+}
+
+// -----
+
+func.func @experimental_vector_interleave2_bad_type1(%vec1: vector<[2]xf16>, %vec2 : vector<[2]xf16>) {
+  // expected-error at +1 {{op failed to verify that 'vec1' has half elements of 'res'}}
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[2]xf16>, vector<[2]xf16>) -> vector<[8]xf16>
+  return
+}
+
 // -----
 
 func.func @invalid_bitcast_ptr_to_i64(%arg : !llvm.ptr) {
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 1958dd56bab7ad..b157cf00141842 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -342,6 +342,13 @@ func.func @mixed_vect(%arg0: vector<8xf32>, %arg1: vector<4xf32>, %arg2: vector<
   return
 }
 
+// CHECK-LABEL: @experimental_vector_interleave2
+func.func @experimental_vector_interleave2(%vec1: vector<[4]xf16>, %vec2 : vector<[4]xf16>) {
+  // CHECK: = "llvm.intr.experimental.vector.interleave2"({{.*}}) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+  return
+}
+
 // CHECK-LABEL: @alloca
 func.func @alloca(%size : i64) {
   // CHECK: llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 002b1f9d804a7c..b63d3f06515690 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -314,10 +314,3 @@ llvm.func @arm_sve_convert_to_svbool(
     : (vector<[1]xi1>) -> vector<[16]xi1>
   llvm.return
 }
-
-// CHECK-LABEL: @arm_sve_zip1
-// CHECK-NEXT: call <vscale x 8 x half> @llvm.aarch64.sve.zip1.nxv8f16(<vscale x 8 x half> %{{.*}}, <vscale x 8 x half> {{.*}})
-llvm.func @arm_sve_zip1(%arg0 : vector<[8]xf16>) -> vector<[8]xf16> {
-  %0 = "arm_sve.intr.zip1"(%arg0, %arg0) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
-  llvm.return %0 : vector<[8]xf16>
-}

>From a8ee8ed19f329af037e0ba6018f5f533576920ae Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 26 Jan 2024 16:11:34 +0000
Subject: [PATCH 3/3] check result has even number of elements

---
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td   | 24 ++++++++++---------
 mlir/test/Dialect/LLVMIR/invalid.mlir         | 10 +++++++-
 2 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index b0c9f0d2d03009..6b55b66cafb606 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -969,19 +969,21 @@ def LLVM_vector_extract
   }];
 }
 
-class HalfElementsVectorTypeConstraint<string lhs, string rhs>
-    : TypesMatchWith<
-        "'" # rhs # "'" # " has half elements of '" # lhs # "'",
-      lhs, rhs,
-      "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self))"
-        ".setDim(0, ::llvm::cast<VectorType>($_self).getShape()[0] / 2))">;
-
 def LLVM_experimental_vector_interleave2
     : LLVM_OneResultIntrOp<"experimental.vector.interleave2",
-                           /*overloadedResults=*/[0], /*overloadedOperands=*/[],
-                           /*traits=*/[Pure, AllTypesMatch<["vec1", "vec2"]>,
-                                       HalfElementsVectorTypeConstraint<"res", "vec1">]>,
-      Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
+        /*overloadedResults=*/[0], /*overloadedOperands=*/[],
+        /*traits=*/[
+          Pure, AllTypesMatch<["vec1", "vec2"]>,
+          PredOpTrait<
+            "result has an even number of elements",
+            CPred<"::llvm::cast<::mlir::VectorType>($res.getType()).getNumElements() % 2 == 0">>,
+          TypesMatchWith<
+            "'vec1' has half as many elements as result",
+            "res", "vec1",
+            "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self))"
+              ".setDim(0, ::llvm::cast<VectorType>($_self).getShape()[0] / 2))">
+        ]>,
+        Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
 
 //
 // LLVM Vector Predication operations.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index c8fad3c7059472..950de64ce9bbf7 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1230,13 +1230,21 @@ func.func @experimental_vector_interleave2_bad_type0(%vec1: vector<[2]xf16>, %ve
 // -----
 
 func.func @experimental_vector_interleave2_bad_type1(%vec1: vector<[2]xf16>, %vec2 : vector<[2]xf16>) {
-  // expected-error at +1 {{op failed to verify that 'vec1' has half elements of 'res'}}
+  // expected-error at +1 {{op failed to verify that 'vec1' has half as many elements as result}}
   %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[2]xf16>, vector<[2]xf16>) -> vector<[8]xf16>
   return
 }
 
 // -----
 
+func.func @experimental_vector_interleave2_bad_type2(%vec1: vector<[1]xf16>, %vec2 : vector<[1]xf16>) {
+  // expected-error at +1 {{op failed to verify that result has an even number of elements}}
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[1]xf16>, vector<[1]xf16>) -> vector<[3]xf16>
+  return
+}
+
+// -----
+
 func.func @invalid_bitcast_ptr_to_i64(%arg : !llvm.ptr) {
   // expected-error at +1 {{can only cast pointers from and to pointers}}
   %1 = llvm.bitcast %arg : !llvm.ptr to i64



More information about the Mlir-commits mailing list