[Mlir-commits] [mlir] [mlir][Vector] Add support for poison indices to `Extract/IndexOp` (PR #123488)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 18 15:24:03 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

Following up on #<!-- -->122188, this PR adds support for poison indices to `ExtractOp` and `InsertOp`. It also includes canonicalization patterns to turn extract/insert ops with poison indices into `ub.poison`.

---
Full diff: https://github.com/llvm/llvm-project/pull/123488.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/Vector.td (+10-1) 
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+5-18) 
- (modified) mlir/include/mlir/Transforms/Passes.td (+1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+32-3) 
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1) 
- (modified) mlir/lib/Transforms/Canonicalizer.cpp (+1) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+42-1) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+2-2) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+14) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
index c439ca083e2e09..1922cc63ef3538 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
@@ -26,6 +26,15 @@ def Vector_Dialect : Dialect {
 
 // Base class for Vector dialect ops.
 class Vector_Op<string mnemonic, list<Trait> traits = []> :
-    Op<Vector_Dialect, mnemonic, traits>;
+    Op<Vector_Dialect, mnemonic, traits> {
+
+  // Includes definitions for operations that support the use of poison values
+  // within positive index ranges.
+  code extraPoisonClassDeclaration = [{
+    // Integer to represent a poison index within a static and positive integer
+    // range.
+    static constexpr int64_t kPoisonIndex = -1;
+  }];
+}
 
 #endif // MLIR_DIALECT_VECTOR_IR_VECTOR
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 4331eda1661960..c57e3dd13233c1 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -469,10 +469,7 @@ def Vector_ShuffleOp
     ```
   }];
 
-  let extraClassDeclaration = [{
-    // Integer to represent a poison value in a vector shuffle mask.
-    static constexpr int64_t kMaskPoisonValue = -1;
-
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getV1VectorType() {
       return ::llvm::cast<VectorType>(getV1().getType());
     }
@@ -706,8 +703,6 @@ def Vector_ExtractOp :
     %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
     %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let arguments = (ins
@@ -724,7 +719,7 @@ def Vector_ExtractOp :
     OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$position)>,
   ];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getVector().getType());
     }
@@ -898,8 +893,6 @@ def Vector_InsertOp :
     %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
     %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let arguments = (ins
@@ -917,7 +910,7 @@ def Vector_InsertOp :
     OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
   ];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     Type getSourceType() { return getSource().getType(); }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
@@ -990,15 +983,13 @@ def Vector_ScalableInsertOp :
     ```mlir
     %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
     $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
   }];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getSource().getType());
     }
@@ -1043,15 +1034,13 @@ def Vector_ScalableExtractOp :
     ```mlir
     %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
     $source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
   }];
 
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getSource().getType());
     }
@@ -1089,8 +1078,6 @@ def Vector_InsertStridedSliceOp :
         {offsets = [0, 0, 2], strides = [1, 1]}:
       vector<2x4xf32> into vector<16x4x8xf32>
     ```
-
-    TODO: Implement support for poison indices.
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index c4a8e7a81fa483..a39ab77fc8fb3b 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> {
     details.
   }];
   let constructor = "mlir::createCanonicalizerPass()";
+  let dependentDialects = ["ub::UBDialect"];
   let options = [
     Option<"topDownProcessingEnabled", "top-down", "bool",
            /*default=*/"true",
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 696d1e0f9b1e68..c30569eb4d2ac8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
   return srcElements[posIdx];
 }
 
