[Mlir-commits] [mlir] e37d6d2 - [mlir][ArmSME] Merge consecutive `arm_sme.intr.zero` ops (#106215)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 29 01:43:42 PDT 2024


Author: Benjamin Maxwell
Date: 2024-08-29T09:43:38+01:00
New Revision: e37d6d2a74d76fdc95f5c5d625e282ce600aad55

URL: https://github.com/llvm/llvm-project/commit/e37d6d2a74d76fdc95f5c5d625e282ce600aad55
DIFF: https://github.com/llvm/llvm-project/commit/e37d6d2a74d76fdc95f5c5d625e282ce600aad55.diff

LOG: [mlir][ArmSME] Merge consecutive `arm_sme.intr.zero` ops (#106215)

This merges consecutive SME zero intrinsics within a basic block, which
avoids the backend eventually emitting multiple zero instructions when
it could just use one.

Note: This kind of peephole optimization could be implemented in the
backend too.

Added: 
    

Modified: 
    mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
    mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
    mlir/test/Dialect/ArmSME/tile-zero-masks.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 4d96091a637cf0..1ad2ec6cee7f8c 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/ScopeExit.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
@@ -481,6 +482,9 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
         loc, rewriter.getI32IntegerAttr(zeroMask));
 
     // Create a placeholder op to preserve dataflow.
+    // Note: Place the `get_tile` op at the start of the block. This ensures
+    // that if there are multiple `zero` ops the intrinsics will be consecutive.
+    rewriter.setInsertionPointToStart(zero->getBlock());
     rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
 
     return success();
@@ -855,6 +859,36 @@ struct StreamingVLOpConversion
   }
 };
 
+/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
+/// or-ing the zero masks. Note: In future the backend _should_ handle this.
+static void mergeConsecutiveTileZerosInBlock(Block *block) {
+  uint32_t mergedZeroMask = 0;
+  SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge;
+  auto replaceMergedZeroOps = [&] {
+    auto cleanup = llvm::make_scope_exit([&] {
+      mergedZeroMask = 0;
+      zeroOpsToMerge.clear();
+    });
+    if (zeroOpsToMerge.size() <= 1)
+      return;
+    IRRewriter rewriter(zeroOpsToMerge.front());
+    rewriter.create<arm_sme::aarch64_sme_zero>(
+        zeroOpsToMerge.front().getLoc(),
+        rewriter.getI32IntegerAttr(mergedZeroMask));
+    for (auto zeroOp : zeroOpsToMerge)
+      rewriter.eraseOp(zeroOp);
+  };
+  for (Operation &op : *block) {
+    if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
+      mergedZeroMask |= zeroOp.getTileMask();
+      zeroOpsToMerge.push_back(zeroOp);
+    } else {
+      replaceMergedZeroOps();
+    }
+  }
+  replaceMergedZeroOps();
+}
+
 } // namespace
 
 namespace {
@@ -879,6 +913,8 @@ struct ConvertArmSMEToLLVMPass
     if (failed(applyPartialConversion(function, target, std::move(patterns))))
       signalPassFailure();
 
+    function->walk(mergeConsecutiveTileZerosInBlock);
+
     // Walk the function and fail if there are unexpected operations on SME
     // tile types after conversion.
     function->walk([&](Operation *op) {

diff  --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 1dced0fcd18c7a..2a183cb4d056a9 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -53,6 +53,7 @@
 /// These are obviously redundant, but there's no checks to avoid this.
 func.func @use_too_many_tiles() {
   %0 = arm_sme.zero : vector<[4]x[4]xi32>
+  "test.prevent_zero_merge"() : () -> ()
   %1 = arm_sme.zero : vector<[4]x[4]xi32>
   // expected-warning @below {{failed to allocate SME virtual tile to operation, tile value will go through memory, expect degraded performance}}
   %2 = arm_sme.zero : vector<[8]x[8]xi16>

diff  --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index ca339be5fb56f1..02128ed8731804 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -19,6 +19,7 @@ func.func @zero_za_b() {
 func.func @zero_za_h() {
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
   %zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
   %zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
   "test.some_use"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
@@ -32,10 +33,13 @@ func.func @zero_za_h() {
 func.func @zero_za_s() {
   // CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
   %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
   %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
   %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
   %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
   "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
@@ -51,18 +55,25 @@ func.func @zero_za_s() {
 func.func @zero_za_d() {
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
   %zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
   %zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
   %zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
   %zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
   %zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
   %zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
   %zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
+  "test.prevent_zero_merge"() : () -> ()
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
   %zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
   "test.some_use"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
@@ -75,3 +86,45 @@ func.func @zero_za_d() {
   "test.some_use"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: merge_consecutive_tile_zero_ops
+func.func @merge_consecutive_tile_zero_ops() {
+  // CHECK-NOT: arm_sme.intr.zero
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
+  // CHECK-NOT: arm_sme.intr.zero
+  %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+  %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+  %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+  %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+  "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+/// arm_sme.intr.zero intrinsics are not merged when there is an op other than
+/// arm_sme.intr.zero between them.
+
+// CHECK-LABEL: merge_consecutive_tile_zero_ops_with_barrier
+func.func @merge_consecutive_tile_zero_ops_with_barrier() {
+  // CHECK-NOT: arm_sme.intr.zero
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 51 : i32}> : () -> ()
+  // CHECK-NOT: arm_sme.intr.zero
+  %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+  %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+  "test.prevent_zero_merge"() : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 204 : i32}> : () -> ()
+  // CHECK-NOT: arm_sme.intr.zero
+  %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+  %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+  "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list