[Mlir-commits] [mlir] 781883e - [mlir][ArmSME] Split lowering of arith.constant from vector.transfer_write

Cullen Rhodes llvmlistbot at llvm.org
Thu Aug 3 01:58:31 PDT 2023


Author: Cullen Rhodes
Date: 2023-08-03T08:57:33Z
New Revision: 781883ea624a6c0cd426af0043b0287b116466ee

URL: https://github.com/llvm/llvm-project/commit/781883ea624a6c0cd426af0043b0287b116466ee
DIFF: https://github.com/llvm/llvm-project/commit/781883ea624a6c0cd426af0043b0287b116466ee.diff

LOG: [mlir][ArmSME] Split lowering of arith.constant from vector.transfer_write

An 'arith.constant dense<0>' is currently lowered to 'arm_sme.zero' as
part of the 'vector.transfer_write' lowering during '-vector-to-arm-sme'
conversion. This patch makes this lowering independent of the
'vector.transfer_write'. This can then be extended for further tile
types and non-zero constants.

Reviewed By: awarzynski

Differential Revision: https://reviews.llvm.org/D156802

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 4106b04877ec52..e069a7601c3c68 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -28,8 +28,7 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
 
 namespace {
 
-/// Look at `vector.transfer_write` operations and convert suitable candidates
-/// to ArmSME operations, e.g.:
+/// Conversion pattern for vector.transfer_write. Currently only supports:
 ///
 ///    %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
 ///    vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -40,6 +39,8 @@ namespace {
 ///    arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
 ///    vector<[16]x[16]xi8>
 ///
+/// The conversion from arith.constant dense<0> to arm_sme.zero is done in
+/// ConstantOpToArmSMELowering.
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -56,8 +57,6 @@ struct TransferWriteToArmSMELowering
     if (vType.getScalableDims().size() != 2)
       return failure();
 
-    auto loc = writeOp.getLoc();
-
     if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
       return failure();
 
@@ -69,10 +68,9 @@ struct TransferWriteToArmSMELowering
     if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
       return failure();
 
-    auto zero = rewriter.create<arm_sme::ZeroOp>(loc, vType);
-
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        writeOp, zero, writeOp.getSource(), writeOp.getIndices());
+        writeOp, writeOp.getVector(), writeOp.getSource(),
+        writeOp.getIndices());
     return success();
   }
 };
@@ -109,10 +107,38 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
   }
 };
 
+/// Conversion pattern for dense arith.constant.
+struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
+  using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
+                                PatternRewriter &rewriter) const final {
+    auto vType = dyn_cast<VectorType>(constantOp.getType());
+    if (!vType)
+      return failure();
+    if (vType.getRank() != 2)
+      return failure();
+    if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
+      return failure();
+    if (vType.getElementType() != rewriter.getI8Type())
+      return failure();
+    if (vType.getScalableDims().size() != 2)
+      return failure();
+
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+    if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, vType);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
   patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
-               VectorStoreToArmSMELowering>(&ctx);
+               VectorStoreToArmSMELowering, ConstantOpToArmSMELowering>(&ctx);
 }

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index f3440e4fc61bf5..a6a4c3a651c10e 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file | mlir-opt | FileCheck %s
-
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // CHECK-LABEL:   func.func @transfer_write_2d_zero(
 // CHECK-SAME:      %[[ARG_0:.*]]: memref<?x?xi8>) {
@@ -16,6 +15,16 @@ func.func @transfer_write_2d_zero(%arg0 : memref<?x?xi8>) {
 
 // -----
 
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i8
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8>
+func.func @arith_constant_dense_2d_zero_i8() {
+  %zero = arith.constant dense<0> : vector<[16]x[16]xi8>
+  "prevent.dce"(%zero) : (vector<[16]x[16]xi8>) -> ()
+  return
+}
+
+// -----
+
 // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
 // lowering only occurs for vector types of correct rank, shape, element size
 // and number of scalable dims.
@@ -70,6 +79,7 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
 
 // CHECK-LABEL: @transfer_write_2d_zero__non_zero_value
 // CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.zero
 // CHECK-NOT: arm_sme.tile_store
 func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list