+// Returns `true` if `index` is either within [0, maxIndex) or equal to
+// `poisonValue`.
+static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
+                                         int64_t maxIndex) {
+  return index == poisonValue || (index >= 0 && index < maxIndex);
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractOp
 //===----------------------------------------------------------------------===//
@@ -1355,7 +1363,8 @@ LogicalResult vector::ExtractOp::verify() {
   for (auto [idx, pos] : llvm::enumerate(position)) {
     if (auto attr = dyn_cast<Attribute>(pos)) {
       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
-      if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
+      if (!isValidPositiveIndexOrPoison(
+              constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
                << " to be a non-negative integer smaller than the "
@@ -2249,6 +2258,23 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
                                          resultType.getNumElements()));
   return success();
 }
+
+/// Fold an insert or extract operation into an poison value when a poison index
+/// is found at any dimension of the static position.
+template <typename OpTy>
+LogicalResult foldPoisonIndexInsertExtractOp(OpTy op,
+                                             PatternRewriter &rewriter) {
+  auto hasPoisonIndex = [](int64_t index) {
+    return index == OpTy::kPoisonIndex;
+  };
+
+  if (llvm::none_of(op.getStaticPosition(), hasPoisonIndex))
+    return failure();
+
+  rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getResult().getType());
+  return success();
+}
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2257,6 +2283,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
               ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
+  results.add(foldPoisonIndexInsertExtractOp<ExtractOp>);
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -2600,7 +2627,7 @@ LogicalResult ShuffleOp::verify() {
   int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
                       (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
   for (auto [idx, maskPos] : llvm::enumerate(mask)) {
-    if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize))
+    if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
       return emitOpError("mask index #") << (idx + 1) << " out of range";
   }
   return success();
@@ -2882,7 +2909,8 @@ LogicalResult InsertOp::verify() {
   for (auto [idx, pos] : llvm::enumerate(position)) {
     if (auto attr = pos.dyn_cast<Attribute>()) {
       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
-      if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
+      if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
+                                        destVectorType.getDimSize(idx))) {
         return emitOpError("expected position attribute #")
                << (idx + 1)
                << " to be a non-negative integer smaller than the "
@@ -3020,6 +3048,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
   results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
               InsertOpConstantFolder>(context);
+  results.add(foldPoisonIndexInsertExtractOp<InsertOp>);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 058039e47313e2..3a8088bccf2994 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms
   MLIRSideEffectInterfaces
   MLIRSupport
   MLIRTransformUtils
+  MLIRUBDialect
   )
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 5f469605070367..7ccd503fb02882 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Transforms/Passes.h"
 
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 89af0f7332f5c4..a010ee32e9d7e0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,26 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
 
 // -----
 
+// CHECK-LABEL: @extract_scalar_poison_idx
+func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
+  //  CHECK-NOT: vector.extract
+  // CHECK-NEXT: ub.poison : f32
+  %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison_idx
+func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
+  //  CHECK-NOT: vector.extract
+  // CHECK-NEXT: ub.poison : vector<5xf32>
+  %0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32>
+  return %0 : vector<5xf32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
 func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
@@ -2778,7 +2798,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
   return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
 }
 
-
 // -----
 
 // CHECK-LABEL: func @vector_insert_const_regression(
@@ -2792,6 +2811,28 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
 
 // -----
 
+// CHECK-LABEL: @insert_scalar_poison_idx
+func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
+    -> vector<4x5xf32> {
+  //  CHECK-NOT: vector.insert
+  // CHECK-NEXT: ub.poison : vector<4x5xf32>
+  %0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
+  return %0 : vector<4x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_vector_poison_idx
+func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
+    -> vector<4x5xf32> {
+  //  CHECK-NOT: vector.insert
+  // CHECK-NEXT: ub.poison : vector<4x5xf32>
+  %0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32>
+  return %0 : vector<4x5xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
 // CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
 // CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1a70791fae1257..9416f4787eefbb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -187,7 +187,7 @@ func.func @extract_0d(%arg0: vector<f32>) {
 
 func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
-  %1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32>
+  %1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>
 }
 
 // -----
@@ -247,7 +247,7 @@ func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) {
 
 func.func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}}
-  %1 = vector.insert %a, %b[0, 0, -1] : f32 into vector<4x8x16xf32>
+  %1 = vector.insert %a, %b[0, 0, -5] : f32 into vector<4x8x16xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index cd6f3f518a1c07..67484e06f456dc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -247,6 +247,13 @@ func.func @extract_0d(%a: vector<f32>) -> f32 {
   return %0 : f32
 }
 
+// CHECK-LABEL: @extract_poison_idx
+func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 {
+  // CHECK-NEXT: vector.extract %{{.*}}[-1, 0] : f32 from vector<4x5xf32>
+  %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
+  return %0 : f32
+}
+
 // CHECK-LABEL: @insert_element_0d
 func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32>
@@ -299,6 +306,13 @@ func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f
   return %1, %2 : vector<f32>, vector<2x3xf32>
 }
 
+// CHECK-LABEL: @insert_poison_idx
+func.func @insert_poison_idx(%a: vector<4x5xf32>, %b: f32) {
+  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[-1, 0] : f32 into vector<4x5xf32>
+  vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
+  return
+}
+
 // CHECK-LABEL: @outerproduct
 func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> {
   // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/123488


More information about the Mlir-commits mailing list