[Mlir-commits] [mlir] f39386b - [MLIR][XeVM] Update XeVM type converter (#189306)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 10 12:52:41 PDT 2026
Author: Sang Ik Lee
Date: 2026-04-10T12:52:36-07:00
New Revision: f39386b79eb002e7dd1376e11198ac60f802cf48
URL: https://github.com/llvm/llvm-project/commit/f39386b79eb002e7dd1376e11198ac60f802cf48
DIFF: https://github.com/llvm/llvm-project/commit/f39386b79eb002e7dd1376e11198ac60f802cf48.diff
LOG: [MLIR][XeVM] Update XeVM type converter (#189306)
Ideally, DLTI should be used for getting Index type which as it is tied
to bitwidth of pointer type that can be expressed with DLTI.
But currently, a separate pass option for bitwidth of Index type is used
in many passes.
GPU to XeVM lowering pipeline also use passes with such options.
But XeVM type converter does not provide a way to reflect choice of
Index type bitwidth and uses a hardcoded value.
This PR updates XeVM type converter to use Index type bitwidth from pass
option. This is done by using LLVM type converter for converting element
type instead of the previous custom logic.
In addition to handling Index type properly, by using LLVM type
converter, low precision float types are correctly converted to LLVM
supported types.
Added:
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index c09faa86528eb..d401b56c7602d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1688,6 +1688,10 @@ def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> {
"memref::MemRefDialect", "arith::ArithDialect",
"LLVM::LLVMDialect", "index::IndexDialect",
"gpu::GPUDialect", "scf::SCFDialect"];
+ let options = [Option<"use64bitIndex", "use-64bit-index", "bool",
+ /*default=*/"true",
+ "Use 64-bit integers to convert index types">,
+ ];
}
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 4816ec693ae77..93da74e938c84 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1092,14 +1092,25 @@ struct ConvertXeGPUToXeVMPass
using Base::Base;
void runOnOperation() override {
- LLVMTypeConverter typeConverter(&getContext());
+ MLIRContext *context = &getContext();
+
+ // XeVM type converter is based on LLVM type converter with the
+ // following customizations.
+ // First, type conversion rules are added for xegpu custom types,
+ // TensorDescType and MemDescType.
+ // Second, MemRefType is lowered to single integer type
+ // Third, VectorType of single element or 0D is converted to vector
+ // element type. Otherwise, vector type is flatten to 1D.
+ LowerToLLVMOptions options(context);
+ options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
+ LLVMTypeConverter typeConverter(context, options);
+
+ Type xevmIndexType = typeConverter.convertType(IndexType::get(context));
+ Type i32Type = IntegerType::get(context, 32);
typeConverter.addConversion([&](VectorType type) -> Type {
- unsigned rank = type.getRank();
- auto elemType = type.getElementType();
- // If the element type is index, convert it to i64.
- if (llvm::isa<IndexType>(elemType))
- elemType = IntegerType::get(&getContext(), 64);
+ auto elemType = typeConverter.convertType(type.getElementType());
// If the vector rank is 0 or has a single element, return the element
+ unsigned rank = type.getRank();
if (rank == 0 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
@@ -1111,17 +1122,21 @@ struct ConvertXeGPUToXeVMPass
if (type.isScattered())
return {};
if (type.getRank() == 1)
- return IntegerType::get(&getContext(), 64);
- auto i32Type = IntegerType::get(&getContext(), 32);
+ return xevmIndexType;
return VectorType::get(8, i32Type);
});
+ // SLM access related type conversions.
+ // TODO: LLVM DLTI provides clean way of representing
diff erent pointer size
+ // based on address space. Currently pointer size of SLM access is hard
+ // coded to 32bit. Update to use DLTI when switching overall XeGPU lowering
+ // to use DLTI instead of use64bitIndex option used above.
+
// Convert MemDescType into i32 for SLM
- typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
- return IntegerType::get(&getContext(), 32);
- });
+ typeConverter.addConversion(
+ [&](xegpu::MemDescType type) -> Type { return i32Type; });
typeConverter.addConversion([&](MemRefType type) -> Type {
- return IntegerType::get(&getContext(), (isSharedMemRef(type) ? 32 : 64));
+ return isSharedMemRef(type) ? i32Type : xevmIndexType;
});
// LLVM type converter puts unrealized casts for the following cases:
@@ -1231,9 +1246,9 @@ struct ConvertXeGPUToXeVMPass
return {};
};
- // Materialization to convert
- // - bitcast vector of same rank
- // - shape vector of
diff erent rank but same element type
+ // Materialization to convert between vector types
+ // - Add shape cast for
diff erent shapes
+ // - Add bitcast for
diff erent element types
// Applies to both source and target materialization.
auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
@@ -1243,17 +1258,22 @@ struct ConvertXeGPUToXeVMPass
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
if (auto targetVecTy = dyn_cast<VectorType>(type)) {
- // If the target type is a vector of same rank,
- // bitcast to the target type.
- if (targetVecTy.getRank() == vecTy.getRank())
- return vector::BitCastOp::create(builder, loc, targetVecTy, input)
- .getResult();
- else if (targetVecTy.getElementType() == vecTy.getElementType()) {
- // If the target type is a vector of
diff erent rank but same element
- // type, reshape to the target type.
- return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
- .getResult();
+ Value cast = input;
+ // If the target type has a
diff erent shape, add a shape cast
+ // If the target type has a
diff erent element type, add a bitcast
+ if (targetVecTy.getShape() != vecTy.getShape()) {
+ cast = vector::ShapeCastOp::create(
+ builder, loc,
+ VectorType::get(targetVecTy.getShape(),
+ vecTy.getElementType()),
+ cast)
+ .getResult();
}
+ if (targetVecTy.getElementType() != vecTy.getElementType()) {
+ cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast)
+ .getResult();
+ }
+ return cast;
}
}
return {};
@@ -1269,26 +1289,31 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
- if (type == vecTy.getElementType() ||
- ((vecTy.getElementType() == builder.getIndexType()) &&
- type.isInteger())) {
- // If the vector rank is 0 or has a single element,
- // extract scalar of target type.
- auto rank = vecTy.getRank();
- Value cast;
- if (rank == 0) {
- cast =
- vector::ExtractOp::create(builder, loc, input, {}).getResult();
- } else {
- cast = vector::ExtractOp::create(builder, loc, input,
- SmallVector<int64_t>(rank, 0))
- .getResult();
- }
- if (type != vecTy.getElementType())
- cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
- .getResult();
- return cast;
+ // Source needs to be single element vector
+ auto rank = vecTy.getRank();
+ if (rank != 0 && vecTy.getNumElements() != 1)
+ return {};
+ auto inElemTy = vecTy.getElementType();
+ // extract scalar
+ Value cast = input;
+ if (rank == 0) {
+ cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult();
+ } else {
+ cast = vector::ExtractOp::create(builder, loc, cast,
+ SmallVector<int64_t>(rank, 0))
+ .getResult();
+ }
+ // Extracted element type may need conversion
+ // Two cases
+ // 1. Index type to integer type
+ // 2. Other element type mismatch
+ if (inElemTy.isIndex()) {
+ cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
+ } else if (inElemTy != type) {
+ cast = arith::BitcastOp::create(builder, loc, type, cast).getResult();
}
+ return cast;
}
return {};
};
@@ -1297,7 +1322,8 @@ struct ConvertXeGPUToXeVMPass
// - single element of vector element type to single element vector
// If result type of original op is single element vector and lowered type
// is scalar. This materialization cast creates a single element vector by
- // broadcasting the scalar value.
+ // First convert element type if needed and then broadcast to single
+ // element vector.
// Applies only to source materialization.
auto singleElementToVectorMaterializationCast =
[](OpBuilder &builder, Type type, ValueRange inputs,
@@ -1305,21 +1331,26 @@ struct ConvertXeGPUToXeVMPass
if (inputs.size() != 1)
return {};
auto input = inputs.front();
+ auto inTy = input.getType();
+ if (!inTy.isIntOrFloat())
+ return {};
// If the target type is a vector of rank 0 or single element vector
// of element type matching input type, broadcast input to target type.
if (auto vecTy = dyn_cast<VectorType>(type)) {
- if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
- if (input.getType() == vecTy.getElementType()) {
- return vector::BroadcastOp::create(builder, loc, vecTy, input)
- .getResult();
- } else if (vecTy.getElementType() == builder.getIndexType()) {
- Value cast = arith::IndexCastUIOp::create(
- builder, loc, builder.getIndexType(), input)
- .getResult();
- return vector::BroadcastOp::create(builder, loc, vecTy, cast)
- .getResult();
- }
+ if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1)
+ return {};
+ auto outElemTy = vecTy.getElementType();
+ Value cast = input;
+ if (outElemTy.isIndex()) {
+ cast = arith::IndexCastUIOp::create(builder, loc,
+ builder.getIndexType(), cast)
+ .getResult();
+ } else if (inTy != outElemTy) {
+ cast = arith::BitcastOp::create(builder, loc, outElemTy, cast)
+ .getResult();
}
+ return vector::BroadcastOp::create(builder, loc, vecTy, cast)
+ .getResult();
}
return {};
};
@@ -1332,14 +1363,14 @@ struct ConvertXeGPUToXeVMPass
typeConverter.addTargetMaterialization(
vectorToSingleElementMaterializationCast);
typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
- ConversionTarget target(getContext());
+ ConversionTarget target(*context);
target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
vector::VectorDialect, arith::ArithDialect,
memref::MemRefDialect, gpu::GPUDialect,
index::IndexDialect>();
target.addIllegalDialect<xegpu::XeGPUDialect>();
- RewritePatternSet patterns(&getContext());
+ RewritePatternSet patterns(context);
populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
patterns, target);
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index f29ffabc1d94d..7600ec39fb3f5 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -98,7 +98,10 @@ void buildGPUPassPipeline(OpPassManager &pm,
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
}
pm.addNestedPass<gpu::GPUModuleOp>(createConvertMathToXeVM());
- pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass());
+ ConvertXeGPUToXeVMPassOptions xegpuToXeVMOptions;
+ xegpuToXeVMOptions.use64bitIndex = options.use64bitIndex;
+ pm.addNestedPass<gpu::GPUModuleOp>(
+ createConvertXeGPUToXeVMPass(xegpuToXeVMOptions));
{
ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex;
More information about the Mlir-commits
mailing list