[Mlir-commits] [mlir] 0065bd2 - [mlir][spirv] Fix loading bool with proper storage capabilities
Lei Zhang
llvmlistbot at llvm.org
Fri Jul 30 15:07:28 PDT 2021
Author: Lei Zhang
Date: 2021-07-30T18:06:11-04:00
New Revision: 0065bd2ad59cd05f0ca762a1cb586d3bfe809f2e
URL: https://github.com/llvm/llvm-project/commit/0065bd2ad59cd05f0ca762a1cb586d3bfe809f2e
DIFF: https://github.com/llvm/llvm-project/commit/0065bd2ad59cd05f0ca762a1cb586d3bfe809f2e.diff
LOG: [mlir][spirv] Fix loading bool with proper storage capabilities
If the source value to load is bool, and we have native storage
capability support for the source bitwidth, we still cannot directly
rewrite uses; we need to perform casting to bool first.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D107119
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 899c582796a6..9f9f11538177 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -119,6 +119,15 @@ static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
return {};
}
+/// Casts the given `srcInt` into a boolean value.
+static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
+ if (srcInt.getType().isInteger(1))
+ return srcInt;
+
+ auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
+ return builder.create<spirv::IEqualOp>(loc, srcInt, one);
+}
+
/// Casts the given `srcBool` into an integer of `dstType`.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
OpBuilder &builder) {
@@ -302,8 +311,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
// If the rewrited load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
- rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
- accessChainOp.getResult());
+ Value loadVal =
+ rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
+ if (isBool)
+ loadVal = castIntNToBool(loc, loadVal, rewriter);
+ rewriter.replaceOp(loadOp, loadVal);
return success();
}
@@ -346,8 +358,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
if (isBool) {
dstType = typeConverter.convertType(loadOp.getType());
mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
- Value isOne = rewriter.create<spirv::IEqualOp>(loc, result, mask);
- result = castBoolToIntN(loc, isOne, dstType, rewriter);
+ result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
} else if (result.getType().getIntOrFloatBitWidth() !=
static_cast<unsigned>(dstBits)) {
result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index daf0ea0797a5..5c8c3c17d838 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -65,6 +65,25 @@ func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?x
return
}
+// CHECK-LABEL: func @load_i1
+// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1>, %[[IDX:.+]]: index)
+func @load_i1(%src: memref<4xi1>, %i : index) -> i1 {
+ // CHECK: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+ // CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+ // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
+ // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
+ // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+ // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
+ // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
+ // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[SRC_CAST]][%[[ZERO_0]], %[[ADD]]]
+ // CHECK: %[[VAL:.+]] = spv.Load "StorageBuffer" %[[ADDR]] : i8
+ // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
+ // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+ %0 = memref.load %src[%i] : memref<4xi1>
+ // CHECK: return %[[BOOL]]
+ return %0: i1
+}
+
// CHECK-LABEL: func @store_i1
// CHECK-SAME: %[[DST:.+]]: memref<4xi1>,
// CHECK-SAME: %[[IDX:.+]]: index
@@ -77,7 +96,7 @@ func @store_i1(%dst: memref<4xi1>, %i: index) {
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
// CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
// CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
- // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]] : !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>, i32, i32
+ // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]]
// CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8
// CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
// CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
More information about the Mlir-commits
mailing list