[Mlir-commits] [mlir] [mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns (PR #86005)
Kojo Acquah
llvmlistbot at llvm.org
Thu Mar 21 16:53:35 PDT 2024
================
@@ -150,16 +150,30 @@ class LowerContractionToSMMLAPattern
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+ // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+ // rows along dimM. Broadcast both to the full width
+ if (isVecmat) {
+ auto lhsBroadcastType = VectorType::get(
+ {2, 8}, tiledLhs.getType().cast<ShapedType>().getElementType());
+ tiledLhs = rewriter.create<vector::BroadcastOp>(loc, lhsBroadcastType,
+ tiledLhs);
+ auto accBroadcastType = VectorType::get(
+ {2, 2}, tiledAcc.getType().cast<ShapedType>().getElementType());
+ tiledAcc = rewriter.create<vector::BroadcastOp>(loc, accBroadcastType,
+ tiledAcc);
+ }
+
// Collapse tiled operands to 1D vectors required by smmla intrinsic
auto collapsedInputType = VectorType::get(
tiledLhs.getType().cast<ShapedType>().getNumElements(),
tiledLhs.getType().cast<ShapedType>().getElementType());
- auto collapsedOutputType = VectorType::get(
- {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+ auto collapsedOutputType = VectorType::get(
+ tiledAcc.getType().cast<ShapedType>().getNumElements(),
+ tiledAcc.getType().cast<ShapedType>().getElementType());
----------------
KoolJBlack wrote:
Don't follow this comment. Can you elaborate?
https://github.com/llvm/llvm-project/pull/86005
More information about the Mlir-commits
mailing list