[Mlir-commits] [mlir] [MLIR][SPIRV] Update cast from IntN to Bool (PR #113329)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 22 07:54:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Dmitriy Smirnov (d-smirnov)

<details>
<summary>Changes</summary>

This PR updates the cast to bool from IntN to treat any non-zero value as TRUE. This makes the cast more resilient to non-generic (i.e. "non 1") TRUE values.

---
Full diff: https://github.com/llvm/llvm-project/pull/113329.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+2-2) 
- (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+4-4) 


``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 285398311fd197..49a391938eaf69 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -165,8 +165,8 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
   if (srcInt.getType().isInteger(1))
     return srcInt;
 
-  auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
-  return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
+  auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
+  return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 6dd5b1988e2a2f..8906de9db37249 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -76,8 +76,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
   // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
   // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
   // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8
-  // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
-  // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+  // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
   %0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<StorageBuffer>>
   // CHECK: return %[[BOOL]]
   return %0: i1
@@ -234,8 +234,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
   // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]]
   // CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
-  // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
-  // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+  // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
   %0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>>
   // CHECK: return %[[BOOL]]
   return %0: i1

``````````

</details>


https://github.com/llvm/llvm-project/pull/113329


More information about the Mlir-commits mailing list