[Mlir-commits] [mlir] [mlir][SME] Add vector.splat -> SME conversion (PR #67659)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 28 04:23:42 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This conversion is identical to vector.broadcast when broadcasting a
scalar.
---
Full diff: https://github.com/llvm/llvm-project/pull/67659.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+59-1)
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+38)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 264539b85c0ee23..a83c0e9cdafa521 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -240,6 +240,63 @@ struct BroadcastOpToArmSMELowering
}
};
+/// Conversion pattern for vector.splat.
+///
+/// Example:
+///
+/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
+/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
+/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
+/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// }
+///
+/// This should, in practice, be identical to vector.broadcast when
+/// broadcasting a scalar.
+struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+ PatternRewriter &rewriter) const final {
+ auto tileType = splatOp.getResult().getType();
+ if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+ return failure();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = splatOp.getLoc();
+
+ auto srcType = splatOp.getOperand().getType();
+ auto tileElementType = tileType.getElementType();
+
+ assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
+
+ // First, broadcast the scalar to a 1-d vector.
+ auto tileSliceType =
+ VectorType::get(tileType.getShape().drop_front(), tileElementType,
+ /*scalableDims=*/{true});
+ Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
+ loc, tileSliceType, splatOp.getOperand());
+
+ arm_sme::CastTileToVector tile =
+ getSMETileAndCastToVector(rewriter, loc, tileType);
+
+ // Next, create a loop over ZA tile slices and "move" the generated 1-d
+ // vector to each slice.
+ auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
+ auto tileSliceIndex = forOp.getInductionVar();
+
+ rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, tile, tileSliceIndex);
+
+ rewriter.replaceOp(splatOp, tile);
+
+ return success();
+ }
+};
+
/// Conversion pattern for vector.transpose.
///
/// Stores the input tile to memory and reloads vertically.
@@ -319,5 +376,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
- BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
+ BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ TransposeOpToArmSMELowering>(&ctx);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index a64753578a1c861..3c08b1deafad27d 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -220,6 +220,44 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
return
}
+//===----------------------------------------------------------------------===//
+// vector.splat
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func.func @splat_vec2d_from_i32(
+// CHECK-SAME: %[[SRC:.*]]: i32) {
+// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
+// CHECK: scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK: arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @splat_vec2d_from_i32(%arg0: i32) {
+ %0 = vector.splat %arg0 : vector<[4]x[4]xi32>
+ "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @splat_vec2d_from_f16(
+// CHECK-SAME: %[[SRC:.*]]: f16) {
+// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
+// CHECK: scf.for {{.*}} to %[[UB]] {{.*}} {
+// CHECK: arm_sme.move_vector_to_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
+func.func @splat_vec2d_from_f16(%arg0: f16) {
+ %0 = vector.splat %arg0 : vector<[8]x[8]xf16>
+ "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+ return
+}
+
//===----------------------------------------------------------------------===//
// vector.transpose
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/67659
More information about the Mlir-commits
mailing list