[Mlir-commits] [mlir] 37a3c01 - [mlir][ArithToSPIRV] Fix uitofp/sitofp for emulated narrow integer types (#186136)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 12 08:55:16 PDT 2026


Author: Quinn Dawkins
Date: 2026-03-12T11:55:11-04:00
New Revision: 37a3c013f2506886a00b26831146eedab3fb4cd9

URL: https://github.com/llvm/llvm-project/commit/37a3c013f2506886a00b26831146eedab3fb4cd9
DIFF: https://github.com/llvm/llvm-project/commit/37a3c013f2506886a00b26831146eedab3fb4cd9.diff

LOG: [mlir][ArithToSPIRV] Fix uitofp/sitofp for emulated narrow integer types (#186136)

When a SPIR-V target lacks Int8/Int16 capabilities, narrow integers are
emulated as i32. The upper bits of the i32 container may contain garbage
(e.g., sign-extended bits from packed byte extraction).

Previously, arith.uitofp and arith.sitofp on these emulated types would
use the generic TypeCastingOpPattern, which either forwards the operand
unchanged (when src/dst types match after conversion) or creates a plain
spirv.ConvertUToF/ConvertSToF without cleaning the upper bits. This
produces incorrect results.

This was exposed by arith canonicalization patterns (UIToFPOfExtUI,
SIToFPOfExtSI) that fold uitofp(extui(x)) -> uitofp(x) and
sitofp(extsi(x)) -> sitofp(x), eliminating the ext operations which were
incidentally cleaning the upper bits.

Replace TypeCastingOpPattern for UIToFP/SIToFP with IntToFPPattern, a
single template parameterized on signedness that handles both widening
and non-widening cases:
- Unsigned: masks with BitwiseAnd before ConvertUToF.
- Signed: sign-extends via ShiftLeftLogical + ShiftRightArithmetic
before ConvertSToF.

---------

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 0bc001b5d576a..d6b1e9552fbc5 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -601,6 +601,64 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
   }
 };
 
+/// Converts arith.uitofp/arith.sitofp to spirv.ConvertUToF/spirv.ConvertSToF.
+/// When the source integer type was widened during type conversion (e.g., i8
+/// emulated as i32), the upper bits of the widened value may contain garbage.
+/// This pattern cleans the upper bits before the conversion:
+/// - For unsigned (IsSigned=false): mask with BitwiseAnd.
+/// - For signed (IsSigned=true): sign-extend via ShiftLeftLogical +
+///   ShiftRightArithmetic.
+template <typename ArithOp, typename SPIRVOp, bool IsSigned>
+struct IntToFPPattern final : public OpConversionPattern<ArithOp> {
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = adaptor.getOperands().front().getType();
+    if (isBoolScalarOrVector(srcType))
+      return failure();
+
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    // Check if the source integer type was widened during type conversion.
+    unsigned originalBitwidth =
+        getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
+    unsigned convertedBitwidth =
+        getElementTypeOrSelf(srcType).getIntOrFloatBitWidth();
+
+    if (originalBitwidth >= convertedBitwidth) {
+      rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
+      return success();
+    }
+
+    // The source was widened. Clean the upper bits before converting.
+    Location loc = op.getLoc();
+    Value cleaned;
+    if constexpr (IsSigned) {
+      // Sign-extend by shifting left then arithmetic right.
+      unsigned shiftAmount = convertedBitwidth - originalBitwidth;
+      Value shiftSize =
+          getScalarOrVectorConstInt(srcType, shiftAmount, rewriter, loc);
+      Value shifted = spirv::ShiftLeftLogicalOp::create(
+          rewriter, loc, srcType, adaptor.getIn(), shiftSize);
+      cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType,
+                                                      shifted, shiftSize);
+    } else {
+      // Zero-extend by masking off the upper bits.
+      Value mask = getScalarOrVectorConstInt(
+          srcType, llvm::maskTrailingOnes<uint64_t>(originalBitwidth), rewriter,
+          loc);
+      cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType,
+                                            adaptor.getIn(), mask);
+    }
+    rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, cleaned);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // IndexCastOp
 //===----------------------------------------------------------------------===//
