[Mlir-commits] [mlir] [mlir][tensor] Remove hard-coded types from `ConstantOpExtractSliceFolder` (PR #184013)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 1 08:08:58 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Use the `Attribute` API, which works with arbitrary element types.


---
Full diff: https://github.com/llvm/llvm-project/pull/184013.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+8-23) 
- (modified) mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir (+12) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7d77d8cb1cc00..ce0f8540d884a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2700,29 +2700,14 @@ class ConstantOpExtractSliceFolder final
       counts.push_back(count);
     }
 
-    // New attribute constructed by the sliced values.
-    DenseElementsAttr newAttr;
-
-    if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
-      SmallVector<APInt> outValues;
-      outValues.reserve(sourceType.getNumElements());
-      sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
-          elems.begin(), counts, offsets, sizes, strides, &outValues);
-      newAttr = DenseElementsAttr::get(resultType, outValues);
-    } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
-      SmallVector<APFloat> outValues;
-      outValues.reserve(sourceType.getNumElements());
-      sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
-          elems.begin(), counts, offsets, sizes, strides, &outValues);
-      newAttr = DenseElementsAttr::get(resultType, outValues);
-    }
-
-    if (newAttr) {
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
-      return success();
-    }
-
-    return failure();
+    // Slice the elements and construct a new attribute.
+    SmallVector<Attribute> outValues;
+    outValues.reserve(resultType.getNumElements());
+    sliceElements(attr.value_begin<Attribute>(), counts, offsets, sizes,
+                  strides, &outValues);
+    auto newAttr = DenseElementsAttr::get(resultType, outValues);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
+    return success();
   }
 
 private:
diff --git a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
index 38df4f03669cd..ae1e0d4d481f1 100644
--- a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
+++ b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
@@ -37,3 +37,15 @@ func.func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32
   return %slice : tensor<2x2xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @slice_constant_dense_element_type
+//   CHECK-NOT:   tensor.extract_slice
+//       CHECK:   %[[CONST:.+]] = arith.constant dense<tensor<2x!test.dense_element> : [9 : i32, 8 : i32]>
+//       CHECK:   return %[[CONST]]
+func.func @slice_constant_dense_element_type() -> tensor<2x!test.dense_element>
+{
+  %cst = arith.constant dense<tensor<4x!test.dense_element> : [10 : i32, 9 : i32, 8 : i32, 7 : i32]>
+  %slice = tensor.extract_slice %cst[1] [2] [1] : tensor<4x!test.dense_element> to tensor<2x!test.dense_element>
+  return %slice : tensor<2x!test.dense_element>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/184013


More information about the Mlir-commits mailing list