[Mlir-commits] [mlir] dfa10ec - [mlir][ArmSME] Extend arm_sme.zero for all types

Andrzej Warzynski llvmlistbot at llvm.org
Fri Aug 11 05:46:11 PDT 2023


Author: Cullen Rhodes
Date: 2023-08-11T12:44:56Z
New Revision: dfa10ec2e6006ce5366a1e9013f0dbf4344f3e60

URL: https://github.com/llvm/llvm-project/commit/dfa10ec2e6006ce5366a1e9013f0dbf4344f3e60
DIFF: https://github.com/llvm/llvm-project/commit/dfa10ec2e6006ce5366a1e9013f0dbf4344f3e60.diff

LOG: [mlir][ArmSME] Extend arm_sme.zero for all types

The arm_sme.zero op currently only supports 8-bit element tiles. This
extends the op and lowering from 'arith.constant dense<0>' ->
'arm_sme.zero' to support all tile types.

The lowering from arm_sme.zero to intrinsics is not updated as part of
this patch and will be done separately.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D156980

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Dialect/ArmSME/roundtrip.mlir
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 11b96f20acdfae..95c5f899bdb52d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -203,19 +203,22 @@ def GetTileID : ArmSME_Op<"get_tile_id"> {
 
 def ZeroOp : ArmSME_Op<"zero", [Pure]> {
   let summary = "Initialize the two-dimensional ZA array with 0s";
-  let results = (outs nxnxv16i8:$res);
+  let results = (outs SMETile:$res);
   let description = [{
     Initialise ZA with 0. This operation is convenient wrapper for the SME
     `zero` intrinsic and instruction. 
 
-    NOTE: At the moment it is assumed that the element type is `i8` and that
-    there's only one "virtual tile".
-
-    Example:
+    Example 1: Zero an 8-bit element ZA tile.
 
     ```mlir
     %0 = arm_sme.zero : vector<[16]x[16]xi8>
     ```
+
+    Example 2: Zero a 64-bit element ZA tile.
+
+    ```mlir
+    %0 = arm_sme.zero : vector<[2]x[2]xi64>
+    ```
   }];
   let extraClassDeclaration = [{
     VectorType getVectorType() {

diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 25cecb67fbfd7f..f9ce4ba94f03eb 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -15,8 +15,6 @@
 
 using namespace mlir;
 
-static constexpr unsigned kMinNumElts = 16;
-
 /// Returns true if 'val' is a splat of zero, false otherwise.
 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
   if (llvm::isa<FloatType>(elemType))
@@ -96,15 +94,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
   LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
                                 PatternRewriter &rewriter) const final {
     auto vType = dyn_cast<VectorType>(constantOp.getType());
-    if (!vType)
-      return failure();
-    if (vType.getRank() != 2)
-      return failure();
-    if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
-      return failure();
-    if (vType.getElementType() != rewriter.getI8Type())
-      return failure();
-    if (vType.getScalableDims().size() != 2)
+    if (!vType || !arm_sme::isValidSMETileVectorType(vType))
       return failure();
 
     auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());

diff  --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 022ae272c4a35a..93c4eba0531786 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -184,7 +184,7 @@ func.func @arm_sme_get_tile_id_i128() -> i128 {
 
 // -----
 
-func.func @arm_sme_zero() {
+func.func @arm_sme_zero_i8() {
   // CHECK: arm_sme.zero : vector<[16]x[16]xi8>
   %0 = arm_sme.zero : vector<[16]x[16]xi8>
   return
@@ -192,6 +192,70 @@ func.func @arm_sme_zero() {
 
 // -----
 
+func.func @arm_sme_zero_i16() {
+  // CHECK: arm_sme.zero : vector<[8]x[8]xi16>
+  %0 = arm_sme.zero : vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_zero_i32() {
+  // CHECK: arm_sme.zero : vector<[4]x[4]xi32>
+  %0 = arm_sme.zero : vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_zero_i64() {
+  // CHECK: arm_sme.zero : vector<[2]x[2]xi64>
+  %0 = arm_sme.zero : vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_zero_i128() {
+  // CHECK: arm_sme.zero : vector<[1]x[1]xi128>
+  %0 = arm_sme.zero : vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+func.func @arm_sme_zero_f16() {
+  // CHECK: arm_sme.zero : vector<[8]x[8]xf16>
+  %0 = arm_sme.zero : vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_zero_bf16() {
+  // CHECK: arm_sme.zero : vector<[8]x[8]xbf16>
+  %0 = arm_sme.zero : vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_zero_f32() {
+  // CHECK: arm_sme.zero : vector<[4]x[4]xf32>
+  %0 = arm_sme.zero : vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_zero_f64() {
+  // CHECK: arm_sme.zero : vector<[2]x[2]xf64>
+  %0 = arm_sme.zero : vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
 func.func @arm_sme_tile_load_i8(%src : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index dc990f7879d8c2..87ed185947adc6 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -104,16 +104,6 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?
 
 // -----
 
-// CHECK-LABEL: @arith_constant_dense_2d_zero_i8
-// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8>
-func.func @arith_constant_dense_2d_zero_i8() {
-  %zero = arith.constant dense<0> : vector<[16]x[16]xi8>
-  "prevent.dce"(%zero) : (vector<[16]x[16]xi8>) -> ()
-  return
-}
-
-// -----
-
 // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
 // lowering only occurs for vector types of correct rank, shape, element size
 // and number of scalable dims.
@@ -163,3 +153,87 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
   %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
   return %0 : tensor<?x?xi8>
 }
+
+// =============================================================================
+// arith.constant dense<0> to arm_sme.zero
+// =============================================================================
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i8
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8>
+func.func @arith_constant_dense_2d_zero_i8() {
+  %zero = arith.constant dense<0> : vector<[16]x[16]xi8>
+  "prevent.dce"(%zero) : (vector<[16]x[16]xi8>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i16
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[8]x[8]xi16>
+func.func @arith_constant_dense_2d_zero_i16() {
+  %zero = arith.constant dense<0> : vector<[8]x[8]xi16>
+  "prevent.dce"(%zero) : (vector<[8]x[8]xi16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i32
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+func.func @arith_constant_dense_2d_zero_i32() {
+  %zero = arith.constant dense<0> : vector<[4]x[4]xi32>
+  "prevent.dce"(%zero) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i64
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[2]x[2]xi64>
+func.func @arith_constant_dense_2d_zero_i64() {
+  %zero = arith.constant dense<0> : vector<[2]x[2]xi64>
+  "prevent.dce"(%zero) : (vector<[2]x[2]xi64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_f16
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[8]x[8]xf16>
+func.func @arith_constant_dense_2d_zero_f16() {
+  %zero = arith.constant dense<0.0> : vector<[8]x[8]xf16>
+  "prevent.dce"(%zero) : (vector<[8]x[8]xf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_bf16
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[8]x[8]xbf16>
+func.func @arith_constant_dense_2d_zero_bf16() {
+  %zero = arith.constant dense<0.0> : vector<[8]x[8]xbf16>
+  "prevent.dce"(%zero) : (vector<[8]x[8]xbf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_f32
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xf32>
+func.func @arith_constant_dense_2d_zero_f32() {
+  %zero = arith.constant dense<0.0> : vector<[4]x[4]xf32>
+  "prevent.dce"(%zero) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_f64
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[2]x[2]xf64>
+func.func @arith_constant_dense_2d_zero_f64() {
+  %zero = arith.constant dense<0.0> : vector<[2]x[2]xf64>
+  "prevent.dce"(%zero) : (vector<[2]x[2]xf64>) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list