[Mlir-commits] [mlir] [mlir][ArmSVE] Add convert_to/from_svbool ops (PR #68586)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Oct 11 07:22:52 PDT 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/68586
>From 4f93456498d2c83c1f1d7fe1cfbfb5e0bcc33629 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 6 Oct 2023 14:40:00 +0000
Subject: [PATCH 1/3] [mlir][ArmSVE] Add convert_to/from_svbool ops
This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).
E.g.
```
// Convert a svbool mask to a mask of SVE predicates:
%svbool = vector.load %memref[%c0, %c0]
: memref<2x?xi1>, vector<2x[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
// => Results in vector<2x[8]xi1>
```
Or:
```
// Convert a mask of SVE predicates to a svbool mask:
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```
---
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 67 +++++++++++++++++
mlir/include/mlir/IR/CommonTypeConstraints.td | 43 +++++++++++
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp | 1 +
mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt | 1 +
.../Dialect/ArmSVE/Transforms/CMakeLists.txt | 1 +
.../Transforms/LegalizeForLLVMExport.cpp | 62 +++++++++++++++-
mlir/test/Dialect/ArmSVE/invalid.mlir | 51 +++++++++++++
.../Dialect/ArmSVE/legalize-for-llvm.mlir | 73 ++++++++++++++++++-
mlir/test/Dialect/ArmSVE/roundtrip.mlir | 49 ++++++++++++-
9 files changed, 343 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/ArmSVE/invalid.mlir
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d4294b4dd9fd4e8..fa7f6d080d5a91c 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect {
This dialect contains the definitions necessary to target specific Arm SVE
scalable vector operations.
}];
+
+ let dependentDialects = ["vector::VectorDialect"];
}
//===----------------------------------------------------------------------===//
@@ -40,6 +42,11 @@ def SVBool : ScalableVectorOfRankAndLengthAndType<
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
[1], [16, 8, 4, 2, 1], [I1]>;
+// Generalizations of SVBool and SVEPredicate to ranks >= 1.
+// These are masks with a single trailing scalable dimension.
+def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>;
+def SVEMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -236,6 +243,66 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+
+class SvboolTypeContraint<string lhsArg, string rhsArg> : TypesMatchWith<
+ "expected corresponding svbool type widened to [16]xi1",
+ lhsArg, rhsArg,
+ "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
+
+def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
+ [Pure, SvboolTypeContraint<"result", "source">]>
+{
+ let summary = "Convert a svbool type to a SVE predicate type";
+ let description = [{
+ Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
+ `vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing
+ dimension can be scalable.
+
+ Example 1: Convert a 1-D svbool mask to a SVE predicate.
+ ```mlir
+ %svbool = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
+ %mask = arm_sve.convert_from_svbool %svbool : vector<[4]xi1>
+ ```
+
+ Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
+ ```mlir
+ %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
+ %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
+ ```
+ }];
+ let arguments = (ins SVBoolMask:$source);
+ let results = (outs SVEMask:$result);
+ let assemblyFormat = "$source attr-dict `:` type($result)";
+}
+
+def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
+ [Pure, SvboolTypeContraint<"source", "result">]>
+{
+ let summary = "Convert a predicate type to a svbool type";
+ let description = [{
+ Converts SVE predicate types (or vectors of predicate types, e.g.
+ `vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
+ be scalable.
+
+ Example 1: Convert a 1-D SVE predicate to a svbool mask.
+ ```mlir
+ %mask = vector.create_mask %dim_size : vector<[4]xi1>
+ %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
+ // => Results in vector<[16]xi1>
+ ```
+
+ Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
+ ```mlir
+ %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
+ %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
+ // => Results in vector<2x[16]xi1>
+ ```
+ }];
+ let arguments = (ins SVEMask:$source);
+ let results = (outs SVBoolMask:$result);
+ let assemblyFormat = "$source attr-dict `:` type($source)";
+}
+
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
[Commutative]>;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..54a5a97fe2b6425 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -37,6 +37,12 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;
+// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
+def IsTrailingScalableVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
+ CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">]>;
+
// Whether a type is a VectorType and all dimensions are scalable.
def allDimsScalableVectorTypePred : And<[
IsVectorTypePred,
@@ -404,6 +410,10 @@ class ScalableVectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
"scalable vector", "::mlir::VectorType">;
+class TrailingScalableVectorOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsTrailingScalableVectorTypePred,
+ "trailing scalable vector", "::mlir::VectorType">;
+
// Whether the number of elements of a vector is from the given
// `allowedRanks` list
class IsVectorOfRankPred<list<int> allowedRanks> :
@@ -481,10 +491,32 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
== }]
# allowedlength>)>]>;
+class abs<int value> {
+ int ret = !if(!lt(value, 0), !sub(0, value), value);
+}
+
+// Whether the n-th (starting from 1) dim of the shape matches the given `size`.
+// Negative values index in reverse.
+class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
+ : And<[CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # abs<n>.ret>,
+ CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
+ # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
+ # !if(!lt(n, 0),
+ "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
+ "" # !sub(n, 1))
+ # "))">]>;
+
// Whether the shape of a vector matches the given `shape` list.
class IsVectorOfShape<list<int> shape>
: CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
+// Any ShapedType where the size of the n-th dim is contained in `sizes`.
+// Negative values index in reverse.
+class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
+ IsNthDimSizeIsOneOfPred<n, allowedSizes>,
+ " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
+ "::mlir::ShapedType">;
+
// Any vector where the number of elements is from the given
// `allowedLengths` list
class VectorOfLength<list<int> allowedLengths> : Type<
@@ -546,6 +578,17 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
+// Any scalable vector with a single trailing scalable dimensions, where the
+// size of the trailing dimension is in `allowedTrailingSizes` list, and the
+// type is in the `allowedTypes` list.
+class TrailingScalableVectorOfSizeAndType<list<int> allowedTrailingSizes,
+ list<Type> allowedTypes> : AllOfType<
+ [TrailingScalableVectorOf<allowedTypes>,
+ ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
+ TrailingScalableVectorOf<allowedTypes>.summary #
+ ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
+ "::mlir::VectorType">;
+
def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index b7f1020deba1e40..594c9b4c270f218 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
index fffc77245d12c93..9ef7384fc54925a 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
+ MLIRVectorDialect
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 7031ab4f799c4d2..2f1c43fae240d51 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms
LINK_LIBS PUBLIC
MLIRArmSVEDialect
MLIRFuncDialect
+ MLIRVectorDialect
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index abbb978304068e2..d280d2415ecdbfd 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -66,6 +68,54 @@ using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;
+namespace {
+
+template <typename Op, typename IntrOp>
+struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
+ using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(Op convertOp, typename Op::Adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = convertOp.getLoc();
+
+ auto source = convertOp.getSource();
+ VectorType sourceType = source.getType();
+ VectorType resultType = convertOp.getResult().getType();
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+
+ SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
+ tileShape.back() = sourceType.getShape().back();
+
+ for (SmallVector<int64_t> index :
+ StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
+ auto extractOrInsertPosition = ArrayRef(index).drop_back();
+ auto sourceVector = rewriter.create<vector::ExtractOp>(
+ loc, source, extractOrInsertPosition);
+ auto convertedType =
+ VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
+ .setDim(0, resultType.getShape().back());
+ auto convertedVector =
+ rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
+ result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
+ extractOrInsertPosition);
+ }
+
+ rewriter.replaceOp(convertOp, result);
+ return success();
+ }
+};
+
+using ConvertToSvboolOpLowering =
+ SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
+
+using ConvertFromSvboolOpLowering =
+ SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
+
+} // namespace
+
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
@@ -88,7 +138,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ScalableMaskedMulFOpLowering,
ScalableMaskedSDivIOpLowering,
ScalableMaskedUDivIOpLowering,
- ScalableMaskedDivFOpLowering>(converter);
+ ScalableMaskedDivFOpLowering,
+ ConvertToSvboolOpLowering,
+ ConvertFromSvboolOpLowering>(converter);
// clang-format on
}
@@ -107,7 +159,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFIntrOp,
ScalableMaskedSDivIIntrOp,
ScalableMaskedUDivIIntrOp,
- ScalableMaskedDivFIntrOp>();
+ ScalableMaskedDivFIntrOp,
+ ConvertToSvboolIntrOp,
+ ConvertFromSvboolIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
@@ -120,6 +174,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFOp,
ScalableMaskedSDivIOp,
ScalableMaskedUDivIOp,
- ScalableMaskedDivFOp>();
+ ScalableMaskedDivFOp,
+ ConvertToSvboolOp,
+ ConvertFromSvboolOp>();
// clang-format on
}
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
new file mode 100644
index 000000000000000..a1fa0d0292b7b76
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_type(%bool: vector<2x[16]xi1>) -> vector<2x[8]xi2> {
+ // expected-error at +1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
+ %mask = arm_sve.convert_from_svbool %bool : vector<2x[8]xi2>
+ return %mask : vector<2x[8]xi2>
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_shape(%bool : vector<[16]xi1>) -> vector<[7]xi1> {
+ // expected-error at +1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
+ %mask = arm_sve.convert_from_svbool %bool : vector<[7]xi1>
+ return %mask : vector<[7]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool__bad_mask_scalability(%bool : vector<[4]x[16]xi1>) -> vector<[4]x[8]xi1> {
+ // expected-error at +1 {{'result' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
+ %mask = arm_sve.convert_from_svbool %bool : vector<[4]x[8]xi1>
+ return %mask : vector<[4]x[8]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_type(%mask: vector<2x[8]xi2>) -> vector<2x[16]xi1> {
+ // expected-error at +1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<2x[8]xi2>'}}
+ %bool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi2>
+ return %bool : vector<2x[16]xi1>
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_shape(%mask : vector<[7]xi1>) -> vector<[16]xi1> {
+ // expected-error at +1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[7]xi1>'}}
+ %bool = arm_sve.convert_to_svbool %mask : vector<[7]xi1>
+ return
+}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8]xi1>) -> vector<[4]x[16]xi1> {
+ // expected-error at +1 {{'source' must be trailing scalable vector of 1-bit signless integer values with dim -1 having a size of {16, 8, 4, 2, 1}, but got 'vector<[4]x[8]xi1>'}}
+ %bool = arm_sve.convert_to_svbool %mask : vector<[4]x[8]xi1>
+ return
+}
+
+
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 2d980db981034dd..04f2f43e6a5e78d 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -canonicalize -split-input-file %s | FileCheck %s
func.func @arm_sve_sdot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
@@ -10,6 +10,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_smmla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
@@ -20,6 +22,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_udot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
@@ -30,6 +34,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_ummla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>)
@@ -40,6 +46,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
@@ -65,6 +73,8 @@ func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
return %4 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
%b: vector<[4]xf32>,
%c: vector<[4]xf32>,
@@ -87,6 +97,8 @@ func.func @arm_sve_arithf_masked(%a: vector<[4]xf32>,
return %3 : vector<[4]xf32>
}
+// -----
+
func.func @arm_sve_abs_diff(%a: vector<[4]xi32>,
%b: vector<[4]xi32>)
-> vector<[4]xi32> {
@@ -111,8 +123,67 @@ func.func @arm_sve_abs_diff(%a: vector<[4]xi32>,
return %3 : vector<[4]xi32>
}
+// -----
+
func.func @get_vector_scale() -> index {
// CHECK: llvm.intr.vscale
%0 = vector.vscale
return %0 : index
}
+
+// -----
+
+func.func @convert_1d_mask_to_svbool(%mask: vector<[4]xi1>) -> vector<[16]xi1>
+{
+ // CHECK: "arm_sve.intr.convert.to.svbool"(%{{.*}}) : (vector<[4]xi1>) -> vector<[16]xi1>
+ %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
+ return %svbool : vector<[16]xi1>
+}
+
+// -----
+
+func.func @convert_1d_mask_from_svbool(%svbool: vector<[16]xi1>) -> vector<[2]xi1>
+{
+ // CHECK: "arm_sve.intr.convert.from.svbool"(%{{.*}}) : (vector<[16]xi1>) -> vector<[2]xi1>
+ %mask = arm_sve.convert_from_svbool %svbool : vector<[2]xi1>
+ return %mask : vector<[2]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @convert_2d_mask_to_svbool(
+// CHECK-SAME: %[[MASK:.*]]: !llvm.array<2 x vector<[8]xi1>>)
+func.func @convert_2d_mask_to_svbool(%mask: vector<2x[8]xi1>) -> vector<2x[16]xi1>
+{
+ // CHECK-NEXT: %[[RES0:.*]] = llvm.mlir.constant(dense<false> : vector<2x[16]xi1>) : !llvm.array<2 x vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK0:.*]] = llvm.extractvalue %[[MASK]][0] : !llvm.array<2 x vector<[8]xi1>>
+ // CHECK-NEXT: %[[SVBOOL0:.*]] = "arm_sve.intr.convert.to.svbool"(%[[MASK0]]) : (vector<[8]xi1>) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[RES1:.*]] = llvm.insertvalue %[[SVBOOL0]], %[[RES0]][0] : !llvm.array<2 x vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK1:.*]] = llvm.extractvalue %[[MASK]][1] : !llvm.array<2 x vector<[8]xi1>>
+ // CHECK-NEXT: %[[SVBOOL1:.*]] = "arm_sve.intr.convert.to.svbool"(%[[MASK1]]) : (vector<[8]xi1>) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = llvm.insertvalue %[[SVBOOL1]], %[[RES1]][1] : !llvm.array<2 x vector<[16]xi1>>
+ %svbool = arm_sve.convert_to_svbool %mask : vector<2x[8]xi1>
+ // CHECK-NEXT: llvm.return %[[SVBOOL]] : !llvm.array<2 x vector<[16]xi1>>
+ return %svbool : vector<2x[16]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @convert_2d_mask_from_svbool(
+// CHECK-SAME: %[[SVBOOL:.*]]: !llvm.array<3 x vector<[16]xi1>>)
+func.func @convert_2d_mask_from_svbool(%svbool: vector<3x[16]xi1>) -> vector<3x[1]xi1>
+{
+ // CHECK-NEXT: %[[RES0:.*]] = llvm.mlir.constant(dense<false> : vector<3x[1]xi1>) : !llvm.array<3 x vector<[1]xi1>>
+ // CHECK-NEXT: %[[SVBOOL0:.*]] = llvm.extractvalue %[[SVBOOL]][0] : !llvm.array<3 x vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK0:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL0]]) : (vector<[16]xi1>) -> vector<[1]xi1>
+ // CHECK-NEXT: %[[RES1:.*]] = llvm.insertvalue %[[MASK0]], %[[RES0]][0] : !llvm.array<3 x vector<[1]xi1>>
+ // CHECK-NEXT: %[[SVBOOL1:.*]] = llvm.extractvalue %[[SVBOOL]][1] : !llvm.array<3 x vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK1:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL1]]) : (vector<[16]xi1>) -> vector<[1]xi1>
+ // CHECK-NEXT: %[[RES2:.*]] = llvm.insertvalue %[[MASK1]], %[[RES1]][1] : !llvm.array<3 x vector<[1]xi1>>
+ // CHECK-NEXT: %[[SVBOOL2:.*]] = llvm.extractvalue %[[SVBOOL]][2] : !llvm.array<3 x vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK2:.*]] = "arm_sve.intr.convert.from.svbool"(%[[SVBOOL2]]) : (vector<[16]xi1>) -> vector<[1]xi1>
+ // CHECK-NEXT: %[[MASK:.*]] = llvm.insertvalue %[[MASK2]], %[[RES2]][2] : !llvm.array<3 x vector<[1]xi1>>
+ %mask = arm_sve.convert_from_svbool %svbool : vector<3x[1]xi1>
+ // CHECK-NEXT: llvm.return %[[MASK]] : !llvm.array<3 x vector<[1]xi1>>
+ return %mask : vector<3x[1]xi1>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index d2ca035c17bfbcf..af390bb330a341d 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt -verify-diagnostics -split-input-file %s | mlir-opt | FileCheck %s
func.func @arm_sve_sdot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
@@ -9,6 +9,8 @@ func.func @arm_sve_sdot(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_smmla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>) -> vector<[4]xi32> {
@@ -18,6 +20,8 @@ func.func @arm_sve_smmla(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_udot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>) -> vector<[4]xi32> {
@@ -27,6 +31,8 @@ func.func @arm_sve_udot(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_ummla(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
%c: vector<[4]xi32>) -> vector<[4]xi32> {
@@ -36,6 +42,8 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
return %0 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
@@ -61,6 +69,8 @@ func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
return %2 : vector<[4]xi32>
}
+// -----
+
func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>,
%b: vector<[4]xf32>,
%c: vector<[4]xf32>,
@@ -82,3 +92,40 @@ func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>,
vector<[4]xf32>
return %3 : vector<[4]xf32>
}
+
+// -----
+
+func.func @arm_sve_convert_to_svbool(%a: vector<[1]xi1>,
+ %b: vector<[2]xi1>,
+ %c: vector<[4]xi1>,
+ %d: vector<[8]xi1>) {
+ // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[1]xi1>
+ %1 = arm_sve.convert_to_svbool %a : vector<[1]xi1>
+
+ // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[2]xi1>
+ %2 = arm_sve.convert_to_svbool %b : vector<[2]xi1>
+
+ // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[4]xi1>
+ %3 = arm_sve.convert_to_svbool %c : vector<[4]xi1>
+
+ // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[8]xi1>
+ %4 = arm_sve.convert_to_svbool %d : vector<[8]xi1>
+ return
+}
+
+// -----
+
+func.func @arm_sve_convert_from_svbool(%bool: vector<[16]xi1>) {
+ // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[1]xi1>
+ %1 = arm_sve.convert_from_svbool %bool : vector<[1]xi1>
+
+ // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[2]xi1>
+ %2 = arm_sve.convert_from_svbool %bool : vector<[2]xi1>
+
+ // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[4]xi1>
+ %3 = arm_sve.convert_from_svbool %bool : vector<[4]xi1>
+
+ // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[8]xi1>
+ %4 = arm_sve.convert_from_svbool %bool : vector<[8]xi1>
+ return
+}
>From 9c656ca0ba0624b5840badd412ba2323d007e52a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 10 Oct 2023 14:39:40 +0000
Subject: [PATCH 2/3] Fixups
---
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 47 ++++++++++++-------
mlir/include/mlir/IR/CommonTypeConstraints.td | 19 ++++++--
.../Transforms/LegalizeForLLVMExport.cpp | 23 +++++++++
.../Dialect/ArmSVE/legalize-for-llvm.mlir | 2 +-
4 files changed, 69 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index fa7f6d080d5a91c..cae87b764fc67dd 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -45,7 +45,7 @@ def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
// Generalizations of SVBool and SVEPredicate to ranks >= 1.
// These are masks with a single trailing scalable dimension.
def SVBoolMask : TrailingScalableVectorOfSizeAndType<[16], [I1]>;
-def SVEMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>;
+def SVEPredicateMask : TrailingScalableVectorOfSizeAndType<[16, 8, 4, 2, 1], [I1]>;
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
@@ -243,14 +243,13 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
-
-class SvboolTypeContraint<string lhsArg, string rhsArg> : TypesMatchWith<
+class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
- [Pure, SvboolTypeContraint<"result", "source">]>
+ [Pure, SvboolTypeConstraint<"result", "source">]>
{
let summary = "Convert a svbool type to a SVE predicate type";
let description = [{
@@ -260,25 +259,33 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
Example 1: Convert a 1-D svbool mask to a SVE predicate.
```mlir
- %svbool = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
- %mask = arm_sve.convert_from_svbool %svbool : vector<[4]xi1>
+ %source = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
+ %result = arm_sve.convert_from_svbool %source : vector<[4]xi1>
```
Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
```mlir
- %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
- %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
+ %source = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
+ %result = arm_sve.convert_from_svbool %source : vector<2x[8]xi1>
```
+
+ ---
+
+ A `svbool` is the smallest SVE predicate type that has a in-memory
+ representation (and maps to a full predicate register). In MLIR `svbool` is
+ represented as `vector<[16]xi1>`. Smaller SVE predicate types
+ (`vector<[1|2|4|8]xi1>`) must be stored as `svbool` then converted back to
+ a predicate after loading.
}];
let arguments = (ins SVBoolMask:$source);
- let results = (outs SVEMask:$result);
+ let results = (outs SVEPredicateMask:$result);
let assemblyFormat = "$source attr-dict `:` type($result)";
}
def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
- [Pure, SvboolTypeContraint<"source", "result">]>
+ [Pure, SvboolTypeConstraint<"source", "result">]>
{
- let summary = "Convert a predicate type to a svbool type";
+ let summary = "Convert a SVE predicate type to a svbool type";
let description = [{
Converts SVE predicate types (or vectors of predicate types, e.g.
`vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
@@ -286,19 +293,27 @@ def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
Example 1: Convert a 1-D SVE predicate to a svbool mask.
```mlir
- %mask = vector.create_mask %dim_size : vector<[4]xi1>
- %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
+ %source = vector.create_mask %dim_size : vector<[4]xi1>
+ %result = arm_sve.convert_to_svbool %source : vector<[4]xi1>
// => Results in vector<[16]xi1>
```
Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
```mlir
- %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
- %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
+ %source = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
+ %result = arm_sve.convert_to_svbool %source : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```
+
+ ---
+
+ A `svbool` is the smallest SVE predicate type that has a in-memory
+ representation (and maps to a full predicate register). In MLIR `svbool` is
+ represented as `vector<[16]xi1>`. Smaller SVE predicate types
+ (`vector<[1|2|4|8]xi1>`) must be converted to a `svbool` before they can be
+ stored.
}];
- let arguments = (ins SVEMask:$source);
+ let arguments = (ins SVEPredicateMask:$source);
let results = (outs SVBoolMask:$result);
let assemblyFormat = "$source attr-dict `:` type($source)";
}
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 54a5a97fe2b6425..a7970e59de8c27e 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -38,10 +38,17 @@ def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &
::llvm::cast<VectorType>($_self).isScalable()}]>;
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
-def IsTrailingScalableVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
- CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
- CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
- CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">]>;
+// Examples:
+// Valid:
+// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
+// Invalid
+// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
+def IsOnlyTrailingDimScalablePred : And<[
+ CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
+ CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
+]>;
// Whether a type is a VectorType and all dimensions are scalable.
def allDimsScalableVectorTypePred : And<[
@@ -410,8 +417,10 @@ class ScalableVectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
"scalable vector", "::mlir::VectorType">;
+// Any vector with a single trailing scalable dimension, with an element type in
+// the `allowedTypes` list.
class TrailingScalableVectorOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsTrailingScalableVectorTypePred,
+ ShapedContainerType<allowedTypes, IsOnlyTrailingDimScalablePred,
"trailing scalable vector", "::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index d280d2415ecdbfd..ca9e280f510858c 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -70,6 +70,25 @@ using ScalableMaskedDivFOpLowering =
namespace {
+/// Unrolls a conversion to/from equivalent vector types, to allow using a
+/// conversion intrinsic that only supports 1-D vector types.
+///
+/// Example:
+/// ```
+/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
+/// ```
+/// is rewritten into:
+/// ```
+/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
+/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
+/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
+/// : (vector<[4]xi1>) -> vector<[16]xi1>
+/// %3 = vector.insert %2, %cst [0] : vector<[16]xi1> into vector<2x[16]xi1>
+/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
+/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
+/// : (vector<[4]xi1>) -> vector<[16]xi1>
+/// %result = vector.insert %5, %3 [1] : vector<[16]xi1> into vector<2x[16]xi1>
+/// ```
template <typename Op, typename IntrOp>
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
@@ -86,9 +105,13 @@ struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
Value result = rewriter.create<arith::ConstantOp>(
loc, resultType, rewriter.getZeroAttr(resultType));
+ // We want to iterate over the input vector in steps of the trailing
+ // dimension. So this creates tile shape where all leading dimensions are 1,
+ // and the trailing dimension step is the size of the dimension.
SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
tileShape.back() = sourceType.getShape().back();
+ // Iterate over all scalable mask/predicate slices of the source vector.
for (SmallVector<int64_t> index :
StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
auto extractOrInsertPosition = ArrayRef(index).drop_back();
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 04f2f43e6a5e78d..8e76fb7119b844e 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -canonicalize -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
func.func @arm_sve_sdot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
>From 308f5f16529341cb2790fdf1edbeb556fc7a0f8f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 11 Oct 2023 14:20:31 +0000
Subject: [PATCH 3/3] Fixups
---
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 4 +-
mlir/include/mlir/IR/CommonTypeConstraints.td | 27 +++++++----
.../Transforms/LegalizeForLLVMExport.cpp | 4 +-
mlir/test/Dialect/ArmSVE/roundtrip.mlir | 46 ++++++++++++++++---
4 files changed, 61 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index cae87b764fc67dd..826f7aac9b38055 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -274,8 +274,8 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
A `svbool` is the smallest SVE predicate type that has a in-memory
representation (and maps to a full predicate register). In MLIR `svbool` is
represented as `vector<[16]xi1>`. Smaller SVE predicate types
- (`vector<[1|2|4|8]xi1>`) must be stored as `svbool` then converted back to
- a predicate after loading.
+ (`vector<[1|2|4|8]xi1>`) must be stored as a `svbool` then converted back to
+ the original predicate type after loading.
}];
let arguments = (ins SVBoolMask:$source);
let results = (outs SVEPredicateMask:$result);
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index a7970e59de8c27e..0c5453ee1a068dd 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -500,20 +500,27 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
== }]
# allowedlength>)>]>;
-class abs<int value> {
- int ret = !if(!lt(value, 0), !sub(0, value), value);
+// Normalizes an index so it can be bounds checked.
+// Negative values are mapped to their absolute value.
+// - These are used to index in reverse (i.e. index -1 would be the last element)
+// Positive values are mapped to their value + 1.
+// - This results the same range of values as the negative indices
+// This allows bounds checking to be: len(list) >= NormalizeIndex<idx>.ret.
+class NormalizeIndex<int value> {
+ int ret = !if(!lt(value, 0), !sub(0, value), !add(value, 1));
}
-// Whether the n-th (starting from 1) dim of the shape matches the given `size`.
+// Whether the n-th dim of the shape matches the given `size`.
// Negative values index in reverse.
class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
- : And<[CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # abs<n>.ret>,
- CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
- # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
- # !if(!lt(n, 0),
- "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
- "" # !sub(n, 1))
- # "))">]>;
+ : And<[
+ CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex<n>.ret>,
+ CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
+ # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
+ # !if(!lt(n, 0),
+ "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
+ "" # n)
+ # "))">]>;
// Whether the shape of a vector matches the given `shape` list.
class IsVectorOfShape<list<int> shape>
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index ca9e280f510858c..f54a26c27c2acad 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -83,11 +83,11 @@ namespace {
/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
-/// %3 = vector.insert %2, %cst [0] : vector<[16]xi1> into vector<2x[16]xi1>
+/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
-/// %result = vector.insert %5, %3 [1] : vector<[16]xi1> into vector<2x[16]xi1>
+/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
/// ```
template <typename Op, typename IntrOp>
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index af390bb330a341d..c9a0b6db8fa803d 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -98,7 +98,11 @@ func.func @arm_sve_masked_arithf(%a: vector<[4]xf32>,
func.func @arm_sve_convert_to_svbool(%a: vector<[1]xi1>,
%b: vector<[2]xi1>,
%c: vector<[4]xi1>,
- %d: vector<[8]xi1>) {
+ %d: vector<[8]xi1>,
+ %e: vector<2x3x[1]xi1>,
+ %f: vector<4x[2]xi1>,
+ %g: vector<1x1x1x2x[4]xi1>,
+ %h: vector<100x[8]xi1>) {
// CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[1]xi1>
%1 = arm_sve.convert_to_svbool %a : vector<[1]xi1>
@@ -110,22 +114,52 @@ func.func @arm_sve_convert_to_svbool(%a: vector<[1]xi1>,
// CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<[8]xi1>
%4 = arm_sve.convert_to_svbool %d : vector<[8]xi1>
+
+ // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<2x3x[1]xi1>
+ %5 = arm_sve.convert_to_svbool %e : vector<2x3x[1]xi1>
+
+ // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<4x[2]xi1>
+ %6 = arm_sve.convert_to_svbool %f : vector<4x[2]xi1>
+
+ // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<1x1x1x2x[4]xi1>
+ %7 = arm_sve.convert_to_svbool %g : vector<1x1x1x2x[4]xi1>
+
+ // CHECK: arm_sve.convert_to_svbool %{{.*}} : vector<100x[8]xi1>
+ %8 = arm_sve.convert_to_svbool %h : vector<100x[8]xi1>
+
return
}
// -----
-func.func @arm_sve_convert_from_svbool(%bool: vector<[16]xi1>) {
+func.func @arm_sve_convert_from_svbool(%a: vector<[16]xi1>,
+ %b: vector<2x3x[16]xi1>,
+ %c: vector<4x[16]xi1>,
+ %d: vector<1x1x1x1x[16]xi1>,
+ %e: vector<32x[16]xi1>) {
// CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[1]xi1>
- %1 = arm_sve.convert_from_svbool %bool : vector<[1]xi1>
+ %1 = arm_sve.convert_from_svbool %a : vector<[1]xi1>
// CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[2]xi1>
- %2 = arm_sve.convert_from_svbool %bool : vector<[2]xi1>
+ %2 = arm_sve.convert_from_svbool %a : vector<[2]xi1>
// CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[4]xi1>
- %3 = arm_sve.convert_from_svbool %bool : vector<[4]xi1>
+ %3 = arm_sve.convert_from_svbool %a : vector<[4]xi1>
// CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<[8]xi1>
- %4 = arm_sve.convert_from_svbool %bool : vector<[8]xi1>
+ %4 = arm_sve.convert_from_svbool %a : vector<[8]xi1>
+
+ // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<2x3x[1]xi1>
+ %5 = arm_sve.convert_from_svbool %b : vector<2x3x[1]xi1>
+
+ // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<4x[2]xi1>
+ %6 = arm_sve.convert_from_svbool %c : vector<4x[2]xi1>
+
+ // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<1x1x1x1x[4]xi1>
+ %7 = arm_sve.convert_from_svbool %d : vector<1x1x1x1x[4]xi1>
+
+ // CHECK: arm_sve.convert_from_svbool %{{.*}} : vector<32x[8]xi1>
+ %8 = arm_sve.convert_from_svbool %e : vector<32x[8]xi1>
+
return
}
More information about the Mlir-commits
mailing list