[Mlir-commits] [mlir] [MLIR][XeVM] Add inital support for mma_mx or scaled mma. (PR #190989)

Sang Ik Lee llvmlistbot at llvm.org
Fri Apr 10 13:57:23 PDT 2026


https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/190989

>From a84e8d9df5cf536456e8a44b22e190758d0d8d58 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 30 Mar 2026 16:46:07 +0000
Subject: [PATCH 1/2] [MLIR][XeVM] Add inital support for mma_mx or scaled mma.

Add xevm truncf and mma_mx lowering to OCL.
Note: truncf support is limited for now.
Add op conversion test.
Add integration test with bf8 type.
---
 mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td   |   8 +-
 mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 188 ++++++++++++++++--
 .../XeVMToLLVM/xevm_mx-to-llvm.mlir           |  41 ++++
 .../XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir  | 156 +++++++++++++++
 4 files changed, 372 insertions(+), 21 deletions(-)
 create mode 100644 mlir/test/Conversion/XeVMToLLVM/xevm_mx-to-llvm.mlir
 create mode 100644 mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 898d65e4e7ec2..fbd2e07fdb894 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -646,8 +646,8 @@ def XeVM_MMAMxOp
       Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
       Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
           FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
-          FixedVectorOfRankAndType<[1], [I8]>:$scale_a,
-          FixedVectorOfRankAndType<[1], [I8]>:$scale_b,
+          AnyTypeOf<[VectorOfRankAndType<[1], [I8]>, I8]>:$scale_a,
+          AnyTypeOf<[VectorOfRankAndType<[1], [I8]>, I8]>:$scale_b,
           Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
           XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
 
@@ -669,8 +669,8 @@ def XeVM_MMAMxOp
     Parameters:
       * `a` - vector of matrix A elements.
       * `b` - vector of matrix B elements.
-      * `scale_a` - vector of scaling factors for matrix A.
-      * `scale_b` - vector of scaling factors for matrix B.
+      * `scale_a` - vector or scalar of scaling factors for matrix A.
+      * `scale_b` - vector or scalar of scaling factors for matrix B.
       * `c` - (optional) vector of matrix C elements.
       * `shape` - the shape of the matrices, specified as `M`, `N`, and `K` values.
       * `types` - the data types of the matrices, specified as `D`, `A`, `B`, and optionally `C`.
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index e6acc0525fdd5..6a7ec879d916e 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -495,6 +495,28 @@ static LLVM::CallOp createDeviceFunctionCall(
   return callOp;
 }
 
+static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+  switch (pTy) {
+  case xevm::ElemType::F32:
+  case xevm::ElemType::TF32:
+    return 1;
+  case xevm::ElemType::BF16:
+  case xevm::ElemType::F16:
+    return 2;
+  case xevm::ElemType::U8:
+  case xevm::ElemType::S8:
+  case xevm::ElemType::BF8:
+  case xevm::ElemType::F8:
+    return 4;
+  case xevm::ElemType::E2M1:
+  case xevm::ElemType::U4:
+  case xevm::ElemType::S4:
+    return 8;
+  default:
+    llvm_unreachable("unsupported xevm::ElemType");
+  }
+}
+
 class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -598,22 +620,6 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
     rewriter.replaceOp(op, result);
     return success();
   }
-
-private:
-  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
-    switch (pTy) {
-    case xevm::ElemType::TF32:
-      return 1;
-    case xevm::ElemType::BF16:
-    case xevm::ElemType::F16:
-      return 2;
-    case xevm::ElemType::U8:
-    case xevm::ElemType::S8:
-      return 4;
-    default:
-      llvm_unreachable("unsupported xevm::ElemType");
-    }
-  }
 };
 
 class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
@@ -1072,6 +1078,153 @@ class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
   }
 };
 
