[Mlir-commits] [mlir] [mlir][tensor] Remove hard-coded types from `ConstantOpExtractSliceFolder` (PR #184013)

Matthias Springer llvmlistbot at llvm.org
Sun Mar 1 08:08:29 PST 2026


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/184013

Use the `Attribute` API, which works with arbitrary element types.


>From 480ae86e292db2af9e3f8378ca6af795047c29e6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 1 Mar 2026 16:07:06 +0000
Subject: [PATCH] [mlir][tensor] Remove hard-coded types from
 `ConstantOpExtractSliceFolder`

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 31 +++++--------------
 .../Tensor/fold-constant-extract-slice.mlir   | 12 +++++++
 2 files changed, 20 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7d77d8cb1cc00..ce0f8540d884a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2700,29 +2700,14 @@ class ConstantOpExtractSliceFolder final
       counts.push_back(count);
     }
 
-    // New attribute constructed by the sliced values.
-    DenseElementsAttr newAttr;
-
-    if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
-      SmallVector<APInt> outValues;
-      outValues.reserve(sourceType.getNumElements());
-      sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
-          elems.begin(), counts, offsets, sizes, strides, &outValues);
-      newAttr = DenseElementsAttr::get(resultType, outValues);
-    } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
-      SmallVector<APFloat> outValues;
-      outValues.reserve(sourceType.getNumElements());
-      sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
-          elems.begin(), counts, offsets, sizes, strides, &outValues);
-      newAttr = DenseElementsAttr::get(resultType, outValues);
-    }
-
-    if (newAttr) {
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
-      return success();
-    }
-
-    return failure();
+    // Slice the elements and construct a new attribute.
+    SmallVector<Attribute> outValues;
+    outValues.reserve(resultType.getNumElements());
+    sliceElements(attr.value_begin<Attribute>(), counts, offsets, sizes,
+                  strides, &outValues);
+    auto newAttr = DenseElementsAttr::get(resultType, outValues);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
+    return success();
   }
 
 private:
diff --git a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
index 38df4f03669cd..ae1e0d4d481f1 100644
--- a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
+++ b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir
@@ -37,3 +37,15 @@ func.func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32
   return %slice : tensor<2x2xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @slice_constant_dense_element_type
+//   CHECK-NOT:   tensor.extract_slice
+//       CHECK:   %[[CONST:.+]] = arith.constant dense<tensor<2x!test.dense_element> : [9 : i32, 8 : i32]>
+//       CHECK:   return %[[CONST]]
+func.func @slice_constant_dense_element_type() -> tensor<2x!test.dense_element>
+{
+  %cst = arith.constant dense<tensor<4x!test.dense_element> : [10 : i32, 9 : i32, 8 : i32, 7 : i32]>
+  %slice = tensor.extract_slice %cst[1] [2] [1] : tensor<4x!test.dense_element> to tensor<2x!test.dense_element>
+  return %slice : tensor<2x!test.dense_element>
+}



More information about the Mlir-commits mailing list