[Mlir-commits] [mlir] b4dff40 - [mlir][spirv] Fix math.ctlz for full zero bit cases

Lei Zhang llvmlistbot at llvm.org
Tue Jun 14 16:41:40 PDT 2022


Author: Lei Zhang
Date: 2022-06-14T19:39:27-04:00
New Revision: b4dff404f37afa1fcad4a96c05462928eb0b89b1

URL: https://github.com/llvm/llvm-project/commit/b4dff404f37afa1fcad4a96c05462928eb0b89b1
DIFF: https://github.com/llvm/llvm-project/commit/b4dff404f37afa1fcad4a96c05462928eb0b89b1.diff

LOG: [mlir][spirv] Fix math.ctlz for full zero bit cases

If the integer has all zero bits, GLSL FindUMsb would return -1.
So theoretically (31 - FindUMsb) should still give use the correct
result.  However, Adreno GPUshave issues with this:
https://buildkite.com/iree/iree-test-android/builds/6482#01815f05-3926-466f-822a-1e20299e5461
This looks like a driver bug. So handle the corner case explicity
to workaround it.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D127747

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 8fd07bfbd3c56..6b792124d269d 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -141,12 +141,20 @@ class CountLeadingZerosPattern final
       return failure();
 
     Location loc = countOp.getLoc();
+    Value allOneBits = getScalarOrVectorI32Constant(type, -1, rewriter, loc);
+    Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
     Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
     Value msb =
         rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
     // We need to subtract from 31 given that the index is from the least
     // significant bit.
-    rewriter.replaceOpWithNewOp<spirv::ISubOp>(countOp, val31, msb);
+    Value sub = rewriter.create<spirv::ISubOp>(loc, val31, msb);
+    // If the integer has all zero bits, GLSL FindUMsb would return -1. So
+    // theoretically (31 - FindUMsb) should still give the correct result.
+    // However, certain Vulkan implementations have driver bugs regarding it.
+    // So handle the corner case explicity to workaround it.
+    Value cmp = rewriter.create<spirv::IEqualOp>(loc, msb, allOneBits);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, val32, sub);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index 8b179b22a7bd7..7940daab10d34 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -96,10 +96,14 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
 // CHECK-LABEL: @ctlz_scalar
 //  CHECK-SAME: (%[[VAL:.+]]: i32)
 func.func @ctlz_scalar(%val: i32) -> i32 {
-  // CHECK: %[[V31:.+]] = spv.Constant 31 : i32
+  // CHECK-DAG: %[[MAX:.+]] = spv.Constant -1 : i32
+  // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32
+  // CHECK-DAG: %[[V31:.+]] = spv.Constant 31 : i32
   // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32
   // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
-  // CHECK: return %[[SUB]]
+  // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : i32
+  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : i1, i32
+  // CHECK: return %[[R]]
   %0 = math.ctlz %val : i32
   return %0 : i32
 }
@@ -108,6 +112,8 @@ func.func @ctlz_scalar(%val: i32) -> i32 {
 func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
   // CHECK: spv.GLSL.FindUMsb
   // CHECK: spv.ISub
+  // CHECK: spv.IEqual
+  // CHECK: spv.Select
   %0 = math.ctlz %val : vector<1xi32>
   return %0 : vector<1xi32>
 }
@@ -115,10 +121,14 @@ func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
 // CHECK-LABEL: @ctlz_vector2
 //  CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>)
 func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
+  // CHECK-DAG: %[[MAX:.+]] = spv.Constant dense<-1> : vector<2xi32>
+  // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32>
   // CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32>
   // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32>
   // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
-  // CHECK: return %[[SUB]]
+  // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : vector<2xi32>
+  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : vector<2xi1>, vector<2xi32>
+  // CHECK: return %[[R]]
   %0 = math.ctlz %val : vector<2xi32>
   return %0 : vector<2xi32>
 }


        


More information about the Mlir-commits mailing list