[Mlir-commits] [mlir] [mlir][tensor] Remove hard-coded types from `ConstantOpExtractSliceFolder` (PR #184013)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 1 08:08:59 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Use the `Attribute` API, which works with arbitrary element types.
---
Full diff: https://github.com/llvm/llvm-project/pull/184013.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+8-23)
- (modified) mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir (+12)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7d77d8cb1cc00..ce0f8540d884a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2700,29 +2700,14 @@ class ConstantOpExtractSliceFolder final
counts.push_back(count);
}
- // New attribute constructed by the sliced values.
- DenseElementsAttr newAttr;
-
- if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
- SmallVector<APInt> outValues;
- outValues.reserve(sourceType.getNumElements());
- sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
- elems.begin(), counts, offsets, sizes, strides, &outValues);
- newAttr = DenseElementsAttr::get(resultType, outValues);
- } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
- SmallVector<APFloat> outValues;
- outValues.reserve(sourceType.getNumElements());
- sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
- elems.begin(), counts, offsets, sizes, strides, &outValues);
- newAttr = DenseElementsAttr::get(resultType, outValues);
- }
-
- if (newAttr) {
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
- return success();
- }
-
- return failure();
+ // Slice the elements and construct a new attribute.
+ SmallVector<Attribute> outValues;
+ outValues.reserve(resultType.getNumElements());
+ sliceElements(attr.value_begin<Attribute>(), counts, offsets, sizes,
+ strides, &outValues);
+ auto newAttr = DenseElementsAttr::get(resultType, outValues);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
+ return success();
}
private:
diff --git a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
index 38df4f03669cd..ae1e0d4d481f1 100644
--- a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
+++ b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
@@ -37,3 +37,15 @@ func.func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32
return %slice : tensor<2x2xf32>
}
+// -----
+
+// CHECK-LABEL: func @slice_constant_dense_element_type
+// CHECK-NOT: tensor.extract_slice
+// CHECK: %[[CONST:.+]] = arith.constant dense<tensor<2x!test.dense_element> : [9 : i32, 8 : i32]>
+// CHECK: return %[[CONST]]
+func.func @slice_constant_dense_element_type() -> tensor<2x!test.dense_element>
+{
+ %cst = arith.constant dense<tensor<4x!test.dense_element> : [10 : i32, 9 : i32, 8 : i32, 7 : i32]>
+ %slice = tensor.extract_slice %cst[1] [2] [1] : tensor<4x!test.dense_element> to tensor<2x!test.dense_element>
+ return %slice : tensor<2x!test.dense_element>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/184013
More information about the Mlir-commits
mailing list