[Mlir-commits] [mlir] 53d4890 - [mlir][ArmSME] Add arm_sme.streaming_vl operation (#77321)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 10 02:11:49 PST 2024
Author: Benjamin Maxwell
Date: 2024-01-10T10:11:44Z
New Revision: 53d48902bc6b05cc284f767089fe070ada651910
URL: https://github.com/llvm/llvm-project/commit/53d48902bc6b05cc284f767089fe070ada651910
DIFF: https://github.com/llvm/llvm-project/commit/53d48902bc6b05cc284f767089fe070ada651910.diff
LOG: [mlir][ArmSME] Add arm_sme.streaming_vl operation (#77321)
This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.
Example:
```mlir
%svl_w = arm_sme.streaming_vl <word>
```
Created based on discussion here:
https://github.com/llvm/llvm-project/pull/76086#discussion_r1434226352
Added:
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
mlir/test/Dialect/ArmSME/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index f7cc1d3fe7517f..bb0db59add0094 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -223,6 +223,21 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
let defaultValue = "CombiningKind::Add";
}
+def TypeSize : I32EnumAttr<"TypeSize", "Size of a vector element type", [
+ I32EnumAttrCase<"Byte" , 0, "byte">,
+ I32EnumAttrCase<"Half" , 1, "half">,
+ I32EnumAttrCase<"Word" , 2, "word">,
+ I32EnumAttrCase<"Double", 3, "double">,
+]> {
+ let cppNamespace = "::mlir::arm_sme";
+ let genSpecializedAttr = 0;
+}
+
+def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
+ "type_size"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
@@ -768,4 +783,33 @@ let arguments = (ins
}];
}
+def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
+{
+ let summary = "Query the streaming vector length";
+
+ let description = [{
+ This operation returns the streaming vector length (SVL) for a given type
+ size. Unlike `vector.vscale` the value returned is invariant to the
+ streaming mode.
+
+ Example:
+ ```mlir
+ // Streaming vector length in:
+ // - bytes (8-bit, SVL.B)
+ %svl_b = arm_sme.streaming_vl <byte>
+ // - half words (16-bit, SVL.H)
+ %svl_h = arm_sme.streaming_vl <half>
+ // - words (32-bit, SVL.W)
+ %svl_w = arm_sme.streaming_vl <word>
+ // - double words (64-bit, SVL.D)
+ %svl_d = arm_sme.streaming_vl <double>
+ ```
+ }];
+
+ let arguments = (ins ArmSME_TypeSizeAttr: $type_size);
+ let results = (outs Index);
+
+ let assemblyFormat = "$type_size attr-dict";
+}
+
#endif // ARMSME_OPS
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 0c6e2e80b88a3b..0bb7ccb463e484 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -518,6 +518,45 @@ struct OuterProductOpConversion
}
};
+/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
+///
+/// Example:
+///
+/// %0 = arm_sme.streaming_vl <half>
+///
+/// is converted to:
+///
+/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
+/// %0 = arith.index_cast %cnt : i64 to index
+///
+struct StreamingVLOpConversion
+ : public ConvertOpToLLVMPattern<arm_sme::StreamingVLOp> {
+ using ConvertOpToLLVMPattern<arm_sme::StreamingVLOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
+ arm_sme::StreamingVLOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = streamingVlOp.getLoc();
+ auto i64Type = rewriter.getI64Type();
+ auto *intrOp = [&]() -> Operation * {
+ switch (streamingVlOp.getTypeSize()) {
+ case arm_sme::TypeSize::Byte:
+ return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
+ case arm_sme::TypeSize::Half:
+ return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
+ case arm_sme::TypeSize::Word:
+ return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
+ case arm_sme::TypeSize::Double:
+ return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
+ }
+ }();
+ rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
+ streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -555,7 +594,9 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
+ arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
+ arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
+ arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
}
@@ -572,8 +613,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
- OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
- converter);
+ OuterProductOpConversion, ZeroOpConversion, GetTileConversion,
+ StreamingVLOpConversion>(converter);
}
std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index bd88da37bdf966..f9cf77ca15ffb9 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -559,3 +559,45 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index)
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_bytes
+// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
+// CHECK: return %[[INDEX_COUNT]] : index
+func.func @arm_sme_streaming_vl_bytes() -> index {
+ %svl_b = arm_sme.streaming_vl <byte>
+ return %svl_b : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_half_words
+// CHECK: "arm_sme.intr.cntsh"() : () -> i64
+func.func @arm_sme_streaming_vl_half_words() -> index {
+ %svl_h = arm_sme.streaming_vl <half>
+ return %svl_h : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_words
+// CHECK: "arm_sme.intr.cntsw"() : () -> i64
+func.func @arm_sme_streaming_vl_words() -> index {
+ %svl_w = arm_sme.streaming_vl <word>
+ return %svl_w : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_double_words
+// CHECK: "arm_sme.intr.cntsd"() : () -> i64
+func.func @arm_sme_streaming_vl_double_words() -> index {
+ %svl_d = arm_sme.streaming_vl <double>
+ return %svl_d : index
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 58ff7ef4d8340e..2ad742493408b0 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1095,3 +1095,39 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v
%result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
return %result : vector<[16]x[16]xi8>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_streaming_vl_bytes() -> index {
+ // CHECK: arm_sme.streaming_vl <byte>
+ %svl_b = arm_sme.streaming_vl <byte>
+ return %svl_b : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_half_words() -> index {
+ // CHECK: arm_sme.streaming_vl <half>
+ %svl_h = arm_sme.streaming_vl <half>
+ return %svl_h : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_words() -> index {
+ // CHECK: arm_sme.streaming_vl <word>
+ %svl_w = arm_sme.streaming_vl <word>
+ return %svl_w : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_double_words() -> index {
+ // CHECK: arm_sme.streaming_vl <double>
+ %svl_d = arm_sme.streaming_vl <double>
+ return %svl_d : index
+}
More information about the Mlir-commits
mailing list