[Mlir-commits] [mlir] [mlir][Vector] Fold `vector.extract` from poison vector (PR #126122)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 6 12:25:45 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Diego Caballero (dcaballe)
<details>
<summary>Changes</summary>
This PR adds a folder for `vector.extract(ub.poison) -> ub.poison`. It also replaces `create` with `createOrFold` insert/extract ops in vector unroll and transpose lowering patterns to trigger the poison foldings introduced recently.
---
Full diff: https://github.com/llvm/llvm-project/pull/126122.diff
4 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+14-4)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+4-4)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+31-24)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+53)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 30ff2df7c38fc34..b4a5461f4405dcf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1991,15 +1991,23 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
-static ub::PoisonAttr
-foldPoisonIndexInsertExtractOp(MLIRContext *context,
- ArrayRef<int64_t> staticPos, int64_t poisonVal) {
+static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
+ ArrayRef<int64_t> staticPos,
+ int64_t poisonVal) {
if (!llvm::is_contained(staticPos, poisonVal))
- return ub::PoisonAttr();
+ return {};
return ub::PoisonAttr::get(context);
}
+/// Fold a vector extract from is a poison source.
+static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
+ if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
+ return srcAttr;
+
+ return {};
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2009,6 +2017,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
+ if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+ return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80f..6135a1290d559f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -209,7 +209,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
ImplicitLocOpBuilder b(source.getLoc(), builder);
SmallVector<Value> vs;
for (int64_t i = 0; i < m; ++i)
- vs.push_back(b.create<vector::ExtractOp>(source, i));
+ vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
// Interleave 32-bit lanes using
// 8x _mm512_unpacklo_epi32
@@ -378,9 +378,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
SmallVector<int64_t> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
- rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
- result =
- rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+ rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
+ result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
+ insertIdxs);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 800c1d9fb1dbfd6..c1e3850f05c5ec7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedRead, result, elementOffsets, strides);
}
rewriter.replaceOp(readOp, result);
@@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
Value resultTensor;
for (SmallVector<int64_t> elementOffsets :
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
- Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
@@ -289,8 +289,9 @@ struct UnrollContractionPattern
SmallVector<int64_t> operandShape = applyPermutationMap(
permutationMap, ArrayRef<int64_t>(*targetShape));
SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
- slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, operand, operandOffets, operandShape, operandStrides);
+ slicesOperands[index] =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand, operandOffets, operandShape, operandStrides);
};
// Extract the new lhs operand.
@@ -333,7 +334,7 @@ struct UnrollContractionPattern
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(contractOp, result);
@@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<Value> operands;
SmallVector<int64_t> operandStrides(offsets.size(), 1);
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getSource(), offsets, *targetShape,
+ operandStrides);
operands.push_back(slicedOperand);
SmallVector<int64_t> dstShape;
SmallVector<int64_t> destOffset;
@@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
if (accIt != accCache.end())
acc = accIt->second;
else
- acc = rewriter.create<vector::ExtractStridedSliceOp>(
+ acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
operands.push_back(acc);
auto targetType = VectorType::get(
@@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
rewriter.getZeroAttr(reductionOp.getDestType()));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(reductionOp, result);
@@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
continue;
}
extractOperands.push_back(
- rewriter.create<vector::ExtractStridedSliceOp>(
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand.get(), offsets, *targetShape, strides));
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, op, extractOperands, newVecType);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, newOp->getResult(0), result, offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.getVector(), offsets, *targetShape, strides);
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getVector(), offsets, *targetShape, strides);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
Value result = newOp->getResult(0);
@@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
permutedShape[indices.value()] = (*targetShape)[indices.index()];
}
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, transposeOp.getVector(), permutedOffsets, permutedShape,
- strides);
- Value transposedSlice =
- rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, transposeOp.getVector(), permutedOffsets, permutedShape,
+ strides);
+ Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
+ loc, slicedOperand, permutation);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, transposedSlice, result, elementOffsets, strides);
}
rewriter.replaceOp(transposeOp, result);
@@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// To get the unrolled gather, extract the same slice based on the
// decomposed shape from each of the index, mask, and pass-through
// vectors.
- Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
- Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
- Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
+ Value passThruSubVec =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
+ strides);
auto slicedGather = rewriter.create<vector::GatherOp>(
loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
indexSubVec, maskSubVec, passThruSubVec);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedGather, result, elementOffsets, strides);
}
rewriter.replaceOp(gatherOp, result);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61e858f5f226a13..d016c2efa142628 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,28 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
// -----
+// CHECK-LABEL: @extract_scalar_poison
+func.func @extract_scalar_poison() -> f32 {
+ // CHECK-NEXT: ub.poison : f32
+ // CHECK-NOT: vector.extract
+ %0 = ub.poison : vector<4x8xf32>
+ %1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison
+func.func @extract_vector_poison() -> vector<8xf32> {
+ // CHECK-NEXT: ub.poison : vector<8xf32>
+ // CHECK-NOT: vector.extract
+ %0 = ub.poison : vector<4x8xf32>
+ %1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32>
+ return %1 : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: @extract_scalar_poison_idx
func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
// CHECK-NOT: vector.extract
@@ -2886,6 +2908,37 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
return %1 : vector<4xi8>
}
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_scalar_poison
+func.func @insert_scalar_poison(%a: vector<4x8xf32>)
+ -> vector<4x8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
+ // CHECK-NEXT: vector.insert %[[UB]]
+ %0 = ub.poison : f32
+ %1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32>
+ return %1 : vector<4x8xf32>
+}
+
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_vector_poison
+func.func @insert_vector_poison(%a: vector<4x8xf32>)
+ -> vector<4x8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
+ // CHECK-NEXT: vector.insert %[[UB]]
+ %0 = ub.poison : vector<8xf32>
+ %1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32>
+ return %1 : vector<4x8xf32>
+}
+
+
// -----
// CHECK-LABEL: @insert_scalar_poison_idx
``````````
</details>
https://github.com/llvm/llvm-project/pull/126122
More information about the Mlir-commits
mailing list