[Mlir-commits] [mlir] [mlir][ArmSVE] Add convert_to/from_svbool ops (PR #68586)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Oct 9 06:43:20 PDT 2023


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/68586

This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).

E.g.

```
// Convert a svbool mask to a mask of SVE predicates:
%svbool = vector.load %memref[%c0, %c0]
                       : memref<2x?xi1>, vector<2x[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
// => Results in vector<2x[8]xi1>
```
Or:
```
// Convert a mask of SVE predicates to a svbool mask:
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```

Depends on #68418

>From 0908c28eb19d840f3603ff1c5a2a88ccf46047d1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 6 Oct 2023 12:56:37 +0000
Subject: [PATCH 1/2] [mlir][ArmSVE] Add convert.from/to.svbool intrinsics

These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not suppor this for any type smaller
than an svbool, which is vector<[16]xi1>).
---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 24 ++++++++++
 mlir/test/Target/LLVMIR/arm-sve.mlir          | 44 +++++++++++++++++++
 2 files changed, 68 insertions(+)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 58dec6091f27f6e..d4294b4dd9fd4e8 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -30,6 +30,16 @@ def ArmSVE_Dialect : Dialect {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ArmSVE type definitions
+//===----------------------------------------------------------------------===//
+
+def SVBool : ScalableVectorOfRankAndLengthAndType<
+  [1], [16], [I1]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I1]>;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -302,4 +312,18 @@ def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
   Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
+def ConvertFromSvboolIntrOp :
+  ArmSVE_IntrOp<"convert.from.svbool",
+    [TypeIs<"res", SVEPredicate>],
+    /*overloadedOperands=*/[],
+    /*overloadedResults=*/[0]>,
+  Arguments<(ins SVBool:$svbool)>;
+
+def ConvertToSvboolIntrOp :
+  ArmSVE_IntrOp<"convert.to.svbool",
+    [TypeIs<"res", SVBool>],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[]>,
+    Arguments<(ins SVEPredicate:$mask)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 999df8079e0727a..172a2f7d12d440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -272,3 +272,47 @@ llvm.func @get_vector_scale() -> i64 {
   %0 = "llvm.intr.vscale"() : () -> i64
   llvm.return %0 : i64
 }
+
+// CHECK-LABEL: @arm_sve_convert_from_svbool(
+// CHECK-SAME:                               <vscale x 16 x i1> %[[SVBOOL:[0-9]+]])
+llvm.func @arm_sve_convert_from_svbool(%nxv16i1 : vector<[16]xi1>) {
+  // CHECK: %[[RES0:.*]] = call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res0 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[8]xi1>
+  // CHECK: %[[RES1:.*]] = call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res1 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[4]xi1>
+  // CHECK: %[[RES2:.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res2 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[2]xi1>
+  // CHECK: %[[RES3:.*]] = call <vscale x 1 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv1i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res3 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[1]xi1>
+  llvm.return
+}
+
+// CHECK-LABEL: arm_sve_convert_to_svbool(
+// CHECK-SAME:                            <vscale x 8 x i1> %[[P8:[0-9]+]],
+// CHECK-SAME:                            <vscale x 4 x i1> %[[P4:[0-9]+]],
+// CHECK-SAME:                            <vscale x 2 x i1> %[[P2:[0-9]+]],
+// CHECK-SAME:                            <vscale x 1 x i1> %[[P1:[0-9]+]])
+llvm.func @arm_sve_convert_to_svbool(
+                                       %nxv8i1  : vector<[8]xi1>,
+                                       %nxv4i1  : vector<[4]xi1>,
+                                       %nxv2i1  : vector<[2]xi1>,
+                                       %nxv1i1  : vector<[1]xi1>
+) {
+  // CHECK-NEXT: %[[RES0:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %[[P8]])
+  %res0 = "arm_sve.intr.convert.to.svbool"(%nxv8i1)
+    : (vector<[8]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES1:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %[[P4]])
+  %res1 = "arm_sve.intr.convert.to.svbool"(%nxv4i1)
+    : (vector<[4]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES2:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %[[P2]])
+  %res2 = "arm_sve.intr.convert.to.svbool"(%nxv2i1)
+    : (vector<[2]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES3:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv1i1(<vscale x 1 x i1> %[[P1]])
+  %res3 = "arm_sve.intr.convert.to.svbool"(%nxv1i1)
+    : (vector<[1]xi1>) -> vector<[16]xi1>
+  llvm.return
+}

>From ee738b982de3765591dc0caf93d6621236eb2a20 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 6 Oct 2023 14:40:00 +0000
Subject: [PATCH 2/2] [mlir][ArmSVE] Add convert_to/from_svbool ops

This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).

