[Mlir-commits] [mlir] [MLIR][SPIRV] Update cast from IntN to Bool (PR #113329)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 22 07:54:39 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-spirv
Author: Dmitriy Smirnov (d-smirnov)
<details>
<summary>Changes</summary>
This PR updates the cast to bool from IntN to treat any non-zero value as TRUE. This makes the cast more resilient to non-generic (i.e. "non 1") TRUE values.
---
Full diff: https://github.com/llvm/llvm-project/pull/113329.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+2-2)
- (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+4-4)
``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 285398311fd197..49a391938eaf69 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -165,8 +165,8 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
if (srcInt.getType().isInteger(1))
return srcInt;
- auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
- return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
+ auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
+ return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 6dd5b1988e2a2f..8906de9db37249 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -76,8 +76,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
// CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8
- // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
- // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+ // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+ // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
%0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<StorageBuffer>>
// CHECK: return %[[BOOL]]
return %0: i1
@@ -234,8 +234,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]]
// CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
- // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
- // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+ // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+ // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
%0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[BOOL]]
return %0: i1
``````````
</details>
https://github.com/llvm/llvm-project/pull/113329
More information about the Mlir-commits
mailing list