[Mlir-commits] [mlir] c6eef00 - [mlir][Vector] Add `vector.shuffle` fold for poison inputs (#125608)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 4 18:03:30 PST 2025
Author: Diego Caballero
Date: 2025-02-04T18:03:26-08:00
New Revision: c6eef00a096e6f3176b8fce84ce4cef6c6e2af5f
URL: https://github.com/llvm/llvm-project/commit/c6eef00a096e6f3176b8fce84ce4cef6c6e2af5f
DIFF: https://github.com/llvm/llvm-project/commit/c6eef00a096e6f3176b8fce84ce4cef6c6e2af5f.diff
LOG: [mlir][Vector] Add `vector.shuffle` fold for poison inputs (#125608)
https://github.com/llvm/llvm-project/pull/124863 added folding support
for poison indices to `vector.shuffle`. This PR adds support for folding
`vector.shuffle` ops with one or two poison input vectors.
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2ec1b97f2f241d1..7a10d2f2c0dfc3f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -26,7 +26,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
@@ -42,7 +41,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/ADT/bit.h"
#include <cassert>
#include <cstdint>
@@ -2684,25 +2682,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
if (!v1Attr || !v2Attr)
return {};
+ // Fold shuffle poison, poison -> poison.
+ bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
+ bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
+ if (isV1Poison && isV2Poison)
+ return ub::PoisonAttr::get(getContext());
+
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
// manipulation.
if (v1Type.getRank() != 1)
return {};
- int64_t v1Size = v1Type.getDimSize(0);
+ // Poison input attributes need special handling as they are not
+ // DenseElementsAttr. If an index is poison, we select the first element of
+ // the first non-poison input.
+ SmallVector<Attribute> v1Elements, v2Elements;
+ Attribute poisonElement;
+ if (!isV2Poison) {
+ v2Elements =
+ to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
+ poisonElement = v2Elements[0];
+ }
+ if (!isV1Poison) {
+ v1Elements =
+ to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
+ poisonElement = v1Elements[0];
+ }
SmallVector<Attribute> results;
- auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
- auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
+ int64_t v1Size = v1Type.getDimSize(0);
for (int64_t maskIdx : mask) {
Attribute indexedElm;
- // Select v1[0] for poison indices.
// TODO: Return a partial poison vector when supported by the UB dialect.
if (maskIdx == ShuffleOp::kPoisonIndex) {
- indexedElm = v1Elements[0];
+ indexedElm = poisonElement;
} else {
- indexedElm =
- maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
+ if (maskIdx < v1Size)
+ indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
+ else
+ indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
}
results.push_back(indexedElm);
@@ -3319,13 +3337,15 @@ class InsertStridedSliceConstantFolder final
!destVector.hasOneUse())
return failure();
- auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
-
TypedValue<VectorType> sourceValue = op.getSource();
Attribute sourceCst;
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
return failure();
+ // TODO: Support poison.
+ if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
+ return failure();
+
// TODO: Handle non-unit strides when they become available.
if (op.hasNonUnitStrides())
return failure();
@@ -3342,6 +3362,7 @@ class InsertStridedSliceConstantFolder final
// increasing linearized position indices.
// Because the destination may have higher dimensionality then the slice,
// we keep track of two overlapping sets of positions and offsets.
+ auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
auto sliceValuesIt = denseSlice.value_begin<Attribute>();
auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6858f0d56e64128..61e858f5f226a13 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2012,17 +2012,56 @@ func.func @shuffle_1d() -> vector<4xi32> {
// input vector. That is, %v[0] (i.e., 5) in this test.
// CHECK-LABEL: func @shuffle_1d_poison_idx
-// CHECK: %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32>
+// CHECK: %[[V:.+]] = arith.constant dense<[13, 10, 15, 10]> : vector<4xi32>
// CHECK: return %[[V]]
func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
- %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
- %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
+ %v0 = arith.constant dense<[10, 11, 12]> : vector<3xi32>
+ %v1 = arith.constant dense<[13, 14, 15]> : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32>
return %shuffle : vector<4xi32>
}
// -----
+// CHECK-LABEL: func @shuffle_1d_rhs_lhs_poison
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[V:.+]] = ub.poison : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_rhs_lhs_poison() -> vector<4xi32> {
+ %v0 = ub.poison : vector<3xi32>
+ %v1 = ub.poison : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_lhs_poison
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[V:.+]] = arith.constant dense<[11, 12, 11, 11]> : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_lhs_poison() -> vector<4xi32> {
+ %v0 = arith.constant dense<[11, 12, 13]> : vector<3xi32>
+ %v1 = ub.poison : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_rhs_poison
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[V:.+]] = arith.constant dense<[11, 11, 13, 12]> : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
+ %v0 = ub.poison : vector<3xi32>
+ %v1 = arith.constant dense<[11, 12, 13]> : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
More information about the Mlir-commits
mailing list