[Mlir-commits] [mlir] [mlir][spirv] Fix some issues related to converting ub.poison to SPIR-V (PR #125905)

Andrea Faulds llvmlistbot at llvm.org
Thu Feb 6 07:55:03 PST 2025


https://github.com/andfau-amd updated https://github.com/llvm/llvm-project/pull/125905

>From 208e5846c10878066940f556cfd30198a14a31b8 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Thu, 6 Feb 2025 16:54:33 +0100
Subject: [PATCH] [mlir][spirv] Fix some issues related to converting ub.poison
 to SPIR-V

This is a follow-up to 5df62bdc9be9c258c5ac45c8093b71e23777fa0e. That
commit should not have needed to make the the vector.insert and
vector.extract conversions to SPIR-V directly handle the static poison
index case, as there is a fold from those to ub.poison, and a conversion
pattern from ub.poison to spirv.Undef, however:

- The ub.poison fold result could not be materialized by the vector
  dialect (fixed as of d13940ee263ff50b7a71e21424913cc0266bf9d4).
- The conversion pattern wasn't being populated in VectorToSPIRVPass,
  which is used by the tests. This commit changes this.
- The ub.poison to spirv.Undef pattern rejected non-scalar types, which
  prevented its use for vector results. It is unclear why this
  restriction existed; a remark in D156163 said this was to avoid
  converting "user types", but it is not obvious why these shouldn't
  be permitted (the SPIR-V specification allows OpUndef for all types
  except OpTypeVoid). This commit removes this restriction.

With these fixed, this commit removes the redundant static poison index
handling, and updates the tests.
---
 mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp   |  5 ---
 .../Conversion/VectorToSPIRV/CMakeLists.txt   |  1 +
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 34 +++++++------------
 .../VectorToSPIRV/VectorToSPIRVPass.cpp       |  3 ++
 .../Conversion/UBToSPIRV/ub-to-spirv.mlir     |  3 +-
 .../VectorToSPIRV/vector-to-spirv.mlir        |  9 ++---
 6 files changed, 23 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
index a3806189e406082..01c35cba48c4903 100644
--- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -29,11 +29,6 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
   matchAndRewrite(ub::PoisonOp op, OpAdaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type origType = op.getType();
-    if (!origType.isIntOrIndexOrFloat())
-      return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-        diag << "unsupported type " << origType;
-      });
-
     Type resType = getTypeConverter()->convertType(origType);
     if (!resType)
       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index bb9f793d7fe0ff5..f4cdb2cf95a30e1 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
   MLIRSPIRVConversion
   MLIRVectorDialect
   MLIRTransforms
+  MLIRUBToSPIRV
   )
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 2c8bc149dc708de..1ecb892a4ea9297 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -182,17 +182,13 @@ struct VectorExtractOpConvert final
 
     if (std::optional<int64_t> id =
             getConstantIntValue(extractOp.getMixedPosition()[0])) {
-      // TODO: ExtractOp::fold() already can fold a static poison index to
-      //       ub.poison; remove this once ub.poison can be converted to SPIR-V.
-      if (id == vector::ExtractOp::kPoisonIndex) {
-        // Arbitrary choice of poison result, intended to stick out.
-        Value zero =
-            spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
-        rewriter.replaceOp(extractOp, zero);
-      } else
-        rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-            extractOp, dstType, adaptor.getVector(),
-            rewriter.getI32ArrayAttr(id.value()));
+      if (id == vector::ExtractOp::kPoisonIndex)
+        return rewriter.notifyMatchFailure(
+            extractOp,
+            "Static use of poison index handled elsewhere (folded to poison)");
+      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+          extractOp, dstType, adaptor.getVector(),
+          rewriter.getI32ArrayAttr(id.value()));
     } else {
       Value sanitizedIndex = sanitizeDynamicIndex(
           rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
@@ -306,16 +302,12 @@ struct VectorInsertOpConvert final
 
     if (std::optional<int64_t> id =
             getConstantIntValue(insertOp.getMixedPosition()[0])) {
-      // TODO: ExtractOp::fold() already can fold a static poison index to
-      //       ub.poison; remove this once ub.poison can be converted to SPIR-V.
-      if (id == vector::InsertOp::kPoisonIndex) {
-        // Arbitrary choice of poison result, intended to stick out.
-        Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
-                                                insertOp.getLoc(), rewriter);
-        rewriter.replaceOp(insertOp, zero);
-      } else
-        rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-            insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+      if (id == vector::InsertOp::kPoisonIndex)
+        return rewriter.notifyMatchFailure(
+            insertOp,
+            "Static use of poison index handled elsewhere (folded to poison)");
+      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
     } else {
       Value sanitizedIndex = sanitizeDynamicIndex(
           rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index cc115b1d3682626..0735e1ee0c6779f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
 
+#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -49,6 +50,8 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
 
   RewritePatternSet patterns(context);
   populateVectorToSPIRVPatterns(typeConverter, patterns);
+  // Used for folds, e.g. vector.extract[-1] -> ub.poison -> spirv.Undef.
+  ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
 
   if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
index 771b53ad123b928..f497eb3bc552ca4 100644
--- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
+++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
@@ -13,8 +13,7 @@ func.func @check_poison() {
   %1 = ub.poison : i16
 // CHECK: {{.*}} = spirv.Undef : f64
   %2 = ub.poison : f64
-// TODO: vector is not covered yet
-// CHECK: {{.*}} = ub.poison : vector<4xf32>
+// CHECK: {{.*}} = spirv.Undef : vector<4xf32>
   %3 = ub.poison : vector<4xf32>
   return
 }
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 5fd7324b1d3c738..3f0bf1962e299b0 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -175,9 +175,10 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
 
 // -----
 
+// CHECK-LABEL: @extract_poison_idx
+//       CHECK:   %[[R:.+]] = spirv.Undef : f32
+//       CHECK:   return %[[R]]
 func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
-  // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
-  // CHECK: return %[[ZERO]]
   %0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
   return %0: f32
 }
@@ -285,8 +286,8 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 // -----
 
 // CHECK-LABEL: @insert_poison_idx
-// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
-// CHECK: return %[[ZERO]]
+//       CHECK:   %[[R:.+]] = spirv.Undef : vector<4xf32>
+//       CHECK:   return %[[R]]
 func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
   %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
   return %1: vector<4xf32>



More information about the Mlir-commits mailing list