E.g.

```
// Convert a svbool mask to a mask of SVE predicates:
%svbool = vector.load %memref[%c0, %c0]
                       : memref<2x?xi1>, vector<2x[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
// => Results in vector<2x[8]xi1>
```
Or:
```
// Convert a mask of SVE predicates to a svbool mask:
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```
---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 67 +++++++++++++++++
 mlir/include/mlir/IR/CommonTypeConstraints.td | 43 +++++++++++
 mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp  |  1 +
 mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt     |  1 +
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/LegalizeForLLVMExport.cpp      | 62 +++++++++++++++-
 mlir/test/Dialect/ArmSVE/invalid.mlir         | 51 +++++++++++++
 .../Dialect/ArmSVE/legalize-for-llvm.mlir     | 73 ++++++++++++++++++-
 mlir/test/Dialect/ArmSVE/roundtrip.mlir       | 49 ++++++++++++-
 9 files changed, 343 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/ArmSVE/invalid.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d4294b4dd9fd4e8..fa7f6d080d5a91c 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect {
     This dialect contains the definitions necessary to target specific Arm SVE
     scalable vector operations.
   }];
+
+  let dependentDialects = ["vector::VectorDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -40,6 +42,11 @@ def SVBool : ScalableVectorOfRankAndLengthAndType<
 def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
   [1], [16, 8, 4, 2, 1], [I1]>;
 
+// Generalizations of SVBool and SVEPredicate to ranks >= 1.
+// These are masks with a single trailing scalable dimension.
+def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>;
+def SVEMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -236,6 +243,66 @@ def UmmlaOp : ArmSVE_Op<"ummla",
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+
+class SvboolTypeContraint<string lhsArg, string rhsArg> : TypesMatchWith<
+      "expected corresponding svbool type widened to [16]xi1",
+      lhsArg, rhsArg,
+      "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
+
+def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
+                            [Pure, SvboolTypeContraint<"result", "source">]>
+{
+  let summary = "Convert a svbool type to a SVE predicate type";
+  let description = [{
+    Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
+    `vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing
+    dimension can be scalable.
+
+    Example 1: Convert a 1-D svbool mask to a SVE predicate.
+    ```mlir
+    %svbool = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
+    %mask = arm_sve.convert_from_svbool %svbool : vector<[4]xi1>
+    ```
+
+    Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
+    ```mlir
+    %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
+    %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
+    ```
+  }];
+  let arguments = (ins SVBoolMask:$source);
+  let results = (outs SVEMask:$result);
+  let assemblyFormat = "$source attr-dict `:` type($result)";
+}
+
+def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
+                            [Pure, SvboolTypeContraint<"source", "result">]>
+{
+  let summary = "Convert a predicate type to a svbool type";
+  let description = [{
+    Converts SVE predicate types (or vectors of predicate types, e.g.
+    `vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
+    be scalable.
+
+    Example 1: Convert a 1-D SVE predicate to a svbool mask.
+    ```mlir
+    %mask = vector.create_mask %dim_size : vector<[4]xi1>
+    %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
+    // => Results in vector<[16]xi1>
+    ```
+
+    Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
+    ```mlir
+    %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
+    %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
+    // => Results in vector<2x[16]xi1>
+    ```
+  }];
+  let arguments = (ins SVEMask:$source);
+  let results = (outs SVBoolMask:$result);
+  let assemblyFormat = "$source attr-dict `:` type($source)";
+}
+
 def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
                                              [Commutative]>;
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..54a5a97fe2b6425 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -37,6 +37,12 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
 def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                    ::llvm::cast<VectorType>($_self).isScalable()}]>;
 
+// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
+def IsTrailingScalableVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+                            CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+                            CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
+                            CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">]>;
+
 // Whether a type is a VectorType and all dimensions are scalable.
 def allDimsScalableVectorTypePred : And<[
   IsVectorTypePred,
@@ -404,6 +410,10 @@ class ScalableVectorOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
           "scalable vector", "::mlir::VectorType">;
 
+class TrailingScalableVectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsTrailingScalableVectorTypePred,
+          "trailing scalable vector", "::mlir::VectorType">;
+
 // Whether the number of elements of a vector is from the given
 // `allowedRanks` list
 class IsVectorOfRankPred<list<int> allowedRanks> :
@@ -481,10 +491,32 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
                            == }]
                          # allowedlength>)>]>;
 
+class abs<int value> {
+  int ret = !if(!lt(value, 0), !sub(0, value), value);
+}
+
+// Whether the n-th (starting from 1) dim of the shape matches the given `size`.
+// Negative values index in reverse.
+class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
+  : And<[CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # abs<n>.ret>,
+        CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
+          # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
+          #   !if(!lt(n, 0),
+                "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
+                "" # !sub(n, 1))
+          # "))">]>;
+
 // Whether the shape of a vector matches the given `shape` list.
 class IsVectorOfShape<list<int> shape>
   : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
 
+// Any ShapedType where the size of the n-th dim is contained in `sizes`.
+// Negative values index in reverse.
+class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
+  IsNthDimSizeIsOneOfPred<n, allowedSizes>,
+  " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
+  "::mlir::ShapedType">;
+
 // Any vector where the number of elements is from the given
 // `allowedLengths` list
 class VectorOfLength<list<int> allowedLengths> : Type<
@@ -546,6 +578,17 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Any scalable vector with a single trailing scalable dimensions, where the
+// size of the trailing dimension is in `allowedTrailingSizes` list, and the
+// type is in the `allowedTypes` list.
+class TrailingScalableVectorOfSizeAndType<list<int> allowedTrailingSizes,
+                                           list<Type> allowedTypes> : AllOfType<
+  [TrailingScalableVectorOf<allowedTypes>,
+   ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
+   TrailingScalableVectorOf<allowedTypes>.summary #
+   ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
+  "::mlir::VectorType">;
+
 def AnyVector : VectorOf<[AnyType]>;
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index b7f1020deba1e40..594c9b4c270f218 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
diff --git a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
index fffc77245d12c93..9ef7384fc54925a 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRVectorDialect
   MLIRSideEffectInterfaces
   )
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 7031ab4f799c4d2..2f1c43fae240d51 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms
   LINK_LIBS PUBLIC
   MLIRArmSVEDialect
   MLIRFuncDialect
+  MLIRVectorDialect
   MLIRIR
   MLIRLLVMCommonConversion
   MLIRLLVMDialect
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index abbb978304068e2..d280d2415ecdbfd 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -12,6 +12,8 @@
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
 
@@ -66,6 +68,54 @@ using ScalableMaskedDivFOpLowering =
     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
                                  ScalableMaskedDivFIntrOp>;
 
+namespace {
+
+template <typename Op, typename IntrOp>
+struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
+  using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Op convertOp, typename Op::Adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = convertOp.getLoc();
+
+    auto source = convertOp.getSource();
+    VectorType sourceType = source.getType();
+    VectorType resultType = convertOp.getResult().getType();
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+
+    SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
+    tileShape.back() = sourceType.getShape().back();
+
+    for (SmallVector<int64_t> index :
+         StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
+      auto extractOrInsertPosition = ArrayRef(index).drop_back();
+      auto sourceVector = rewriter.create<vector::ExtractOp>(
+          loc, source, extractOrInsertPosition);
+      auto convertedType =
+          VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
+              .setDim(0, resultType.getShape().back());
+      auto convertedVector =
+          rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
+      result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
+                                                 extractOrInsertPosition);
+    }
+
+    rewriter.replaceOp(convertOp, result);
+    return success();
+  }
+};
+
+using ConvertToSvboolOpLowering =
+    SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
+
+using ConvertFromSvboolOpLowering =
+    SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
+
+} // namespace
+
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
@@ -88,7 +138,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ScalableMaskedMulFOpLowering,
                ScalableMaskedSDivIOpLowering,
                ScalableMaskedUDivIOpLowering,
