[Mlir-commits] [mlir] c3c3262 - [mlir][Vector] Fix `vector.shuffle` folder for poison indices (#124863)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 31 15:47:51 PST 2025
Author: Diego Caballero
Date: 2025-01-31T15:47:47-08:00
New Revision: c3c326213e80abd6db9da83dbf0ab8452780705c
URL: https://github.com/llvm/llvm-project/commit/c3c326213e80abd6db9da83dbf0ab8452780705c
DIFF: https://github.com/llvm/llvm-project/commit/c3c326213e80abd6db9da83dbf0ab8452780705c.diff
LOG: [mlir][Vector] Fix `vector.shuffle` folder for poison indices (#124863)
This PR fixes the folder of a `vector.shuffle` with constant input
vectors in the presence of a poison index. Partially poison vectors are
currently not supported in UB so the folder select v1[0] for elements
indexed by poison.
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6a329499c71109..93f89eda2da5a6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2673,43 +2673,51 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
}
OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
- VectorType v1Type = getV1VectorType();
+ auto v1Type = getV1VectorType();
+ auto v2Type = getV2VectorType();
+
+ assert(!v1Type.isScalable() && !v2Type.isScalable() &&
+ "Vector shuffle does not support scalable vectors");
+
// For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
// but must be a canonicalization into a vector.broadcast.
if (v1Type.getRank() == 0)
return {};
- // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
- if (!v1Type.isScalable() &&
- isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
+ // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
+ auto mask = getMask();
+ if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
return getV1();
- // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
- if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
- isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
- getV2VectorType().getDimSize(0)))
+ // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
+ if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
return getV2();
- Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
- if (!lhs || !rhs)
+ Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
+ if (!v1Attr || !v2Attr)
return {};
- auto lhsType =
- llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
// manipulation.
- if (lhsType.getRank() != 1)
+ if (v1Type.getRank() != 1)
return {};
- int64_t lhsSize = lhsType.getDimSize(0);
+
+ int64_t v1Size = v1Type.getDimSize(0);
SmallVector<Attribute> results;
- auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
- auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
- for (int64_t i : this->getMask()) {
- if (i >= lhsSize) {
- results.push_back(rhsElements[i - lhsSize]);
+ auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
+ auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
+ for (int64_t maskIdx : mask) {
+ Attribute indexedElm;
+ // Select v1[0] for poison indices.
+ // TODO: Return a partial poison vector when supported by the UB dialect.
+ if (maskIdx == ShuffleOp::kPoisonIndex) {
+ indexedElm = v1Elements[0];
} else {
- results.push_back(lhsElements[i]);
+ indexedElm =
+ maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
}
+
+ results.push_back(indexedElm);
}
return DenseElementsAttr::get(getResultVectorType(), results);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f9e3b772f9f0a2..6858f0d56e6412 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2006,6 +2006,23 @@ func.func @shuffle_1d() -> vector<4xi32> {
return %shuffle : vector<4xi32>
}
+// -----
+
+// Check that poison indices pick the first element of the first non-poison
+// input vector. That is, %v[0] (i.e., 5) in this test.
+
+// CHECK-LABEL: func @shuffle_1d_poison_idx
+// CHECK: %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
+ %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
+ %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
@@ -2013,6 +2030,8 @@ func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vect
return %shuffle : vector<1xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_fold1
// CHECK: %arg0 : vector<4xi32>
func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> {
@@ -2020,6 +2039,8 @@ func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi
return %shuffle : vector<4xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_fold2
// CHECK: %arg1 : vector<2xi32>
func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi32> {
@@ -2027,6 +2048,8 @@ func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi
return %shuffle : vector<2xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_fold3
// CHECK: return %arg0 : vector<4x5x6xi32>
func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<4x5x6xi32> {
@@ -2034,6 +2057,8 @@ func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve
return %shuffle : vector<4x5x6xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_fold4
// CHECK: return %arg1 : vector<2x5x6xi32>
func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<2x5x6xi32> {
@@ -2041,6 +2066,8 @@ func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve
return %shuffle : vector<2x5x6xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_nofold1
// CHECK: %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3, 4] : vector<4xi32>, vector<2xi32>
// CHECK: return %[[V]]
More information about the Mlir-commits
mailing list