[Mlir-commits] [mlir] [mlir][Vector] Move vector.extract canonicalizers for DenseElementsAttr to folders (PR #127995)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Feb 20 15:46:58 PST 2025
================
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
return {};
}
+static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
+ Attribute srcAttr) {
+ auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
+ if (!denseAttr) {
+ return {};
+ }
+
+ if (denseAttr.isSplat()) {
+ Attribute newAttr = denseAttr.getSplatValue<Attribute>();
+ if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
+ newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+ return newAttr;
+ }
+
+ auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
+ if (vecTy.isScalable())
+ return {};
+
+ if (extractOp.hasDynamicPosition()) {
+ return {};
+ }
+
+ // Calculate the linearized position of the continuous chunk of elements to
+ // extract.
+ llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+ copy(extractOp.getStaticPosition(), completePositions.begin());
+ int64_t elemBeginPosition =
+ linearize(completePositions, computeStrides(vecTy.getShape()));
+ auto denseValuesBegin =
+ denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
+
+ TypedAttr newAttr;
+ if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
+ SmallVector<Attribute> elementValues(
+ denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
----------------
kuhar wrote:
Something like this (pseudo-ir):
```
X = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]>
A = extract %X[0:5] // [0, 1, 2, 3, 4]
B = extract %X[1:6] // [1, 2, 3, 4, 5]
C = extract %X[2:7] // [2, 3, 4, 5, 6]
D = extract %X[3:8] // [3, 4, 5, 6, 7]
E = extract %X[4:9] // [4, 5, 6, 7, 8]
F = extract %X[5:10] // [5, 6, 7, 8, 9]
```
Folding this sequence would replace a dense elements attr of size 10 with 5 unique dense elements arrays of size 5 (25). But like Kunwar said, I think this only comes up in the extract_strided_slice case.
https://github.com/llvm/llvm-project/pull/127995
More information about the Mlir-commits
mailing list