[Mlir-commits] [mlir] 620425a - [mlir][tensor] Fix crash in tensor.from_elements fold with non-scalar element types (#183659)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 27 07:31:33 PST 2026


Author: Mehdi Amini
Date: 2026-02-27T16:31:29+01:00
New Revision: 620425a884388947ae691f3b80545e47c8687840

URL: https://github.com/llvm/llvm-project/commit/620425a884388947ae691f3b80545e47c8687840
DIFF: https://github.com/llvm/llvm-project/commit/620425a884388947ae691f3b80545e47c8687840.diff

LOG: [mlir][tensor] Fix crash in tensor.from_elements fold with non-scalar element types (#183659)

The fold for tensor.from_elements attempted to always produce a
DenseElementsAttr by calling DenseElementsAttr::get(type, elements).
However, DenseElementsAttr::get only handles basic scalar element types
(integer, index, float, complex) directly. For other element types such
as vector types, it expects StringAttr (raw bytes) for each element,
which folded constants won't provide — triggering an assertion.

Fix this by guarding the fold: only attempt the DenseElementsAttr fold
when the tensor element type is integer, index, float, or complex.

Fixes #180459

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c0ab7c9ec8a0..743bdabdd8542 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1456,6 +1456,13 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result,
 }
 
 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
+  // DenseElementsAttr::get requires StringAttr for element types that are not
+  // integer, index, float, or complex (e.g. vector types), but folded constants
+  // won't be StringAttr instances. Only fold for element types directly
+  // supported by DenseElementsAttr.
+  Type eltType = getType().getElementType();
+  if (!eltType.isIntOrIndexOrFloat() && !isa<ComplexType>(eltType))
+    return {};
   if (!llvm::is_contained(adaptor.getElements(), nullptr))
     return DenseElementsAttr::get(getType(), adaptor.getElements());
   return {};

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index fc499da5422fc..4b7e43ca84cec 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -411,6 +411,21 @@ func.func @from_elements_with_poison() -> tensor<1xindex> {
 
 // -----
 
+// Ensure tensor.from_elements with a vector element type doesn't crash
+// when the elements fold to constants (DenseElementsAttr does not support
+// non-scalar element types via the Attribute overload).
+// CHECK-LABEL: func @from_elements_with_vector_element_type
+func.func @from_elements_with_vector_element_type() -> tensor<1xvector<1xi1>> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+  // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[CST]] : tensor<1xvector<1xi1>>
+  // CHECK: return %[[TENSOR]]
+  %0 = vector.constant_mask [1] : vector<1xi1>
+  %1 = tensor.from_elements %0 : tensor<1xvector<1xi1>>
+  return %1 : tensor<1xvector<1xi1>>
+}
+
+// -----
+
 // Ensure the optimization doesn't segfault from bad constants
 // CHECK-LABEL: func @extract_negative_from_tensor.from_elements
 func.func @extract_negative_from_tensor.from_elements(%element : index) -> index {


        


More information about the Mlir-commits mailing list