[Mlir-commits] [mlir] 27158ed - [MLIR][SPIRV] Update cast from IntN to Bool (#113329)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 23 01:47:37 PDT 2024
Author: Dmitriy Smirnov
Date: 2024-10-23T09:47:33+01:00
New Revision: 27158edaa4e18a7d7275c77e8c483dd29145c3c4
URL: https://github.com/llvm/llvm-project/commit/27158edaa4e18a7d7275c77e8c483dd29145c3c4
DIFF: https://github.com/llvm/llvm-project/commit/27158edaa4e18a7d7275c77e8c483dd29145c3c4.diff
LOG: [MLIR][SPIRV] Update cast from IntN to Bool (#113329)
This PR updates the cast to bool from IntN to treat any non-zero value
as TRUE. This makes the cast more resilient to non-generic (i.e. "non
1") TRUE values.
Signed-off-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 285398311fd197..49a391938eaf69 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -165,8 +165,8 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
if (srcInt.getType().isInteger(1))
return srcInt;
- auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
- return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
+ auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
+ return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 6dd5b1988e2a2f..8906de9db37249 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -76,8 +76,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
// CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8
- // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
- // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+ // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+ // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
%0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<StorageBuffer>>
// CHECK: return %[[BOOL]]
return %0: i1
@@ -234,8 +234,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]]
// CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
- // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
- // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+ // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+ // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
%0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[BOOL]]
return %0: i1
More information about the Mlir-commits
mailing list