[Mlir-commits] [mlir] [mlir][SME] Add vector.splat -> SME conversion (PR #67659)
Cullen Rhodes
llvmlistbot at llvm.org
Thu Sep 28 05:02:08 PDT 2023
================
@@ -240,6 +240,63 @@ struct BroadcastOpToArmSMELowering
}
};
+/// Conversion pattern for vector.splat.
+///
+/// Example:
+///
+/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
+/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
+/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
+/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// }
+///
+/// This should, in practice, be identical to vector.broadcast when
+/// broadcasting a scalar.
+struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+ PatternRewriter &rewriter) const final {
+ auto tileType = splatOp.getResult().getType();
+ if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+ return failure();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = splatOp.getLoc();
+
+ auto srcType = splatOp.getOperand().getType();
+ auto tileElementType = tileType.getElementType();
+
+ assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
+
+ // First, broadcast the scalar to a 1-d vector.
+ auto tileSliceType =
+ VectorType::get(tileType.getShape().drop_front(), tileElementType,
+ /*scalableDims=*/{true});
+ Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
+ loc, tileSliceType, splatOp.getOperand());
----------------
c-rhodes wrote:
it seems `getInput` is typically used (the name of the operand)
```suggestion
loc, tileSliceType, splatOp.getInput());
```
https://github.com/llvm/llvm-project/pull/67659
More information about the Mlir-commits
mailing list