-               ScalableMaskedDivFOpLowering>(converter);
+               ScalableMaskedDivFOpLowering,
+               ConvertToSvboolOpLowering,
+               ConvertFromSvboolOpLowering>(converter);
   // clang-format on
 }
 
@@ -107,7 +159,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ScalableMaskedMulFIntrOp,
                     ScalableMaskedSDivIIntrOp,
                     ScalableMaskedUDivIIntrOp,
-                    ScalableMaskedDivFIntrOp>();
+                    ScalableMaskedDivFIntrOp,
+                    ConvertToSvboolIntrOp,
+                    ConvertFromSvboolIntrOp>();
   target.addIllegalOp<SdotOp,
                       SmmlaOp,
                       UdotOp,
@@ -120,6 +174,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       ScalableMaskedMulFOp,
                       ScalableMaskedSDivIOp,
                       ScalableMaskedUDivIOp,
-                      ScalableMaskedDivFOp>();
+                      ScalableMaskedDivFOp,
+                      ConvertToSvboolOp,
+                      ConvertFromSvboolOp>();
   // clang-format on
 }
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
new file mode 100644
index 000000000000000..a1fa0d0292b7b76
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_type(%bool: vector<2x[16]xi1>) -> vector<2x[8]xi2> {
+  // expected-error at +1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
+  %mask = arm_sve.convert_from_svbool %bool : vector<2x[8]xi2>
+  return %mask : vector<2x[8]xi2>
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_shape(%bool : vector<[16]xi1>) -> vector<[7]xi1> {
+  // expected-error at +1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
+  %mask = arm_sve.convert_from_svbool %bool : vector<[7]xi1>
+  return %mask : vector<[7]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_scalability(%bool : vector<[4]x[16]xi1>) -> vector<[4]x[8]xi1> {
+  // expected-error at +1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
+  %mask = arm_sve.convert_from_svbool %bool : vector<[4]x[8]xi1>
+  return %mask : vector<[4]x[8]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_type(%mask: vector<2x[8]xi2>) -> vector<2x[16]xi1> {
+  // expected-error at +1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
+  %bool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi2>
+  return %bool : vector<2x[16]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_shape(%mask : vector<[7]xi1>) -> vector<[16]xi1> {
+  // expected-error at +1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
+  %bool = arm_sve.convert_to_svbool %mask : vector<[7]xi1>
+  return
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8]xi1>) -> vector<[4]x[16]xi1> {
+  // expected-error at +1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
+  %bool = arm_sve.convert_to_svbool %mask : vector<[4]x[8]xi1>
+  return
+}
+
+
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 2d980db981034dd..04f2f43e6a5e78d 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -canonicalize -split-input-file %s | FileCheck %s
 
 func.func @arm_sve_sdot(%a: vector<[16]xi8>,
                    %b: vector<[16]xi8>,
@@ -10,6 +10,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_smmla(%a: vector<[16]xi8>,
                     %b: vector<[16]xi8>,
                     %c: vector<[4]xi32>)
@@ -20,6 +22,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_udot(%a: vector<[16]xi8>,
                    %b: vector<[16]xi8>,
                    %c: vector<[4]xi32>)
@@ -30,6 +34,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_ummla(%a: vector<[16]xi8>,
                     %b: vector<[16]xi8>,
                     %c: vector<[4]xi32>)
@@ -40,6 +46,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
@@ -65,6 +73,8 @@ func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
   return %4 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
                             %b: vector<[4]xf32>,
                             %c: vector<[4]xf32>,
@@ -87,6 +97,8 @@ func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
   return %3 : vector<[4]xf32>
 }
 
+// -----
+
 func.func @arm_sve_abs_diff(%a: vector<[4]xi32>,
                        %b: vector<[4]xi32>)
                        -> vector<[4]xi32> {
@@ -111,8 +123,67 @@ func.func @arm_sve_abs_diff(%a: vector<[4]xi32>,
   return %3 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @get_vector_scale() -> index {
   // CHECK: llvm.intr.vscale
   %0 = vector.vscale
   return %0 : index
 }
+
+// -----
+
+func.func @convert_1d_mask_to_svbool(%mask: vector<[4]xi1>) -> vector<[16]xi1>
+{
+  // CHECK: "arm_sve.intr.convert.to.svbool"(%{{.*}}) : (vector<[4]xi1>) -> vector<[16]xi1>
+  %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
+  return %svbool : vector<[16]xi1>
+}
+
+// -----
+
+func.func @convert_1d_mask_from_svbool(%svbool: vector<[16]xi1>) -> vector<[2]xi1>
+{
+  // CHECK: "arm_sve.intr.convert.from.svbool"(%{{.*}}) : (vector<[16]xi1>) -> vector<[2]xi1>
+  %mask = arm_sve.convert_from_svbool %svbool : vector<[2]xi1>
+  return %mask : vector<[2]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @convert_2d_mask_to_svbool(
+// CHECK-SAME:                             %[[MASK:.*]]: !llvm.array<2 x vector<[8]xi1>>)
+func.func @convert_2d_mask_to_svbool(%mask: vector<2x[8]xi1>) -> vector<2x[16]xi1>
+{
+  // CHECK-NEXT:    %[[RES0:.*]] = llvm.mlir.constant(dense<false> : vector<2x[16]xi1>) : !llvm.array<2 x vector<[16]xi1>>
+  // CHECK-NEXT:   %[[MASK0:.*]] = llvm.extractvalue %[[MASK]][0] : !llvm.array<2 x vector<[8]xi1>>
+  // CHECK-NEXT: %[[SVBOOL0:.*]] = "arm_sve.intr.convert.to.svbool"(%[[MASK0]]) : (vector<[8]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT:    %[[RES1:.*]] = llvm.insertvalue %[[SVBOOL0]], %[[RES0]][0] : !llvm.array<2 x vector<[16]xi1>>
+  // CHECK-NEXT:   %[[MASK1:.*]] = llvm.extractvalue %[[MASK]][1] : !llvm.array<2 x vector<[8]xi1>>
+  // CHECK-NEXT: %[[SVBOOL1:.*]] = "arm_sve.intr.convert.to.svbool"(%[[MASK1]]) : (vector<[8]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT:  %[[SVBOOL:.*]] = llvm.insertvalue %[[SVBOOL1]], %[[RES1]][1] : !llvm.array<2 x vector<[16]xi1>>
+  %svbool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi1>
+  // CHECK-NEXT: llvm.return %[[SVBOOL]] : !llvm.array<2 x vector<[16]xi1>>
+  return %svbool : vector<2x[16]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @convert_2d_mask_from_svbool(
+// CHECK-SAME:                               %[[SVBOOL:.*]]: !llvm.array<3 x vector<[16]xi1>>)
+func.func @convert_2d_mask_from_svbool(%svbool: vector<3x[16]xi1>) -> vector<3x[1]xi1>
+{
+  // CHECK-NEXT:    %[[RES0:.*]] = llvm.mlir.constant(dense<false> : vector<3x[1]xi1>) : !llvm.array<3 x vector<[1]xi1>>
+  // CHECK-NEXT: %[[SVBOOL0:.*]] = llvm.extractvalue %[[SVBOOL]][0] : !llvm.array<3 x vector<[16]xi1>>
+  // CHECK-NEXT:   %[[MASK0:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL0]]) : (vector<[16]xi1>) -> vector<[1]xi1>
+  // CHECK-NEXT:    %[[RES1:.*]] = llvm.insertvalue %[[MASK0]], %[[RES0]][0] : !llvm.array<3 x vector<[1]xi1>>
+  // CHECK-NEXT: %[[SVBOOL1:.*]] = llvm.extractvalue %[[SVBOOL]][1] : !llvm.array<3 x vector<[16]xi1>>
+  // CHECK-NEXT:   %[[MASK1:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL1]]) : (vector<[16]xi1>) -> vector<[1]xi1>
+  // CHECK-NEXT:    %[[RES2:.*]] = llvm.insertvalue %[[MASK1]], %[[RES1]][1] : !llvm.array<3 x vector<[1]xi1>>
+  // CHECK-NEXT: %[[SVBOOL2:.*]] = llvm.extractvalue %[[SVBOOL]][2] : !llvm.array<3 x vector<[16]xi1>>
+  // CHECK-NEXT:   %[[MASK2:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL2]]) : (vector<[16]xi1>) -> vector<[1]xi1>
+  // CHECK-NEXT:    %[[MASK:.*]] = llvm.insertvalue %[[MASK2]], %[[RES2]][2] : !llvm.array<3 x vector<[1]xi1>>
+  %mask = arm_sve.convert_from_svbool %svbool : vector<3x[1]xi1>
+  // CHECK-NEXT: llvm.return %[[MASK]] : !llvm.array<3 x vector<[1]xi1>>
+  return %mask : vector<3x[1]xi1>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index d2ca035c17bfbcf..af390bb330a341d 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s
 
 func.func @arm_sve_sdot(%a: vector<[16]xi8>,
                    %b: vector<[16]xi8>,
@@ -9,6 +9,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_smmla(%a: vector<[16]xi8>,
                     %b: vector<[16]xi8>,
                     %c: vector<[4]xi32>) -> vector<[4]xi32> {
@@ -18,6 +20,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_udot(%a: vector<[16]xi8>,
                    %b: vector<[16]xi8>,
                    %c: vector<[4]xi32>) -> vector<[4]xi32> {
@@ -27,6 +31,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_ummla(%a: vector<[16]xi8>,
                     %b: vector<[16]xi8>,
                     %c: vector<[4]xi32>) -> vector<[4]xi32> {
@@ -36,6 +42,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
   return %0 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
@@ -61,6 +69,8 @@ func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
   return %2 : vector<[4]xi32>
 }
 
+// -----
+
 func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>,
                             %b: vector<[4]xf32>,
                             %c: vector<[4]xf32>,
@@ -82,3 +92,40 @@ func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>,
                                            vector<[4]xf32>
   return %3 : vector<[4]xf32>
 }
+
+// -----
+
+func.func @arm_sve_convert_to_svbool(%a: vector<[1]xi1>,
+                                     %b: vector<[2]xi1>,
+                                     %c: vector<[4]xi1>,
+                                     %d: vector<[8]xi1>) {
+  // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[1]xi1>
+  %1 = arm_sve.convert_to_svbool %a : vector<[1]xi1>
+
+  // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[2]xi1>
+  %2 = arm_sve.convert_to_svbool %b : vector<[2]xi1>
+
+  // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[4]xi1>
+  %3 = arm_sve.convert_to_svbool %c : vector<[4]xi1>
+
+  // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[8]xi1>
+  %4 = arm_sve.convert_to_svbool %d : vector<[8]xi1>
+  return
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool(%bool: vector<[16]xi1>) {
+  // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[1]xi1>
+  %1 = arm_sve.convert_from_svbool %bool : vector<[1]xi1>
+
+  // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[2]xi1>
+  %2 = arm_sve.convert_from_svbool %bool : vector<[2]xi1>
+
+  // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[4]xi1>
+  %3 = arm_sve.convert_from_svbool %bool : vector<[4]xi1>
+
+  // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[8]xi1>
+  %4 = arm_sve.convert_from_svbool %bool : vector<[8]xi1>
+  return
+}



More information about the Mlir-commits mailing list