[Mlir-commits] [mlir] [mlir][ArmSME][test] Make use of arm_sme.streaming_vl (NFC) (PR #77322)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon Jan 8 06:46:38 PST 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/77322
Depends on: #77321
>From e8022d98cab11176ae8b20ae5ca246706923b7c1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 8 Jan 2024 14:04:09 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Add `arm_sme.intr.cnts(b|h|w|d)`
intrinsics
This adds MLIR versions of Arm streaming vector length intrinsics. These
allow reading the streaming vector length regardless of the streaming
mode.
---
.../mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td | 13 +++++++++++++
mlir/test/Target/LLVMIR/arm-sme-invalid.mlir | 8 ++++++++
mlir/test/Target/LLVMIR/arm-sme.mlir | 14 ++++++++++++++
3 files changed, 35 insertions(+)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index 4d96e04c886fa3..d85ef963ae5dc4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -187,4 +187,17 @@ def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+class ArmSME_IntrCountOp<string mnemonic>
+ : ArmSME_IntrOp<mnemonic,
+ /*immArgPositions=*/[],
+ /*immArgAttrNames=*/[],
+ /*overloadedOperands=*/[],
+ /*traits*/[PredOpTrait<"`res` is i64", TypeIsPred<"res", I64>>],
+ /*numResults=*/1, /*overloadedResults=*/[]>;
+
+def LLVM_aarch64_sme_cntsb : ArmSME_IntrCountOp<"cntsb">;
+def LLVM_aarch64_sme_cntsh : ArmSME_IntrCountOp<"cntsh">;
+def LLVM_aarch64_sme_cntsw : ArmSME_IntrCountOp<"cntsw">;
+def LLVM_aarch64_sme_cntsd : ArmSME_IntrCountOp<"cntsd">;
+
#endif // ARMSME_INTRINSIC_OPS
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index 7c9976bed91273..14821da838726f 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -31,3 +31,11 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
(vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
llvm.return %res : vector<[4]xi32>
}
+
+// -----
+
+llvm.func @arm_sme_streaming_vl_invalid_return_type() -> i32 {
+ // expected-error @+1 {{failed to verify that `res` is i64}}
+ %res = "arm_sme.intr.cntsb"() : () -> i32
+ llvm.return %res : i32
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index edc1f749130440..7a42033dc04bc0 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -403,3 +403,17 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
: (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
llvm.return
}
+
+// -----
+
+llvm.func @arm_sme_streaming_vl() {
+ // CHECK: call i64 @llvm.aarch64.sme.cntsb()
+ %svl_b = "arm_sme.intr.cntsb"() : () -> i64
+ // CHECK: call i64 @llvm.aarch64.sme.cntsh()
+ %svl_h = "arm_sme.intr.cntsh"() : () -> i64
+ // CHECK: call i64 @llvm.aarch64.sme.cntsw()
+ %svl_w = "arm_sme.intr.cntsw"() : () -> i64
+ // CHECK: call i64 @llvm.aarch64.sme.cntsd()
+ %svl_d = "arm_sme.intr.cntsd"() : () -> i64
+ llvm.return
+}
>From d7ebafecad54d2631d8bf836469f58498ee566d2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 8 Jan 2024 14:12:25 +0000
Subject: [PATCH 2/3] [mlir][ArmSME] Add `arm_sme.streaming_vl` operation
This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.
Example:
```mlir
%svl_w = arm_sme.streaming_vl <words>
```
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 43 +++++++++++++++++
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 47 +++++++++++++++++--
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 42 +++++++++++++++++
mlir/test/Dialect/ArmSME/roundtrip.mlir | 36 ++++++++++++++
4 files changed, 165 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index f7cc1d3fe7517f..4060407d81c0fa 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -223,6 +223,21 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
let defaultValue = "CombiningKind::Add";
}
+def TypeSize : I32EnumAttr<"TypeSize", "Size of vector type", [
+ I32EnumAttrCase<"Bytes" , 0, "bytes">,
+ I32EnumAttrCase<"HalfWords" , 1, "half_words">,
+ I32EnumAttrCase<"Words" , 2, "words">,
+ I32EnumAttrCase<"DoubleWords", 3, "double_words">,
+]> {
+ let cppNamespace = "::mlir::arm_sme";
+ let genSpecializedAttr = 0;
+}
+
+def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
+ "type_size"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
@@ -768,4 +783,32 @@ let arguments = (ins
}];
}
+def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
+{
+ let summary = "Query the streaming vector length";
+
+ let description = [{
+ This operation returns the streaming vector length for a type size. Unlike
+ `vector.vscale` the value returned is invariant to the streaming mode.
+
+ Example:
+ ```mlir
+ // Streaming vector length in:
+ // - bytes (8-bit)
+ %svl_b = arm_sme.streaming_vl <bytes>
+ // - half words (16-bit)
+ %svl_h = arm_sme.streaming_vl <half_words>
+ // - words (32-bit)
+ %svl_w = arm_sme.streaming_vl <words>
+ // - double words (64-bit)
+ %svl_d = arm_sme.streaming_vl <double_words>
+ ```
+ }];
+
+ let arguments = (ins ArmSME_TypeSizeAttr: $type_size);
+ let results = (outs Index);
+
+ let assemblyFormat = "$type_size attr-dict";
+}
+
#endif // ARMSME_OPS
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 0c6e2e80b88a3b..c5fdba6d00cc0f 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -518,6 +518,45 @@ struct OuterProductOpConversion
}
};
+/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
+///
+/// Example:
+///
+/// %0 = arm_sme.streaming_vl <half_words>
+///
+/// is converted to:
+///
+/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
+/// %0 = arith.index_cast %cnt : i64 to index
+///
+struct StreamingVLOpConversion
+ : public ConvertOpToLLVMPattern<arm_sme::StreamingVLOp> {
+ using ConvertOpToLLVMPattern<arm_sme::StreamingVLOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
+ arm_sme::StreamingVLOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = streamingVlOp.getLoc();
+ auto i64Type = rewriter.getI64Type();
+ auto *intrOp = [&]() -> Operation * {
+ switch (streamingVlOp.getTypeSize()) {
+ case arm_sme::TypeSize::Bytes:
+ return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
+ case arm_sme::TypeSize::HalfWords:
+ return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
+ case arm_sme::TypeSize::Words:
+ return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
+ case arm_sme::TypeSize::DoubleWords:
+ return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
+ }
+ }();
+ rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
+ streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -555,7 +594,9 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
+ arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
+ arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
+ arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
}
@@ -572,8 +613,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
- OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
- converter);
+ OuterProductOpConversion, ZeroOpConversion, GetTileConversion,
+ StreamingVLOpConversion>(converter);
}
std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index bd88da37bdf966..bab4fd60518c54 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -559,3 +559,45 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index)
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_bytes
+// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
+// CHECK: return %[[INDEX_COUNT]] : index
+func.func @arm_sme_streaming_vl_bytes() -> index {
+ %svl_b = arm_sme.streaming_vl <bytes>
+ return %svl_b : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_half_words
+// CHECK: "arm_sme.intr.cntsh"() : () -> i64
+func.func @arm_sme_streaming_vl_half_words() -> index {
+ %svl_h = arm_sme.streaming_vl <half_words>
+ return %svl_h : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_words
+// CHECK: "arm_sme.intr.cntsw"() : () -> i64
+func.func @arm_sme_streaming_vl_words() -> index {
+ %svl_w = arm_sme.streaming_vl <words>
+ return %svl_w : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_double_words
+// CHECK: "arm_sme.intr.cntsd"() : () -> i64
+func.func @arm_sme_streaming_vl_double_words() -> index {
+ %svl_d = arm_sme.streaming_vl <double_words>
+ return %svl_d : index
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 58ff7ef4d8340e..eb4c7149b61f1d 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1095,3 +1095,39 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v
%result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
return %result : vector<[16]x[16]xi8>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_streaming_vl_bytes() -> index {
+ // CHECK: arm_sme.streaming_vl <bytes>
+ %svl_b = arm_sme.streaming_vl <bytes>
+ return %svl_b : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_half_words() -> index {
+ // CHECK: arm_sme.streaming_vl <half_words>
+ %svl_h = arm_sme.streaming_vl <half_words>
+ return %svl_h : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_words() -> index {
+ // CHECK: arm_sme.streaming_vl <words>
+ %svl_w = arm_sme.streaming_vl <words>
+ return %svl_w : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_double_words() -> index {
+ // CHECK: arm_sme.streaming_vl <double_words>
+ %svl_d = arm_sme.streaming_vl <double_words>
+ return %svl_d : index
+}
>From 67170f7784a945ef3e7f23cae0de21614fd57f1e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 8 Jan 2024 14:30:42 +0000
Subject: [PATCH 3/3] [mlir][ArmSME][test] Make use of `arm_sme.streaming_vl`
(NFC)
---
.../Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir | 5 +----
.../Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir | 4 +---
.../Vector/CPU/ArmSME/test-transfer-read-2d.mlir | 9 +--------
.../Vector/CPU/ArmSME/test-transfer-write-2d.mlir | 9 +--------
.../Dialect/Vector/CPU/ArmSME/test-transpose.mlir | 4 +---
.../Dialect/Vector/CPU/ArmSME/vector-load-store.mlir | 10 ++--------
.../Dialect/Vector/CPU/ArmSME/vector-ops.mlir | 5 +----
7 files changed, 8 insertions(+), 38 deletions(-)
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 81546de6a3466b..55c9db21b246e5 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -20,12 +20,9 @@ func.func @entry() {
%c123_f32 = arith.constant 123.0 : f32
- %min_elts_s = arith.constant 4 : index
- %vscale = vector.vscale
-
// "svl" refers to the Streaming Vector Length and "svl_s" the number of
// 32-bit elements in a vector of SVL bits.
- %svl_s = arith.muli %min_elts_s, %vscale : index
+ %svl_s = arm_sme.streaming_vl <words>
%tile_init = bufferization.alloc_tensor(%svl_s, %svl_s) : tensor<?x?xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 936163d1cd30d9..4f2adb228b30a7 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -17,9 +17,7 @@ func.func @entry() {
%c1_i32 = arith.constant 1 : i32
// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
- %vscale = vector.vscale
- %min_elts_s = arith.constant 4 : index
- %svl_s = arith.muli %min_elts_s, %vscale : index
+ %svl_s = arm_sme.streaming_vl <words>
%za_s_size = arith.muli %svl_s, %svl_s : index
// Allocate memory.
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index 839aed2e840c90..dc990835dc2dfd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -134,12 +134,6 @@ func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
return %A : memref<?x?xf32>
}
-// This will be made a streaming function by enable-arm-streaming so return SVL.
-func.func @get_svl() -> index {
- %vscale = vector.vscale
- return %vscale : index
-}
-
func.func @entry() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -148,8 +142,7 @@ func.func @entry() {
// Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
// non-zero offsets while remaining inbounds.
- %svl = call @get_svl() : () -> index
- %svl_s = arith.muli %c4, %svl : index
+ %svl_s = arm_sme.streaming_vl <words>
%svl_s_plus_two = arith.addi %svl_s, %c2 : index
%A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index 84246606daa8af..9964df079e760e 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -96,12 +96,6 @@ func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
return %A : memref<?x?xf32>
}
-// This will be made a streaming function by enable-arm-streaming so return SVL.
-func.func @get_svl() -> index {
- %vscale = vector.vscale
- return %vscale : index
-}
-
func.func @entry() {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
@@ -111,8 +105,7 @@ func.func @entry() {
//
// Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
// non-zero offsets while remaining inbounds.
- %svl = call @get_svl() : () -> index
- %svl_s = arith.muli %c4, %svl : index
+ %svl_s = arm_sme.streaming_vl <words>
%svl_s_plus_two = arith.addi %svl_s, %c2 : index
%A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index 2751c2d136485e..2803c284537f84 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -17,9 +17,7 @@ func.func @entry() {
%c1_i32 = arith.constant 1 : i32
// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
- %vscale = vector.vscale
- %min_elts_s = arith.constant 4 : index
- %svl_s = arith.muli %min_elts_s, %vscale : index
+ %svl_s = arm_sme.streaming_vl <words>
%za_s_size = arith.muli %svl_s, %svl_s : index
// Allocate memory.
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 5d2d0a73992f1e..044ce58c4a61f2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -24,12 +24,9 @@ func.func @za0_d_f64() -> i32 {
%c1_f64 = arith.constant 1.0 : f64
%c1_index = arith.constant 1 : index
- %min_elts_d = arith.constant 2 : index
- %vscale = vector.vscale
-
// "svl" refers to the Streaming Vector Length and "svl_d" the number of
// 64-bit elements in a vector of SVL bits.
- %svl_d = arith.muli %min_elts_d, %vscale : index
+ %svl_d = arm_sme.streaming_vl <double_words>
// Allocate "mem1" and fill each "row" with row number.
//
@@ -170,13 +167,10 @@ func.func @load_store_two_za_s_tiles() -> i32 {
%c1_index = arith.constant 1 : index
%c2_index = arith.constant 2 : index
- %min_elts_s = arith.constant 4 : index
- %vscale = vector.vscale
-
// "svl" refers to the Streaming Vector Length and "svl_s" can mean either:
// * the number of 32-bit elements in a vector of SVL bits.
// * the number of tile slices (1d vectors) in a 32-bit element tile.
- %svl_s = arith.muli %min_elts_s, %vscale : index
+ %svl_s = arm_sme.streaming_vl <words>
// Allocate memory for two 32-bit element tiles.
%size_of_tile = arith.muli %svl_s, %svl_s : index
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index af5fe236e5bd57..6bbf36e90d2d6d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -14,12 +14,9 @@ func.func @entry() -> i32 {
%c1_i8 = arith.constant 1 : i8
%c1_index = arith.constant 1 : index
- %c16 = arith.constant 16 : index
- %vscale = vector.vscale
-
// "svl" refers to the Streaming Vector Length and "svl_b" the number of
// 8-bit elements in a vector of SVL bits.
- %svl_b = arith.muli %c16, %vscale : index
+ %svl_b = arm_sme.streaming_vl <bytes>
// Allocate memory and fill with ones.
//
More information about the Mlir-commits
mailing list