[Mlir-commits] [mlir] [mlir][ArmSME] Add arm_sme.streaming_vl operation (PR #77321)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Jan 9 05:51:04 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/77321

>From 58a5ca66d8f6a411315c5adb4cdc492febe72d6b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 8 Jan 2024 14:12:25 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add `arm_sme.streaming_vl` operation

This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.

Example:
```mlir
%svl_w = arm_sme.streaming_vl <words>
```
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 43 +++++++++++++++++
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 47 +++++++++++++++++--
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         | 42 +++++++++++++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 36 ++++++++++++++
 4 files changed, 165 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index f7cc1d3fe7517f..4060407d81c0fa 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -223,6 +223,21 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
   let defaultValue = "CombiningKind::Add";
 }
 
+def TypeSize : I32EnumAttr<"TypeSize", "Size of vector type", [
+  I32EnumAttrCase<"Bytes"      , 0, "bytes">,
+  I32EnumAttrCase<"HalfWords"  , 1, "half_words">,
+  I32EnumAttrCase<"Words"      , 2, "words">,
+  I32EnumAttrCase<"DoubleWords", 3, "double_words">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
+                                          "type_size"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -768,4 +783,32 @@ let arguments = (ins
   }];
 }
 
+def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
+{
+  let summary = "Query the streaming vector length";
+
+  let description = [{
+    This operation returns the streaming vector length for a type size. Unlike
+    `vector.vscale` the value returned is invariant to the streaming mode.
+
+    Example:
+    ```mlir
+    // Streaming vector length in:
+    // - bytes (8-bit)
+    %svl_b = arm_sme.streaming_vl <bytes>
+    // - half words (16-bit)
+    %svl_h = arm_sme.streaming_vl <half_words>
+    // - words (32-bit)
+    %svl_w = arm_sme.streaming_vl <words>
+    // - double words (64-bit)
+    %svl_d = arm_sme.streaming_vl <double_words>
+    ```
+  }];
+
+  let arguments = (ins ArmSME_TypeSizeAttr: $type_size);
+  let results = (outs Index);
+
+  let assemblyFormat = "$type_size attr-dict";
+}
+
 #endif // ARMSME_OPS
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 0c6e2e80b88a3b..c5fdba6d00cc0f 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -518,6 +518,45 @@ struct OuterProductOpConversion
   }
 };
 
+/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
+///
+/// Example:
+///
+///   %0 = arm_sme.streaming_vl <half_words>
+///
+/// is converted to:
+///
+///   %cnt = "arm_sme.intr.cntsh"() : () -> i64
+///   %0 = arith.index_cast %cnt : i64 to index
+///
+struct StreamingVLOpConversion
+    : public ConvertOpToLLVMPattern<arm_sme::StreamingVLOp> {
+  using ConvertOpToLLVMPattern<arm_sme::StreamingVLOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
+                  arm_sme::StreamingVLOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = streamingVlOp.getLoc();
+    auto i64Type = rewriter.getI64Type();
+    auto *intrOp = [&]() -> Operation * {
+      switch (streamingVlOp.getTypeSize()) {
+      case arm_sme::TypeSize::Bytes:
+        return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
+      case arm_sme::TypeSize::HalfWords:
+        return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
+      case arm_sme::TypeSize::Words:
+        return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
+      case arm_sme::TypeSize::DoubleWords:
+        return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
+      }
+    }();
+    rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
+        streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -555,7 +594,9 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
       arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
+      arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
+      arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
 }
