[Mlir-commits] [mlir] [mlir][spirv] Support poison index when converting vector.insert/extract (PR #125560)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 3 11:22:47 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrea Faulds (andfau-amd)
<details>
<summary>Changes</summary>
This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible.
Resolves #<!-- -->124162.
---
Full diff: https://github.com/llvm/llvm-project/pull/125560.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+57-13)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+15-4)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index af882cb1ca6e91..3481a2e8b7733c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -137,6 +137,26 @@ struct VectorBroadcastConvert final
}
};
+// SPIR-V does not have a concept of a poison index for certain instructions,
+// which creates a UB hazard when lowering from otherwise equivalent Vector
+// dialect instructions, because this index will be considered out-of-bounds.
+// To avoid this, this function implements a dynamic sanitization, arbitrarily
+// choosing to replace the poison index with index 0 (always in-bounds).
+static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
+ Location loc, Value dynamicIndex,
+ int64_t kPoisonIndex) {
+ Value poisonIndex = rewriter.create<spirv::ConstantOp>(
+ loc, dynamicIndex.getType(),
+ rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
+ Value cmpResult =
+ rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
+ Value sanitizedIndex = rewriter.create<spirv::SelectOp>(
+ loc, cmpResult,
+ spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
+ dynamicIndex);
+ return sanitizedIndex;
+}
+
struct VectorExtractOpConvert final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
@@ -154,14 +174,26 @@ struct VectorExtractOpConvert final
}
if (std::optional<int64_t> id =
- getConstantIntValue(extractOp.getMixedPosition()[0]))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, dstType, adaptor.getVector(),
- rewriter.getI32ArrayAttr(id.value()));
- else
+ getConstantIntValue(extractOp.getMixedPosition()[0])) {
+ // TODO: It would be better to apply the ub.poison folding for this case
+ // unconditionally, and have a specific SPIR-V lowering for it,
+ // rather than having to handle it here.
+ if (id == vector::ExtractOp::kPoisonIndex) {
+ // Arbitrary choice of poison result, intended to stick out.
+ Value zero =
+ spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
+ rewriter.replaceOp(extractOp, zero);
+ } else
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, dstType, adaptor.getVector(),
+ rewriter.getI32ArrayAttr(id.value()));
+ } else {
+ Value sanitizedIndex = sanitizeDynamicIndex(
+ rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
+ vector::ExtractOp::kPoisonIndex);
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, dstType, adaptor.getVector(),
- adaptor.getDynamicPosition()[0]);
+ extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+ }
return success();
}
};
@@ -266,13 +298,25 @@ struct VectorInsertOpConvert final
}
if (std::optional<int64_t> id =
- getConstantIntValue(insertOp.getMixedPosition()[0]))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
- else
+ getConstantIntValue(insertOp.getMixedPosition()[0])) {
+ // TODO: It would be better to apply the ub.poison folding for this case
+ // unconditionally, and have a specific SPIR-V lowering for it,
+ // rather than having to handle it here.
+ if (id == vector::InsertOp::kPoisonIndex) {
+ // Arbitrary choice of poison result, intended to stick out.
+ Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
+ insertOp.getLoc(), rewriter);
+ rewriter.replaceOp(insertOp, zero);
+ } else
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ } else {
+ Value sanitizedIndex = sanitizeDynamicIndex(
+ rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
+ vector::InsertOp::kPoisonIndex);
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, insertOp.getDest(), adaptor.getSource(),
- adaptor.getDynamicPosition()[0]);
+ insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+ }
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 383215c016039a..35ef759cf24168 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
// -----
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
- // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
+ // CHECK: return %[[ZERO]]
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
return %0: f32
}
@@ -208,7 +209,11 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
// CHECK-LABEL: @extract_dynamic
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
-// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
+// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
+// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+// CHECK: spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<4xf32>, i32
func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
%0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
return %0: f32
@@ -264,8 +269,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// -----
+// CHECK-LABEL: @insert_poison_idx
+// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
+// CHECK: return %[[ZERO]]
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
- // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
return %1: vector<4xf32>
}
@@ -306,7 +313,11 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
// CHECK-LABEL: @insert_dynamic
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
-// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
+// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
+// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<4xf32>, i32
func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
%0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
return %0: vector<4xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/125560
More information about the Mlir-commits
mailing list