[Mlir-commits] [mlir] a4e1541 - [mlir][ArmSME] Move creation of load/store intrinsics to helpers (NFC) (#76168)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 21 09:46:16 PST 2023


Author: Benjamin Maxwell
Date: 2023-12-21T17:46:12Z
New Revision: a4e15416b41459b6f69086a22088520ee826f244

URL: https://github.com/llvm/llvm-project/commit/a4e15416b41459b6f69086a22088520ee826f244
DIFF: https://github.com/llvm/llvm-project/commit/a4e15416b41459b6f69086a22088520ee826f244.diff

LOG: [mlir][ArmSME] Move creation of load/store intrinsics to helpers (NFC) (#76168)

Also, for consistency make the ZeroOp lowering switch on the ArmSMETileType,
rather than the element bit width.

Added: 
    

Modified: 
    mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index f9d6f04a811f3e..0c6e2e80b88a3b 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -32,6 +32,95 @@ using namespace mlir;
 
 namespace {
 
+/// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
+static Operation *createLoadTileSliceIntrinsic(
+    RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+    arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+    IntegerAttr tileId, Value tileSliceI32) {
+  if (layout == arm_sme::TileSliceLayout::Horizontal) {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    }
+  } else {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+      break;
+    }
+  }
+}
+
+/// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
+static Operation *createStoreTileSliceIntrinsic(
+    RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+    arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+    IntegerAttr tileId, Value tileSliceI32) {
+  if (layout == arm_sme::TileSliceLayout::Horizontal) {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    }
+  } else {
+    switch (type) {
+    case arm_sme::ArmSMETileType::ZAB:
+      return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAH:
+      return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAS:
+      return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAD:
+      return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    case arm_sme::ArmSMETileType::ZAQ:
+      return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
+          loc, maskOp, ptr, tileId, tileSliceI32);
+    }
+  }
+}
+
 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
   auto tileId = op.getTileId();
   if (!tileId)
@@ -75,9 +164,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = zero.getLoc();
 
-    unsigned tileElementWidth =
-        zero.getVectorType().getElementType().getIntOrFloatBitWidth();
-
     auto tileId = getTileIdOrError(zero);
     if (!tileId)
       return failure();
@@ -86,23 +172,24 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
     // The base mask is just the mask to zero the first tile (of a size).
     // These masks are derived from:
     // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
+    arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
     auto baseMaskForSize = [&] {
-      switch (tileElementWidth) {
-      case 8:
+      switch (tileType) {
+      case arm_sme::ArmSMETileType::ZAB:
         // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
         // 64-bit element tiles named ZA0.D to ZA7.D.
         return 0b1111'1111;
-      case 16:
-        // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
-        // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
-        // Shift this left once for ZA1.H.
+      case arm_sme::ArmSMETileType::ZAH:
+        // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
+        // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
+        // once for ZA1.H.
         return 0b0101'0101;
-      case 32:
+      case arm_sme::ArmSMETileType::ZAS:
         // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
         // element tiles named ZA0.D and ZA4.D.
         // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
         return 0b0001'0001;
-      case 64:
+      case arm_sme::ArmSMETileType::ZAD:
         // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
         // setting the bit for that tile.
         return 0b0000'0001;
@@ -172,63 +259,13 @@ struct LoadTileSliceConversion
     // Create all active predicate mask.
     auto maskOp = loadTileSliceOp.getMask();
 
-    auto tileType = loadTileSliceOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+    auto tileVectorType = loadTileSliceOp.getVectorType();
+    arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
     arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
 
     // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
-    if (layout == arm_sme::TileSliceLayout::Horizontal) {
-      switch (tileElementWidth) {
-      default:
-        llvm_unreachable("unexpected element type!");
-      case 8:
-        rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      case 16:
-        rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      case 32:
-        rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      case 64:
-        rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      case 128:
-        rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
-                                                         tileId, tileSliceI32);
-        break;
-      }
-    } else {
-      switch (tileElementWidth) {
-      default:
-        llvm_unreachable("unexpected element type!");
-      case 8:
-        rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      case 16:
-        rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      case 32:
-        rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      case 64:
-        rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      case 128:
-        rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
-                                                        tileId, tileSliceI32);
-        break;
-      }
-    }
+    createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
+                                 tileId, tileSliceI32);
 
     // The load intrinsics have no result, replace 'arm_sme.tile_load' with
     // the input tile to preserve dataflow.
@@ -249,9 +286,7 @@ struct StoreTileSliceConversion
                   arm_sme::StoreTileSliceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = storeTileSliceOp.getLoc();
-    auto tileType = storeTileSliceOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+    auto tileVectorType = storeTileSliceOp.getVectorType();
 
     auto tileId = getTileIdOrError(storeTileSliceOp);
     if (!tileId)
@@ -271,58 +306,12 @@ struct StoreTileSliceConversion
     auto maskOp = storeTileSliceOp.getMask();
 
     arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
+    arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
 
-    if (layout == arm_sme::TileSliceLayout::Horizontal) {
-      switch (tileElementWidth) {
-      default:
-        llvm_unreachable("unexpected element type!");
-      case 8:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 16:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 32:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 64:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 128:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      }
-    } else {
-      switch (tileElementWidth) {
-      default:
-        llvm_unreachable("unexpected element type!");
-      case 8:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 16:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 32:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 64:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      case 128:
-        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
-            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
-        break;
-      }
-    }
+    rewriter.replaceOp(storeTileSliceOp,
+                       createStoreTileSliceIntrinsic(rewriter, loc, tileType,
+                                                     layout, maskOp, ptr,
+                                                     tileId, tileSliceI32));
 
     return success();
   }


        


More information about the Mlir-commits mailing list