[Mlir-commits] [mlir] [mlir][Vector] Fix `vector.shuffle` folder for poison indices (PR #124863)

Diego Caballero llvmlistbot at llvm.org
Fri Jan 31 14:44:24 PST 2025


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/124863

>From d0d87adfcdaca6ba7d382e652b74cfeeafb590f8 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 28 Jan 2025 16:20:29 -0800
Subject: [PATCH 1/4] [mlir][Vector] Fix vector.shuffle folder for poison
 indices

This PR fixes the folder of a `vector.shuffle` with constant input vectors
in the presence of a poison index. Partially poison vectors are currently
not supported in UB so the folder select v1[0] for elements indexed by poison.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 50 +++++++++++++---------
 mlir/test/Dialect/Vector/canonicalize.mlir | 24 +++++++++++
 2 files changed, 53 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6a329499c71109..f5b414bed50397 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2673,43 +2673,51 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
 }
 
 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
-  VectorType v1Type = getV1VectorType();
+  auto v1Type = getV1VectorType();
+  auto v2Type = getV2VectorType();
+
+  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
+         "Vector shuffle does not support scalable vectors");
+
   // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
   // but must be a canonicalization into a vector.broadcast.
   if (v1Type.getRank() == 0)
     return {};
 
-  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
-  if (!v1Type.isScalable() &&
-      isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
+  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
+  if (isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
     return getV1();
-  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
-  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
-      isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
-                       getV2VectorType().getDimSize(0)))
+  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
+  if (isStepIndexArray(getMask(), v1Type.getDimSize(0), v2Type.getDimSize(0)))
     return getV2();
 
-  Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
-  if (!lhs || !rhs)
+  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
+  if (!v1Attr || !v2Attr)
     return {};
 
-  auto lhsType =
-      llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
   // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
   // manipulation.
-  if (lhsType.getRank() != 1)
+  if (v1Type.getRank() != 1)
     return {};
-  int64_t lhsSize = lhsType.getDimSize(0);
+
+  int64_t v1Size = v1Type.getDimSize(0);
 
   SmallVector<Attribute> results;
-  auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
-  auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
-  for (int64_t i : this->getMask()) {
-    if (i >= lhsSize) {
-      results.push_back(rhsElements[i - lhsSize]);
-    } else {
-      results.push_back(lhsElements[i]);
+  auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
+  auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
+  for (int64_t maskIdx : this->getMask()) {
+    Attribute indexedElm;
+    // Select v1[0] for poison indices.
+    // TODO: Return a partial poison vector when supported by the UB dialect.
+    if (maskIdx == ShuffleOp::kPoisonIndex) {
+      indexedElm = v1Elements[0];
+    }
+    else {
+      indexedElm =
+          maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
     }
+
+    results.push_back(indexedElm);
   }
 
   return DenseElementsAttr::get(getResultVectorType(), results);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f9e3b772f9f0a2..070135828de901 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2006,6 +2006,20 @@ func.func @shuffle_1d() -> vector<4xi32> {
   return %shuffle : vector<4xi32>
 }
 
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_poison_idx
+//       CHECK:   %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32>
+//       CHECK:   return %[[V]]
+func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
+  %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
+  %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
+  %shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
   // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
@@ -2013,6 +2027,8 @@ func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vect
   return %shuffle : vector<1xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_fold1
 //       CHECK:   %arg0 : vector<4xi32>
 func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> {
@@ -2020,6 +2036,8 @@ func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi
   return %shuffle : vector<4xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_fold2
 //       CHECK:   %arg1 : vector<2xi32>
 func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi32> {
@@ -2027,6 +2045,8 @@ func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi
   return %shuffle : vector<2xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_fold3
 //       CHECK:   return %arg0 : vector<4x5x6xi32>
 func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<4x5x6xi32> {
@@ -2034,6 +2054,8 @@ func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve
   return %shuffle : vector<4x5x6xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_fold4
 //       CHECK:   return %arg1 : vector<2x5x6xi32>
 func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<2x5x6xi32> {
@@ -2041,6 +2063,8 @@ func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve
   return %shuffle : vector<2x5x6xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_nofold1
 //       CHECK:   %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3, 4] : vector<4xi32>, vector<2xi32>
 //       CHECK:   return %[[V]]

>From d8499c031fa9fc1a176dee655217a22d4213f5cd Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Thu, 30 Jan 2025 16:16:47 -0800
Subject: [PATCH 2/4] format

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f5b414bed50397..47d0da0260464a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2711,8 +2711,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
     // TODO: Return a partial poison vector when supported by the UB dialect.
     if (maskIdx == ShuffleOp::kPoisonIndex) {
       indexedElm = v1Elements[0];
-    }
-    else {
+    } else {
       indexedElm =
           maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
     }

>From 81658b013a303f9ea899d74fa8762468b8c9a7fa Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Thu, 30 Jan 2025 16:20:32 -0800
Subject: [PATCH 3/4] Refactor getMask

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 47d0da0260464a..93f89eda2da5a6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2685,10 +2685,11 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
     return {};
 
   // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
-  if (isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
+  auto mask = getMask();
+  if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
     return getV1();
   // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
-  if (isStepIndexArray(getMask(), v1Type.getDimSize(0), v2Type.getDimSize(0)))
+  if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
     return getV2();
 
   Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
@@ -2705,7 +2706,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
   SmallVector<Attribute> results;
   auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
   auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
-  for (int64_t maskIdx : this->getMask()) {
+  for (int64_t maskIdx : mask) {
     Attribute indexedElm;
     // Select v1[0] for poison indices.
     // TODO: Return a partial poison vector when supported by the UB dialect.

>From 5df26fcce57e4c5c11f36739ad2c7cbec693de91 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Fri, 31 Jan 2025 14:27:39 -0800
Subject: [PATCH 4/4] Add comment

---
 mlir/test/Dialect/Vector/canonicalize.mlir | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 070135828de901..6858f0d56e6412 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2008,6 +2008,9 @@ func.func @shuffle_1d() -> vector<4xi32> {
 
 // -----
 
+// Check that poison indices pick the first element of the first non-poison
+// input vector. That is, %v[0] (i.e., 5) in this test.
+
 // CHECK-LABEL: func @shuffle_1d_poison_idx
 //       CHECK:   %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32>
 //       CHECK:   return %[[V]]



More information about the Mlir-commits mailing list