[Mlir-commits] [mlir] [mlir][spirv] Support poison index when converting vector.insert/extract (PR #125560)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 3 11:22:47 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrea Faulds (andfau-amd)

<details>
<summary>Changes</summary>

This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible.

Resolves #<!-- -->124162.

---
Full diff: https://github.com/llvm/llvm-project/pull/125560.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+57-13) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+15-4) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index af882cb1ca6e91..3481a2e8b7733c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -137,6 +137,26 @@ struct VectorBroadcastConvert final
   }
 };
 
+// SPIR-V does not have a concept of a poison index for certain instructions,
+// which creates a UB hazard when lowering from otherwise equivalent Vector
+// dialect instructions, because this index will be considered out-of-bounds.
+// To avoid this, this function implements a dynamic sanitization, arbitrarily
+// choosing to replace the poison index with index 0 (always in-bounds).
+static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
+                                  Location loc, Value dynamicIndex,
+                                  int64_t kPoisonIndex) {
+  Value poisonIndex = rewriter.create<spirv::ConstantOp>(
+      loc, dynamicIndex.getType(),
+      rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
+  Value cmpResult =
+      rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
+  Value sanitizedIndex = rewriter.create<spirv::SelectOp>(
+      loc, cmpResult,
+      spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
+      dynamicIndex);
+  return sanitizedIndex;
+}
+
 struct VectorExtractOpConvert final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -154,14 +174,26 @@ struct VectorExtractOpConvert final
     }
 
     if (std::optional<int64_t> id =
-            getConstantIntValue(extractOp.getMixedPosition()[0]))
-      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-          extractOp, dstType, adaptor.getVector(),
-          rewriter.getI32ArrayAttr(id.value()));
-    else
+            getConstantIntValue(extractOp.getMixedPosition()[0])) {
+      // TODO: It would be better to apply the ub.poison folding for this case
+      //       unconditionally, and have a specific SPIR-V lowering for it,
+      //       rather than having to handle it here.
+      if (id == vector::ExtractOp::kPoisonIndex) {
+        // Arbitrary choice of poison result, intended to stick out.
+        Value zero =
+            spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
+        rewriter.replaceOp(extractOp, zero);
+      } else
+        rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+            extractOp, dstType, adaptor.getVector(),
+            rewriter.getI32ArrayAttr(id.value()));
+    } else {
+      Value sanitizedIndex = sanitizeDynamicIndex(
+          rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
+          vector::ExtractOp::kPoisonIndex);
       rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
-          extractOp, dstType, adaptor.getVector(),
-          adaptor.getDynamicPosition()[0]);
+          extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+    }
     return success();
   }
 };
@@ -266,13 +298,25 @@ struct VectorInsertOpConvert final
     }
 
     if (std::optional<int64_t> id =
-            getConstantIntValue(insertOp.getMixedPosition()[0]))
-      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
-    else
+            getConstantIntValue(insertOp.getMixedPosition()[0])) {
+      // TODO: It would be better to apply the ub.poison folding for this case
+      //       unconditionally, and have a specific SPIR-V lowering for it,
+      //       rather than having to handle it here.
+      if (id == vector::InsertOp::kPoisonIndex) {
+        // Arbitrary choice of poison result, intended to stick out.
+        Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
+                                                insertOp.getLoc(), rewriter);
+        rewriter.replaceOp(insertOp, zero);
+      } else
+        rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+            insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+    } else {
+      Value sanitizedIndex = sanitizeDynamicIndex(
+          rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
+          vector::InsertOp::kPoisonIndex);
       rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
-          insertOp, insertOp.getDest(), adaptor.getSource(),
-          adaptor.getDynamicPosition()[0]);
+          insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+    }
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 383215c016039a..35ef759cf24168 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
 // -----
 
 func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
-  // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
+  // CHECK: return %[[ZERO]]
   %0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
   return %0: f32
 }
@@ -208,7 +209,11 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
 // CHECK-LABEL: @extract_dynamic
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
 //       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
-//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+//       CHECK:   %[[POISON:.+]] = spirv.Constant -1 :
+//       CHECK:   %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+//       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 :
+//       CHECK:   %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<4xf32>, i32
 func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
   %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
   return %0: f32
@@ -264,8 +269,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: @insert_poison_idx
+// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
+// CHECK: return %[[ZERO]]
 func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
-  // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
   %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
   return %1: vector<4xf32>
 }
@@ -306,7 +313,11 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
 // CHECK-LABEL: @insert_dynamic
 //  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
 //       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
-//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+//       CHECK:   %[[POISON:.+]] = spirv.Constant -1 :
+//       CHECK:   %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+//       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 :
+//       CHECK:   %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<4xf32>, i32
 func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
   %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
   return %0: vector<4xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/125560


More information about the Mlir-commits mailing list