[Mlir-commits] [mlir] [MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (PR #111197)
Renato Golin
llvmlistbot at llvm.org
Fri Oct 4 14:27:39 PDT 2024
================
@@ -146,19 +146,24 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
LogicalResult
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType aType = op.getLhsVectorType();
- VectorType bType = op.getRhsVectorType();
- VectorType cType = op.getVectorType();
+ amx::TileType aType = op.getLhsTileType();
+ amx::TileType bType = op.getRhsTileType();
+ amx::TileType cType = op.getTileType();
// Determine m x n x k tile sizes.
std::pair<Value, Value> tsza =
getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
std::pair<Value, Value> tszb =
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
// Replace operation with intrinsic.
Type resType = typeConverter->convertType(cType);
- rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
- op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
- adaptor.getLhs(), adaptor.getRhs());
+ if (aType.getElementType().isBF16())
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+ op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+ adaptor.getLhs(), adaptor.getRhs());
+ else
----------------
rengolin wrote:
Perhaps make sure this is still `fp16` and not something else that fell here by accident.
https://github.com/llvm/llvm-project/pull/111197
More information about the Mlir-commits
mailing list