[Mlir-commits] [mlir] 12e1a9b - [mlir][ArmSME] Extend vector.transfer_write lowering

Andrzej Warzynski llvmlistbot at llvm.org
Fri Aug 11 05:33:21 PDT 2023


Author: Cullen Rhodes
Date: 2023-08-11T12:33:09Z
New Revision: 12e1a9b876e87e80c70fc0e72b648ce15aaaf24e

URL: https://github.com/llvm/llvm-project/commit/12e1a9b876e87e80c70fc0e72b648ce15aaaf24e
DIFF: https://github.com/llvm/llvm-project/commit/12e1a9b876e87e80c70fc0e72b648ce15aaaf24e.diff

LOG: [mlir][ArmSME] Extend vector.transfer_write lowering

Enables the lowering of other tile types and values to match the
vector.store -> arm_sme.tile_store lowering.

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D156976

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index e069a7601c3c68..25cecb67fbfd7f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -28,19 +28,15 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
 
 namespace {
 
-/// Conversion pattern for vector.transfer_write. Currently only supports:
+/// Conversion pattern for vector.transfer_write.
 ///
-///    %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
-///    vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref<?x?xi8>
+///   vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
+///                                                      memref<?x?xi8>
 ///
 /// is converted to:
 ///
-///    %0 = arm_sme.zero : vector<[16]x[16]xi8>
-///    arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
-///    vector<[16]x[16]xi8>
-///
-/// The conversion from arith.constant dense<0> to arm_sme.zero is done in
-/// ConstantOpToArmSMELowering.
+///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
+///                                                   vector<[16]x[16]xi8>
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -48,26 +44,12 @@ struct TransferWriteToArmSMELowering
   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                 PatternRewriter &rewriter) const final {
     auto vType = writeOp.getVectorType();
-    if (vType.getRank() != 2)
-      return failure();
-    if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
-      return failure();
-    if (vType.getElementType() != rewriter.getI8Type())
-      return failure();
-    if (vType.getScalableDims().size() != 2)
+    if (!arm_sme::isValidSMETileVectorType(vType))
       return failure();
 
     if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
       return failure();
 
-    auto constant = writeOp.getVector().getDefiningOp<arith::ConstantOp>();
-    if (!constant)
-      return failure();
-
-    auto denseAttr = dyn_cast<DenseElementsAttr>(constant.getValueAttr());
-    if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
-      return failure();
-
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
         writeOp, writeOp.getVector(), writeOp.getSource(),
         writeOp.getIndices());

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a6a4c3a651c10e..dc990f7879d8c2 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,15 +1,104 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
-// CHECK-LABEL:   func.func @transfer_write_2d_zero(
-// CHECK-SAME:      %[[ARG_0:.*]]: memref<?x?xi8>) {
-func.func @transfer_write_2d_zero(%arg0 : memref<?x?xi8>) {
-// CHECK:           %[[C_0:.*]] = arith.constant 0 : index
-// CHECK:           %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8>
-// CHECK:           arm_sme.tile_store %[[ZERO]], %[[ARG_0]][%[[C_0]], %[[C_0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
-// CHECK:           return
+// CHECK-LABEL: func.func @transfer_write_2d_i8(
+// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi8>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_write_2d_i8(%vector : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
-  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_i16(
+// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[8]x[8]xi16>,
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi16>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_write_2d_i16(%vector : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi16>, memref<?x?xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_i32(
+// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transfer_write_2d_i32(%vector : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xi32>, memref<?x?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_i64(
+// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi64>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_write_2d_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_f16(
+// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[8]x[8]xf16>,
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf16>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transfer_write_2d_f16(%vector : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_bf16(
+// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xbf16>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_write_2d_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_f32(
+// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf32>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_write_2d_f32(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_2d_f64(
+// CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf64>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
   return
 }
 
@@ -74,27 +163,3 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
   %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
   return %0 : tensor<?x?xi8>
 }
-
-// -----
-
-// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.zero
-// CHECK-NOT: arm_sme.tile_store
-func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<1> : vector<[16]x[16]xi8>
-  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.tile_store
-func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref<?x?xi8>, %arg1 : vector<[16]x[16]xi8>) {
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
-  return
-}


        


More information about the Mlir-commits mailing list