[Mlir-commits] [mlir] 620425a - [mlir][tensor] Fix crash in tensor.from_elements fold with non-scalar element types (#183659)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 27 07:31:33 PST 2026
Author: Mehdi Amini
Date: 2026-02-27T16:31:29+01:00
New Revision: 620425a884388947ae691f3b80545e47c8687840
URL: https://github.com/llvm/llvm-project/commit/620425a884388947ae691f3b80545e47c8687840
DIFF: https://github.com/llvm/llvm-project/commit/620425a884388947ae691f3b80545e47c8687840.diff
LOG: [mlir][tensor] Fix crash in tensor.from_elements fold with non-scalar element types (#183659)
The fold for tensor.from_elements attempted to always produce a
DenseElementsAttr by calling DenseElementsAttr::get(type, elements).
However, DenseElementsAttr::get only handles basic scalar element types
(integer, index, float, complex) directly. For other element types such
as vector types, it expects StringAttr (raw bytes) for each element,
which folded constants won't provide — triggering an assertion.
Fix this by guarding the fold: only attempt the DenseElementsAttr fold
when the tensor element type is integer, index, float, or complex.
Fixes #180459
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c0ab7c9ec8a0..743bdabdd8542 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1456,6 +1456,13 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result,
}
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
+ // DenseElementsAttr::get requires StringAttr for element types that are not
+ // integer, index, float, or complex (e.g. vector types), but folded constants
+ // won't be StringAttr instances. Only fold for element types directly
+ // supported by DenseElementsAttr.
+ Type eltType = getType().getElementType();
+ if (!eltType.isIntOrIndexOrFloat() && !isa<ComplexType>(eltType))
+ return {};
if (!llvm::is_contained(adaptor.getElements(), nullptr))
return DenseElementsAttr::get(getType(), adaptor.getElements());
return {};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index fc499da5422fc..4b7e43ca84cec 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -411,6 +411,21 @@ func.func @from_elements_with_poison() -> tensor<1xindex> {
// -----
+// Ensure tensor.from_elements with a vector element type doesn't crash
+// when the elements fold to constants (DenseElementsAttr does not support
+// non-scalar element types via the Attribute overload).
+// CHECK-LABEL: func @from_elements_with_vector_element_type
+func.func @from_elements_with_vector_element_type() -> tensor<1xvector<1xi1>> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[CST]] : tensor<1xvector<1xi1>>
+ // CHECK: return %[[TENSOR]]
+ %0 = vector.constant_mask [1] : vector<1xi1>
+ %1 = tensor.from_elements %0 : tensor<1xvector<1xi1>>
+ return %1 : tensor<1xvector<1xi1>>
+}
+
+// -----
+
// Ensure the optimization doesn't segfault from bad constants
// CHECK-LABEL: func @extract_negative_from_tensor.from_elements
func.func @extract_negative_from_tensor.from_elements(%element : index) -> index {
More information about the Mlir-commits
mailing list