[Mlir-commits] [mlir] [MLIR][XeGPU]: Reject `tensor_desc` types with unknown bitwidth (PR #173922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 29 11:56:52 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Stefan Weigl-Bosker (sweiglbosker)

<details>
<summary>Changes</summary>

Fixes https://github.com/llvm/llvm-project/issues/173851

1. Only allow XeGPU_ScalarType element types in `xegpu::TensorDescType` (via verifier, keeping mlir::Type params in api)
2. Fix `VectorToXeGPU` to prevent vectors with invalid TensorDescType element types from lowering

---
Full diff: https://github.com/llvm/llvm-project/pull/173922.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+1-1) 
- (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+4) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 716681fe9e187..5cb9110b3e4ad 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -74,7 +74,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
 
     ```
     TensorDesc-type ::= `tensor_desc` `<` dim-list element-type (attr-list)? `>`
-    element-type ::= float-type | integer-type | index-type
+    element-type ::= float-type | integer-type
     dim-list := (static-dim-list `x`)?
     static-dim-list ::= decimal-literal `x` decimal-literal
     attr-list = (, encoding-attr)? (, layout-attr)?
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 55ade0ae8eeec..7d7f0a23848ad 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -57,6 +57,10 @@ static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
   if (!(vecRank == 1 || vecRank == 2))
     return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
 
+  if (!vecTy.getElementType().isIntOrFloat())
+    return rewriter.notifyMatchFailure(
+        op, "Expected scalar type with known bitwidth");
+
   return success();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index ccf17da26c942..378e246c6808d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -786,6 +786,10 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
       return emitError() << "SLM is only supported for 1D block tensor";
   }
 
+  if (!elementType.isIntOrFloat())
+    return emitError() << "unsupported element type " << elementType
+                       << ": expected integer or float";
+
   // for gather and scatter ops, Low-precision types are packed in 32-bit units.
   unsigned bitWidth = elementType.getIntOrFloatBitWidth();
   int chunkAlignmentFactor =
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 67faa60f2835e..7c3c8e0d3fa35 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -64,6 +64,12 @@ func.func @create_nd_tdesc_9(%src: ui64) {
   return
 }
 
+// -----
+func.func @create_nd_tdesc_10(%src: memref<24xindex>) {
+  // expected-error @+1 {{unsupported element type 'index': expected integer or float}}
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24xindex> -> !xegpu.tensor_desc<24xindex>
+  return
+}
 
 // -----
 func.func @prefetch_nd_vc_1(%src: memref<24x32xf16>) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/173922


More information about the Mlir-commits mailing list