@@ -572,8 +613,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
 
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
-               OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
-      converter);
+               OuterProductOpConversion, ZeroOpConversion, GetTileConversion,
+               StreamingVLOpConversion>(converter);
 }
 
 std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index bd88da37bdf966..bab4fd60518c54 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -559,3 +559,45 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index)
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_bytes
+// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
+// CHECK: return %[[INDEX_COUNT]] : index
+func.func @arm_sme_streaming_vl_bytes() -> index {
+  %svl_b = arm_sme.streaming_vl <bytes>
+  return %svl_b : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_half_words
+// CHECK: "arm_sme.intr.cntsh"() : () -> i64
+func.func @arm_sme_streaming_vl_half_words() -> index {
+  %svl_h = arm_sme.streaming_vl <half_words>
+  return %svl_h : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_words
+// CHECK: "arm_sme.intr.cntsw"() : () -> i64
+func.func @arm_sme_streaming_vl_words() -> index {
+  %svl_w = arm_sme.streaming_vl <words>
+  return %svl_w : index
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_streaming_vl_double_words
+// CHECK: "arm_sme.intr.cntsd"() : () -> i64
+func.func @arm_sme_streaming_vl_double_words() -> index {
+  %svl_d = arm_sme.streaming_vl <double_words>
+  return %svl_d : index
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 58ff7ef4d8340e..eb4c7149b61f1d 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1095,3 +1095,39 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v
   %result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
   return %result : vector<[16]x[16]xi8>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.streaming_vl
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_streaming_vl_bytes() -> index {
+  // CHECK: arm_sme.streaming_vl <bytes>
+  %svl_b = arm_sme.streaming_vl <bytes>
+  return %svl_b : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_half_words() -> index {
+  // CHECK: arm_sme.streaming_vl <half_words>
+  %svl_h = arm_sme.streaming_vl <half_words>
+  return %svl_h : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_words() -> index {
+  // CHECK: arm_sme.streaming_vl <words>
+  %svl_w = arm_sme.streaming_vl <words>
+  return %svl_w : index
+}
+
+// -----
+
+func.func @arm_sme_streaming_vl_double_words() -> index {
+  // CHECK: arm_sme.streaming_vl <double_words>
+  %svl_d = arm_sme.streaming_vl <double_words>
+  return %svl_d : index
+}

>From 044e4745c77c9b34af7a5f1ab3b547e77541140d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 9 Jan 2024 13:48:59 +0000
Subject: [PATCH 2/2] Fixups

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 25 ++++++++++---------
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 10 ++++----
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         |  8 +++---
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 16 ++++++------
 4 files changed, 30 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 4060407d81c0fa..2d6399ebc0c929 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -223,18 +223,18 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
   let defaultValue = "CombiningKind::Add";
 }
 
-def TypeSize : I32EnumAttr<"TypeSize", "Size of vector type", [
-  I32EnumAttrCase<"Bytes"      , 0, "bytes">,
-  I32EnumAttrCase<"HalfWords"  , 1, "half_words">,
-  I32EnumAttrCase<"Words"      , 2, "words">,
-  I32EnumAttrCase<"DoubleWords", 3, "double_words">,
+def TypeSize : I32EnumAttr<"TypeSize", "Size of a vector element type", [
+  I32EnumAttrCase<"Byte"  , 0, "byte">,
+  I32EnumAttrCase<"Half"  , 1, "half">,
+  I32EnumAttrCase<"Word"  , 2, "word">,
+  I32EnumAttrCase<"Double", 3, "double">,
 ]> {
   let cppNamespace = "::mlir::arm_sme";
   let genSpecializedAttr = 0;
 }
 
 def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
-                                          "type_size"> {
+                                   "type_size"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
@@ -788,20 +788,21 @@ def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
   let summary = "Query the streaming vector length";
 
   let description = [{
-    This operation returns the streaming vector length for a type size. Unlike
-    `vector.vscale` the value returned is invariant to the streaming mode.
+    This operation returns the streaming vector length for a given type size.
+    Unlike `vector.vscale` the value returned is invariant to the streaming
+    mode.
 
     Example:
     ```mlir
     // Streaming vector length in:
     // - bytes (8-bit)
-    %svl_b = arm_sme.streaming_vl <bytes>
+    %svl_b = arm_sme.streaming_vl <byte>
     // - half words (16-bit)
-    %svl_h = arm_sme.streaming_vl <half_words>
+    %svl_h = arm_sme.streaming_vl <half>
     // - words (32-bit)
-    %svl_w = arm_sme.streaming_vl <words>
+    %svl_w = arm_sme.streaming_vl <word>
     // - double words (64-bit)
-    %svl_d = arm_sme.streaming_vl <double_words>
+    %svl_d = arm_sme.streaming_vl <double>
     ```
   }];
 
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index c5fdba6d00cc0f..0bb7ccb463e484 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -522,7 +522,7 @@ struct OuterProductOpConversion
 ///
 /// Example:
 ///
-///   %0 = arm_sme.streaming_vl <half_words>
+///   %0 = arm_sme.streaming_vl <half>
 ///
 /// is converted to:
 ///
@@ -541,13 +541,13 @@ struct StreamingVLOpConversion
     auto i64Type = rewriter.getI64Type();
     auto *intrOp = [&]() -> Operation * {
       switch (streamingVlOp.getTypeSize()) {
-      case arm_sme::TypeSize::Bytes:
+      case arm_sme::TypeSize::Byte:
         return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
-      case arm_sme::TypeSize::HalfWords:
+      case arm_sme::TypeSize::Half:
         return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
-      case arm_sme::TypeSize::Words:
+      case arm_sme::TypeSize::Word:
         return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
-      case arm_sme::TypeSize::DoubleWords:
+      case arm_sme::TypeSize::Double:
         return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
       }
     }();
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index bab4fd60518c54..f9cf77ca15ffb9 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -571,7 +571,7 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index)
 // CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
 // CHECK: return %[[INDEX_COUNT]] : index
 func.func @arm_sme_streaming_vl_bytes() -> index {
-  %svl_b = arm_sme.streaming_vl <bytes>
+  %svl_b = arm_sme.streaming_vl <byte>
   return %svl_b : index
 }
 
@@ -580,7 +580,7 @@ func.func @arm_sme_streaming_vl_bytes() -> index {
 // CHECK-LABEL: @arm_sme_streaming_vl_half_words
 // CHECK: "arm_sme.intr.cntsh"() : () -> i64
 func.func @arm_sme_streaming_vl_half_words() -> index {
-  %svl_h = arm_sme.streaming_vl <half_words>
+  %svl_h = arm_sme.streaming_vl <half>
   return %svl_h : index
 }
 
@@ -589,7 +589,7 @@ func.func @arm_sme_streaming_vl_half_words() -> index {
 // CHECK-LABEL: @arm_sme_streaming_vl_words
 // CHECK: "arm_sme.intr.cntsw"() : () -> i64
 func.func @arm_sme_streaming_vl_words() -> index {
-  %svl_w = arm_sme.streaming_vl <words>
+  %svl_w = arm_sme.streaming_vl <word>
   return %svl_w : index
 }
 
@@ -598,6 +598,6 @@ func.func @arm_sme_streaming_vl_words() -> index {
 // CHECK-LABEL: @arm_sme_streaming_vl_double_words
 // CHECK: "arm_sme.intr.cntsd"() : () -> i64
 func.func @arm_sme_streaming_vl_double_words() -> index {
-  %svl_d = arm_sme.streaming_vl <double_words>
+  %svl_d = arm_sme.streaming_vl <double>
   return %svl_d : index
 }
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index eb4c7149b61f1d..2ad742493408b0 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1103,31 +1103,31 @@ func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: v
 // -----
 
 func.func @arm_sme_streaming_vl_bytes() -> index {
-  // CHECK: arm_sme.streaming_vl <bytes>
-  %svl_b = arm_sme.streaming_vl <bytes>
+  // CHECK: arm_sme.streaming_vl <byte>
+  %svl_b = arm_sme.streaming_vl <byte>
   return %svl_b : index
 }
 
 // -----
 
 func.func @arm_sme_streaming_vl_half_words() -> index {
-  // CHECK: arm_sme.streaming_vl <half_words>
-  %svl_h = arm_sme.streaming_vl <half_words>
+  // CHECK: arm_sme.streaming_vl <half>
+  %svl_h = arm_sme.streaming_vl <half>
   return %svl_h : index
 }
 
 // -----
 
 func.func @arm_sme_streaming_vl_words() -> index {
-  // CHECK: arm_sme.streaming_vl <words>
-  %svl_w = arm_sme.streaming_vl <words>
+  // CHECK: arm_sme.streaming_vl <word>
+  %svl_w = arm_sme.streaming_vl <word>
   return %svl_w : index
 }
 
 // -----
 
 func.func @arm_sme_streaming_vl_double_words() -> index {
-  // CHECK: arm_sme.streaming_vl <double_words>
-  %svl_d = arm_sme.streaming_vl <double_words>
+  // CHECK: arm_sme.streaming_vl <double>
+  %svl_d = arm_sme.streaming_vl <double>
   return %svl_d : index
 }



More information about the Mlir-commits mailing list