[Mlir-commits] [mlir] [mlir][xegpu] cleanup the print format for TensorDesc (PR #149182)
Adam Siemieniuk
llvmlistbot at llvm.org
Fri Jul 18 08:58:18 PDT 2025
================
@@ -131,62 +131,51 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
}
- BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
- return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
- }
-
- ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr() const {
- return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
+ template <typename T,
+ typename = std::enable_if_t<
+ std::is_same_v<T, BlockTensorDescAttr> ||
+ std::is_same_v<T, ScatterTensorDescAttr>>>
+ T getEncodingOfType() const {
+ return llvm::dyn_cast_if_present<T>(getEncoding());
}
LayoutAttr getLayoutAttr() const {
return llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
}
xegpu::MemorySpace getMemorySpace() const {
- auto block_attr = getEncodingAsBlockTensorDescAttr();
- if (block_attr && block_attr.getMemorySpace())
- return block_attr.getMemorySpace().getValue();
+ if (auto attr = getEncodingOfType<BlockTensorDescAttr>())
+ return attr.getMemorySpace().getValue();
- auto scatter_attr = getEncodingAsScatterTensorDescAttr();
- if (scatter_attr && scatter_attr.getMemorySpace())
- return scatter_attr.getMemorySpace().getValue();
+ if (auto attr = getEncodingOfType<ScatterTensorDescAttr>())
+ return attr.getMemorySpace().getValue();
- // return default value
+ llvm_unreachable("invalid encoding");
----------------
adam-smnk wrote:
Why make it unreachable?
Representing some memory space makes sense. Or is it impossible to create a type without an encoding?
https://github.com/llvm/llvm-project/pull/149182
More information about the Mlir-commits
mailing list