[Mlir-commits] [mlir] 3b4b6cb - [mlir][ArmSME] Add move vector to tile slice op and lowerings
Cullen Rhodes
llvmlistbot at llvm.org
Tue Aug 29 02:38:21 PDT 2023
Author: Cullen Rhodes
Date: 2023-08-29T09:29:22Z
New Revision: 3b4b6cbba5e04fc6bbbba4a034176aa4039a73a6
URL: https://github.com/llvm/llvm-project/commit/3b4b6cbba5e04fc6bbbba4a034176aa4039a73a6
DIFF: https://github.com/llvm/llvm-project/commit/3b4b6cbba5e04fc6bbbba4a034176aa4039a73a6.diff
LOG: [mlir][ArmSME] Add move vector to tile slice op and lowerings
This adds a 'move_vector_to_tile_slice' op to the ArmSME dialect that
moves a 1-D scalable vector to a slice of a 2-D tile at a given index.
This is lowered to the 'llvm.aarch64.sme.write.horiz' intrinsic that
maps to the MOVA (vector to tile, single) SME instruction [1] when
lowering to LLVM. Like the SME load and store instructions this operates
on ZA tile slices, which are 1D vectors of horizontally or vertically
contiguous elements within a ZA tile.
This patch extends the lowering of 'arith.constant' to SME to support
non-zero constants using this new op. This requires materializing a
loop that broadcasts the constant to each tile slice with the
'vector_to_tile_slice' op. Unlike load and store, this is done during
conversion from Vector to ArmSME, rather than ArmSME to SCF. The latter
would require a higher-level custom op in the ArmSME dialect like
'tile_load' and 'tile_store' and this isn't necessary. We may also
remove the load and store ops in the future in favour of lowering
straight from Vector, at which point this would converge.
Currently only horizontal tile slices are supported. A future patch will
extend this mechanism to support 'vector.broadcast'.
Depends on D156980 D157004
[1] https://developer.arm.com/documentation/ddi0602
Reviewed By: awarzynski, dcaballe
Differential Revision: https://reviews.llvm.org/D157005
Added:
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
mlir/test/Dialect/ArmSME/invalid.mlir
mlir/test/Dialect/ArmSME/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index b083baf03fa96a..7f02e723f3d91c 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -414,6 +414,51 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];
}
+def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
+ AllTypesMatch<["tile", "result"]>,
+ TypesMatchWith<
+ "type of 'vector' matches type of 'tile' slice",
+ "tile", "vector",
+ "VectorType::get("
+ "::llvm::cast<mlir::VectorType>($_self).getShape().drop_front(),"
+ "::llvm::cast<mlir::VectorType>($_self).getElementType(),"
+ "/*scalableDims=*/{true})">,
+]> {
+ let summary = "Move 1-D scalable vector to slice of 2-D tile";
+ let description = [{
+ The vector to tile slice operation moves a 1-D scalable vector to a slice
+ of a 2-D scalable vector tile at the given index. The type of the 1-D
+ scalable vector to be moved must match the type of the tile slice. A tile
+ slice is a 1-D vector of horizontally or vertically contiguous elements
+ within a ZA tile. Horizontal tile slices are currently assumed when
+ lowering to intrinsics. The updated tile is returned as the result.
+
+ Example 1: Move a vector<[16]xi8> into tile at given index.
+ ```mlir
+ %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
+ ```
+
+ Example 2: Move a vector<[2]xf64> into tile at given index.
+ ```mlir
+ %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
+ ```
+ }];
+ let arguments = (ins
+ SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
+ let results = (outs SMETile:$result);
+
+ let extraClassDeclaration = [{
+ VectorType getTileType() {
+ return ::llvm::cast<VectorType>(getTile().getType());
+ }
+ }];
+
+ let assemblyFormat = [{
+ $vector `,` $tile `,` $tile_slice_index
+ attr-dict `:` type($vector) `into` type($result)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ArmSME Intrinsic op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index f9ce4ba94f03eb..aedb0bbde4858a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -93,15 +93,67 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
PatternRewriter &rewriter) const final {
- auto vType = dyn_cast<VectorType>(constantOp.getType());
- if (!vType || !arm_sme::isValidSMETileVectorType(vType))
+ auto tileType = dyn_cast<VectorType>(constantOp.getType());
+ if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
- if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
+ if (!denseAttr || !denseAttr.isSplat())
return failure();
- rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, vType);
+ auto tileElementType = tileType.getElementType();
+
+ // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
+ if (isSplatZero(tileElementType, denseAttr)) {
+ rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
+ return success();
+ }
+
+ // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
+ // ops that broadcast the constant to each tile slice.
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = constantOp.getLoc();
+
+ // Unpack 1-d vector type from 2-d vector type.
+ auto tileSliceType =
+ VectorType::get(tileType.getShape().drop_front(), tileElementType,
+ /*scalableDims=*/{true});
+ auto denseAttr1D = DenseElementsAttr::get(
+ tileSliceType, denseAttr.getSplatValue<Attribute>());
+ auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+ // Create 'arm_sme.get_tile' op.
+ auto tileId = rewriter.create<arm_sme::GetTileID>(
+ loc, rewriter.getIntegerType(tileElementWidth));
+
+ // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+ // use as input tile to 'arm_sme.move_vector_to_tile_slice' ops.
+ auto tile =
+ rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ // Create a loop that broadcasts the constant to each ZA tile slice.
+ auto forOp =
+ rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto tileSliceIndex = forOp.getInductionVar();
+
+ // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice.
+ rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, constantOp1D, tile, tileSliceIndex);
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ rewriter.replaceOp(constantOp, tile);
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index e846b63b011a62..88beb80de934fe 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -308,6 +308,58 @@ struct StoreTileSliceToArmSMELowering
}
};
+/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
+/// tile slices are currently supported.
+struct MoveVectorToTileSliceToArmSMELowering
+ : public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
+ using ConvertOpToLLVMPattern<
+ arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp,
+ arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = moveVectorToTileSliceOp.getLoc();
+ auto tileType = moveVectorToTileSliceOp.getTileType();
+ auto tileElementType = tileType.getElementType();
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+ // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
+ // loaded to.
+ auto tile = rewriter.create<arm_sme::CastVectorToTile>(
+ loc, rewriter.getIntegerType(tileElementWidth),
+ moveVectorToTileSliceOp.getTile());
+
+ auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
+
+ // Cast tile slice from index to i32 for intrinsic.
+ auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI32Type(), tileSlice);
+
+ // Create all active predicate mask.
+ auto one = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI1Type(),
+ rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
+ auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
+ /*scalableDims=*/{true});
+ auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+
+ auto tileI32 = castTileIDToI32(tile, loc, rewriter);
+
+ // Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
+ rewriter.create<arm_sme::aarch64_sme_write_horiz>(
+ loc, tileI32, tileSliceI32, allActiveMask,
+ moveVectorToTileSliceOp.getVector());
+
+ // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
+ // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
+ rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
+ moveVectorToTileSliceOp, tileType, tile);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::configureArmSMELegalizeForExportTarget(
@@ -320,8 +372,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_za_enable,
- arm_sme::aarch64_sme_za_disable>();
+ arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
+ arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
// Mark 'func.func' ops as legal if either:
@@ -353,5 +405,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
- LoadTileSliceToArmSMELowering>(converter);
+ LoadTileSliceToArmSMELowering,
+ MoveVectorToTileSliceToArmSMELowering>(converter);
}
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index a08646cdca9aef..12f7e7333ebc97 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -83,3 +83,47 @@ func.func @arith_constant_dense_2d_zero_f64() {
"prevent.dce"(%zero) : (vector<[2]x[2]xf64>) -> ()
return
}
+
+// =============================================================================
+// Non-zero arith.constant dense to SME
+// =============================================================================
+
+// -----
+
+// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
+// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
+// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK: arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[16]xi8> into vector<[16]x[16]xi8>
+// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[16]x[16]xi8>) -> ()
+func.func @arith_constant_dense_2d_nonzero_i8() {
+ %two = arith.constant dense<2> : vector<[16]x[16]xi8>
+ "prevent.dce"(%two) : (vector<[16]x[16]xi8>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
+// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i64
+// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i64 to vector<[2]x[2]xf64>
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
+// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK: arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[2]xf64> into vector<[2]x[2]xf64>
+// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[2]x[2]xf64>) -> ()
+func.func @arith_constant_dense_2d_nonzero_f64() {
+ %two = arith.constant dense<2.0> : vector<[2]x[2]xf64>
+ "prevent.dce"(%two) : (vector<[2]x[2]xf64>) -> ()
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 1609ed39e64164..f529b47cb98c3d 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -71,3 +71,21 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 {
%0 = arm_sme.get_tile_id : i1
return %0 : i1
}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
+ %0 = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xi8> into vector<[16]x[16]xi8>
+ return %0 : vector<[16]x[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
+ %0 = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xf32> into vector<[4]x[4]xf32>
+ return %0 : vector<[4]x[4]xf32>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 93c4eba0531786..bae48be87b2dcd 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -577,3 +577,84 @@ func.func @arm_sme_store_tile_slice_f64(%tile : vector<[2]x[2]xf64>, %tile_slice
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[16]xi8> into vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_i16(%vector : vector<[8]xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[8]xi16> into vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xi16> into vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_i64(%vector : vector<[2]xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[2]xi64> into vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xi64> into vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_i128(%vector : vector<[1]xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[1]xi128> into vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[1]xi128> into vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_f16(%vector : vector<[8]xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xf16> into vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_f32(%vector : vector<[4]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xf32> into vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[2]xf64> into vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
new file mode 100644
index 00000000000000..a407b13b541839
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN: -march=aarch64 -mattr=+sve,+sme \
+// RUN: -e entry -entry-point-result=i32 \
+// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+// Integration test demonstrating filling a 32-bit element ZA tile with a
+// non-zero constant via vector to tile (MOVA) ops.
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+func.func @printTileBegin() {
+ %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.getelementptr %0[%1, %1]
+ : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+func.func @printTileEnd() {
+ %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.getelementptr %0[%1, %1]
+ : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+func.func @entry() -> i32 {
+ %c0 = arith.constant 0 : index
+ %c1_index = arith.constant 1 : index
+
+ %min_elts_s = arith.constant 4 : index
+ %vscale = vector.vscale
+
+ // "svl" refers to the Streaming Vector Length and "svl_s" the number of
+ // 32-bit elements in a vector of SVL bits.
+ %svl_s = arith.muli %min_elts_s, %vscale : index
+
+ // Allocate memory.
+ %tilesize = arith.muli %svl_s, %svl_s : index
+ %mem = memref.alloca(%tilesize) : memref<?xi32>
+
+ // Fill a tile with '123'. This will get lowered to a 1-d vector splat of
+ // '123' and a loop that writes this vector to each tile slice in the ZA
+ // tile.
+ %tile = arith.constant dense<123> : vector<[4]x[4]xi32>
+
+ // Store tile to memory so it can be dumped.
+ vector.store %tile, %mem[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // Dump "mem". The smallest SVL is 128-bits so the tile will be at least
+ // 4x4xi32.
+ //
+ // CHECK: TILE BEGIN
+ // CHECK-NEXT: ( 123, 123, 123, 123
+ // CHECK-NEXT: ( 123, 123, 123, 123
+ // CHECK-NEXT: ( 123, 123, 123, 123
+ // CHECK-NEXT: ( 123, 123, 123, 123
+ // CHECK: TILE END
+ func.call @printTileBegin() : () -> ()
+ scf.for %i = %c0 to %tilesize step %svl_s {
+ %tileslice = vector.load %mem[%i] : memref<?xi32>, vector<[4]xi32>
+ vector.print %tileslice : vector<[4]xi32>
+ }
+ func.call @printTileEnd() : () -> ()
+
+ %c0_i32 = arith.constant 0 : i32
+ return %c0_i32 : i32
+}
+
+llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
+llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
More information about the Mlir-commits
mailing list