[Mlir-commits] [mlir] 7dcca62 - [mlir][ArmSVE] Add `arm_sve.zip.x2` and `arm_sve.zip.x4` ops (#81278)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 16 03:34:37 PST 2024


Author: Benjamin Maxwell
Date: 2024-02-16T11:34:34Z
New Revision: 7dcca6213258bf2df2dd8a7d555c9a12c1484759

URL: https://github.com/llvm/llvm-project/commit/7dcca6213258bf2df2dd8a7d555c9a12c1484759
DIFF: https://github.com/llvm/llvm-project/commit/7dcca6213258bf2df2dd8a7d555c9a12c1484759.diff

LOG: [mlir][ArmSVE] Add `arm_sve.zip.x2` and `arm_sve.zip.x4` ops (#81278)

This adds ops for the two and four-way SME 2 multi-vector zips.

See:

-
https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--two-registers---Interleave-elements-from-two-vectors-?lang=en
-
https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--four-registers---Interleave-elements-from-four-vectors-?lang=en

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
    mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/ArmSVE/invalid.mlir
    mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
    mlir/test/Dialect/ArmSVE/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index f237f232487e50..f2d330c98e7d6d 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -49,6 +49,12 @@ def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType<
 def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType<
   [16, 8, 4, 2, 1], [I1]>;
 
+// A constraint for a 1-D scalable vector of `length`.
+class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedContainerType<
+  elementTypes, And<[IsVectorOfShape<[length]>, IsVectorTypeWithAnyDimScalablePred]>,
+  "a 1-D scalable vector with length " # length,
+  "::mlir::VectorType">;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -321,6 +327,121 @@ def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
   let assemblyFormat = "$source attr-dict `:` type($source)";
 }
 
+// Inputs valid for the multi-vector zips (not including the 128-bit element zipqs)
+def ZipInputVectorType : AnyTypeOf<[
+  Scalable1DVectorOfLength<2, [I64, F64]>,
+  Scalable1DVectorOfLength<4, [I32, F32]>,
+  Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+  Scalable1DVectorOfLength<16, [I8]>],
+  "an SVE vector with element size <= 64-bit">;
+
+def ZipX2Op  : ArmSVE_Op<"zip.x2", [
+  Pure,
+  AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
+> {
+  let summary = "Multi-vector two-way zip op";
+
+  let description = [{
+    This operation interleaves elements from two input SVE vectors, returning
+    two new SVE vectors (`resultV1` and `resultV2`), which contain the low and
+    high halves of the result respectively.
+
+    Example:
+    ```mlir
+    // sourceV1 = [ A1, A2, A3, ... An ]
+    // sourceV2 = [ B1, B2, B3, ... Bn ]
+    // (resultV1, resultV2) = [ A1, B1, A2, B2, A3, B3, ... An, Bn ]
+    %resultV1, %resultV2 = arm_sve.zip.x2 %sourceV1, %sourceV2 : vector<[16]xi8>
+    ```
+
+    Note: This requires SME 2 (`+sme2` in LLVM target features)
+
+    [Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--two-registers---Interleave-elements-from-two-vectors-?lang=en)
+  }];
+
+  let arguments = (ins ZipInputVectorType:$sourceV1,
+                       ZipInputVectorType:$sourceV2);
+
+  let results = (outs ZipInputVectorType:$resultV1,
+                      ZipInputVectorType:$resultV2);
+
+  let builders = [
+    OpBuilder<(ins "Value":$v1, "Value":$v2), [{
+      build($_builder, $_state, v1.getType(), v1.getType(), v1, v2);
+  }]>];
+
+  let assemblyFormat = "$sourceV1 `,` $sourceV2 attr-dict `:` type($sourceV1)";
+
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return ::llvm::cast<VectorType>(getSourceV1().getType());
+    }
+  }];
+}
+
+def ZipX4Op  : ArmSVE_Op<"zip.x4", [
+  Pure,
+  AllTypesMatch<[
+    "sourceV1", "sourceV2", "sourceV3", "sourceV4",
+    "resultV1", "resultV2", "resultV3", "resultV4"]>]
+> {
+  let summary = "Multi-vector four-way zip op";
+
+  let description = [{
+    This operation interleaves elements from four input SVE vectors, returning
+    four new SVE vectors, each of which contain a quarter of the result. The
+    first quarter will be in `resultV1`, second in `resultV2`, third in
+    `resultV3`, and fourth in `resultV4`.
+
+    ```mlir
+    // sourceV1 = [ A1, A2, ... An ]
+    // sourceV2 = [ B1, B2, ... Bn ]
+    // sourceV3 = [ C1, C2, ... Cn ]
+    // sourceV4 = [ D1, D2, ... Dn ]
+    // (resultV1, resultV2, resultV3, resultV4)
+    //   = [ A1, B1, C1, D1, A2, B2, C2, D2, ... An, Bn, Cn, Dn ]
+    %resultV1, %resultV2, %resultV3, %resultV4 = arm_sve.zip.x4
+      %sourceV1, %sourceV2, %sourceV3, %sourceV4 : vector<[16]xi8>
+    ```
+
+    **Warning:** The result of this op is undefined for 64-bit elements on
+    hardware with less than 256-bit vectors!
+
+    Note: This requires SME 2 (`+sme2` in LLVM target features)
+
+    [Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--four-registers---Interleave-elements-from-four-vectors-?lang=en)
+  }];
+
+  let arguments = (ins ZipInputVectorType:$sourceV1,
+                       ZipInputVectorType:$sourceV2,
+                       ZipInputVectorType:$sourceV3,
+                       ZipInputVectorType:$sourceV4);
+
+  let results = (outs ZipInputVectorType:$resultV1,
+                      ZipInputVectorType:$resultV2,
+                      ZipInputVectorType:$resultV3,
+                      ZipInputVectorType:$resultV4);
+
+  let builders = [
+    OpBuilder<(ins "Value":$v1, "Value":$v2, "Value":$v3, "Value":$v4), [{
+      build($_builder, $_state,
+        v1.getType(), v1.getType(),
+        v1.getType(), v1.getType(),
+        v1, v2, v3, v4);
+  }]>];
+
+  let assemblyFormat = [{
+    $sourceV1 `,` $sourceV2 `,` $sourceV3 `,` $sourceV4 attr-dict
+      `:` type($sourceV1)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return ::llvm::cast<VectorType>(getSourceV1().getType());
+    }
+  }];
+}
+
 def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
                                              [Commutative]>;
 

