[Mlir-commits] [mlir] 37a3c01 - [mlir][ArithToSPIRV] Fix uitofp/sitofp for emulated narrow integer types (#186136)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 12 08:55:16 PDT 2026
Author: Quinn Dawkins
Date: 2026-03-12T11:55:11-04:00
New Revision: 37a3c013f2506886a00b26831146eedab3fb4cd9
URL: https://github.com/llvm/llvm-project/commit/37a3c013f2506886a00b26831146eedab3fb4cd9
DIFF: https://github.com/llvm/llvm-project/commit/37a3c013f2506886a00b26831146eedab3fb4cd9.diff
LOG: [mlir][ArithToSPIRV] Fix uitofp/sitofp for emulated narrow integer types (#186136)
When a SPIR-V target lacks Int8/Int16 capabilities, narrow integers are
emulated as i32. The upper bits of the i32 container may contain garbage
(e.g., sign-extended bits from packed byte extraction).
Previously, arith.uitofp and arith.sitofp on these emulated types would
use the generic TypeCastingOpPattern, which either forwards the operand
unchanged (when src/dst types match after conversion) or creates a plain
spirv.ConvertUToF/ConvertSToF without cleaning the upper bits. This
produces incorrect results.
This was exposed by arith canonicalization patterns (UIToFPOfExtUI,
SIToFPOfExtSI) that fold uitofp(extui(x)) -> uitofp(x) and
sitofp(extsi(x)) -> sitofp(x), eliminating the ext operations which were
incidentally cleaning the upper bits.
Replace TypeCastingOpPattern for UIToFP/SIToFP with IntToFPPattern, a
single template parameterized on signedness that handles both widening
and non-widening cases:
- Unsigned: masks with BitwiseAnd before ConvertUToF.
- Signed: sign-extends via ShiftLeftLogical + ShiftRightArithmetic
before ConvertSToF.
---------
Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Added:
Modified:
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 0bc001b5d576a..d6b1e9552fbc5 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -601,6 +601,64 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
}
};
+/// Converts arith.uitofp/arith.sitofp to spirv.ConvertUToF/spirv.ConvertSToF.
+/// When the source integer type was widened during type conversion (e.g., i8
+/// emulated as i32), the upper bits of the widened value may contain garbage.
+/// This pattern cleans the upper bits before the conversion:
+/// - For unsigned (IsSigned=false): mask with BitwiseAnd.
+/// - For signed (IsSigned=true): sign-extend via ShiftLeftLogical +
+/// ShiftRightArithmetic.
+template <typename ArithOp, typename SPIRVOp, bool IsSigned>
+struct IntToFPPattern final : public OpConversionPattern<ArithOp> {
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = adaptor.getOperands().front().getType();
+ if (isBoolScalarOrVector(srcType))
+ return failure();
+
+ Type dstType = this->getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ // Check if the source integer type was widened during type conversion.
+ unsigned originalBitwidth =
+ getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
+ unsigned convertedBitwidth =
+ getElementTypeOrSelf(srcType).getIntOrFloatBitWidth();
+
+ if (originalBitwidth >= convertedBitwidth) {
+ rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
+ return success();
+ }
+
+ // The source was widened. Clean the upper bits before converting.
+ Location loc = op.getLoc();
+ Value cleaned;
+ if constexpr (IsSigned) {
+ // Sign-extend by shifting left then arithmetic right.
+ unsigned shiftAmount = convertedBitwidth - originalBitwidth;
+ Value shiftSize =
+ getScalarOrVectorConstInt(srcType, shiftAmount, rewriter, loc);
+ Value shifted = spirv::ShiftLeftLogicalOp::create(
+ rewriter, loc, srcType, adaptor.getIn(), shiftSize);
+ cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType,
+ shifted, shiftSize);
+ } else {
+ // Zero-extend by masking off the upper bits.
+ Value mask = getScalarOrVectorConstInt(
+ srcType, llvm::maskTrailingOnes<uint64_t>(originalBitwidth), rewriter,
+ loc);
+ cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType,
+ adaptor.getIn(), mask);
+ }
+ rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, cleaned);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
@@ -1376,8 +1434,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
TruncIPattern, TruncII1Pattern,
TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
- TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
- TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
+ IntToFPPattern<arith::UIToFPOp, spirv::ConvertUToFOp, false>,
+ UIToFPI1Pattern,
+ IntToFPPattern<arith::SIToFPOp, spirv::ConvertSToFOp, true>,
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 3cb5294598994..9c726b8643a46 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1491,6 +1491,66 @@ func.func @float_scalar(%arg0: f16) {
return
}
+// When i8 is emulated as i32 (no Int8 capability), uitofp from i8 needs a
+// bitmask to clear upper bits that may contain garbage from sign-extension
+// during packed byte extraction.
+// CHECK-LABEL: @uitofp_i8_emulated_f32
+func.func @uitofp_i8_emulated_f32(%arg0: i8) -> f32 {
+ // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
+ // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %{{.*}}, %[[MASK]] : i32
+ // CHECK: spirv.ConvertUToF %[[MASKED]] : i32 to f32
+ %0 = arith.uitofp %arg0 : i8 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_i16_emulated_f32
+func.func @uitofp_i16_emulated_f32(%arg0: i16) -> f32 {
+ // CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
+ // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %{{.*}}, %[[MASK]] : i32
+ // CHECK: spirv.ConvertUToF %[[MASKED]] : i32 to f32
+ %0 = arith.uitofp %arg0 : i16 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_vec_i8_emulated_f32
+func.func @uitofp_vec_i8_emulated_f32(%arg0: vector<4xi8>) -> vector<4xf32> {
+ // CHECK: %[[MASK:.+]] = spirv.Constant dense<255> : vector<4xi32>
+ // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %{{.*}}, %[[MASK]] : vector<4xi32>
+ // CHECK: spirv.ConvertUToF %[[MASKED]] : vector<4xi32> to vector<4xf32>
+ %0 = arith.uitofp %arg0 : vector<4xi8> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @sitofp_i8_emulated_f32
+func.func @sitofp_i8_emulated_f32(%arg0: i8) -> f32 {
+ // CHECK: %[[SHIFT:.+]] = spirv.Constant 24 : i32
+ // CHECK: %[[SHL:.+]] = spirv.ShiftLeftLogical %{{.*}}, %[[SHIFT]] : i32, i32
+ // CHECK: %[[SHR:.+]] = spirv.ShiftRightArithmetic %[[SHL]], %[[SHIFT]] : i32, i32
+ // CHECK: spirv.ConvertSToF %[[SHR]] : i32 to f32
+ %0 = arith.sitofp %arg0 : i8 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @sitofp_i16_emulated_f32
+func.func @sitofp_i16_emulated_f32(%arg0: i16) -> f32 {
+ // CHECK: %[[SHIFT:.+]] = spirv.Constant 16 : i32
+ // CHECK: %[[SHL:.+]] = spirv.ShiftLeftLogical %{{.*}}, %[[SHIFT]] : i32, i32
+ // CHECK: %[[SHR:.+]] = spirv.ShiftRightArithmetic %[[SHL]], %[[SHIFT]] : i32, i32
+ // CHECK: spirv.ConvertSToF %[[SHR]] : i32 to f32
+ %0 = arith.sitofp %arg0 : i16 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @sitofp_vec_i8_emulated_f32
+func.func @sitofp_vec_i8_emulated_f32(%arg0: vector<4xi8>) -> vector<4xf32> {
+ // CHECK: %[[SHIFT:.+]] = spirv.Constant dense<24> : vector<4xi32>
+ // CHECK: %[[SHL:.+]] = spirv.ShiftLeftLogical %{{.*}}, %[[SHIFT]] : vector<4xi32>, vector<4xi32>
+ // CHECK: %[[SHR:.+]] = spirv.ShiftRightArithmetic %[[SHL]], %[[SHIFT]] : vector<4xi32>, vector<4xi32>
+ // CHECK: spirv.ConvertSToF %[[SHR]] : vector<4xi32> to vector<4xf32>
+ %0 = arith.sitofp %arg0 : vector<4xi8> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
} // end module
// -----
More information about the Mlir-commits
mailing list