+class TruncfToOCLPattern : public OpConversionPattern<TruncfOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(TruncfOp op, TruncfOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Supported source and result types are resticted for now.
+    auto srcEtype = op.getSrcEtype().getEtype();
+    auto dstEtype = op.getDstEtype().getEtype();
+    if (auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType())) {
+      if (vecSrcTy.getNumElements() != 16)
+        return rewriter.notifyMatchFailure(
+            op, "Only vector src of 16 elements is supported");
+    } else {
+      return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
+    }
+    if (auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType())) {
+      if (vecDstTy.getNumElements() != 16)
+        return rewriter.notifyMatchFailure(
+            op, "Only vector dst of 16 elements is supported");
+    } else {
+      return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
+    }
+    if (srcEtype == TruncfSrcElemTypes::F16 &&
+        dstEtype == TruncfDstElemTypes::BF8) {
+      auto firstHalf =
+          LLVM::ShuffleVectorOp::create(rewriter, op.getLoc(), op.getSrc(),
+                                        op.getSrc(), {0, 1, 2, 3, 4, 5, 6, 7});
+      auto secondHalf = LLVM::ShuffleVectorOp::create(
+          rewriter, op.getLoc(), op.getSrc(), op.getSrc(),
+          {8, 9, 10, 11, 12, 13, 14, 15});
+      auto firstHalfCasted = LLVM::BitcastOp::create(
+          rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
+          firstHalf);
+      auto secondHalfCasted = LLVM::BitcastOp::create(
+          rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
+          secondHalf);
+      auto resFirstHalf = LLVM::ShuffleVectorOp::create(
+          rewriter, op.getLoc(), firstHalfCasted, firstHalfCasted,
+          {1, 3, 5, 7, 9, 11, 13, 15});
+      auto resSecondHalf = LLVM::ShuffleVectorOp::create(
+          rewriter, op.getLoc(), secondHalfCasted, secondHalfCasted,
+          {1, 3, 5, 7, 9, 11, 13, 15});
+      auto res = LLVM::ShuffleVectorOp::create(
+          rewriter, op.getLoc(), resFirstHalf, resSecondHalf,
+          {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
+      rewriter.replaceOp(op, res);
+    } else {
+      return rewriter.notifyMatchFailure(
+          op, "Unsupported src, dst element type pair.");
+    }
+    return success();
+  }
+};
+
+class MMAMxToOCLPattern : public OpConversionPattern<MMAMxOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(MMAMxOp op, MMAMxOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op.getC()) {
+      return rewriter.notifyMatchFailure(op, "OCL requires C operand");
+    }
+    auto precisionC = op.getTypes().getC();
+    auto precisionD = op.getTypes().getD();
+    if (precisionC != precisionD) {
+      return rewriter.notifyMatchFailure(op, "type of C and D need to match");
+    }
+
+    constexpr uint32_t bitWidthPackedA{16};
+    constexpr uint32_t bitWidthPackedB{32};
+    auto loc = op.getLoc();
+
+    auto castIfNeeded = [&](Value val, Type packedType) -> Value {
+      VectorType origTy = cast<VectorType>(val.getType());
+      const uint32_t vecBitSize =
+          origTy.getNumElements() *
+          origTy.getElementType().getIntOrFloatBitWidth();
+      VectorType newTy = VectorType::get(
+          vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
+      if (origTy != newTy)
+        val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
+      return val;
+    };
+
+    Value a = op.getA();
+    Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
+                           ? cast<Type>(rewriter.getF32Type())
+                           : rewriter.getIntegerType(bitWidthPackedA);
+    a = castIfNeeded(a, packedAType);
+
+    Value b = op.getB();
+    Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
+                           ? cast<Type>(rewriter.getF32Type())
+                           : rewriter.getIntegerType(bitWidthPackedB);
+    b = castIfNeeded(b, packedBType);
+
+    Value c = op.getC();
+    VectorType cOrigTy = cast<VectorType>(c.getType());
+    VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
+    assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
+    // OCL builtins encode bfloat16 as int16
+    VectorType cTy =
+        cOrigTy.getElementType().isBF16()
+            ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
+            : cOrigTy;
+    VectorType resTy = cTy;
+    if (cOrigTy != cTy)
+      c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
+
+    constexpr int32_t systolicDepth{8};
+    std::string fnName =
+        llvm::formatv("intel_sub_group_{0}_{1}_scaled_matrix_mad_k{2}_{3}",
+                      stringifyElemType(op.getTypes().getA()).str(),
+                      stringifyElemType(op.getTypes().getB()).str(),
+                      systolicDepth *
+                          getNumOperandsPerDword(op.getTypes().getA()),
+                      stringifyElemType(op.getTypes().getC()).str())
+            .str();
+    auto scaleA = op.getScaleA();
+    auto scaleB = op.getScaleB();
+    SmallVector<Type> argTypes{a.getType(), b.getType(), cTy, scaleA.getType(),
+                               scaleB.getType()};
+    fnName = mangle(fnName, argTypes);
+    SmallVector<Value> args{a, b, c, scaleA, scaleB};
+
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/LLVM::ModRefInfo::NoModRef,
+        /*argMem=*/LLVM::ModRefInfo::NoModRef,
+        /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+        /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+        /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+        /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
+    auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+    funcAttrs.memEffectsAttr = memAttr;
+    Value result =
+        createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
+                                 funcAttrs, op.getOperation())
+            ->getResult(0);
+
+    if (resOrigTy != resTy)
+      result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 class AllocaToGlobalPattern : public OpConversionPattern<LLVM::AllocaOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -1357,5 +1510,6 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
                SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
                SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
                SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
