[Mlir-commits] [mlir] [mlir][spirv] Support poison index when converting vector.insert/extract (PR #125560)
Andrea Faulds
llvmlistbot at llvm.org
Tue Feb 4 10:48:27 PST 2025
https://github.com/andfau-amd updated https://github.com/llvm/llvm-project/pull/125560
>From 03a81e7489606e475ac1aee856caf3f50b6bdcda Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Tue, 4 Feb 2025 19:47:27 +0100
Subject: [PATCH] [mlir][spirv] Support poison index when converting
vector.insert/extract
This modifies the conversion patterns so that, in the case where the
index is known statically to be poison, the insertion/extraction is
replaced by an arbitrary junk constant value, and in the dynamic case,
the index is sanitized at runtime. This avoids triggering a UB in both
cases. The dynamic case is definitely a pessimisation of the generated
code, but the use of dynamic indexes is expected to be very rare and
already slow on real-world GPU compilers ingesting SPIR-V, so the impact
should be negligible.
Resolves #124162.
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 77 +++++++++++++++----
.../VectorToSPIRV/vector-to-spirv.mlir | 47 ++++++++++-
2 files changed, 107 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index af882cb1ca6e91..2c8bc149dc708d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -137,6 +137,33 @@ struct VectorBroadcastConvert final
}
};
+// SPIR-V does not have a concept of a poison index for certain instructions,
+// which creates a UB hazard when lowering from otherwise equivalent Vector
+// dialect instructions, because this index will be considered out-of-bounds.
+// To avoid this, this function implements a dynamic sanitization that returns
+// some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
+// (presumably more efficient), and otherwise index 0 (always in-bounds).
+static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
+ Location loc, Value dynamicIndex,
+ int64_t kPoisonIndex, unsigned vectorSize) {
+ if (llvm::isPowerOf2_32(vectorSize)) {
+ Value inBoundsMask = rewriter.create<spirv::ConstantOp>(
+ loc, dynamicIndex.getType(),
+ rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
+ return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex,
+ inBoundsMask);
+ }
+ Value poisonIndex = rewriter.create<spirv::ConstantOp>(
+ loc, dynamicIndex.getType(),
+ rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
+ Value cmpResult =
+ rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
+ return rewriter.create<spirv::SelectOp>(
+ loc, cmpResult,
+ spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
+ dynamicIndex);
+}
+
struct VectorExtractOpConvert final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
@@ -154,14 +181,26 @@ struct VectorExtractOpConvert final
}
if (std::optional<int64_t> id =
- getConstantIntValue(extractOp.getMixedPosition()[0]))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, dstType, adaptor.getVector(),
- rewriter.getI32ArrayAttr(id.value()));
- else
+ getConstantIntValue(extractOp.getMixedPosition()[0])) {
+ // TODO: ExtractOp::fold() already can fold a static poison index to
+ // ub.poison; remove this once ub.poison can be converted to SPIR-V.
+ if (id == vector::ExtractOp::kPoisonIndex) {
+ // Arbitrary choice of poison result, intended to stick out.
+ Value zero =
+ spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
+ rewriter.replaceOp(extractOp, zero);
+ } else
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, dstType, adaptor.getVector(),
+ rewriter.getI32ArrayAttr(id.value()));
+ } else {
+ Value sanitizedIndex = sanitizeDynamicIndex(
+ rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
+ vector::ExtractOp::kPoisonIndex,
+ extractOp.getSourceVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, dstType, adaptor.getVector(),
- adaptor.getDynamicPosition()[0]);
+ extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+ }
return success();
}
};
@@ -266,13 +305,25 @@ struct VectorInsertOpConvert final
}
if (std::optional<int64_t> id =
- getConstantIntValue(insertOp.getMixedPosition()[0]))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
- else
+ getConstantIntValue(insertOp.getMixedPosition()[0])) {
+ // TODO: ExtractOp::fold() already can fold a static poison index to
+ // ub.poison; remove this once ub.poison can be converted to SPIR-V.
+ if (id == vector::InsertOp::kPoisonIndex) {
+ // Arbitrary choice of poison result, intended to stick out.
+ Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
+ insertOp.getLoc(), rewriter);
+ rewriter.replaceOp(insertOp, zero);
+ } else
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ } else {
+ Value sanitizedIndex = sanitizeDynamicIndex(
+ rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
+ vector::InsertOp::kPoisonIndex,
+ insertOp.getDestVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, insertOp.getDest(), adaptor.getSource(),
- adaptor.getDynamicPosition()[0]);
+ insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+ }
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 383215c016039a..5fd7324b1d3c73 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
// -----
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
- // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
+ // CHECK: return %[[ZERO]]
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
return %0: f32
}
@@ -208,12 +209,31 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
// CHECK-LABEL: @extract_dynamic
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
-// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+// CHECK: %[[MASK:.+]] = spirv.Constant 3 :
+// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] :
+// CHECK: spirv.VectorExtractDynamic %[[V]][%[[MASKED]]] : vector<4xf32>, i32
func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
%0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
return %0: f32
}
+// -----
+
+// CHECK-LABEL: @extract_dynamic_non_pow2
+// CHECK-SAME: %[[V:.*]]: vector<3xf32>, %[[ARG1:.*]]: index
+// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
+// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
+// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+// CHECK: spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<3xf32>, i32
+func.func @extract_dynamic_non_pow2(%arg0 : vector<3xf32>, %id : index) -> f32 {
+ %0 = vector.extract %arg0[%id] : f32 from vector<3xf32>
+ return %0: f32
+}
+
+// -----
+
// CHECK-LABEL: @extract_dynamic_cst
// CHECK-SAME: %[[V:.*]]: vector<4xf32>
// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
@@ -264,8 +284,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// -----
+// CHECK-LABEL: @insert_poison_idx
+// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
+// CHECK: return %[[ZERO]]
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
- // expected-error at +1 {{index -1 out of bounds for 'vector<4xf32>'}}
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
return %1: vector<4xf32>
}
@@ -306,7 +328,9 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
// CHECK-LABEL: @insert_dynamic
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
-// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+// CHECK: %[[MASK:.+]] = spirv.Constant 3 :
+// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] :
+// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[MASKED]]] : vector<4xf32>, i32
func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
%0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
return %0: vector<4xf32>
@@ -314,6 +338,21 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect
// -----
+// CHECK-LABEL: @insert_dynamic_non_pow2
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<3xf32>, %[[ARG2:.*]]: index
+// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
+// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
+// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<3xf32>, i32
+func.func @insert_dynamic_non_pow2(%val: f32, %arg0 : vector<3xf32>, %id : index) -> vector<3xf32> {
+ %0 = vector.insert %val, %arg0[%id] : f32 into vector<3xf32>
+ return %0: vector<3xf32>
+}
+
+// -----
+
// CHECK-LABEL: @insert_dynamic_cst
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
More information about the Mlir-commits
mailing list