[Mlir-commits] [mlir] [mlir][spirv] Support poison index when converting vector.insert/extract (PR #125560)

Andrea Faulds llvmlistbot at llvm.org
Tue Feb 4 10:48:27 PST 2025


https://github.com/andfau-amd updated https://github.com/llvm/llvm-project/pull/125560

>From 03a81e7489606e475ac1aee856caf3f50b6bdcda Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Tue, 4 Feb 2025 19:47:27 +0100
Subject: [PATCH] [mlir][spirv] Support poison index when converting
 vector.insert/extract

This modifies the conversion patterns so that, in the case where the
index is known statically to be poison, the insertion/extraction is
replaced by an arbitrary junk constant value, and in the dynamic case,
the index is sanitized at runtime. This avoids triggering a UB in both
cases. The dynamic case is definitely a pessimisation of the generated
code, but the use of dynamic indexes is expected to be very rare and
already slow on real-world GPU compilers ingesting SPIR-V, so the impact
should be negligible.

Resolves #124162.
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 77 +++++++++++++++----
 .../VectorToSPIRV/vector-to-spirv.mlir        | 47 ++++++++++-
 2 files changed, 107 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index af882cb1ca6e91..2c8bc149dc708d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -137,6 +137,33 @@ struct VectorBroadcastConvert final
   }
 };
 
+// SPIR-V does not have a concept of a poison index for certain instructions,
+// which creates a UB hazard when lowering from otherwise equivalent Vector
+// dialect instructions, because this index will be considered out-of-bounds.
+// To avoid this, this function implements a dynamic sanitization that returns
+// some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
+// (presumably more efficient), and otherwise index 0 (always in-bounds).
+static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
+                                  Location loc, Value dynamicIndex,
+                                  int64_t kPoisonIndex, unsigned vectorSize) {
+  if (llvm::isPowerOf2_32(vectorSize)) {
+    Value inBoundsMask = rewriter.create<spirv::ConstantOp>(
+        loc, dynamicIndex.getType(),
+        rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
+    return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex,
+                                                inBoundsMask);
+  }
+  Value poisonIndex = rewriter.create<spirv::ConstantOp>(
+      loc, dynamicIndex.getType(),
+      rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
+  Value cmpResult =
+      rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
+  return rewriter.create<spirv::SelectOp>(
+      loc, cmpResult,
+      spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
+      dynamicIndex);
+}
+
 struct VectorExtractOpConvert final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -154,14 +181,26 @@ struct VectorExtractOpConvert final
     }
 
     if (std::optional<int64_t> id =
-            getConstantIntValue(extractOp.getMixedPosition()[0]))
-      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-          extractOp, dstType, adaptor.getVector(),
-          rewriter.getI32ArrayAttr(id.value()));
-    else
+            getConstantIntValue(extractOp.getMixedPosition()[0])) {
+      // TODO: ExtractOp::fold() already can fold a static poison index to
+      //       ub.poison; remove this once ub.poison can be converted to SPIR-V.
+      if (id == vector::ExtractOp::kPoisonIndex) {
+        // Arbitrary choice of poison result, intended to stick out.
+        Value zero =
+            spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
+        rewriter.replaceOp(extractOp, zero);
+      } else
+        rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+            extractOp, dstType, adaptor.getVector(),
+            rewriter.getI32ArrayAttr(id.value()));
+    } else {
+      Value sanitizedIndex = sanitizeDynamicIndex(
+          rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
+          vector::ExtractOp::kPoisonIndex,
+          extractOp.getSourceVectorType().getNumElements());
       rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
-          extractOp, dstType, adaptor.getVector(),
-          adaptor.getDynamicPosition()[0]);
+          extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+    }
     return success();
   }
 };
@@ -266,13 +305,25 @@ struct VectorInsertOpConvert final
     }
 
     if (std::optional<int64_t> id =
-            getConstantIntValue(insertOp.getMixedPosition()[0]))
-      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
-    else
+            getConstantIntValue(insertOp.getMixedPosition()[0])) {
+      // TODO: ExtractOp::fold() already can fold a static poison index to
+      //       ub.poison; remove this once ub.poison can be converted to SPIR-V.
+      if (id == vector::InsertOp::kPoisonIndex) {
+        // Arbitrary choice of poison result, intended to stick out.
+        Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
+                                                insertOp.getLoc(), rewriter);
+        rewriter.replaceOp(insertOp, zero);
+      } else
+        rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+            insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+    } else {
+      Value sanitizedIndex = sanitizeDynamicIndex(
+          rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
+          vector::InsertOp::kPoisonIndex,
+          insertOp.getDestVectorType().getNumElements());
       rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
-          insertOp, insertOp.getDest(), adaptor.getSource(),
-          adaptor.getDynamicPosition()[0]);
+          insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+    }
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 383215c016039a..5fd7324b1d3c73 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
 // -----
 
 func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
-  // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
+  // CHECK: return %[[ZERO]]
   %0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
   return %0: f32
 }
@@ -208,12 +209,31 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
 // CHECK-LABEL: @extract_dynamic
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
 //       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
-//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+//       CHECK:   %[[MASK:.+]] = spirv.Constant 3 :
+//       CHECK:   %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] :
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[MASKED]]] : vector<4xf32>, i32
 func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
   %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
   return %0: f32
 }
 
+// -----
+
+// CHECK-LABEL: @extract_dynamic_non_pow2
+//  CHECK-SAME: %[[V:.*]]: vector<3xf32>, %[[ARG1:.*]]: index
+//       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+//       CHECK:   %[[POISON:.+]] = spirv.Constant -1 :
+//       CHECK:   %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+//       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 :
+//       CHECK:   %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<3xf32>, i32
+func.func @extract_dynamic_non_pow2(%arg0 : vector<3xf32>, %id : index) -> f32 {
+  %0 = vector.extract %arg0[%id] : f32 from vector<3xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @extract_dynamic_cst
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>
 //       CHECK:   spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
@@ -264,8 +284,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: @insert_poison_idx
+// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
+// CHECK: return %[[ZERO]]
 func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
-  // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
   %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
   return %1: vector<4xf32>
 }
@@ -306,7 +328,9 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
 // CHECK-LABEL: @insert_dynamic
 //  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
 //       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
-//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+//       CHECK:   %[[MASK:.+]] = spirv.Constant 3 :
+//       CHECK:   %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] :
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[MASKED]]] : vector<4xf32>, i32
 func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
   %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
   return %0: vector<4xf32>
@@ -314,6 +338,21 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect
 
 // -----
 
+// CHECK-LABEL: @insert_dynamic_non_pow2
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<3xf32>, %[[ARG2:.*]]: index
+//       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+//       CHECK:   %[[POISON:.+]] = spirv.Constant -1 :
+//       CHECK:   %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+//       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 :
+//       CHECK:   %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<3xf32>, i32
+func.func @insert_dynamic_non_pow2(%val: f32, %arg0 : vector<3xf32>, %id : index) -> vector<3xf32> {
+  %0 = vector.insert %val, %arg0[%id] : f32 into vector<3xf32>
+  return %0: vector<3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @insert_dynamic_cst
 //  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
 //       CHECK:   spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>



More information about the Mlir-commits mailing list