[Mlir-commits] [mlir] b4dff40 - [mlir][spirv] Fix math.ctlz for full zero bit cases
Lei Zhang
llvmlistbot at llvm.org
Tue Jun 14 16:41:40 PDT 2022
Author: Lei Zhang
Date: 2022-06-14T19:39:27-04:00
New Revision: b4dff404f37afa1fcad4a96c05462928eb0b89b1
URL: https://github.com/llvm/llvm-project/commit/b4dff404f37afa1fcad4a96c05462928eb0b89b1
DIFF: https://github.com/llvm/llvm-project/commit/b4dff404f37afa1fcad4a96c05462928eb0b89b1.diff
LOG: [mlir][spirv] Fix math.ctlz for full zero bit cases
If the integer has all zero bits, GLSL FindUMsb would return -1.
So theoretically (31 - FindUMsb) should still give use the correct
result. However, Adreno GPUshave issues with this:
https://buildkite.com/iree/iree-test-android/builds/6482#01815f05-3926-466f-822a-1e20299e5461
This looks like a driver bug. So handle the corner case explicity
to workaround it.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D127747
Added:
Modified:
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 8fd07bfbd3c56..6b792124d269d 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -141,12 +141,20 @@ class CountLeadingZerosPattern final
return failure();
Location loc = countOp.getLoc();
+ Value allOneBits = getScalarOrVectorI32Constant(type, -1, rewriter, loc);
+ Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
Value msb =
rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
// We need to subtract from 31 given that the index is from the least
// significant bit.
- rewriter.replaceOpWithNewOp<spirv::ISubOp>(countOp, val31, msb);
+ Value sub = rewriter.create<spirv::ISubOp>(loc, val31, msb);
+ // If the integer has all zero bits, GLSL FindUMsb would return -1. So
+ // theoretically (31 - FindUMsb) should still give the correct result.
+ // However, certain Vulkan implementations have driver bugs regarding it.
+ // So handle the corner case explicity to workaround it.
+ Value cmp = rewriter.create<spirv::IEqualOp>(loc, msb, allOneBits);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, val32, sub);
return success();
}
};
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index 8b179b22a7bd7..7940daab10d34 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -96,10 +96,14 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
// CHECK-LABEL: @ctlz_scalar
// CHECK-SAME: (%[[VAL:.+]]: i32)
func.func @ctlz_scalar(%val: i32) -> i32 {
- // CHECK: %[[V31:.+]] = spv.Constant 31 : i32
+ // CHECK-DAG: %[[MAX:.+]] = spv.Constant -1 : i32
+ // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32
+ // CHECK-DAG: %[[V31:.+]] = spv.Constant 31 : i32
// CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32
// CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
- // CHECK: return %[[SUB]]
+ // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : i32
+ // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : i1, i32
+ // CHECK: return %[[R]]
%0 = math.ctlz %val : i32
return %0 : i32
}
@@ -108,6 +112,8 @@ func.func @ctlz_scalar(%val: i32) -> i32 {
func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
// CHECK: spv.GLSL.FindUMsb
// CHECK: spv.ISub
+ // CHECK: spv.IEqual
+ // CHECK: spv.Select
%0 = math.ctlz %val : vector<1xi32>
return %0 : vector<1xi32>
}
@@ -115,10 +121,14 @@ func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
// CHECK-LABEL: @ctlz_vector2
// CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>)
func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
+ // CHECK-DAG: %[[MAX:.+]] = spv.Constant dense<-1> : vector<2xi32>
+ // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32>
// CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32>
// CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32>
// CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
- // CHECK: return %[[SUB]]
+ // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : vector<2xi32>
+ // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : vector<2xi1>, vector<2xi32>
+ // CHECK: return %[[R]]
%0 = math.ctlz %val : vector<2xi32>
return %0 : vector<2xi32>
}
More information about the Mlir-commits
mailing list