[Mlir-commits] [mlir] [mlir][ArmSVE] Lower predicate-sized vector.create_masks to whilelt (PR #95531)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Jun 14 05:07:30 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/95531
This better/more canonical codegen than the generic LLVM lowering, which is a pattern the backend currently does not recognize. See: https://github.com/llvm/llvm-project/issues/81840.
>From 24a736a199957f9a35db100e962943a8cfc4cb77 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 14 Jun 2024 11:57:04 +0000
Subject: [PATCH] [mlir][ArmSVE] Lower predicate-sized vector.create_masks to
whilelt
This better/more canonical codegen than the generic LLVM lowering, which
is a pattern the backend currently does not recognize. See:
https://github.com/llvm/llvm-project/issues/81840.
---
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 7 ++++
.../Transforms/LegalizeForLLVMExport.cpp | 40 ++++++++++++++++++-
.../Dialect/ArmSVE/legalize-for-llvm.mlir | 30 +++++++++++++-
mlir/test/Target/LLVMIR/arm-sve.mlir | 15 +++++++
4 files changed, 90 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index f2d330c98e7d6..aea55830c6607 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -552,4 +552,11 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
Arg<AnyScalableVector, "v3">:$v3,
Arg<AnyScalableVector, "v3">:$v4)>;
+def WhileLTIntrOp :
+ ArmSVE_IntrOp<"whilelt",
+ [TypeIs<"res", SVEPredicate>, Pure],
+ /*overloadedOperands=*/[0],
+ /*overloadedResults=*/[0]>,
+ Arguments<(ins I64:$base, I64:$n)>;
+
#endif // ARMSVE_OPS
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 387937e811ced..7facb3f6b9da0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -140,6 +140,40 @@ using ConvertFromSvboolOpLowering =
using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
+/// Converts `vector.create_mask` ops that match the size of an SVE predicate
+/// to the `whilelt` intrinsic. This produces more canonical codegen than the
+/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
+/// for more details. Note that we can't (the more general) get.active.lane.mask
+/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
+/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
+/// `n` is zero (whereas `create_mask` just returns an all-false mask).
+struct PredicateCreateMaskOpLowering
+ : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp,
+ vector::CreateMaskOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto maskType = createMaskOp.getVectorType();
+ if (maskType.getRank() != 1 || !maskType.isScalable())
+ return failure();
+
+ // TODO: Support masks which are multiples of SVE predicates.
+ auto maskBaseSize = maskType.getDimSize(0);
+ if (maskBaseSize < 2 || maskBaseSize > 16 ||
+ !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
+ return failure();
+
+ auto loc = createMaskOp.getLoc();
+ auto zero = rewriter.create<LLVM::ZeroOp>(
+ loc, typeConverter->convertType(rewriter.getI64Type()));
+ rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
+ adaptor.getOperands()[0]);
+ return success();
+ }
+};
+
} // namespace
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -169,6 +203,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ConvertFromSvboolOpLowering,
ZipX2OpLowering,
ZipX4OpLowering>(converter);
+ // Add predicate conversion with a high benefit as it produces much nicer code
+ // than the generic lowering.
+ patterns.add<PredicateCreateMaskOpLowering>(converter, /*benifit=*/4096);
// clang-format on
}
@@ -191,7 +228,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ConvertToSvboolIntrOp,
ConvertFromSvboolIntrOp,
ZipX2IntrOp,
- ZipX4IntrOp>();
+ ZipX4IntrOp,
+ WhileLTIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8d11c2bcaa8d5..3fc5e6e9fcc96 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -cse -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
func.func @arm_sve_sdot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
@@ -211,3 +211,31 @@ func.func @arm_sve_zip_x4(
%0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_predicate_sized_create_masks(
+// CHECK-SAME: %[[INDEX:.*]]: i64
+func.func @arm_sve_predicate_sized_create_masks(%index: index) -> (vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>) {
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i64
+ // CHECK: %[[P2:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[2]xi1>
+ %0 = vector.create_mask %index : vector<[2]xi1>
+ // CHECK: %[[P4:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[4]xi1>
+ %1 = vector.create_mask %index : vector<[4]xi1>
+ // CHECK: %[[P8:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[8]xi1>
+ %2 = vector.create_mask %index : vector<[8]xi1>
+ // CHECK: %[[P16:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[16]xi1>
+ %3 = vector.create_mask %index : vector<[16]xi1>
+ return %0, %1, %2, %3 : vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_unsupported_create_masks
+func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>) {
+ // CHECK-NOT: arm_sve.intr.whilelt
+ %0 = vector.create_mask %index : vector<[1]xi1>
+ %1 = vector.create_mask %index : vector<[7]xi1>
+ %2 = vector.create_mask %index : vector<[32]xi1>
+ return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index c7cd1b74ccdb5..34413d46b440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -356,3 +356,18 @@ llvm.func @arm_sve_zip_x4(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>,
-> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)>
llvm.return
}
+
+// CHECK-LABEL: arm_sve_whilelt(
+// CHECK-SAME: i64 %[[BASE:[0-9]+]],
+// CHECK-SAME: i64 %[[N:[0-9]+]]
+llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
+ // call <vscale x 2 x i1> @llvm.aarch64.sve.whilelt.nxv2i1.i64(i64 %[[BASE]], i64 %[[N]])
+ %1 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[2]xi1>
+ // call <vscale x 4 x i1> @llvm.aarch64.sve.whilelt.nxv4i1.i64(i64 %[[BASE]], i64 %[[N]])
+ %2 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[4]xi1>
+ // call <vscale x 8 x i1> @llvm.aarch64.sve.whilelt.nxv8i1.i64(i64 %[[BASE]], i64 %[[N]])
+ %3 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[8]xi1>
+ // call <vscale x 16 x i1> @llvm.aarch64.sve.whilelt.nxv16i1.i64(i64 %[[BASE]], i64 %[[N]])
+ %4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
+ llvm.return
+}
More information about the Mlir-commits
mailing list