[Mlir-commits] [mlir] dfa10ec - [mlir][ArmSME] Extend arm_sme.zero for all types
Andrzej Warzynski
llvmlistbot at llvm.org
Fri Aug 11 05:46:11 PDT 2023
Author: Cullen Rhodes
Date: 2023-08-11T12:44:56Z
New Revision: dfa10ec2e6006ce5366a1e9013f0dbf4344f3e60
URL: https://github.com/llvm/llvm-project/commit/dfa10ec2e6006ce5366a1e9013f0dbf4344f3e60
DIFF: https://github.com/llvm/llvm-project/commit/dfa10ec2e6006ce5366a1e9013f0dbf4344f3e60.diff
LOG: [mlir][ArmSME] Extend arm_sme.zero for all types
The arm_sme.zero op currently only supports 8-bit element tiles. This
extends the op and lowering from 'arith.constant dense<0>' ->
'arm_sme.zero' to support all tile types.
The lowering from arm_sme.zero to intrinsics is not updated as part of
this patch and will be done separately.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D156980
Added:
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/test/Dialect/ArmSME/roundtrip.mlir
mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 11b96f20acdfae..95c5f899bdb52d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -203,19 +203,22 @@ def GetTileID : ArmSME_Op<"get_tile_id"> {
def ZeroOp : ArmSME_Op<"zero", [Pure]> {
let summary = "Initialize the two-dimensional ZA array with 0s";
- let results = (outs nxnxv16i8:$res);
+ let results = (outs SMETile:$res);
let description = [{
Initialise ZA with 0. This operation is convenient wrapper for the SME
`zero` intrinsic and instruction.
- NOTE: At the moment it is assumed that the element type is `i8` and that
- there's only one "virtual tile".
-
- Example:
+ Example 1: Zero an 8-bit element ZA tile.
```mlir
%0 = arm_sme.zero : vector<[16]x[16]xi8>
```
+
+ Example 2: Zero a 64-bit element ZA tile.
+
+ ```mlir
+ %0 = arm_sme.zero : vector<[2]x[2]xi64>
+ ```
}];
let extraClassDeclaration = [{
VectorType getVectorType() {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 25cecb67fbfd7f..f9ce4ba94f03eb 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -15,8 +15,6 @@
using namespace mlir;
-static constexpr unsigned kMinNumElts = 16;
-
/// Returns true if 'val' is a splat of zero, false otherwise.
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
@@ -96,15 +94,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
PatternRewriter &rewriter) const final {
auto vType = dyn_cast<VectorType>(constantOp.getType());
- if (!vType)
- return failure();
- if (vType.getRank() != 2)
- return failure();
- if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
- return failure();
- if (vType.getElementType() != rewriter.getI8Type())
- return failure();
- if (vType.getScalableDims().size() != 2)
+ if (!vType || !arm_sme::isValidSMETileVectorType(vType))
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 022ae272c4a35a..93c4eba0531786 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -184,7 +184,7 @@ func.func @arm_sme_get_tile_id_i128() -> i128 {
// -----
-func.func @arm_sme_zero() {
+func.func @arm_sme_zero_i8() {
// CHECK: arm_sme.zero : vector<[16]x[16]xi8>
%0 = arm_sme.zero : vector<[16]x[16]xi8>
return
@@ -192,6 +192,70 @@ func.func @arm_sme_zero() {
// -----
+func.func @arm_sme_zero_i16() {
+ // CHECK: arm_sme.zero : vector<[8]x[8]xi16>
+ %0 = arm_sme.zero : vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_zero_i32() {
+ // CHECK: arm_sme.zero : vector<[4]x[4]xi32>
+ %0 = arm_sme.zero : vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_zero_i64() {
+ // CHECK: arm_sme.zero : vector<[2]x[2]xi64>
+ %0 = arm_sme.zero : vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_zero_i128() {
+ // CHECK: arm_sme.zero : vector<[1]x[1]xi128>
+ %0 = arm_sme.zero : vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_zero_f16() {
+ // CHECK: arm_sme.zero : vector<[8]x[8]xf16>
+ %0 = arm_sme.zero : vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_zero_bf16() {
+ // CHECK: arm_sme.zero : vector<[8]x[8]xbf16>
+ %0 = arm_sme.zero : vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_zero_f32() {
+ // CHECK: arm_sme.zero : vector<[4]x[4]xf32>
+ %0 = arm_sme.zero : vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_zero_f64() {
+ // CHECK: arm_sme.zero : vector<[2]x[2]xf64>
+ %0 = arm_sme.zero : vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
func.func @arm_sme_tile_load_i8(%src : memref<?x?xi8>) {
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index dc990f7879d8c2..87ed185947adc6 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -104,16 +104,6 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?
// -----
-// CHECK-LABEL: @arith_constant_dense_2d_zero_i8
-// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8>
-func.func @arith_constant_dense_2d_zero_i8() {
- %zero = arith.constant dense<0> : vector<[16]x[16]xi8>
- "prevent.dce"(%zero) : (vector<[16]x[16]xi8>) -> ()
- return
-}
-
-// -----
-
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
// lowering only occurs for vector types of correct rank, shape, element size
// and number of scalable dims.
@@ -163,3 +153,87 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
%0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
return %0 : tensor<?x?xi8>
}
+
+// =============================================================================
+// arith.constant dense<0> to arm_sme.zero
+// =============================================================================
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i8
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8>
+func.func @arith_constant_dense_2d_zero_i8() {
+ %zero = arith.constant dense<0> : vector<[16]x[16]xi8>
+ "prevent.dce"(%zero) : (vector<[16]x[16]xi8>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i16
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[8]x[8]xi16>
+func.func @arith_constant_dense_2d_zero_i16() {
+ %zero = arith.constant dense<0> : vector<[8]x[8]xi16>
+ "prevent.dce"(%zero) : (vector<[8]x[8]xi16>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i32
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+func.func @arith_constant_dense_2d_zero_i32() {
+ %zero = arith.constant dense<0> : vector<[4]x[4]xi32>
+ "prevent.dce"(%zero) : (vector<[4]x[4]xi32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i64
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[2]x[2]xi64>
+func.func @arith_constant_dense_2d_zero_i64() {
+ %zero = arith.constant dense<0> : vector<[2]x[2]xi64>
+ "prevent.dce"(%zero) : (vector<[2]x[2]xi64>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_f16
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[8]x[8]xf16>
+func.func @arith_constant_dense_2d_zero_f16() {
+ %zero = arith.constant dense<0.0> : vector<[8]x[8]xf16>
+ "prevent.dce"(%zero) : (vector<[8]x[8]xf16>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_bf16
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[8]x[8]xbf16>
+func.func @arith_constant_dense_2d_zero_bf16() {
+ %zero = arith.constant dense<0.0> : vector<[8]x[8]xbf16>
+ "prevent.dce"(%zero) : (vector<[8]x[8]xbf16>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_f32
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xf32>
+func.func @arith_constant_dense_2d_zero_f32() {
+ %zero = arith.constant dense<0.0> : vector<[4]x[4]xf32>
+ "prevent.dce"(%zero) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arith_constant_dense_2d_zero_f64
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[2]x[2]xf64>
+func.func @arith_constant_dense_2d_zero_f64() {
+ %zero = arith.constant dense<0.0> : vector<[2]x[2]xf64>
+ "prevent.dce"(%zero) : (vector<[2]x[2]xf64>) -> ()
+ return
+}
More information about the Mlir-commits
mailing list