[Mlir-commits] [mlir] [MLIR][XeVM] blockload and blockstore ops should use scalar types (PR #161708)
Sang Ik Lee
llvmlistbot at llvm.org
Thu Oct 2 10:37:13 PDT 2025
https://github.com/silee2 created https://github.com/llvm/llvm-project/pull/161708
instead of single element vectors.
>From 646260d3764b7debdb2b4a14723670d712dd3e21 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 2 Oct 2025 17:35:22 +0000
Subject: [PATCH] [MLIR][XeVM] blockload and blockstore ops should use scalar
types instead of single element vectors.
---
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 9 ++++++---
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 18 +++++++++++-------
mlir/test/Dialect/LLVMIR/invalid.mlir | 4 ++--
3 files changed, 19 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 4f7a8421c07b9..2dd612139fa2d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -190,8 +190,9 @@ def XeVM_StoreCacheControlAttr
def XeVM_BlockLoadOp
: XeVM_Op<"blockload">,
- Results<(
- outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>,
+ Results<(outs AnyTypeOf<
+ [XeVM_1DBlockElemType,
+ FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>]>:$res)>,
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
let summary = "subgroup block load";
@@ -228,7 +229,9 @@ def XeVM_BlockLoadOp
def XeVM_BlockStoreOp
: XeVM_Op<"blockstore">,
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
- FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val,
+ AnyTypeOf<[XeVM_1DBlockElemType,
+ FixedVectorOfRankAndType<[1],
+ [XeVM_1DBlockElemType]>]>:$val,
OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
let summary = "subgroup block store";
let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 8295492ad73a8..04e8836c00359 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() {
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, BlockLoadOp, BlockStoreOp>::value>>
LogicalResult verify1DBlockArg(OpType op) {
- VectorType vTy;
+ Type srcOrDstTy;
if constexpr (std::is_same_v<OpType, BlockLoadOp>)
- vTy = op.getResult().getType();
+ srcOrDstTy = op.getResult().getType();
else
- vTy = op.getVal().getType();
+ srcOrDstTy = op.getVal().getType();
+ VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
+ // scalar case is always valid
+ if (!vTy)
+ return success();
int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
if (elemTySize == 1) {
- llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
+ llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
- "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
+ "vector size must be 2, 4, 8 or 16 for 8-bit element type");
} else {
- llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
+ llvm::SmallSet<int, 3> validSizes{2, 4, 8};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
- "vector size must be 1, 2, 4 or 8 for element type > 8 bits");
+ "vector size must be 2, 4 or 8 for element type > 8 bits");
}
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 627abd0665d8c..7ef56b52f1d83 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1943,14 +1943,14 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
// -----
llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) {
- // expected-error at +1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}}
+ // expected-error at +1 {{op vector size must be 2, 4 or 8 for element type > 8 bits}}
%0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16>
llvm.return
}
// -----
llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) {
- // expected-error at +1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}}
+ // expected-error at +1 {{op vector size must be 2, 4, 8 or 16 for 8-bit element type}}
xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>)
llvm.return
}
More information about the Mlir-commits
mailing list