[Mlir-commits] [mlir] 12e1a9b - [mlir][ArmSME] Extend vector.transfer_write lowering
Andrzej Warzynski
llvmlistbot at llvm.org
Fri Aug 11 05:33:21 PDT 2023
Author: Cullen Rhodes
Date: 2023-08-11T12:33:09Z
New Revision: 12e1a9b876e87e80c70fc0e72b648ce15aaaf24e
URL: https://github.com/llvm/llvm-project/commit/12e1a9b876e87e80c70fc0e72b648ce15aaaf24e
DIFF: https://github.com/llvm/llvm-project/commit/12e1a9b876e87e80c70fc0e72b648ce15aaaf24e.diff
LOG: [mlir][ArmSME] Extend vector.transfer_write lowering
Enables the lowering of other tile types and values to match the
vector.store -> arm_sme.tile_store lowering.
Reviewed By: awarzynski, dcaballe
Differential Revision: https://reviews.llvm.org/D156976
Added:
Modified:
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index e069a7601c3c68..25cecb67fbfd7f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -28,19 +28,15 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
namespace {
-/// Conversion pattern for vector.transfer_write. Currently only supports:
+/// Conversion pattern for vector.transfer_write.
///
-/// %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
-/// vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref<?x?xi8>
+/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
+/// memref<?x?xi8>
///
/// is converted to:
///
-/// %0 = arm_sme.zero : vector<[16]x[16]xi8>
-/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
-/// vector<[16]x[16]xi8>
-///
-/// The conversion from arith.constant dense<0> to arm_sme.zero is done in
-/// ConstantOpToArmSMELowering.
+/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
+/// vector<[16]x[16]xi8>
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -48,26 +44,12 @@ struct TransferWriteToArmSMELowering
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const final {
auto vType = writeOp.getVectorType();
- if (vType.getRank() != 2)
- return failure();
- if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
- return failure();
- if (vType.getElementType() != rewriter.getI8Type())
- return failure();
- if (vType.getScalableDims().size() != 2)
+ if (!arm_sme::isValidSMETileVectorType(vType))
return failure();
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
return failure();
- auto constant = writeOp.getVector().getDefiningOp<arith::ConstantOp>();
- if (!constant)
- return failure();
-
- auto denseAttr = dyn_cast<DenseElementsAttr>(constant.getValueAttr());
- if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
- return failure();
-
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
writeOp, writeOp.getVector(), writeOp.getSource(),
writeOp.getIndices());
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a6a4c3a651c10e..dc990f7879d8c2 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,15 +1,104 @@
// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
-// CHECK-LABEL: func.func @transfer_write_2d_zero(
-// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?xi8>) {
-func.func @transfer_write_2d_zero(%arg0 : memref<?x?xi8>) {
-// CHECK: %[[C_0:.*]] = arith.constant 0 : index
-// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8>
-// CHECK: arm_sme.tile_store %[[ZERO]], %[[ARG_0]][%[[C_0]], %[[C_0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
-// CHECK: return
+// CHECK-LABEL: func.func @transfer_write_2d_i8(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi8>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_write_2d_i8(%vector : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
- %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
- vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_i16(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xi16>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi16>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_write_2d_i16(%vector : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi16>, memref<?x?xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_i32(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transfer_write_2d_i32(%vector : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xi32>, memref<?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_i64(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi64>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_write_2d_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_f16(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xf16>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf16>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transfer_write_2d_f16(%vector : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_bf16(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xbf16>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_write_2d_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_f32(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_write_2d_f32(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_f64(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf64>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
return
}
@@ -74,27 +163,3 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
%0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
return %0 : tensor<?x?xi8>
}
-
-// -----
-
-// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.zero
-// CHECK-NOT: arm_sme.tile_store
-func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) {
- %c0 = arith.constant 0 : index
- %cst = arith.constant dense<1> : vector<[16]x[16]xi8>
- vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.tile_store
-func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref<?x?xi8>, %arg1 : vector<[16]x[16]xi8>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
- return
-}
More information about the Mlir-commits
mailing list