-               AllocaToGlobalPattern>(patterns.getContext());
+               TruncfToOCLPattern, MMAMxToOCLPattern, AllocaToGlobalPattern>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm_mx-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm_mx-to-llvm.mlir
new file mode 100644
index 0000000000000..65726a57cb700
--- /dev/null
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm_mx-to-llvm.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: llvm.func @truncf_f16_to_bf8
+// CHECK-SAME: %[[ARG0:.*]]: vector<16xf16>
+llvm.func @truncf_f16_to_bf8(%src: vector<16xf16>) -> vector<16xi8> {
+  // CHECK:  %[[VAR0:.*]] = llvm.shufflevector %[[ARG0]], %[[ARG0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<16xf16> 
+  // CHECK:  %[[VAR1:.*]] = llvm.shufflevector %[[ARG0]], %[[ARG0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16> 
+  // CHECK:  %[[VAR2:.*]] = llvm.bitcast %[[VAR0]] : vector<8xf16> to vector<16xi8>
+  // CHECK:  %[[VAR3:.*]] = llvm.bitcast %[[VAR1]] : vector<8xf16> to vector<16xi8>
+  // CHECK:  %[[VAR4:.*]] = llvm.shufflevector %[[VAR2]], %[[VAR2]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi8> 
+  // CHECK:  %[[VAR5:.*]] = llvm.shufflevector %[[VAR3]], %[[VAR3]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi8> 
+  // CHECK:  %[[VAR6:.*]] = llvm.shufflevector %4, %5 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8> 
+  %dst = xevm.truncf %src { src_etype = f16, dst_etype = bf8 } : (vector<16xf16>) -> vector<16xi8>
+  llvm.return %dst : vector<16xi8>
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func spir_funccc @_Z49intel_sub_group_bf8_bf8_scaled_matrix_mad_k32_f32Dv8_sDv8_iDv8_fcc
+// CHECK-SAME: (vector<8xi16>, vector<8xi32>, vector<8xf32>, i8, i8) -> vector<8xf32>
+// CHECK-SAME:   attributes {convergent, memory_effects = #llvm.memory_effects<other = none,
+// CHECK-SAME:   argMem = none, inaccessibleMem = none, errnoMem = none,
+// CHECK-SAME:   targetMem0 = none, targetMem1 = none>, no_unwind, will_return}
+// CHECK: llvm.func @mma_mx_bf8_bf8_k32_f32
+// CHECK-SAME: %[[ARG0:.*]]: vector<8xi16>, %[[ARG1:.*]]: vector<8xi32>
+// CHECK-SAME: %[[ARG2:.*]]: i8, %[[ARG3:.*]]: i8, %[[ARG4:.*]]: vector<8xf32>
+llvm.func @mma_mx_bf8_bf8_k32_f32(%a: vector<8xi16>, %b: vector<8xi32>, %scale_a: i8, %scale_b: i8, %c: vector<8xf32>) -> vector<8xf32> {
+  // CHECK:  %[[VAR0:.*]] = llvm.call spir_funccc @_Z49intel_sub_group_bf8_bf8_scaled_matrix_mad_k32_f32Dv8_sDv8_iDv8_fcc
+  // CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG4]], %[[ARG2]], %[[ARG3]])
+  // CHECK-SAME: {convergent, function_type = !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>, i8, i8)>,
+  // CHECK-SAME: linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none,
+  // CHECK-SAME:   argMem = none, inaccessibleMem = none, errnoMem = none,
+  // CHECK-SAME:   targetMem0 = none, targetMem1 = none>,
+  // CHECK-SAME: no_unwind, sym_name = "_Z49intel_sub_group_bf8_bf8_scaled_matrix_mad_k32_f32Dv8_sDv8_iDv8_fcc",
+  // CHECK-SAME: visibility_ = 0 : i64, will_return} :
+  // CHECK-SAME: (vector<8xi16>, vector<8xi32>, vector<8xf32>, i8, i8) -> vector<8xf32>
+  %result = xevm.mma_mx %a, %b, %scale_a, %scale_b, %c
+          {shape=<m=8, n=16, k=32>, types=<d=f32, a=bf8, b=bf8, c=f32>}
+          : (vector<8xi16>, vector<8xi32>, i8, i8, vector<8xf32>) -> vector<8xf32>
+  llvm.return %result : vector<8xf32>
+}
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir
new file mode 100644
index 0000000000000..39c58a31e2138
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/xevm_block_scaled_dpas_bf8.mlir
@@ -0,0 +1,156 @@
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane zebin-chip=cri" \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_levelzero_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --shared-libs=%mlir_c_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+// XFAIL: *
+module @gemm attributes {gpu.container_module} {
+  gpu.module @kernel {
+    gpu.func @block_scaled_dpas_bf8(%a: !llvm.ptr<1>, %b: !llvm.ptr<1>, %c: !llvm.ptr<1>) kernel {
+      // TODO: some values are related can be derived from others like the following.
+      // %M = arith.constant 8 : i32
+      // %N = arith.constant 16 : i32
+      // %K = arith.constant 8 : i32
+      // %load_a_elem_bitwidth = arith.constant 32 : i32
+      // %a_elem_bitwidth = arith.constant 16 : i32
+      // %mx_elem_bitwidth = arith.constant 8 : i32
+      // %load_a_pack_ratio = arith.divsi %load_a_elem_bitwidth, %a_elem_bitwidth : i32
+      // %mx_pack_ratio = arith.divsi %load_a_elem_bitwidth, %mx_elem_bitwidth : i32
+      // %load_a_K = arith.muli %K, %load_a_pack_ratio : i32
+      // %load_b_K = arith.muli %K, %mx_pack_ratio : i32
+
+      %base_width_a = arith.constant 16 : i32
+      %base_height_a = arith.constant 8 : i32
+      %base_pitch_a = arith.constant 16 : i32
+      %x = arith.constant 0 : i32
+      %y = arith.constant 0 : i32
+      // A is loaded as fp16, but it will be truncated to bf8 before MMA.
+      // The blockload2d op need to be configured to load with double the width
+      // in number of elements or double the element bitwidth.
+      // block load does not support width of 32 elements of 16 bit,
+      // but it supports width of 16 elements of 32 bit.
+      // So the configuration is set to load 8 elements of 32 bits per lane and then
+      // bitcast to 16 elements of fp16 element type.
+      %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
+          <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32,
+            transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+      %loaded_a_casted = vector.bitcast %loaded_a : vector<8xi32> to vector<16xf16>
+      %a_trunc = xevm.truncf %loaded_a_casted { src_etype = f16, dst_etype = bf8 } : (vector<16xf16>) -> vector<16xi8>
+      %a_trunc_casted = vector.bitcast %a_trunc : vector<16xi8> to vector<8xi16>
+
+      %base_width_b = arith.constant 16 : i32
+      %base_height_b = arith.constant 32 : i32
+      %base_pitch_b = arith.constant 16 : i32
+      // B is already in bf8, and it will be used as is for MMA.
+      // So the blockload2d op is configured to load normally with 8bit element bitwidth
+      // with pack_register request.
+      %loaded_b = xevm.blockload2d %b, %base_width_b, %base_height_b, %base_pitch_b, %x, %y
+          <{elem_size_in_bits=8 : i32, tile_width=16 : i32, tile_height=32 : i32, v_blocks=1 : i32,
+            transpose=false, pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+
+      // Note: scale is not computed. Constant values are used for simplifying the example
+      %scale_a = arith.constant 1.0 : f8E8M0FNU
+      %scale_b = arith.constant 1.0 : f8E8M0FNU
+      %scale_a_casted = arith.bitcast %scale_a : f8E8M0FNU to i8
+      %scale_b_casted = arith.bitcast %scale_b : f8E8M0FNU to i8
+      // Note: c is not loaded. constant vector is used for simplifying the example
+      %loaded_c_casted = arith.constant dense<0.0> : vector<8xf32>
+
+      %c_result = xevm.mma_mx %a_trunc_casted, %loaded_b, %scale_a_casted, %scale_b_casted, %loaded_c_casted
+          {shape=<m=8, n=16, k=32>, types=<d=f32, a=bf8, b=bf8, c=f32>}
+          : (vector<8xi16>, vector<8xi32>, i8, i8, vector<8xf32>) -> vector<8xf32>
+      %c_result_casted = vector.bitcast %c_result : vector<8xf32> to vector<8xi32>
+
+      %base_width_c = arith.constant 16 : i32
+      %base_height_c = arith.constant 8 : i32
+      %base_pitch_c = arith.constant 16 : i32
+      xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
+          <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}>
+          : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+      gpu.return
+    }
+  }
+
+  func.func @test(%a : memref<8x32xf16>, %b : memref<32x16xf8E5M2>, %c : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
+    %c1 = arith.constant 1 : index
+    %c16 = arith.constant 16 : index
+
+    %memref_a = gpu.alloc() : memref<8x32xf16>
+    gpu.memcpy %memref_a, %a : memref<8x32xf16>, memref<8x32xf16>
+    %a_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_a : memref<8x32xf16> -> index
+    %a_ptr_as_i64 = arith.index_cast %a_ptr_as_idx : index to i64
+    %a_ptr = llvm.inttoptr %a_ptr_as_i64 : i64 to !llvm.ptr
+    %a_ptr_casted = llvm.addrspacecast %a_ptr : !llvm.ptr to !llvm.ptr<1>
+
+    %memref_b = gpu.alloc() : memref<32x16xf8E5M2>
+    gpu.memcpy %memref_b, %b : memref<32x16xf8E5M2>, memref<32x16xf8E5M2>
+    %b_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_b : memref<32x16xf8E5M2> -> index
+    %b_ptr_as_i64 = arith.index_cast %b_ptr_as_idx : index to i64
+    %b_ptr = llvm.inttoptr %b_ptr_as_i64 : i64 to !llvm.ptr
+    %b_ptr_casted = llvm.addrspacecast %b_ptr : !llvm.ptr to !llvm.ptr<1>
+
+    %memref_c = gpu.alloc() : memref<8x16xf32>
+    gpu.memcpy %memref_c, %c : memref<8x16xf32>, memref<8x16xf32>
+    %c_ptr_as_idx = memref.extract_aligned_pointer_as_index %memref_c : memref<8x16xf32> -> index
+    %c_ptr_as_i64 = arith.index_cast %c_ptr_as_idx : index to i64
+    %c_ptr = llvm.inttoptr %c_ptr_as_i64 : i64 to !llvm.ptr
+    %c_ptr_casted = llvm.addrspacecast %c_ptr : !llvm.ptr to !llvm.ptr<1>
+
+    gpu.launch_func @kernel::@block_scaled_dpas_bf8 blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1)
+        args(%a_ptr_casted : !llvm.ptr<1>, %b_ptr_casted : !llvm.ptr<1>, %c_ptr_casted : !llvm.ptr<1>)
+    gpu.dealloc %memref_a : memref<8x32xf16>
+    gpu.dealloc %memref_b : memref<32x16xf8E5M2>
+    %res = memref.alloc() : memref<8x16xf32>
+    gpu.memcpy %res, %memref_c : memref<8x16xf32>, memref<8x16xf32>
+    gpu.dealloc %memref_c : memref<8x16xf32>
+    return %res : memref<8x16xf32>
+  }
+
+  func.func @main() attributes {llvm.emit_c_interface} {
+
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c8 = arith.constant 8 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c1f16 = arith.constant 1.0 : f16
+    %c1bf8 = arith.constant 1.0 : f8E5M2
+    %c0f32 = arith.constant 0.0 : f32
+
+    %A = memref.alloc() : memref<8x32xf16>
+    scf.for %i = %c0 to %c8 step %c1 {
+      scf.for %j = %c0 to %c16 step %c1 {
+        memref.store %c1f16, %A[%i, %j] : memref<8x32xf16>
+      }
+    }
+
+    %B = memref.alloc() : memref<32x16xf8E5M2>
+    scf.for %i = %c0 to %c32 step %c1 {
+      scf.for %j = %c0 to %c16 step %c1 {
+        memref.store %c1bf8, %B[%i, %j] : memref<32x16xf8E5M2>
+      }
+    }
+
+    %C = memref.alloc() : memref<8x16xf32>
+    scf.for %i = %c0 to %c8 step %c1 {
+      scf.for %j = %c0 to %c16 step %c1 {
+        memref.store %c0f32, %C[%i, %j] : memref<8x16xf32>
+      }
+    }
+
+    %C_res = call @test(%A, %B, %C) : (memref<8x32xf16>, memref<32x16xf8E5M2>, memref<8x16xf32>) -> memref<8x16xf32>
+    %C_cast = memref.cast %C_res : memref<8x16xf32> to memref<*xf32>
+    call @printMemrefF32(%C_cast) : (memref<*xf32>) -> ()
+
+    memref.dealloc %A : memref<8x32xf16>
+    memref.dealloc %B : memref<32x16xf8E5M2>
+    memref.dealloc %C : memref<8x16xf32>
+    memref.dealloc %C_res : memref<8x16xf32>
+    return
+  }
+  func.func private @printMemrefF32(%ptr : memref<*xf32>) attributes { llvm.emit_c_interface }
+
+}

