[Mlir-commits] [mlir] [MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (PR #111197)

Renato Golin llvmlistbot at llvm.org
Fri Oct 4 14:27:39 PDT 2024


================
@@ -146,19 +146,24 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
   LogicalResult
   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType aType = op.getLhsVectorType();
-    VectorType bType = op.getRhsVectorType();
-    VectorType cType = op.getVectorType();
+    amx::TileType aType = op.getLhsTileType();
+    amx::TileType bType = op.getRhsTileType();
+    amx::TileType cType = op.getTileType();
     // Determine m x n x k tile sizes.
     std::pair<Value, Value> tsza =
         getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
     std::pair<Value, Value> tszb =
         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
     // Replace operation with intrinsic.
     Type resType = typeConverter->convertType(cType);
-    rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
-        op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
-        adaptor.getLhs(), adaptor.getRhs());
+    if (aType.getElementType().isBF16())
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
+          op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
+          adaptor.getLhs(), adaptor.getRhs());
+    else
----------------
rengolin wrote:

Perhaps make sure this is still `fp16` and not something else that fell here by accident.

https://github.com/llvm/llvm-project/pull/111197


More information about the Mlir-commits mailing list