[Mlir-commits] [mlir] [mlir][ArmSVE] Lower predicate-sized vector.create_masks to whilelt (PR #95531)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Jun 17 02:27:03 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/95531

>From 24a736a199957f9a35db100e962943a8cfc4cb77 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 14 Jun 2024 11:57:04 +0000
Subject: [PATCH 1/3] [mlir][ArmSVE] Lower predicate-sized vector.create_masks
 to whilelt

This better/more canonical codegen than the generic LLVM lowering, which
is a pattern the backend currently does not recognize. See:
https://github.com/llvm/llvm-project/issues/81840.
---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td |  7 ++++
 .../Transforms/LegalizeForLLVMExport.cpp      | 40 ++++++++++++++++++-
 .../Dialect/ArmSVE/legalize-for-llvm.mlir     | 30 +++++++++++++-
 mlir/test/Target/LLVMIR/arm-sve.mlir          | 15 +++++++
 4 files changed, 90 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index f2d330c98e7d6..aea55830c6607 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -552,4 +552,11 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
                    Arg<AnyScalableVector, "v3">:$v3,
                    Arg<AnyScalableVector, "v3">:$v4)>;
 
+def WhileLTIntrOp :
+  ArmSVE_IntrOp<"whilelt",
+    [TypeIs<"res", SVEPredicate>, Pure],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[0]>,
+  Arguments<(ins I64:$base, I64:$n)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 387937e811ced..7facb3f6b9da0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -140,6 +140,40 @@ using ConvertFromSvboolOpLowering =
 using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
 using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
 
+/// Converts `vector.create_mask` ops that match the size of an SVE predicate
+/// to the `whilelt` intrinsic. This produces more canonical codegen than the
+/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
+/// for more details. Note that we can't (the more general) get.active.lane.mask
+/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
+/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
+/// `n` is zero (whereas `create_mask` just returns an all-false mask).
+struct PredicateCreateMaskOpLowering
+    : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp,
+                  vector::CreateMaskOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto maskType = createMaskOp.getVectorType();
+    if (maskType.getRank() != 1 || !maskType.isScalable())
+      return failure();
+
+    // TODO: Support masks which are multiples of SVE predicates.
+    auto maskBaseSize = maskType.getDimSize(0);
+    if (maskBaseSize < 2 || maskBaseSize > 16 ||
+        !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
+      return failure();
+
+    auto loc = createMaskOp.getLoc();
+    auto zero = rewriter.create<LLVM::ZeroOp>(
+        loc, typeConverter->convertType(rewriter.getI64Type()));
+    rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
+                                               adaptor.getOperands()[0]);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -169,6 +203,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ConvertFromSvboolOpLowering,
                ZipX2OpLowering,
                ZipX4OpLowering>(converter);
+  // Add predicate conversion with a high benefit as it produces much nicer code
+  // than the generic lowering.
+  patterns.add<PredicateCreateMaskOpLowering>(converter, /*benifit=*/4096);
   // clang-format on
 }
 
@@ -191,7 +228,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ConvertToSvboolIntrOp,
                     ConvertFromSvboolIntrOp,
                     ZipX2IntrOp,
