[Mlir-commits] [mlir] [mlir][Vector] Add `vector.shuffle` fold for poison inputs (PR #125608)

Diego Caballero llvmlistbot at llvm.org
Mon Feb 3 16:37:04 PST 2025


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/125608

https://github.com/llvm/llvm-project/pull/124863 added folding support for poison indices to `vector.shuffle`. This PR adds support for folding `vector.shuffle` ops with one or two poison input vectors.

>From ece1c6dcc7803b8502bb8a901827b6637545f19b Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Mon, 3 Feb 2025 15:37:05 -0800
Subject: [PATCH] [mlir][Vector] Add `vector.shuffle` fold for poison inputs

We recently added folding support for poison indices to `vector.shuffle`.
This PR adds support for folding poison inputs.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 43 ++++++++++++++++------
 mlir/test/Dialect/Vector/canonicalize.mlir | 39 ++++++++++++++++++++
 2 files changed, 71 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 93f89eda2da5a6b..8d5691f38f273c2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -26,7 +26,6 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/IRMapping.h"
@@ -42,7 +41,6 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/ADT/bit.h"
 
 #include <cassert>
 #include <cstdint>
@@ -2696,25 +2694,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
   if (!v1Attr || !v2Attr)
     return {};
 
+  // Fold shuffle poison, poison -> poison.
+  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
+  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
+  if (isV1Poison && isV2Poison)
+    return ub::PoisonAttr::get(getContext());
+
   // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
   // manipulation.
   if (v1Type.getRank() != 1)
     return {};
 
-  int64_t v1Size = v1Type.getDimSize(0);
+  // Poison input attributes need special handling as they are not
+  // DenseElementsAttr. If an index is poison, we select the first element of
+  // the first non-poison input.
+  SmallVector<Attribute> v1Elements, v2Elements;
+  Attribute poisonElement;
+  if (!isV2Poison) {
+    v2Elements =
+        to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
+    poisonElement = v2Elements[0];
+  }
+  if (!isV1Poison) {
+    v1Elements =
+        to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
+    poisonElement = v1Elements[0];
+  }
 
   SmallVector<Attribute> results;
-  auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
-  auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
+  int64_t v1Size = v1Type.getDimSize(0);
   for (int64_t maskIdx : mask) {
     Attribute indexedElm;
-    // Select v1[0] for poison indices.
     // TODO: Return a partial poison vector when supported by the UB dialect.
     if (maskIdx == ShuffleOp::kPoisonIndex) {
-      indexedElm = v1Elements[0];
+      indexedElm = poisonElement;
     } else {
-      indexedElm =
-          maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
+      if (maskIdx < v1Size)
+        indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
+      else
+        indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
     }
 
     results.push_back(indexedElm);
@@ -3332,13 +3350,15 @@ class InsertStridedSliceConstantFolder final
         !destVector.hasOneUse())
       return failure();
 
-    auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
-
     TypedValue<VectorType> sourceValue = op.getSource();
     Attribute sourceCst;
     if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
       return failure();
 
+    // TODO: Support poison.
+    if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
+      return failure();
+
     // TODO: Handle non-unit strides when they become available.
     if (op.hasNonUnitStrides())
       return failure();
@@ -3355,6 +3375,7 @@ class InsertStridedSliceConstantFolder final
     // increasing linearized position indices.
     // Because the destination may have higher dimensionality then the slice,
     // we keep track of two overlapping sets of positions and offsets.
+    auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
     auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
     auto sliceValuesIt = denseSlice.value_begin<Attribute>();
     auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6858f0d56e64128..65c3ab264283d2d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2023,6 +2023,45 @@ func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
 
 // -----
 
+// CHECK-LABEL: func @shuffle_1d_rhs_lhs_poison
+//   CHECK-NOT:   vector.shuffle
+//       CHECK:   %[[V:.+]] = ub.poison : vector<4xi32>
+//       CHECK:   return %[[V]]
+func.func @shuffle_1d_rhs_lhs_poison() -> vector<4xi32> {
+  %v0 = ub.poison : vector<3xi32>
+  %v1 = ub.poison : vector<3xi32>
+  %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_lhs_poison
+//   CHECK-NOT:   vector.shuffle
+//       CHECK:   %[[V:.+]] = arith.constant dense<[5, 4, 5, 5]> : vector<4xi32>
+//       CHECK:   return %[[V]]
+func.func @shuffle_1d_lhs_poison() -> vector<4xi32> {
+  %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
+  %v1 = ub.poison : vector<3xi32>
+  %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_rhs_poison
+//   CHECK-NOT:   vector.shuffle
+//       CHECK:   %[[V:.+]] = arith.constant dense<[2, 2, 0, 1]> : vector<4xi32>
+//       CHECK:   return %[[V]]
+func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
+  %v0 = ub.poison : vector<3xi32>
+  %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
+  %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
   // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>



More information about the Mlir-commits mailing list