[Mlir-commits] [mlir] [mlir][ArmSME] Move creation of load/store intrinsics to helpers (NFC) (PR #76168)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Dec 21 09:08:21 PST 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/76168

>From 36c67252bb565407173ce2054ca55113d5c458c5 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Dec 2023 16:46:44 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Move creation of load/store intrinsics to
 helpers (NFC)

Also, for consistency make the ZeroOp lowering switch on the ArmSMETileType,
rather than the element bit width.
---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 230 +++++++++---------
 1 file changed, 110 insertions(+), 120 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index f9d6f04a811f3e..9b70057b25a243 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -32,6 +32,95 @@ using namespace mlir;
 
 namespace {
 
+/// Helper to create a arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
+static Operation *createLoadTileSliceIntrinsic(
+    RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+    arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+    IntegerAttr tileId, Value tileSliceI32) {
+  if (layout == arm_sme::TileSliceLayout::Horizontal) {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    }
+  } else {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+      break;
+    }
+  }
+}
+
+/// Helper to create a arm_sme.intr.st1*.(horiz|vert)' intrinsic.
+static Operation *createStoreTileSliceIntrinsic(
+    RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+    arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+    IntegerAttr tileId, Value tileSliceI32) {
+  if (layout == arm_sme::TileSliceLayout::Horizontal) {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    }
+  } else {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    }
+  }
+}
+
 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
   auto tileId = op.getTileId();
   if (!tileId)
@@ -75,9 +164,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = zero.getLoc();
 
-    unsigned tileElementWidth =
-        zero.getVectorType().getElementType().getIntOrFloatBitWidth();
-
     auto tileId = getTileIdOrError(zero);
     if (!tileId)
       return failure();
@@ -86,23 +172,24 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
     // The base mask is just the mask to zero the first tile (of a size).
     // These masks are derived from:
     // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
+    arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
     auto baseMaskForSize = [&] {
-      switch (tileElementWidth) {
-      case 8:
+      switch (tileType) {
+      case arm_sme::ArmSMETileType::ZAB:
         // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
         // 64-bit element tiles named ZA0.D to ZA7.D.
         return 0b1111'1111;
-      case 16:
-        // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
-        // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
-        // Shift this left once for ZA1.H.
+      case arm_sme::ArmSMETileType::ZAH:
+        // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
+        // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
+        // once for ZA1.H.
         return 0b0101'0101;
-      case 32:
+      case arm_sme::ArmSMETileType::ZAS:
         // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
         // element tiles named ZA0.D and ZA4.D.
         // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
         return 0b0001'0001;
-      case 64:
+      case arm_sme::ArmSMETileType::ZAD:
         // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
         // setting the bit for that tile.
         return 0b0000'0001;
@@ -172,63 +259,14 @@ struct LoadTileSliceConversion
     // Create all active predicate mask.
     auto maskOp = loadTileSliceOp.getMask();
 
-    auto tileType = loadTileSliceOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+    auto tileVectorType = loadTileSliceOp.getVectorType();
+    arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
     arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
 
-    // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
-    if (layout == arm_sme::TileSliceLayout::Horizontal) {
-      switch (tileElementWidth) {
-      default:
-        llvm_unreachable("unexpected element type!");
-      case 8:
-        rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      case 16:
-        rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      case 32:
-        rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      case 64:
-        rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      case 128:
-        rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      }
-    } else {
-      switch (tileElementWidth) {
-      default:
-        llvm_unreachable("unexpected element type!");
-      case 8:
-        rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      case 16:
-        rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      case 32:
-        rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      case 64:
-        rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      case 128:
-        rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      }
-    }
+    // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile
+    // slice.
+    createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
+                                 tileId, tileSliceI32);
 
     // The load intrinsics have no result, replace 'arm_sme.tile_load' with
     // the input tile to preserve dataflow.
@@ -249,9 +287,7 @@ struct StoreTileSliceConversion
                   arm_sme::StoreTileSliceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = storeTileSliceOp.getLoc();
-    auto tileType = storeTileSliceOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+    auto tileVectorType = storeTileSliceOp.getVectorType();
 
     auto tileId = getTileIdOrError(storeTileSliceOp);
     if (!tileId)
@@ -271,58 +307,12 @@ struct StoreTileSliceConversion
     auto maskOp = storeTileSliceOp.getMask();
 
     arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
+    arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
 
-    if (layout == arm_sme::TileSliceLayout::Horizontal) {
-      switch (tileElementWidth) {
-      default:
-        llvm_unreachable("unexpected element type!");
-      case 8:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 16:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 32:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 64:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 128:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      }
-    } else {
-      switch (tileElementWidth) {
-      default:
-        llvm_unreachable("unexpected element type!");
-      case 8:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 16:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 32:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 64:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 128:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      }
-    }
+    rewriter.replaceOp(storeTileSliceOp,
+                       createStoreTileSliceIntrinsic(rewriter, loc, tileType,
+                                                     layout, maskOp, ptr,
+                                                     tileId, tileSliceI32));
 
     return success();
   }

>From 949fe5b0dcff0ce932d77d3153aeb8346e1a4055 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Dec 2023 17:07:03 +0000
Subject: [PATCH 2/2] `a` -> `an`

---
 mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 9b70057b25a243..8dd7cf453661f4 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -32,7 +32,7 @@ using namespace mlir;
 
 namespace {
 
-/// Helper to create a arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
+/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
 static Operation *createLoadTileSliceIntrinsic(
     RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
     arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
@@ -77,7 +77,7 @@ static Operation *createLoadTileSliceIntrinsic(
   }
 }
 
-/// Helper to create a arm_sme.intr.st1*.(horiz|vert)' intrinsic.
+/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
 static Operation *createStoreTileSliceIntrinsic(
     RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
     arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,



More information about the Mlir-commits mailing list