diff  --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 32c87c1b824074..387937e811ced8 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -137,6 +137,9 @@ using ConvertToSvboolOpLowering =
 using ConvertFromSvboolOpLowering =
     SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
 
+using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
+using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
+
 } // namespace
 
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -163,7 +166,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ScalableMaskedUDivIOpLowering,
                ScalableMaskedDivFOpLowering,
                ConvertToSvboolOpLowering,
-               ConvertFromSvboolOpLowering>(converter);
+               ConvertFromSvboolOpLowering,
+               ZipX2OpLowering,
+               ZipX4OpLowering>(converter);
   // clang-format on
 }
 
@@ -184,7 +189,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ScalableMaskedUDivIIntrOp,
                     ScalableMaskedDivFIntrOp,
                     ConvertToSvboolIntrOp,
-                    ConvertFromSvboolIntrOp>();
+                    ConvertFromSvboolIntrOp,
+                    ZipX2IntrOp,
+                    ZipX4IntrOp>();
   target.addIllegalOp<SdotOp,
                       SmmlaOp,
                       UdotOp,
@@ -199,6 +206,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       ScalableMaskedUDivIOp,
                       ScalableMaskedDivFOp,
                       ConvertToSvboolOp,
-                      ConvertFromSvboolOp>();
+                      ConvertFromSvboolOp,
+                      ZipX2Op,
+                      ZipX4Op>();
   // clang-format on
 }

diff  --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
index a1fa0d0292b7b7..1258d3532c049c 100644
--- a/mlir/test/Dialect/ArmSVE/invalid.mlir
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -49,3 +49,18 @@ func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8
 }
 
 
+// -----
+
+func.func @arm_sve_zip_x2_bad_vector_type(%a : vector<[7]xi8>) {
+  // expected-error at +1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[7]xi8>'}}
+  arm_sve.zip.x2 %a, %a : vector<[7]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
+  // expected-error at +1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[5]xf64>'}}
+  arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
+  return
+}

