[Mlir-commits] [mlir] 2320a4a - [mlir][spirv] Workaround driver bug in math.ctlz conversion again

Lei Zhang llvmlistbot at llvm.org
Thu Jun 16 07:53:56 PDT 2022


Author: Lei Zhang
Date: 2022-06-16T10:53:49-04:00
New Revision: 2320a4ae907fd0006389d3176e1b9fdbf891b729

URL: https://github.com/llvm/llvm-project/commit/2320a4ae907fd0006389d3176e1b9fdbf891b729
DIFF: https://github.com/llvm/llvm-project/commit/2320a4ae907fd0006389d3176e1b9fdbf891b729.diff

LOG: [mlir][spirv] Workaround driver bug in math.ctlz conversion again

The previous approach does not work as the Adreno driver is
clever at optimizing away the selection. So now check two
inputs together.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D127930

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 07c99f06ab2c3..ea367b1faa201 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -141,20 +141,25 @@ class CountLeadingZerosPattern final
       return failure();
 
     Location loc = countOp.getLoc();
-    Value allOneBits = getScalarOrVectorI32Constant(type, -1, rewriter, loc);
-    Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
+    Value input = adaptor.getOperand();
+    Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
     Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
-    Value msb =
-        rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
-    // We need to subtract from 31 given that the index is from the least
-    // significant bit.
-    Value sub = rewriter.create<spirv::ISubOp>(loc, val31, msb);
-    // If the integer has all zero bits, GLSL FindUMsb would return -1. So
-    // theoretically (31 - FindUMsb) should still give the correct result.
-    // However, certain Vulkan implementations have driver bugs regarding it.
-    // So handle the corner case explicity to workaround it.
-    Value cmp = rewriter.create<spirv::IEqualOp>(loc, msb, allOneBits);
-    rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, val32, sub);
+    Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
+
+    Value msb = rewriter.create<spirv::GLSLFindUMsbOp>(loc, input);
+    // We need to subtract from 31 given that the index returned by GLSL
+    // FindUMsb is counted from the least significant bit. Theoretically this
+    // also gives the correct result even if the integer has all zero bits, in
+    // which case GLSL FindUMsb would return -1.
+    Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
+    // However, certain Vulkan implementations have driver bugs for the corner
+    // case where the input is zero. And.. it can be smart to optimize a select
+    // only involving the corner case. So separately compute the result when the
+    // input is either zero or one.
+    Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
+    Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
+                                                 subMsb);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index d8126d4e956c6..a3067af661f64 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -82,13 +82,14 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
 // CHECK-LABEL: @ctlz_scalar
 //  CHECK-SAME: (%[[VAL:.+]]: i32)
 func.func @ctlz_scalar(%val: i32) -> i32 {
-  // CHECK-DAG: %[[MAX:.+]] = spv.Constant -1 : i32
-  // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32
+  // CHECK-DAG: %[[V1:.+]] = spv.Constant 1 : i32
   // CHECK-DAG: %[[V31:.+]] = spv.Constant 31 : i32
+  // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32
   // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32
-  // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
-  // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : i32
-  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : i1, i32
+  // CHECK: %[[SUB1:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
+  // CHECK: %[[SUB2:.+]] = spv.ISub %[[V32]], %[[VAL]] : i32
+  // CHECK: %[[CMP:.+]] = spv.ULessThanEqual %[[VAL]], %[[V1]] : i32
+  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[SUB2]], %[[SUB1]] : i1, i32
   // CHECK: return %[[R]]
   %0 = math.ctlz %val : i32
   return %0 : i32
@@ -98,7 +99,7 @@ func.func @ctlz_scalar(%val: i32) -> i32 {
 func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
   // CHECK: spv.GLSL.FindUMsb
   // CHECK: spv.ISub
-  // CHECK: spv.IEqual
+  // CHECK: spv.ULessThanEqual
   // CHECK: spv.Select
   %0 = math.ctlz %val : vector<1xi32>
   return %0 : vector<1xi32>
@@ -107,14 +108,14 @@ func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
 // CHECK-LABEL: @ctlz_vector2
 //  CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>)
 func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
-  // CHECK-DAG: %[[MAX:.+]] = spv.Constant dense<-1> : vector<2xi32>
-  // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32>
+  // CHECK-DAG: %[[V1:.+]] = spv.Constant dense<1> : vector<2xi32>
   // CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32>
+  // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32>
   // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32>
-  // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
-  // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : vector<2xi32>
-  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : vector<2xi1>, vector<2xi32>
-  // CHECK: return %[[R]]
+  // CHECK: %[[SUB1:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
+  // CHECK: %[[SUB2:.+]] = spv.ISub %[[V32]], %[[VAL]] : vector<2xi32>
+  // CHECK: %[[CMP:.+]] = spv.ULessThanEqual %[[VAL]], %[[V1]] : vector<2xi32>
+  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[SUB2]], %[[SUB1]] : vector<2xi1>, vector<2xi32>
   %0 = math.ctlz %val : vector<2xi32>
   return %0 : vector<2xi32>
 }


        


More information about the Mlir-commits mailing list