[Mlir-commits] [mlir] [mlir][spirv] Fix some issues related to converting ub.poison to SPIR-V (PR #125905)
Andrea Faulds
llvmlistbot at llvm.org
Thu Feb 6 06:02:42 PST 2025
https://github.com/andfau-amd updated https://github.com/llvm/llvm-project/pull/125905
>From fbb3ac28899286c066901a6ca15c87ac4ade595e Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Thu, 6 Feb 2025 14:56:03 +0100
Subject: [PATCH] [mlir][spirv] Fix some issues related to converting ub.poison
to SPIR-V
This is a follow-up to 5df62bdc9be9c258c5ac45c8093b71e23777fa0e. That
commit should not have needed to make the the vector.insert and
vector.extract conversions to SPIR-V directly handle the static poison
index case, as there is a fold from those to ub.poison, and a conversion
pattern from ub.poison to spirv.Undef, however:
- The ub.poison fold result could not be materialized by the vector
dialect (fixed as of d13940ee263ff50b7a71e21424913cc0266bf9d4).
- The conversion pattern wasn't being populated in VectorToSPIRVPass,
which is used by the tests. This commit changes this.
- The ub.poison to spirv.Undef pattern rejected non-scalar types, which
prevented its use for vector results. It is unclear why this
restriction existed; a remark in D156163 said this was to avoid
converting "user types", but it is not obvious why these shouldn't
be permitted (the SPIR-V specification allows OpUndef for all types
except OpTypeVoid). This commit removes this restriction.
With these fixed, this commit removes the redundant static poison index
handling, and updates the tests.
---
mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp | 5 ---
.../Conversion/VectorToSPIRV/CMakeLists.txt | 1 +
.../VectorToSPIRV/VectorToSPIRV.cpp | 32 +++++++------------
.../VectorToSPIRV/VectorToSPIRVPass.cpp | 3 ++
.../Conversion/UBToSPIRV/ub-to-spirv.mlir | 3 +-
.../VectorToSPIRV/vector-to-spirv.mlir | 10 +++---
6 files changed, 22 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
index a3806189e406082..01c35cba48c4903 100644
--- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -29,11 +29,6 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
matchAndRewrite(ub::PoisonOp op, OpAdaptor,
ConversionPatternRewriter &rewriter) const override {
Type origType = op.getType();
- if (!origType.isIntOrIndexOrFloat())
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << "unsupported type " << origType;
- });
-
Type resType = getTypeConverter()->convertType(origType);
if (!resType)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index bb9f793d7fe0ff5..f4cdb2cf95a30e1 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
MLIRSPIRVConversion
MLIRVectorDialect
MLIRTransforms
+ MLIRUBToSPIRV
)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2c8bc149dc708de..1c70cb4d287d454 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -182,17 +182,12 @@ struct VectorExtractOpConvert final
if (std::optional<int64_t> id =
getConstantIntValue(extractOp.getMixedPosition()[0])) {
- // TODO: ExtractOp::fold() already can fold a static poison index to
- // ub.poison; remove this once ub.poison can be converted to SPIR-V.
- if (id == vector::ExtractOp::kPoisonIndex) {
- // Arbitrary choice of poison result, intended to stick out.
- Value zero =
- spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
- rewriter.replaceOp(extractOp, zero);
- } else
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, dstType, adaptor.getVector(),
- rewriter.getI32ArrayAttr(id.value()));
+ // Static use of the poison index is handled elsewhere (folded to poison).
+ if (id == vector::ExtractOp::kPoisonIndex)
+ return failure();
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, dstType, adaptor.getVector(),
+ rewriter.getI32ArrayAttr(id.value()));
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
@@ -306,16 +301,11 @@ struct VectorInsertOpConvert final
if (std::optional<int64_t> id =
getConstantIntValue(insertOp.getMixedPosition()[0])) {
- // TODO: ExtractOp::fold() already can fold a static poison index to
- // ub.poison; remove this once ub.poison can be converted to SPIR-V.
- if (id == vector::InsertOp::kPoisonIndex) {
- // Arbitrary choice of poison result, intended to stick out.
- Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
- insertOp.getLoc(), rewriter);
- rewriter.replaceOp(insertOp, zero);
- } else
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ // Static use of the poison index is handled elsewhere (folded to poison).
+ if (id == vector::InsertOp::kPoisonIndex)
+ return failure();
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index cc115b1d3682626..2fff0644265bef8 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
@@ -49,6 +50,8 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
RewritePatternSet patterns(context);
populateVectorToSPIRVPatterns(typeConverter, patterns);
+ // Used for folds, e.g. vector.extract[-1] -> ub.poison -> spirv.Undef.
+ ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
index 771b53ad123b928..f497eb3bc552ca4 100644
--- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
+++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
@@ -13,8 +13,7 @@ func.func @check_poison() {
%1 = ub.poison : i16
// CHECK: {{.*}} = spirv.Undef : f64
%2 = ub.poison : f64
-// TODO: vector is not covered yet
-// CHECK: {{.*}} = ub.poison : vector<4xf32>
+// CHECK: {{.*}} = spirv.Undef : vector<4xf32>
%3 = ub.poison : vector<4xf32>
return
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 5fd7324b1d3c738..9e69eb15bc4052c 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -175,15 +175,17 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
// -----
+// CHECK-LABEL: @extract_poison_idx
+// CHECK: %[[R:.+]] = spirv.Undef : f32
+// CHECK: return %[[R]]
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
- // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
- // CHECK: return %[[ZERO]]
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
return %0: f32
}
// -----
+
// CHECK-LABEL: @extract_size1_vector
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
@@ -285,8 +287,8 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// -----
// CHECK-LABEL: @insert_poison_idx
-// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
-// CHECK: return %[[ZERO]]
+// CHECK: %[[R:.+]] = spirv.Undef : vector<4xf32>
+// CHECK: return %[[R]]
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
return %1: vector<4xf32>
More information about the Mlir-commits
mailing list