[Mlir-commits] [mlir] [mlir][ArmSME] Merge consecutive `arm_sme.intr.zero` ops (PR #106215)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 27 05:51:28 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sme
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This merges consecutive SME zero intrinsics within a basic block, which avoids the backend eventually emitting multiple zero instructions when it could just use one.
Note: This kind of peephole optimization could be implemented in the backend too.
---
Full diff: https://github.com/llvm/llvm-project/pull/106215.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+37)
- (modified) mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir (+1)
- (modified) mlir/test/Dialect/ArmSME/tile-zero-masks.mlir (+48)
``````````diff
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 4d96091a637cf0..8cdf83e431b69b 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/ScopeExit.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
@@ -481,6 +482,9 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
loc, rewriter.getI32IntegerAttr(zeroMask));
// Create a placeholder op to preserve dataflow.
+ // Note: Place the `get_tile` op at the start of the block. This ensures
+ // that if there are multiple `zero` ops the intrinsics will be consecutive.
+ rewriter.setInsertionPointToStart(zero->getBlock());
rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
return success();
@@ -855,6 +859,37 @@ struct StreamingVLOpConversion
}
};
+/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
+/// or-ing the zero masks. Note: In furture the backend _should_ handle this.
+static void mergeConsecutiveTileZerosInBlock(Block *block) {
+ uint32_t mergedZeroMask = 0;
+ SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge;
+ auto replaceMergedZeroOps = [&] {
+ auto cleanup = llvm::make_scope_exit([&] {
+ mergedZeroMask = 0;
+ zeroOpsToMerge.clear();
+ });
+ if (zeroOpsToMerge.size() <= 1)
+ return;
+ IRRewriter rewriter(zeroOpsToMerge.front());
+ rewriter.setInsertionPoint(zeroOpsToMerge.front());
+ rewriter.create<arm_sme::aarch64_sme_zero>(
+ zeroOpsToMerge.front().getLoc(),
+ rewriter.getI32IntegerAttr(mergedZeroMask));
+ for (auto zeroOp : zeroOpsToMerge)
+ rewriter.eraseOp(zeroOp);
+ };
+ for (Operation &op : *block) {
+ if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
+ mergedZeroMask |= zeroOp.getTileMask();
+ zeroOpsToMerge.push_back(zeroOp);
+ } else {
+ replaceMergedZeroOps();
+ }
+ }
+ replaceMergedZeroOps();
+}
+
} // namespace
namespace {
@@ -879,6 +914,8 @@ struct ConvertArmSMEToLLVMPass
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
+ function->walk(mergeConsecutiveTileZerosInBlock);
+
// Walk the function and fail if there are unexpected operations on SME
// tile types after conversion.
function->walk([&](Operation *op) {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 1dced0fcd18c7a..2a183cb4d056a9 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -53,6 +53,7 @@
/// These are obviously redundant, but there's no checks to avoid this.
func.func @use_too_many_tiles() {
%0 = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
%1 = arm_sme.zero : vector<[4]x[4]xi32>
// expected-warning @below {{failed to allocate SME virtual tile to operation, tile value will go through memory, expect degraded performance}}
%2 = arm_sme.zero : vector<[8]x[8]xi16>
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index ca339be5fb56f1..6e229b4a7de53a 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -19,6 +19,7 @@ func.func @zero_za_b() {
func.func @zero_za_h() {
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
%zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
"test.some_use"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
@@ -32,10 +33,13 @@ func.func @zero_za_h() {
func.func @zero_za_s() {
// CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
"test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
@@ -51,18 +55,25 @@ func.func @zero_za_s() {
func.func @zero_za_d() {
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
%zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
%zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
%zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
%zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
%zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
%zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
%zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
"test.some_use"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
@@ -75,3 +86,40 @@ func.func @zero_za_d() {
"test.some_use"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: merge_consecutive_tile_zero_ops
+func.func @merge_consecutive_tile_zero_ops() {
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
+ %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
+
+/// arm_sme.intr.zero intrinsics are not merged when there is an op other than
+/// arm_sme.intr.zero between them.
+
+// CHECK-LABEL: merge_consecutive_tile_zero_ops_with_barrier
+func.func @merge_consecutive_tile_zero_ops_with_barrier() {
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 51 : i32}> : () -> ()
+ %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 204 : i32}> : () -> ()
+ %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/106215
More information about the Mlir-commits
mailing list