[Mlir-commits] [mlir] [mlir][tensor] add tensor insert/extract op folders (PR #142458)
Ian Wood
llvmlistbot at llvm.org
Tue Jun 3 09:37:01 PDT 2025
================
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
+namespace {
+
+/// Pattern to fold an insert op of a constant destination and scalar to a new
+/// constant.
+///
+/// Example:
+/// ```
+/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// %c0 = arith.constant 0 : index
+/// %c4_f32 = arith.constant 4.0 : f32
+/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// ```
+class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ // Requires a ranked tensor type.
+ auto destType =
+ llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
+ if (!destType)
+ return failure();
+
+ // Pattern requires constant indices
+ SmallVector<uint64_t, 8> indices;
+ for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
+ auto indiceAttr = dyn_cast<Attribute>(indice);
+ if (!indiceAttr)
+ return failure();
+ indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
+ }
+
+ // Requires a constant scalar to insert
+ OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
+ Attribute scalarAttr = dyn_cast<Attribute>(scalar);
+ if (!scalarAttr)
+ return failure();
+
+ if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
+ insertOp.getDest().getDefiningOp())) {
+ if (auto sourceAttr =
+ llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
+ // Update the attribute at the inserted index.
+ auto sourceValues = sourceAttr.getValues<Attribute>();
+ auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
+ std::vector<Attribute> updatedValues;
+ updatedValues.reserve(sourceAttr.getNumElements());
+ for (auto i = 0; i < sourceAttr.getNumElements(); ++i) {
+ updatedValues.push_back(i == flattenedIndex ? scalarAttr
+ : sourceValues[i]);
----------------
IanWood1 wrote:
This is causing a warning during build. It probably just needs a `static_cast`
```
warning: comparison of integers of different signs: 'int' and 'uint64_t' (aka 'unsigned long')
```
https://github.com/llvm/llvm-project/pull/142458
More information about the Mlir-commits
mailing list