[Mlir-commits] [mlir] 0065bd2 - [mlir][spirv] Fix loading bool with proper storage capabilities

Lei Zhang llvmlistbot at llvm.org
Fri Jul 30 15:07:28 PDT 2021


Author: Lei Zhang
Date: 2021-07-30T18:06:11-04:00
New Revision: 0065bd2ad59cd05f0ca762a1cb586d3bfe809f2e

URL: https://github.com/llvm/llvm-project/commit/0065bd2ad59cd05f0ca762a1cb586d3bfe809f2e
DIFF: https://github.com/llvm/llvm-project/commit/0065bd2ad59cd05f0ca762a1cb586d3bfe809f2e.diff

LOG: [mlir][spirv] Fix loading bool with proper storage capabilities

If the source value to load is bool, and we have native storage
capability support for the source bitwidth, we still cannot directly
rewrite uses; we need to perform casting to bool first.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D107119

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 899c582796a6..9f9f11538177 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -119,6 +119,15 @@ static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
   return {};
 }
 
+/// Casts the given `srcInt` into a boolean value.
+static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
+  if (srcInt.getType().isInteger(1))
+    return srcInt;
+
+  auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
+  return builder.create<spirv::IEqualOp>(loc, srcInt, one);
+}
+
 /// Casts the given `srcBool` into an integer of `dstType`.
 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
                             OpBuilder &builder) {
@@ -302,8 +311,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
   // If the rewrited load op has the same bit width, use the loading value
   // directly.
   if (srcBits == dstBits) {
-    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
-                                               accessChainOp.getResult());
+    Value loadVal =
+        rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
+    if (isBool)
+      loadVal = castIntNToBool(loc, loadVal, rewriter);
+    rewriter.replaceOp(loadOp, loadVal);
     return success();
   }
 
@@ -346,8 +358,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
   if (isBool) {
     dstType = typeConverter.convertType(loadOp.getType());
     mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
-    Value isOne = rewriter.create<spirv::IEqualOp>(loc, result, mask);
-    result = castBoolToIntN(loc, isOne, dstType, rewriter);
+    result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
   } else if (result.getType().getIntOrFloatBitWidth() !=
              static_cast<unsigned>(dstBits)) {
     result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index daf0ea0797a5..5c8c3c17d838 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -65,6 +65,25 @@ func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?x
   return
 }
 
+// CHECK-LABEL: func @load_i1
+//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1>, %[[IDX:.+]]: index)
+func @load_i1(%src: memref<4xi1>, %i : index) -> i1 {
+  // CHECK: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+  // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
+  // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+  // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
+  // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
+  // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[SRC_CAST]][%[[ZERO_0]], %[[ADD]]]
+  // CHECK: %[[VAL:.+]] = spv.Load "StorageBuffer" %[[ADDR]] : i8
+  // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
+  // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+  %0 = memref.load %src[%i] : memref<4xi1>
+  // CHECK: return %[[BOOL]]
+  return %0: i1
+}
+
 // CHECK-LABEL: func @store_i1
 //  CHECK-SAME: %[[DST:.+]]: memref<4xi1>,
 //  CHECK-SAME: %[[IDX:.+]]: index
@@ -77,7 +96,7 @@ func @store_i1(%dst: memref<4xi1>, %i: index) {
   // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
   // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
   // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
-  // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]] : !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>, i32, i32
+  // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]]
   // CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8
   // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
   // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8


        


More information about the Mlir-commits mailing list