[Mlir-commits] [mlir] [mlir][Vector] Add `vector.shuffle` fold for poison inputs (PR #125608)
Diego Caballero
llvmlistbot at llvm.org
Mon Feb 3 16:37:04 PST 2025
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/125608
https://github.com/llvm/llvm-project/pull/124863 added folding support for poison indices to `vector.shuffle`. This PR adds support for folding `vector.shuffle` ops with one or two poison input vectors.
>From ece1c6dcc7803b8502bb8a901827b6637545f19b Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Mon, 3 Feb 2025 15:37:05 -0800
Subject: [PATCH] [mlir][Vector] Add `vector.shuffle` fold for poison inputs
We recently added folding support for poison indices to `vector.shuffle`.
This PR adds support for folding poison inputs.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 43 ++++++++++++++++------
mlir/test/Dialect/Vector/canonicalize.mlir | 39 ++++++++++++++++++++
2 files changed, 71 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 93f89eda2da5a6b..8d5691f38f273c2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -26,7 +26,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
@@ -42,7 +41,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/ADT/bit.h"
#include <cassert>
#include <cstdint>
@@ -2696,25 +2694,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
if (!v1Attr || !v2Attr)
return {};
+ // Fold shuffle poison, poison -> poison.
+ bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
+ bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
+ if (isV1Poison && isV2Poison)
+ return ub::PoisonAttr::get(getContext());
+
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
// manipulation.
if (v1Type.getRank() != 1)
return {};
- int64_t v1Size = v1Type.getDimSize(0);
+ // Poison input attributes need special handling as they are not
+ // DenseElementsAttr. If an index is poison, we select the first element of
+ // the first non-poison input.
+ SmallVector<Attribute> v1Elements, v2Elements;
+ Attribute poisonElement;
+ if (!isV2Poison) {
+ v2Elements =
+ to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
+ poisonElement = v2Elements[0];
+ }
+ if (!isV1Poison) {
+ v1Elements =
+ to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
+ poisonElement = v1Elements[0];
+ }
SmallVector<Attribute> results;
- auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
- auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
+ int64_t v1Size = v1Type.getDimSize(0);
for (int64_t maskIdx : mask) {
Attribute indexedElm;
- // Select v1[0] for poison indices.
// TODO: Return a partial poison vector when supported by the UB dialect.
if (maskIdx == ShuffleOp::kPoisonIndex) {
- indexedElm = v1Elements[0];
+ indexedElm = poisonElement;
} else {
- indexedElm =
- maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
+ if (maskIdx < v1Size)
+ indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
+ else
+ indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
}
results.push_back(indexedElm);
@@ -3332,13 +3350,15 @@ class InsertStridedSliceConstantFolder final
!destVector.hasOneUse())
return failure();
- auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
-
TypedValue<VectorType> sourceValue = op.getSource();
Attribute sourceCst;
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
return failure();
+ // TODO: Support poison.
+ if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
+ return failure();
+
// TODO: Handle non-unit strides when they become available.
if (op.hasNonUnitStrides())
return failure();
@@ -3355,6 +3375,7 @@ class InsertStridedSliceConstantFolder final
// increasing linearized position indices.
// Because the destination may have higher dimensionality then the slice,
// we keep track of two overlapping sets of positions and offsets.
+ auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
auto sliceValuesIt = denseSlice.value_begin<Attribute>();
auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6858f0d56e64128..65c3ab264283d2d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2023,6 +2023,45 @@ func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
// -----
+// CHECK-LABEL: func @shuffle_1d_rhs_lhs_poison
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[V:.+]] = ub.poison : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_rhs_lhs_poison() -> vector<4xi32> {
+ %v0 = ub.poison : vector<3xi32>
+ %v1 = ub.poison : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_lhs_poison
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[V:.+]] = arith.constant dense<[5, 4, 5, 5]> : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_lhs_poison() -> vector<4xi32> {
+ %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
+ %v1 = ub.poison : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_rhs_poison
+// CHECK-NOT: vector.shuffle
+// CHECK: %[[V:.+]] = arith.constant dense<[2, 2, 0, 1]> : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
+ %v0 = ub.poison : vector<3xi32>
+ %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
More information about the Mlir-commits
mailing list