diff  --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8e76fb7119b844..8d11c2bcaa8d51 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -187,3 +187,27 @@ func.func @convert_2d_mask_from_svbool(%svbool: vector<3x[16]xi1>) -> vector<3x[
   // CHECK-NEXT: llvm.return %[[MASK]] : !llvm.array<3 x vector<[1]xi1>>
   return %mask : vector<3x[1]xi1>
 }
+
+// -----
+
+func.func @arm_sve_zip_x2(%a: vector<[8]xi16>, %b: vector<[8]xi16>)
+  -> (vector<[8]xi16>, vector<[8]xi16>)
+{
+  // CHECK: arm_sve.intr.zip.x2
+  %0, %1 = arm_sve.zip.x2 %a, %b : vector<[8]xi16>
+  return %0, %1 : vector<[8]xi16>, vector<[8]xi16>
+}
+
+// -----
+
+func.func @arm_sve_zip_x4(
+  %a: vector<[16]xi8>,
+  %b: vector<[16]xi8>,
+  %c: vector<[16]xi8>,
+  %d: vector<[16]xi8>
+) -> (vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>)
+{
+  // CHECK: arm_sve.intr.zip.x4
+  %0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
+  return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
+}

diff  --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index c9a0b6db8fa803..f7b79aa2f275c4 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -163,3 +163,65 @@ func.func @arm_sve_convert_from_svbool(%a: vector<[16]xi1>,
 
   return
 }
+
+// -----
+
+func.func @arm_sve_zip_x2(
+  %v1: vector<[2]xi64>,
+  %v2: vector<[2]xf64>,
+  %v3: vector<[4]xi32>,
+  %v4: vector<[4]xf32>,
+  %v5: vector<[8]xi16>,
+  %v6: vector<[8]xf16>,
+  %v7: vector<[8]xbf16>,
+  %v8: vector<[16]xi8>
+) {
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xi64>
+  %a1, %b1 = arm_sve.zip.x2 %v1, %v1 : vector<[2]xi64>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xf64>
+  %a2, %b2 = arm_sve.zip.x2 %v2, %v2 : vector<[2]xf64>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xi32>
+  %a3, %b3 = arm_sve.zip.x2 %v3, %v3 : vector<[4]xi32>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xf32>
+  %a4, %b4 = arm_sve.zip.x2 %v4, %v4 : vector<[4]xf32>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xi16>
+  %a5, %b5 = arm_sve.zip.x2 %v5, %v5 : vector<[8]xi16>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xf16>
+  %a6, %b6 = arm_sve.zip.x2 %v6, %v6 : vector<[8]xf16>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xbf16>
+  %a7, %b7 = arm_sve.zip.x2 %v7, %v7 : vector<[8]xbf16>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[16]xi8>
+  %a8, %b8 = arm_sve.zip.x2 %v8, %v8 : vector<[16]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sve_zip_x4(
+  %v1: vector<[2]xi64>,
+  %v2: vector<[2]xf64>,
+  %v3: vector<[4]xi32>,
+  %v4: vector<[4]xf32>,
+  %v5: vector<[8]xi16>,
+  %v6: vector<[8]xf16>,
+  %v7: vector<[8]xbf16>,
+  %v8: vector<[16]xi8>
+) {
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xi64>
+  %a1, %b1, %c1, %d1 = arm_sve.zip.x4 %v1, %v1, %v1, %v1 : vector<[2]xi64>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xf64>
+  %a2, %b2, %c2, %d2 = arm_sve.zip.x4 %v2, %v2, %v2, %v2 : vector<[2]xf64>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xi32>
+  %a3, %b3, %c3, %d3 = arm_sve.zip.x4 %v3, %v3, %v3, %v3 : vector<[4]xi32>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xf32>
+  %a4, %b4, %c4, %d4 = arm_sve.zip.x4 %v4, %v4, %v4, %v4 : vector<[4]xf32>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xi16>
+  %a5, %b5, %c5, %d5 = arm_sve.zip.x4 %v5, %v5, %v5, %v5 : vector<[8]xi16>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xf16>
+  %a6, %b6, %c6, %d6 = arm_sve.zip.x4 %v6, %v6, %v6, %v6 : vector<[8]xf16>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xbf16>
+  %a7, %b7, %c7, %d7 = arm_sve.zip.x4 %v7, %v7, %v7, %v7 : vector<[8]xbf16>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[16]xi8>
+  %a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
+  return
+}


        


More information about the Mlir-commits mailing list