[Mlir-commits] [mlir] [mlir][Vector] Fold `vector.extract` from poison vector (PR #126122)
Diego Caballero
llvmlistbot at llvm.org
Fri Feb 7 10:03:21 PST 2025
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/126122
>From bcf0d6f96587fdb62c4dfba0ade13834f22aec5b Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Thu, 6 Feb 2025 12:14:33 -0800
Subject: [PATCH 1/2] [mlir][Vector] Fold `vector.extract` from poison vector
This PR adds a folder for `vector.extract(ub.poison) -> ub.poison`.
It also replaces `create` with `createOrFold` insert/extract ops in
vector unroll and transpose lowering patterns to trigger the poison
foldings introduced recently.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 18 ++++--
.../Transforms/LowerVectorTranspose.cpp | 8 +--
.../Vector/Transforms/VectorUnroll.cpp | 55 +++++++++++--------
mlir/test/Dialect/Vector/canonicalize.mlir | 53 ++++++++++++++++++
4 files changed, 102 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 30ff2df7c38fc34..b4a5461f4405dcf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1991,15 +1991,23 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
-static ub::PoisonAttr
-foldPoisonIndexInsertExtractOp(MLIRContext *context,
- ArrayRef<int64_t> staticPos, int64_t poisonVal) {
+static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
+ ArrayRef<int64_t> staticPos,
+ int64_t poisonVal) {
if (!llvm::is_contained(staticPos, poisonVal))
- return ub::PoisonAttr();
+ return {};
return ub::PoisonAttr::get(context);
}
+/// Fold a vector extract from is a poison source.
+static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
+ if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
+ return srcAttr;
+
+ return {};
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2009,6 +2017,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
+ if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+ return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80f..6135a1290d559f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -209,7 +209,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
ImplicitLocOpBuilder b(source.getLoc(), builder);
SmallVector<Value> vs;
for (int64_t i = 0; i < m; ++i)
- vs.push_back(b.create<vector::ExtractOp>(source, i));
+ vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
// Interleave 32-bit lanes using
// 8x _mm512_unpacklo_epi32
@@ -378,9 +378,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
SmallVector<int64_t> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
- rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
- result =
- rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+ rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
+ result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
+ insertIdxs);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 800c1d9fb1dbfd6..c1e3850f05c5ec7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedRead, result, elementOffsets, strides);
}
rewriter.replaceOp(readOp, result);
@@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
Value resultTensor;
for (SmallVector<int64_t> elementOffsets :
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
- Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
@@ -289,8 +289,9 @@ struct UnrollContractionPattern
SmallVector<int64_t> operandShape = applyPermutationMap(
permutationMap, ArrayRef<int64_t>(*targetShape));
SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
- slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, operand, operandOffets, operandShape, operandStrides);
+ slicesOperands[index] =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand, operandOffets, operandShape, operandStrides);
};
// Extract the new lhs operand.
@@ -333,7 +334,7 @@ struct UnrollContractionPattern
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(contractOp, result);
@@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<Value> operands;
SmallVector<int64_t> operandStrides(offsets.size(), 1);
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getSource(), offsets, *targetShape,
+ operandStrides);
operands.push_back(slicedOperand);
SmallVector<int64_t> dstShape;
SmallVector<int64_t> destOffset;
@@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
if (accIt != accCache.end())
acc = accIt->second;
else
- acc = rewriter.create<vector::ExtractStridedSliceOp>(
+ acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
operands.push_back(acc);
auto targetType = VectorType::get(
@@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
rewriter.getZeroAttr(reductionOp.getDestType()));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(reductionOp, result);
@@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
continue;
}
extractOperands.push_back(
- rewriter.create<vector::ExtractStridedSliceOp>(
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand.get(), offsets, *targetShape, strides));
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, op, extractOperands, newVecType);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, newOp->getResult(0), result, offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.getVector(), offsets, *targetShape, strides);
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getVector(), offsets, *targetShape, strides);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
Value result = newOp->getResult(0);
@@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
permutedShape[indices.value()] = (*targetShape)[indices.index()];
}
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, transposeOp.getVector(), permutedOffsets, permutedShape,
- strides);
- Value transposedSlice =
- rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, transposeOp.getVector(), permutedOffsets, permutedShape,
+ strides);
+ Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
+ loc, slicedOperand, permutation);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, transposedSlice, result, elementOffsets, strides);
}
rewriter.replaceOp(transposeOp, result);
@@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// To get the unrolled gather, extract the same slice based on the
// decomposed shape from each of the index, mask, and pass-through
// vectors.
- Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
- Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
- Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
+ Value passThruSubVec =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
+ strides);
auto slicedGather = rewriter.create<vector::GatherOp>(
loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
indexSubVec, maskSubVec, passThruSubVec);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedGather, result, elementOffsets, strides);
}
rewriter.replaceOp(gatherOp, result);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61e858f5f226a13..d016c2efa142628 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,28 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
// -----
+// CHECK-LABEL: @extract_scalar_poison
+func.func @extract_scalar_poison() -> f32 {
+ // CHECK-NEXT: ub.poison : f32
+ // CHECK-NOT: vector.extract
+ %0 = ub.poison : vector<4x8xf32>
+ %1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison
+func.func @extract_vector_poison() -> vector<8xf32> {
+ // CHECK-NEXT: ub.poison : vector<8xf32>
+ // CHECK-NOT: vector.extract
+ %0 = ub.poison : vector<4x8xf32>
+ %1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32>
+ return %1 : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: @extract_scalar_poison_idx
func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
// CHECK-NOT: vector.extract
@@ -2886,6 +2908,37 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
return %1 : vector<4xi8>
}
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_scalar_poison
+func.func @insert_scalar_poison(%a: vector<4x8xf32>)
+ -> vector<4x8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
+ // CHECK-NEXT: vector.insert %[[UB]]
+ %0 = ub.poison : f32
+ %1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32>
+ return %1 : vector<4x8xf32>
+}
+
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_vector_poison
+func.func @insert_vector_poison(%a: vector<4x8xf32>)
+ -> vector<4x8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
+ // CHECK-NEXT: vector.insert %[[UB]]
+ %0 = ub.poison : vector<8xf32>
+ %1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32>
+ return %1 : vector<4x8xf32>
+}
+
+
// -----
// CHECK-LABEL: @insert_scalar_poison_idx
>From 83a1404594f323cf96a85c10813612793d40a2c6 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Fri, 7 Feb 2025 10:02:48 -0800
Subject: [PATCH 2/2] Feedback
---
mlir/test/Dialect/Vector/canonicalize.mlir | 30 ++++++++++++++--------
1 file changed, 20 insertions(+), 10 deletions(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d016c2efa142628..a74e562ad2f68d7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -134,8 +134,9 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
// CHECK-LABEL: @extract_scalar_poison
func.func @extract_scalar_poison() -> f32 {
- // CHECK-NEXT: ub.poison : f32
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
// CHECK-NOT: vector.extract
+ // CHECK-NEXT: return %[[UB]] : f32
%0 = ub.poison : vector<4x8xf32>
%1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32>
return %1 : f32
@@ -145,8 +146,9 @@ func.func @extract_scalar_poison() -> f32 {
// CHECK-LABEL: @extract_vector_poison
func.func @extract_vector_poison() -> vector<8xf32> {
- // CHECK-NEXT: ub.poison : vector<8xf32>
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
// CHECK-NOT: vector.extract
+ // CHECK-NEXT: return %[[UB]] : vector<8xf32>
%0 = ub.poison : vector<4x8xf32>
%1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32>
return %1 : vector<8xf32>
@@ -156,8 +158,9 @@ func.func @extract_vector_poison() -> vector<8xf32> {
// CHECK-LABEL: @extract_scalar_poison_idx
func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
// CHECK-NOT: vector.extract
- // CHECK-NEXT: ub.poison : f32
+ // CHECK-NEXT: return %[[UB]] : f32
%0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
return %0 : f32
}
@@ -166,8 +169,9 @@ func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
// CHECK-LABEL: @extract_vector_poison_idx
func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<5xf32>
// CHECK-NOT: vector.extract
- // CHECK-NEXT: ub.poison : vector<5xf32>
+ // CHECK-NEXT: return %[[UB]] : vector<5xf32>
%0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32>
return %0 : vector<5xf32>
}
@@ -177,8 +181,9 @@ func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
// CHECK-LABEL: @extract_multiple_poison_idx
func.func @extract_multiple_poison_idx(%a: vector<4x5x8xf32>)
-> vector<8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
// CHECK-NOT: vector.extract
- // CHECK-NEXT: ub.poison : vector<8xf32>
+ // CHECK-NEXT: return %[[UB]] : vector<8xf32>
%0 = vector.extract %a[-1, -1] : vector<8xf32> from vector<4x5x8xf32>
return %0 : vector<8xf32>
}
@@ -2917,7 +2922,8 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
func.func @insert_scalar_poison(%a: vector<4x8xf32>)
-> vector<4x8xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
- // CHECK-NEXT: vector.insert %[[UB]]
+ // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[UB]]
+ // CHECK-NEXT: return %[[RES]] : vector<4x8xf32>
%0 = ub.poison : f32
%1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32>
return %1 : vector<4x8xf32>
@@ -2932,7 +2938,8 @@ func.func @insert_scalar_poison(%a: vector<4x8xf32>)
func.func @insert_vector_poison(%a: vector<4x8xf32>)
-> vector<4x8xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
- // CHECK-NEXT: vector.insert %[[UB]]
+ // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[UB]]
+ // CHECK-NEXT: return %[[RES]] : vector<4x8xf32>
%0 = ub.poison : vector<8xf32>
%1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32>
return %1 : vector<4x8xf32>
@@ -2944,8 +2951,9 @@ func.func @insert_vector_poison(%a: vector<4x8xf32>)
// CHECK-LABEL: @insert_scalar_poison_idx
func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
-> vector<4x5xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5xf32>
// CHECK-NOT: vector.insert
- // CHECK-NEXT: ub.poison : vector<4x5xf32>
+ // CHECK-NEXT: return %[[UB]] : vector<4x5xf32>
%0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
return %0 : vector<4x5xf32>
}
@@ -2955,8 +2963,9 @@ func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
// CHECK-LABEL: @insert_vector_poison_idx
func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
-> vector<4x5xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5xf32>
// CHECK-NOT: vector.insert
- // CHECK-NEXT: ub.poison : vector<4x5xf32>
+ // CHECK-NEXT: return %[[UB]] : vector<4x5xf32>
%0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32>
return %0 : vector<4x5xf32>
}
@@ -2966,8 +2975,9 @@ func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
// CHECK-LABEL: @insert_multiple_poison_idx
func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>)
-> vector<4x5x8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5x8xf32>
// CHECK-NOT: vector.insert
- // CHECK-NEXT: ub.poison : vector<4x5x8xf32>
+ // CHECK-NEXT: return %[[UB]] : vector<4x5x8xf32>
%0 = vector.insert %b, %a[-1, -1] : vector<8xf32> into vector<4x5x8xf32>
return %0 : vector<4x5x8xf32>
}
More information about the Mlir-commits
mailing list