[Mlir-commits] [mlir] [MLIR][SPIRV] Update cast from IntN to Bool (PR #113329)
Dmitriy Smirnov
llvmlistbot at llvm.org
Tue Oct 22 07:54:01 PDT 2024
https://github.com/d-smirnov created https://github.com/llvm/llvm-project/pull/113329
This PR updates the cast to bool from IntN to treat any non-zero value as TRUE. This makes the cast more resilient to non-generic (i.e. "non 1") TRUE values.
>From 1459efb2ff0f85488b005e6c9418af279138f186 Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Tue, 22 Oct 2024 15:12:41 +0100
Subject: [PATCH] [MLIR][SPIRV] Update cast from IntN to Bool
The cast to bool from IntN has been updated to treat any non-zero value as
TRUE. This makes the cast more resiliant to non-generic TRUE values.
Signed-off-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
---
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 4 ++--
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir | 8 ++++----
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 285398311fd197..49a391938eaf69 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -165,8 +165,8 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
if (srcInt.getType().isInteger(1))
return srcInt;
- auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
- return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
+ auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
+ return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 6dd5b1988e2a2f..8906de9db37249 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -76,8 +76,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i :
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
// CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8
- // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
- // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+ // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+ // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
%0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<StorageBuffer>>
// CHECK: return %[[BOOL]]
return %0: i1
@@ -234,8 +234,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]]
// CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
- // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
- // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+ // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+ // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8
%0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>>
// CHECK: return %[[BOOL]]
return %0: i1
More information about the Mlir-commits
mailing list