[Mlir-commits] [mlir] 781883e - [mlir][ArmSME] Split lowering of arith.constant from vector.transfer_write
Cullen Rhodes
llvmlistbot at llvm.org
Thu Aug 3 01:58:31 PDT 2023
Author: Cullen Rhodes
Date: 2023-08-03T08:57:33Z
New Revision: 781883ea624a6c0cd426af0043b0287b116466ee
URL: https://github.com/llvm/llvm-project/commit/781883ea624a6c0cd426af0043b0287b116466ee
DIFF: https://github.com/llvm/llvm-project/commit/781883ea624a6c0cd426af0043b0287b116466ee.diff
LOG: [mlir][ArmSME] Split lowering of arith.constant from vector.transfer_write
An 'arith.constant dense<0>' is currently lowered to 'arm_sme.zero' as
part of the 'vector.transfer_write' lowering during '-vector-to-arm-sme'
conversion. This patch makes this lowering independent of the
'vector.transfer_write'. This can then be extended for further tile
types and non-zero constants.
Reviewed By: awarzynski
Differential Revision: https://reviews.llvm.org/D156802
Added:
Modified:
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 4106b04877ec52..e069a7601c3c68 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -28,8 +28,7 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
namespace {
-/// Look at `vector.transfer_write` operations and convert suitable candidates
-/// to ArmSME operations, e.g.:
+/// Conversion pattern for vector.transfer_write. Currently only supports:
///
/// %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
/// vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -40,6 +39,8 @@ namespace {
/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
/// vector<[16]x[16]xi8>
///
+/// The conversion from arith.constant dense<0> to arm_sme.zero is done in
+/// ConstantOpToArmSMELowering.
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -56,8 +57,6 @@ struct TransferWriteToArmSMELowering
if (vType.getScalableDims().size() != 2)
return failure();
- auto loc = writeOp.getLoc();
-
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
return failure();
@@ -69,10 +68,9 @@ struct TransferWriteToArmSMELowering
if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
return failure();
- auto zero = rewriter.create<arm_sme::ZeroOp>(loc, vType);
-
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
- writeOp, zero, writeOp.getSource(), writeOp.getIndices());
+ writeOp, writeOp.getVector(), writeOp.getSource(),
+ writeOp.getIndices());
return success();
}
};
@@ -109,10 +107,38 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
}
};
+/// Conversion pattern for dense arith.constant.
+struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
+ using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
+ PatternRewriter &rewriter) const final {
+ auto vType = dyn_cast<VectorType>(constantOp.getType());
+ if (!vType)
+ return failure();
+ if (vType.getRank() != 2)
+ return failure();
+ if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
+ return failure();
+ if (vType.getElementType() != rewriter.getI8Type())
+ return failure();
+ if (vType.getScalableDims().size() != 2)
+ return failure();
+
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+ if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, vType);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
- VectorStoreToArmSMELowering>(&ctx);
+ VectorStoreToArmSMELowering, ConstantOpToArmSMELowering>(&ctx);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index f3440e4fc61bf5..a6a4c3a651c10e 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file | mlir-opt | FileCheck %s
-
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: func.func @transfer_write_2d_zero(
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?xi8>) {
@@ -16,6 +15,16 @@ func.func @transfer_write_2d_zero(%arg0 : memref<?x?xi8>) {
// -----
+// CHECK-LABEL: @arith_constant_dense_2d_zero_i8
+// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8>
+func.func @arith_constant_dense_2d_zero_i8() {
+ %zero = arith.constant dense<0> : vector<[16]x[16]xi8>
+ "prevent.dce"(%zero) : (vector<[16]x[16]xi8>) -> ()
+ return
+}
+
+// -----
+
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
// lowering only occurs for vector types of correct rank, shape, element size
// and number of scalable dims.
@@ -70,6 +79,7 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value
// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.zero
// CHECK-NOT: arm_sme.tile_store
func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list