[Mlir-commits] [mlir] 657ec73 - [mlir][ArmSVE] Lower predicate-sized vector.create_masks to whilelt (#95531)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 17 02:28:43 PDT 2024


Author: Benjamin Maxwell
Date: 2024-06-17T10:28:39+01:00
New Revision: 657ec7320d8a28171755ba0dd5afc570a5a16791

URL: https://github.com/llvm/llvm-project/commit/657ec7320d8a28171755ba0dd5afc570a5a16791
DIFF: https://github.com/llvm/llvm-project/commit/657ec7320d8a28171755ba0dd5afc570a5a16791.diff

LOG: [mlir][ArmSVE] Lower predicate-sized vector.create_masks to whilelt (#95531)

This produces better/more canonical codegen than the generic LLVM
lowering, which is a pattern the backend currently does not recognize.
See: https://github.com/llvm/llvm-project/issues/81840.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
    mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
    mlir/test/Target/LLVMIR/arm-sve.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index f2d330c98e7d6..aea55830c6607 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -552,4 +552,11 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
                    Arg<AnyScalableVector, "v3">:$v3,
                    Arg<AnyScalableVector, "v3">:$v4)>;
 
+def WhileLTIntrOp :
+  ArmSVE_IntrOp<"whilelt",
+    [TypeIs<"res", SVEPredicate>, Pure],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[0]>,
+  Arguments<(ins I64:$base, I64:$n)>;
+
 #endif // ARMSVE_OPS

diff  --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 387937e811ced..ed4f4cc7f0718 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -140,6 +140,40 @@ using ConvertFromSvboolOpLowering =
 using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
 using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
 
+/// Converts `vector.create_mask` ops that match the size of an SVE predicate
+/// to the `whilelt` intrinsic. This produces more canonical codegen than the
+/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
+/// for more details. Note that we can't use (the more general) active.lane.mask
+/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
+/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
+/// `n` is zero (whereas `create_mask` just returns an all-false mask).
+struct CreateMaskOpLowering
+    : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp,
+                  vector::CreateMaskOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto maskType = createMaskOp.getVectorType();
+    if (maskType.getRank() != 1 || !maskType.isScalable())
+      return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
+
+    // TODO: Support masks which are multiples of SVE predicates.
+    auto maskBaseSize = maskType.getDimSize(0);
+    if (maskBaseSize < 2 || maskBaseSize > 16 ||
+        !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "not SVE predicate-sized");
+
+    auto loc = createMaskOp.getLoc();
+    auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
+    rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
+                                               adaptor.getOperands()[0]);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -169,6 +203,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ConvertFromSvboolOpLowering,
                ZipX2OpLowering,
                ZipX4OpLowering>(converter);
+  // Add vector.create_mask conversion with a high benefit as it produces much
+  // nicer code than the generic lowering.
+  patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
   // clang-format on
 }
 
@@ -191,7 +228,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ConvertToSvboolIntrOp,
                     ConvertFromSvboolIntrOp,
                     ZipX2IntrOp,
-                    ZipX4IntrOp>();
+                    ZipX4IntrOp,
+                    WhileLTIntrOp>();
   target.addIllegalOp<SdotOp,
                       SmmlaOp,
                       UdotOp,

diff  --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8d11c2bcaa8d5..3fc5e6e9fcc96 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -cse -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
 
 func.func @arm_sve_sdot(%a: vector<[16]xi8>,
                    %b: vector<[16]xi8>,
@@ -211,3 +211,31 @@ func.func @arm_sve_zip_x4(
   %0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
   return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sve_predicate_sized_create_masks(
+// CHECK-SAME:                                        %[[INDEX:.*]]: i64
+func.func @arm_sve_predicate_sized_create_masks(%index: index) -> (vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>) {
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i64
+  // CHECK: %[[P2:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[2]xi1>
+  %0 = vector.create_mask %index : vector<[2]xi1>
+  // CHECK: %[[P4:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[4]xi1>
+  %1 = vector.create_mask %index : vector<[4]xi1>
+  // CHECK: %[[P8:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[8]xi1>
+  %2 = vector.create_mask %index : vector<[8]xi1>
+  // CHECK: %[[P16:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[16]xi1>
+  %3 = vector.create_mask %index : vector<[16]xi1>
+  return %0, %1, %2, %3 : vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_unsupported_create_masks
+func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>)  {
+  // CHECK-NOT: arm_sve.intr.whilelt
+  %0 = vector.create_mask %index : vector<[1]xi1>
+  %1 = vector.create_mask %index : vector<[7]xi1>
+  %2 = vector.create_mask %index : vector<[32]xi1>
+  return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
+}

diff  --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index c7cd1b74ccdb5..34413d46b440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -356,3 +356,18 @@ llvm.func @arm_sve_zip_x4(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>,
      -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)>
   llvm.return
 }
+
+// CHECK-LABEL: arm_sve_whilelt(
+// CHECK-SAME:                  i64 %[[BASE:[0-9]+]],
+// CHECK-SAME:                  i64 %[[N:[0-9]+]]
+llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
+  // call <vscale x 2 x i1> @llvm.aarch64.sve.whilelt.nxv2i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %1 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[2]xi1>
+  // call <vscale x 4 x i1> @llvm.aarch64.sve.whilelt.nxv4i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %2 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[4]xi1>
+  // call <vscale x 8 x i1> @llvm.aarch64.sve.whilelt.nxv8i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %3 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[8]xi1>
+  // call <vscale x 16 x i1> @llvm.aarch64.sve.whilelt.nxv16i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
+  llvm.return
+}


        


More information about the Mlir-commits mailing list