-                    ZipX4IntrOp>();
+                    ZipX4IntrOp,
+                    WhileLTIntrOp>();
   target.addIllegalOp<SdotOp,
                       SmmlaOp,
                       UdotOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8d11c2bcaa8d5..3fc5e6e9fcc96 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -cse -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
 
 func.func @arm_sve_sdot(%a: vector<[16]xi8>,
                    %b: vector<[16]xi8>,
@@ -211,3 +211,31 @@ func.func @arm_sve_zip_x4(
   %0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
   return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sve_predicate_sized_create_masks(
+// CHECK-SAME:                                        %[[INDEX:.*]]: i64
+func.func @arm_sve_predicate_sized_create_masks(%index: index) -> (vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>) {
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i64
+  // CHECK: %[[P2:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[2]xi1>
+  %0 = vector.create_mask %index : vector<[2]xi1>
+  // CHECK: %[[P4:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[4]xi1>
+  %1 = vector.create_mask %index : vector<[4]xi1>
+  // CHECK: %[[P8:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[8]xi1>
+  %2 = vector.create_mask %index : vector<[8]xi1>
+  // CHECK: %[[P16:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[16]xi1>
+  %3 = vector.create_mask %index : vector<[16]xi1>
+  return %0, %1, %2, %3 : vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_unsupported_create_masks
+func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>)  {
+  // CHECK-NOT: arm_sve.intr.whilelt
+  %0 = vector.create_mask %index : vector<[1]xi1>
+  %1 = vector.create_mask %index : vector<[7]xi1>
+  %2 = vector.create_mask %index : vector<[32]xi1>
+  return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index c7cd1b74ccdb5..34413d46b440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -356,3 +356,18 @@ llvm.func @arm_sve_zip_x4(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>,
      -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)>
   llvm.return
 }
+
+// CHECK-LABEL: arm_sve_whilelt(
+// CHECK-SAME:                  i64 %[[BASE:[0-9]+]],
+// CHECK-SAME:                  i64 %[[N:[0-9]+]]
+llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
+  // call <vscale x 2 x i1> @llvm.aarch64.sve.whilelt.nxv2i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %1 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[2]xi1>
+  // call <vscale x 4 x i1> @llvm.aarch64.sve.whilelt.nxv4i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %2 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[4]xi1>
+  // call <vscale x 8 x i1> @llvm.aarch64.sve.whilelt.nxv8i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %3 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[8]xi1>
+  // call <vscale x 16 x i1> @llvm.aarch64.sve.whilelt.nxv16i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
+  llvm.return
+}

>From 3e2e2eca1961f3b8461ffa6bec8fa0fc36ba0a4f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 14 Jun 2024 15:06:48 +0000
Subject: [PATCH 2/3] Fixups

---
 .../ArmSVE/Transforms/LegalizeForLLVMExport.cpp  | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 7facb3f6b9da0..605cb44e80449 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -147,7 +147,7 @@ using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
 /// as its semantics don't neatly map on to `vector.create_mask`, as it does an
 /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
 /// `n` is zero (whereas `create_mask` just returns an all-false mask).
-struct PredicateCreateMaskOpLowering
+struct CreateMaskOpLowering
     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
@@ -157,17 +157,17 @@ struct PredicateCreateMaskOpLowering
                   ConversionPatternRewriter &rewriter) const override {
     auto maskType = createMaskOp.getVectorType();
     if (maskType.getRank() != 1 || !maskType.isScalable())
-      return failure();
+      return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
 
     // TODO: Support masks which are multiples of SVE predicates.
     auto maskBaseSize = maskType.getDimSize(0);
     if (maskBaseSize < 2 || maskBaseSize > 16 ||
         !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
-      return failure();
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "not SVE predicate-sized");
 
     auto loc = createMaskOp.getLoc();
-    auto zero = rewriter.create<LLVM::ZeroOp>(
-        loc, typeConverter->convertType(rewriter.getI64Type()));
+    auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
     rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
                                                adaptor.getOperands()[0]);
     return success();
@@ -203,9 +203,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ConvertFromSvboolOpLowering,
                ZipX2OpLowering,
                ZipX4OpLowering>(converter);
-  // Add predicate conversion with a high benefit as it produces much nicer code
-  // than the generic lowering.
-  patterns.add<PredicateCreateMaskOpLowering>(converter, /*benifit=*/4096);
+  // Add vector.create_mask conversion with a high benefit as it produces much
+  // nicer code than the generic lowering.
+  patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
   // clang-format on
 }
 

>From 4d621383f81873b80be7ac408fe61fa777460aa6 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 17 Jun 2024 09:26:02 +0000
Subject: [PATCH 3/3] Typo

---
 mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 605cb44e80449..ed4f4cc7f0718 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -143,7 +143,7 @@ using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
 /// Converts `vector.create_mask` ops that match the size of an SVE predicate
 /// to the `whilelt` intrinsic. This produces more canonical codegen than the
 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
-/// for more details. Note that we can't (the more general) get.active.lane.mask
+/// for more details. Note that we can't use (the more general) active.lane.mask
 /// as its semantics don't neatly map on to `vector.create_mask`, as it does an
 /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
 /// `n` is zero (whereas `create_mask` just returns an all-false mask).



More information about the Mlir-commits mailing list