>From b20d9d137d82c19c19f95a62920023761f08066e Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Fri, 10 Apr 2026 20:06:11 +0000
Subject: [PATCH 2/2] Use FixedVectorOfRankAndType

---
 mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index fbd2e07fdb894..5929df98188f2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -622,9 +622,9 @@ def XeVM_TruncfDstElemTypeAttr : XeVM_Attr<"TruncfDstElemType", "dst_etype"> {
 
 def XeVM_TruncfOp
     : XeVM_Op<"truncf">,
-      Results<(outs AnyTypeOf<[VectorOfRankAndType<[1], [I8, I<4>]>, I8,
+      Results<(outs AnyTypeOf<[FixedVectorOfRankAndType<[1], [I8, I<4>]>, I8,
                                I<4>]>:$dst)>,
-      Arguments<(ins AnyTypeOf<[VectorOfRankAndType<[1], [F16, BF16]>, F16,
+      Arguments<(ins AnyTypeOf<[FixedVectorOfRankAndType<[1], [F16, BF16]>, F16,
                                 BF16]>:$src,
           XeVM_TruncfSrcElemTypeAttr:$src_etype,
           XeVM_TruncfDstElemTypeAttr:$dst_etype)> {
@@ -646,8 +646,8 @@ def XeVM_MMAMxOp
       Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
       Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
           FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
-          AnyTypeOf<[VectorOfRankAndType<[1], [I8]>, I8]>:$scale_a,
-          AnyTypeOf<[VectorOfRankAndType<[1], [I8]>, I8]>:$scale_b,
+          AnyTypeOf<[FixedVectorOfRankAndType<[1], [I8]>, I8]>:$scale_a,
+          AnyTypeOf<[FixedVectorOfRankAndType<[1], [I8]>, I8]>:$scale_b,
           Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
           XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
 



More information about the Mlir-commits mailing list