[Mlir-commits] [mlir] [MLIR][XeVM] Add inital support for mma_mx or scaled mma. (PR #190989)
Sang Ik Lee
llvmlistbot at llvm.org
Fri Apr 10 13:57:23 PDT 2026
https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/190989
>From a84e8d9df5cf536456e8a44b22e190758d0d8d58 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 30 Mar 2026 16:46:07 +0000
Subject: [PATCH 1/2] [MLIR][XeVM] Add inital support for mma_mx or scaled mma.
Add xevm truncf and mma_mx lowering to OCL.
Note: truncf support is limited for now.
Add op conversion test.
Add integration test with bf8 type.
---
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 8 +-
mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 188 ++++++++++++++++--
.../XeVMToLLVM/xevm_mx-to-llvm.mlir | 41 ++++
.../XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir | 156 +++++++++++++++
4 files changed, 372 insertions(+), 21 deletions(-)
create mode 100644 mlir/test/Conversion/XeVMToLLVM/xevm_mx-to-llvm.mlir
create mode 100644 mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 898d65e4e7ec2..fbd2e07fdb894 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -646,8 +646,8 @@ def XeVM_MMAMxOp
Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
- FixedVectorOfRankAndType<[1], [I8]>:$scale_a,
- FixedVectorOfRankAndType<[1], [I8]>:$scale_b,
+ AnyTypeOf<[VectorOfRankAndType<[1], [I8]>, I8]>:$scale_a,
+ AnyTypeOf<[VectorOfRankAndType<[1], [I8]>, I8]>:$scale_b,
Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
@@ -669,8 +669,8 @@ def XeVM_MMAMxOp
Parameters:
* `a` - vector of matrix A elements.
* `b` - vector of matrix B elements.
- * `scale_a` - vector of scaling factors for matrix A.
- * `scale_b` - vector of scaling factors for matrix B.
+ * `scale_a` - vector or scalar of scaling factors for matrix A.
+ * `scale_b` - vector or scalar of scaling factors for matrix B.
* `c` - (optional) vector of matrix C elements.
* `shape` - the shape of the matrices, specified as `M`, `N`, and `K` values.
* `types` - the data types of the matrices, specified as `D`, `A`, `B`, and optionally `C`.
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index e6acc0525fdd5..6a7ec879d916e 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -495,6 +495,28 @@ static LLVM::CallOp createDeviceFunctionCall(
return callOp;
}
+static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+ switch (pTy) {
+ case xevm::ElemType::F32:
+ case xevm::ElemType::TF32:
+ return 1;
+ case xevm::ElemType::BF16:
+ case xevm::ElemType::F16:
+ return 2;
+ case xevm::ElemType::U8:
+ case xevm::ElemType::S8:
+ case xevm::ElemType::BF8:
+ case xevm::ElemType::F8:
+ return 4;
+ case xevm::ElemType::E2M1:
+ case xevm::ElemType::U4:
+ case xevm::ElemType::S4:
+ return 8;
+ default:
+ llvm_unreachable("unsupported xevm::ElemType");
+ }
+}
+
class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -598,22 +620,6 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
rewriter.replaceOp(op, result);
return success();
}
-
-private:
- static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
- switch (pTy) {
- case xevm::ElemType::TF32:
- return 1;
- case xevm::ElemType::BF16:
- case xevm::ElemType::F16:
- return 2;
- case xevm::ElemType::U8:
- case xevm::ElemType::S8:
- return 4;
- default:
- llvm_unreachable("unsupported xevm::ElemType");
- }
- }
};
class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
@@ -1072,6 +1078,153 @@ class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
}
};
+class TruncfToOCLPattern : public OpConversionPattern<TruncfOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(TruncfOp op, TruncfOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Supported source and result types are resticted for now.
+ auto srcEtype = op.getSrcEtype().getEtype();
+ auto dstEtype = op.getDstEtype().getEtype();
+ if (auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType())) {
+ if (vecSrcTy.getNumElements() != 16)
+ return rewriter.notifyMatchFailure(
+ op, "Only vector src of 16 elements is supported");
+ } else {
+ return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
+ }
+ if (auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType())) {
+ if (vecDstTy.getNumElements() != 16)
+ return rewriter.notifyMatchFailure(
+ op, "Only vector dst of 16 elements is supported");
+ } else {
+ return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
+ }
+ if (srcEtype == TruncfSrcElemTypes::F16 &&
+ dstEtype == TruncfDstElemTypes::BF8) {
+ auto firstHalf =
+ LLVM::ShuffleVectorOp::create(rewriter, op.getLoc(), op.getSrc(),
+ op.getSrc(), {0, 1, 2, 3, 4, 5, 6, 7});
+ auto secondHalf = LLVM::ShuffleVectorOp::create(
+ rewriter, op.getLoc(), op.getSrc(), op.getSrc(),
+ {8, 9, 10, 11, 12, 13, 14, 15});
+ auto firstHalfCasted = LLVM::BitcastOp::create(
+ rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
+ firstHalf);
+ auto secondHalfCasted = LLVM::BitcastOp::create(
+ rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
+ secondHalf);
+ auto resFirstHalf = LLVM::ShuffleVectorOp::create(
+ rewriter, op.getLoc(), firstHalfCasted, firstHalfCasted,
+ {1, 3, 5, 7, 9, 11, 13, 15});
+ auto resSecondHalf = LLVM::ShuffleVectorOp::create(
+ rewriter, op.getLoc(), secondHalfCasted, secondHalfCasted,
+ {1, 3, 5, 7, 9, 11, 13, 15});
+ auto res = LLVM::ShuffleVectorOp::create(
+ rewriter, op.getLoc(), resFirstHalf, resSecondHalf,
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
+ rewriter.replaceOp(op, res);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported src, dst element type pair.");
+ }
+ return success();
+ }
+};
+
+class MMAMxToOCLPattern : public OpConversionPattern<MMAMxOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(MMAMxOp op, MMAMxOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!op.getC()) {
+ return rewriter.notifyMatchFailure(op, "OCL requires C operand");
+ }
+ auto precisionC = op.getTypes().getC();
+ auto precisionD = op.getTypes().getD();
+ if (precisionC != precisionD) {
+ return rewriter.notifyMatchFailure(op, "type of C and D need to match");
+ }
+
+ constexpr uint32_t bitWidthPackedA{16};
+ constexpr uint32_t bitWidthPackedB{32};
+ auto loc = op.getLoc();
+
+ auto castIfNeeded = [&](Value val, Type packedType) -> Value {
+ VectorType origTy = cast<VectorType>(val.getType());
+ const uint32_t vecBitSize =
+ origTy.getNumElements() *
+ origTy.getElementType().getIntOrFloatBitWidth();
+ VectorType newTy = VectorType::get(
+ vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
+ if (origTy != newTy)
+ val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
+ return val;
+ };
+
+ Value a = op.getA();
+ Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
+ ? cast<Type>(rewriter.getF32Type())
+ : rewriter.getIntegerType(bitWidthPackedA);
+ a = castIfNeeded(a, packedAType);
+
+ Value b = op.getB();
+ Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
+ ? cast<Type>(rewriter.getF32Type())
+ : rewriter.getIntegerType(bitWidthPackedB);
+ b = castIfNeeded(b, packedBType);
+
+ Value c = op.getC();
+ VectorType cOrigTy = cast<VectorType>(c.getType());
+ VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
+ assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
+ // OCL builtins encode bfloat16 as int16
+ VectorType cTy =
+ cOrigTy.getElementType().isBF16()
+ ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
+ : cOrigTy;
+ VectorType resTy = cTy;
+ if (cOrigTy != cTy)
+ c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
+
+ constexpr int32_t systolicDepth{8};
+ std::string fnName =
+ llvm::formatv("intel_sub_group_{0}_{1}_scaled_matrix_mad_k{2}_{3}",
+ stringifyElemType(op.getTypes().getA()).str(),
+ stringifyElemType(op.getTypes().getB()).str(),
+ systolicDepth *
+ getNumOperandsPerDword(op.getTypes().getA()),
+ stringifyElemType(op.getTypes().getC()).str())
+ .str();
+ auto scaleA = op.getScaleA();
+ auto scaleB = op.getScaleB();
+ SmallVector<Type> argTypes{a.getType(), b.getType(), cTy, scaleA.getType(),
+ scaleB.getType()};
+ fnName = mangle(fnName, argTypes);
+ SmallVector<Value> args{a, b, c, scaleA, scaleB};
+
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/LLVM::ModRefInfo::NoModRef,
+ /*argMem=*/LLVM::ModRefInfo::NoModRef,
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
+ auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+ funcAttrs.memEffectsAttr = memAttr;
+ Value result =
+ createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
+ funcAttrs, op.getOperation())
+ ->getResult(0);
+
+ if (resOrigTy != resTy)
+ result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
class AllocaToGlobalPattern : public OpConversionPattern<LLVM::AllocaOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -1357,5 +1510,6 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
- AllocaToGlobalPattern>(patterns.getContext());
+ TruncfToOCLPattern, MMAMxToOCLPattern, AllocaToGlobalPattern>(
+ patterns.getContext());
}
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm_mx-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm_mx-to-llvm.mlir
new file mode 100644
index 0000000000000..65726a57cb700
--- /dev/null
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm_mx-to-llvm.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: llvm.func @truncf_f16_to_bf8
+// CHECK-SAME: %[[ARG0:.*]]: vector<16xf16>
+llvm.func @truncf_f16_to_bf8(%src: vector<16xf16>) -> vector<16xi8> {
+ // CHECK: %[[VAR0:.*]] = llvm.shufflevector %[[ARG0]], %[[ARG0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<16xf16>
+ // CHECK: %[[VAR1:.*]] = llvm.shufflevector %[[ARG0]], %[[ARG0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>
+ // CHECK: %[[VAR2:.*]] = llvm.bitcast %[[VAR0]] : vector<8xf16> to vector<16xi8>
+ // CHECK: %[[VAR3:.*]] = llvm.bitcast %[[VAR1]] : vector<8xf16> to vector<16xi8>
+ // CHECK: %[[VAR4:.*]] = llvm.shufflevector %[[VAR2]], %[[VAR2]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi8>
+ // CHECK: %[[VAR5:.*]] = llvm.shufflevector %[[VAR3]], %[[VAR3]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi8>
+ // CHECK: %[[VAR6:.*]] = llvm.shufflevector %4, %5 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+ %dst = xevm.truncf %src { src_etype = f16, dst_etype = bf8 } : (vector<16xf16>) -> vector<16xi8>
+ llvm.return %dst : vector<16xi8>
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func spir_funccc @_Z49intel_sub_group_bf8_bf8_scaled_matrix_mad_k32_f32Dv8_sDv8_iDv8_fcc
+// CHECK-SAME: (vector<8xi16>, vector<8xi32>, vector<8xf32>, i8, i8) -> vector<8xf32>
+// CHECK-SAME: attributes {convergent, memory_effects = #llvm.memory_effects<other = none,
+// CHECK-SAME: argMem = none, inaccessibleMem = none, errnoMem = none,
+// CHECK-SAME: targetMem0 = none, targetMem1 = none>, no_unwind, will_return}
+// CHECK: llvm.func @mma_mx_bf8_bf8_k32_f32
+// CHECK-SAME: %[[ARG0:.*]]: vector<8xi16>, %[[ARG1:.*]]: vector<8xi32>
+// CHECK-SAME: %[[ARG2:.*]]: i8, %[[ARG3:.*]]: i8, %[[ARG4:.*]]: vector<8xf32>
+llvm.func @mma_mx_bf8_bf8_k32_f32(%a: vector<8xi16>, %b: vector<8xi32>, %scale_a: i8, %scale_b: i8, %c: vector<8xf32>) -> vector<8xf32> {
+ // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z49intel_sub_group_bf8_bf8_scaled_matrix_mad_k32_f32Dv8_sDv8_iDv8_fcc
+ // CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG4]], %[[ARG2]], %[[ARG3]])
+ // CHECK-SAME: {convergent, function_type = !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>, i8, i8)>,
+ // CHECK-SAME: linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none,
+ // CHECK-SAME: argMem = none, inaccessibleMem = none, errnoMem = none,
+ // CHECK-SAME: targetMem0 = none, targetMem1 = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z49intel_sub_group_bf8_bf8_scaled_matrix_mad_k32_f32Dv8_sDv8_iDv8_fcc",
+ // CHECK-SAME: visibility_ = 0 : i64, will_return} :
+ // CHECK-SAME: (vector<8xi16>, vector<8xi32>, vector<8xf32>, i8, i8) -> vector<8xf32>
+ %result = xevm.mma_mx %a, %b, %scale_a, %scale_b, %c
+ {shape=<m=8, n=16, k=32>, types=<d=f32, a=bf8, b=bf8, c=f32>}
+ : (vector<8xi16>, vector<8xi32>, i8, i8, vector<8xf32>) -> vector<8xf32>
+ llvm.return %result : vector<8xf32>
+}
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir
new file mode 100644
index 0000000000000..39c58a31e2138
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir
@@ -0,0 +1,156 @@
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane zebin-chip=cri" \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+// XFAIL: *
+module @gemm attributes {gpu.container_module} {
+ gpu.module @kernel {
+ gpu.func @block_scaled_dpas_bf8(%a: !llvm.ptr<1>, %b: !llvm.ptr<1>, %c: !llvm.ptr<1>) kernel {
+ // TODO: some values are related can be derived from others like the following.
+ // %M = arith.constant 8 : i32
+ // %N = arith.constant 16 : i32
+ // %K = arith.constant 8 : i32
+ // %load_a_elem_bitwidth = arith.constant 32 : i32
+ // %a_elem_bitwidth = arith.constant 16 : i32
+ // %mx_elem_bitwidth = arith.constant 8 : i32
+ // %load_a_pack_ratio = arith.divsi %load_a_elem_bitwidth, %a_elem_bitwidth : i32
+ // %mx_pack_ratio = arith.divsi %load_a_elem_bitwidth, %mx_elem_bitwidth : i32
+ // %load_a_K = arith.muli %K, %load_a_pack_ratio : i32
+ // %load_b_K = arith.muli %K, %mx_pack_ratio : i32
+
+ %base_width_a = arith.constant 16 : i32
+ %base_height_a = arith.constant 8 : i32
+ %base_pitch_a = arith.constant 16 : i32
+ %x = arith.constant 0 : i32
+ %y = arith.constant 0 : i32
+ // A is loaded as fp16, but it will be truncated to bf8 before MMA.
+ // The blockload2d op need to be configured to load with double the width
+ // in number of elements or double the element bitwidth.
+ // block load does not support width of 32 elements of 16 bit,
+ // but it supports width of 16 elements of 32 bit.
+ // So the configuration is set to load 8 elements of 32 bits per lane and then
+ // bitcast to 16 elements of fp16 element type.
+ %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32,
+ transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+ %loaded_a_casted = vector.bitcast %loaded_a : vector<8xi32> to vector<16xf16>
+ %a_trunc = xevm.truncf %loaded_a_casted { src_etype = f16, dst_etype = bf8 } : (vector<16xf16>) -> vector<16xi8>
+ %a_trunc_casted = vector.bitcast %a_trunc : vector<16xi8> to vector<8xi16>
+
+ %base_width_b = arith.constant 16 : i32
+ %base_height_b = arith.constant 32 : i32
+ %base_pitch_b = arith.constant 16 : i32
+ // B is already in bf8, and it will be used as is for MMA.
+ // So the blockload2d op is configured to load normally with 8bit element bitwidth
+ // with pack_register request.
+ %loaded_b = xevm.blockload2d %b, %base_width_b, %base_height_b, %base_pitch_b, %x, %y
+ <{elem_size_in_bits=8 : i32, tile_width=16 : i32, tile_height=32 : i32, v_blocks=1 : i32,
+ transpose=false, pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+
+ // Note: scale is not computed. Constant values are used for simplifying the example
+ %scale_a = arith.constant 1.0 : f8E8M0FNU
+ %scale_b = arith.constant 1.0 : f8E8M0FNU
+ %scale_a_casted = arith.bitcast %scale_a : f8E8M0FNU to i8
+ %scale_b_casted = arith.bitcast %scale_b : f8E8M0FNU to i8
+ // Note: c is not loaded. constant vector is used for simplifying the example
+ %loaded_c_casted = arith.constant dense<0.0> : vector<8xf32>
+
+ %c_result = xevm.mma_mx %a_trunc_casted, %loaded_b, %scale_a_casted, %scale_b_casted, %loaded_c_casted
+ {shape=<m=8, n=16, k=32>, types=<d=f32, a=bf8, b=bf8, c=f32>}
+ : (vector<8xi16>, vector<8xi32>, i8, i8, vector<8xf32>) -> vector<8xf32>
+ %c_result_casted = vector.bitcast %c_result : vector<8xf32> to vector<8xi32>
+
+ %base_width_c = arith.constant 16 : i32
+ %base_height_c = arith.constant 8 : i32
+ %base_pitch_c = arith.constant 16 : i32
+ xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ gpu.return
+ }
+ }
+
+ func.func @test(%a : memref<8x32xf16>, %b : memref<32x16xf8E5M2>, %c : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+
+ %memref_a = gpu.alloc() : memref<8x32xf16>
+ gpu.memcpy %memref_a, %a : memref<8x32xf16>, memref<8x32xf16>
+ %a_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_a : memref<8x32xf16> -> index
+ %a_ptr_as_i64 = arith.index_cast %a_ptr_as_idx : index to i64
+ %a_ptr = llvm.inttoptr %a_ptr_as_i64 : i64 to !llvm.ptr
+ %a_ptr_casted = llvm.addrspacecast %a_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ %memref_b = gpu.alloc() : memref<32x16xf8E5M2>
+ gpu.memcpy %memref_b, %b : memref<32x16xf8E5M2>, memref<32x16xf8E5M2>
+ %b_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_b : memref<32x16xf8E5M2> -> index
+ %b_ptr_as_i64 = arith.index_cast %b_ptr_as_idx : index to i64
+ %b_ptr = llvm.inttoptr %b_ptr_as_i64 : i64 to !llvm.ptr
+ %b_ptr_casted = llvm.addrspacecast %b_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ %memref_c = gpu.alloc() : memref<8x16xf32>
+ gpu.memcpy %memref_c, %c : memref<8x16xf32>, memref<8x16xf32>
+ %c_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_c : memref<8x16xf32> -> index
+ %c_ptr_as_i64 = arith.index_cast %c_ptr_as_idx : index to i64
+ %c_ptr = llvm.inttoptr %c_ptr_as_i64 : i64 to !llvm.ptr
+ %c_ptr_casted = llvm.addrspacecast %c_ptr : !llvm.ptr to !llvm.ptr<1>
+
+ gpu.launch_func @kernel::@block_scaled_dpas_bf8 blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1)
+ args(%a_ptr_casted : !llvm.ptr<1>, %b_ptr_casted : !llvm.ptr<1>, %c_ptr_casted : !llvm.ptr<1>)
+ gpu.dealloc %memref_a : memref<8x32xf16>
+ gpu.dealloc %memref_b : memref<32x16xf8E5M2>
+ %res = memref.alloc() : memref<8x16xf32>
+ gpu.memcpy %res, %memref_c : memref<8x16xf32>, memref<8x16xf32>
+ gpu.dealloc %memref_c : memref<8x16xf32>
+ return %res : memref<8x16xf32>
+ }
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1f16 = arith.constant 1.0 : f16
+ %c1bf8 = arith.constant 1.0 : f8E5M2
+ %c0f32 = arith.constant 0.0 : f32
+
+ %A = memref.alloc() : memref<8x32xf16>
+ scf.for %i = %c0 to %c8 step %c1 {
+ scf.for %j = %c0 to %c16 step %c1 {
+ memref.store %c1f16, %A[%i, %j] : memref<8x32xf16>
+ }
+ }
+
+ %B = memref.alloc() : memref<32x16xf8E5M2>
+ scf.for %i = %c0 to %c32 step %c1 {
+ scf.for %j = %c0 to %c16 step %c1 {
+ memref.store %c1bf8, %B[%i, %j] : memref<32x16xf8E5M2>
+ }
+ }
+
+ %C = memref.alloc() : memref<8x16xf32>
+ scf.for %i = %c0 to %c8 step %c1 {
+ scf.for %j = %c0 to %c16 step %c1 {
+ memref.store %c0f32, %C[%i, %j] : memref<8x16xf32>
+ }
+ }
+
+ %C_res = call @test(%A, %B, %C) : (memref<8x32xf16>, memref<32x16xf8E5M2>, memref<8x16xf32>) -> memref<8x16xf32>
+ %C_cast = memref.cast %C_res : memref<8x16xf32> to memref<*xf32>
+ call @printMemrefF32(%C_cast) : (memref<*xf32>) -> ()
+
+ memref.dealloc %A : memref<8x32xf16>
+ memref.dealloc %B : memref<32x16xf8E5M2>
+ memref.dealloc %C : memref<8x16xf32>
+ memref.dealloc %C_res : memref<8x16xf32>
+ return
+ }
+ func.func private @printMemrefF32(%ptr : memref<*xf32>) attributes { llvm.emit_c_interface }
+
+}
>From b20d9d137d82c19c19f95a62920023761f08066e Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Fri, 10 Apr 2026 20:06:11 +0000
Subject: [PATCH 2/2] Use FixedVectorOfRankAndType
---
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index fbd2e07fdb894..5929df98188f2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -622,9 +622,9 @@ def XeVM_TruncfDstElemTypeAttr : XeVM_Attr<"TruncfDstElemType", "dst_etype"> {
def XeVM_TruncfOp
: XeVM_Op<"truncf">,
- Results<(outs AnyTypeOf<[VectorOfRankAndType<[1], [I8, I<4>]>, I8,
+ Results<(outs AnyTypeOf<[FixedVectorOfRankAndType<[1], [I8, I<4>]>, I8,
I<4>]>:$dst)>,
- Arguments<(ins AnyTypeOf<[VectorOfRankAndType<[1], [F16, BF16]>, F16,
+ Arguments<(ins AnyTypeOf<[FixedVectorOfRankAndType<[1], [F16, BF16]>, F16,
BF16]>:$src,
XeVM_TruncfSrcElemTypeAttr:$src_etype,
XeVM_TruncfDstElemTypeAttr:$dst_etype)> {
@@ -646,8 +646,8 @@ def XeVM_MMAMxOp
Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
- AnyTypeOf<[VectorOfRankAndType<[1], [I8]>, I8]>:$scale_a,
- AnyTypeOf<[VectorOfRankAndType<[1], [I8]>, I8]>:$scale_b,
+ AnyTypeOf<[FixedVectorOfRankAndType<[1], [I8]>, I8]>:$scale_a,
+ AnyTypeOf<[FixedVectorOfRankAndType<[1], [I8]>, I8]>:$scale_b,
Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
More information about the Mlir-commits
mailing list