[Mlir-commits] [mlir] [mlir][Math][SPIRV] fix `math.round` conversion for unit dimensional … (PR #182067)

Artem Gindinson llvmlistbot at llvm.org
Wed Feb 18 09:08:31 PST 2026


https://github.com/AGindinson created https://github.com/llvm/llvm-project/pull/182067

…vectors

In SPIR-V, unit dimensional vectors, e.g. `vector<1xf32` are legalized as scalars (vectors are 2, 3, 4, and possibly 8 and 16 dimensional). This PR fixes the `math.round` conversion pattern to legalize these vectors during conversion.

>From c850d6fcd58c0e185799fa930711cfb399978cef Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Wed, 18 Feb 2026 17:07:26 +0000
Subject: [PATCH] [mlir][Math][SPIRV] fix `math.round` conversion for unit
 dimensional vectors

In SPIR-V, unit dimensional vectors, e.g. `vector<1xf32` are legalized as scalars (vectors are 2, 3, 4, and possibly 8 and 16 dimensional).
This PR fixes the `math.round` conversion pattern to legalize these vectors during conversion.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
Co-authored-by: Ege Beysel <beysel at roofline.ai>
---
 .../Conversion/MathToSPIRV/MathToSPIRV.cpp    | 14 +++++++++----
 .../MathToSPIRV/math-to-gl-spirv.mlir         | 21 +++++++++++++++++++
 2 files changed, 31 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 610ce1f13c56b..489f76a0f5976 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -449,8 +449,13 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
       return res;
 
     Location loc = roundOp.getLoc();
-    Value operand = roundOp.getOperand();
-    Type ty = operand.getType();
+    auto ty = getTypeConverter()->convertType(adaptor.getOperand().getType());
+    if (!ty)
+      return rewriter.notifyMatchFailure(
+          roundOp->getLoc(),
+          llvm::formatv("failed to convert type {0} for SPIR-V",
+                        roundOp.getType()));
+
     Type ety = getElementTypeOrSelf(ty);
 
     auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
@@ -466,14 +471,15 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
                                        rewriter.getFloatAttr(ety, 0.5));
     }
 
-    auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
+    auto abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getOperand());
     auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
     auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
     auto greater =
         spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
     auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
     auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
-    rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
+    rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add,
+                                                  adaptor.getOperand());
     return success();
   }
 };
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index b8e001c9f6950..608abffd8bd82 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -257,6 +257,27 @@ func.func @round_vector(%x: vector<4xf32>) -> vector<4xf32> {
   return %0: vector<4xf32>
 }
 
+// Unit dimensional vectors are converted to scalars by inserting
+// unrealized_conversion_cast's.
+//
+// CHECK-LABEL: @round_vector_unit_dim
+//  CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>) -> vector<1xf32>
+func.func @round_vector_unit_dim(%x: vector<1xf32>) -> vector<1xf32> {
+  // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<1xf32> to f32
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
+  // CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00
+  // CHECK: %[[HALF:.+]] = spirv.Constant 5.000000e-01
+  // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[CAST]] : f32
+  // CHECK: %[[FLOOR:.+]] = spirv.GL.Floor %[[ABS]]
+  // CHECK: %[[SUB:.+]] = spirv.FSub %[[ABS]], %[[FLOOR]]
+  // CHECK: %[[GE:.+]] = spirv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]]
+  // CHECK: %[[SEL:.+]] = spirv.Select %[[GE]], %[[ONE]], %[[ZERO]]
+  // CHECK: %[[ADD:.+]] = spirv.FAdd %[[FLOOR]], %[[SEL]]
+  // CHECK: %[[BITCAST:.+]] = spirv.Bitcast %[[ADD]] : f32 to i32
+  %0 = math.round %x : vector<1xf32>
+  return %0: vector<1xf32>
+}
+
 } // end module
 
 // -----



More information about the Mlir-commits mailing list