@@ -1376,8 +1434,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
     TruncIPattern, TruncII1Pattern,
     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
-    TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
-    TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
+    IntToFPPattern<arith::UIToFPOp, spirv::ConvertUToFOp, false>,
+    UIToFPI1Pattern,
+    IntToFPPattern<arith::SIToFPOp, spirv::ConvertSToFOp, true>,
     TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 3cb5294598994..9c726b8643a46 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1491,6 +1491,66 @@ func.func @float_scalar(%arg0: f16) {
   return
 }
 
+// When i8 is emulated as i32 (no Int8 capability), uitofp from i8 needs a
+// bitmask to clear upper bits that may contain garbage from sign-extension
+// during packed byte extraction.
+// CHECK-LABEL: @uitofp_i8_emulated_f32
+func.func @uitofp_i8_emulated_f32(%arg0: i8) -> f32 {
+  // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
+  // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %{{.*}}, %[[MASK]] : i32
+  // CHECK: spirv.ConvertUToF %[[MASKED]] : i32 to f32
+  %0 = arith.uitofp %arg0 : i8 to f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_i16_emulated_f32
+func.func @uitofp_i16_emulated_f32(%arg0: i16) -> f32 {
+  // CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
+  // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %{{.*}}, %[[MASK]] : i32
+  // CHECK: spirv.ConvertUToF %[[MASKED]] : i32 to f32
+  %0 = arith.uitofp %arg0 : i16 to f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_vec_i8_emulated_f32
+func.func @uitofp_vec_i8_emulated_f32(%arg0: vector<4xi8>) -> vector<4xf32> {
+  // CHECK: %[[MASK:.+]] = spirv.Constant dense<255> : vector<4xi32>
+  // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %{{.*}}, %[[MASK]] : vector<4xi32>
+  // CHECK: spirv.ConvertUToF %[[MASKED]] : vector<4xi32> to vector<4xf32>
+  %0 = arith.uitofp %arg0 : vector<4xi8> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @sitofp_i8_emulated_f32
+func.func @sitofp_i8_emulated_f32(%arg0: i8) -> f32 {
+  // CHECK: %[[SHIFT:.+]] = spirv.Constant 24 : i32
+  // CHECK: %[[SHL:.+]] = spirv.ShiftLeftLogical %{{.*}}, %[[SHIFT]] : i32, i32
+  // CHECK: %[[SHR:.+]] = spirv.ShiftRightArithmetic %[[SHL]], %[[SHIFT]] : i32, i32
+  // CHECK: spirv.ConvertSToF %[[SHR]] : i32 to f32
+  %0 = arith.sitofp %arg0 : i8 to f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @sitofp_i16_emulated_f32
+func.func @sitofp_i16_emulated_f32(%arg0: i16) -> f32 {
+  // CHECK: %[[SHIFT:.+]] = spirv.Constant 16 : i32
+  // CHECK: %[[SHL:.+]] = spirv.ShiftLeftLogical %{{.*}}, %[[SHIFT]] : i32, i32
+  // CHECK: %[[SHR:.+]] = spirv.ShiftRightArithmetic %[[SHL]], %[[SHIFT]] : i32, i32
+  // CHECK: spirv.ConvertSToF %[[SHR]] : i32 to f32
+  %0 = arith.sitofp %arg0 : i16 to f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @sitofp_vec_i8_emulated_f32
+func.func @sitofp_vec_i8_emulated_f32(%arg0: vector<4xi8>) -> vector<4xf32> {
+  // CHECK: %[[SHIFT:.+]] = spirv.Constant dense<24> : vector<4xi32>
+  // CHECK: %[[SHL:.+]] = spirv.ShiftLeftLogical %{{.*}}, %[[SHIFT]] : vector<4xi32>, vector<4xi32>
+  // CHECK: %[[SHR:.+]] = spirv.ShiftRightArithmetic %[[SHL]], %[[SHIFT]] : vector<4xi32>, vector<4xi32>
+  // CHECK: spirv.ConvertSToF %[[SHR]] : vector<4xi32> to vector<4xf32>
+  %0 = arith.sitofp %arg0 : vector<4xi8> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list