[Mlir-commits] [mlir] 9f5300c - [mlir][spirv] Fix storing bool with proper storage capabilities

Lei Zhang llvmlistbot at llvm.org
Fri Jul 30 15:07:26 PDT 2021


Author: Lei Zhang
Date: 2021-07-30T18:06:10-04:00
New Revision: 9f5300c8be4576d93256b88d195a1eb44de189f6

URL: https://github.com/llvm/llvm-project/commit/9f5300c8be4576d93256b88d195a1eb44de189f6
DIFF: https://github.com/llvm/llvm-project/commit/9f5300c8be4576d93256b88d195a1eb44de189f6.diff

LOG: [mlir][spirv] Fix storing bool with proper storage capabilities

If the source value to store is bool, and we have native storage
capability support for the target bitwidth, we still cannot directly
store; we need to perform casting to match the target memref
element's bitwidth.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D107114

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 86eff5572919..24b2cf1dc422 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -304,6 +304,11 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> {
   let summary = "Convert MemRef dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertMemRefToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
+  let options = [
+    Option<"boolNumBits", "bool-num-bits",
+           "int", /*default=*/"8",
+           "The number of bits to store a boolean value">
+  ];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index ddc312eb0914..899c582796a6 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -119,6 +119,17 @@ static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
   return {};
 }
 
+/// Casts the given `srcBool` into an integer of `dstType`.
+static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
+                            OpBuilder &builder) {
+  assert(srcBool.getType().isInteger(1));
+  if (dstType.isInteger(1))
+    return srcBool;
+  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
+  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
+  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -336,9 +347,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
     dstType = typeConverter.convertType(loadOp.getType());
     mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
     Value isOne = rewriter.create<spirv::IEqualOp>(loc, result, mask);
-    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
-    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
-    result = rewriter.create<spirv::SelectOp>(loc, dstType, isOne, one, zero);
+    result = castBoolToIntN(loc, isOne, dstType, rewriter);
   } else if (result.getType().getIntOrFloatBitWidth() !=
              static_cast<unsigned>(dstBits)) {
     result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
@@ -392,6 +401,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
   bool isBool = srcBits == 1;
   if (isBool)
     srcBits = typeConverter.getOptions().boolNumBits;
+
   Type pointeeType = typeConverter.convertType(memrefType)
                          .cast<spirv::PointerType>()
                          .getPointeeType();
@@ -406,8 +416,11 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
   assert(dstBits % srcBits == 0);
 
   if (srcBits == dstBits) {
+    Value storeVal = storeOperands.value();
+    if (isBool)
+      storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
     rewriter.replaceOpWithNewOp<spirv::StoreOp>(
-        storeOp, accessChainOp.getResult(), storeOperands.value());
+        storeOp, accessChainOp.getResult(), storeVal);
     return success();
   }
 
@@ -435,12 +448,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
   clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
 
   Value storeVal = storeOperands.value();
-  if (isBool) {
-    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
-    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
-    storeVal =
-        rewriter.create<spirv::SelectOp>(loc, dstType, storeVal, one, zero);
-  }
+  if (isBool)
+    storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
   storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
index 79d137fd8a7e..cd0d3429c6b5 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
@@ -34,7 +34,9 @@ void ConvertMemRefToSPIRVPass::runOnOperation() {
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
-  SPIRVTypeConverter typeConverter(targetAttr);
+  SPIRVTypeConverter::Options options;
+  options.boolNumBits = this->boolNumBits;
+  SPIRVTypeConverter typeConverter(targetAttr, options);
 
   // Use UnrealizedConversionCast as the bridge so that we don't need to pull in
   // patterns for other dialects.

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index bb8b8a5e86d0..daf0ea0797a5 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,9 +1,17 @@
-// RUN: mlir-opt -split-input-file -convert-memref-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" %s -o - | FileCheck %s
+
+// Check that with proper compute and storage extensions, we don't need to
+// perform special tricks.
 
 module attributes {
   spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [Shader, Int8, Int16, Int64, Float16, Float64],
-             [SPV_KHR_storage_buffer_storage_class]>, {}>
+    #spv.vce<v1.0,
+      [
+        Shader, Int8, Int16, Int64, Float16, Float64,
+        StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
+        StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
+      ],
+      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, {}>
 } {
 
 // CHECK-LABEL: @load_store_zero_rank_float
@@ -57,6 +65,27 @@ func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?x
   return
 }
 
+// CHECK-LABEL: func @store_i1
+//  CHECK-SAME: %[[DST:.+]]: memref<4xi1>,
+//  CHECK-SAME: %[[IDX:.+]]: index
+func @store_i1(%dst: memref<4xi1>, %i: index) {
+  %true = constant true
+  // CHECK: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+  // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
+  // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+  // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
+  // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
+  // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]] : !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>, i32, i32
+  // CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8
+  // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
+  // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
+  // CHECK: spv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8
+  memref.store %true, %dst[%i]: memref<4xi1>
+  return
+}
+
 } // end module
 
 // -----
@@ -88,10 +117,7 @@ func @load_i1(%arg0: memref<i1>) -> i1 {
   //     CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   // Convert to i1 type.
   //     CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
-  //     CHECK: %[[ISONE:.+]]  = spv.IEqual %[[T4]], %[[ONE]] : i32
-  //     CHECK: %[[FALSE:.+]] = spv.Constant false
-  //     CHECK: %[[TRUE:.+]] = spv.Constant true
-  //     CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1
+  //     CHECK: %[[RES:.+]]  = spv.IEqual %[[T4]], %[[ONE]] : i32
   //     CHECK: return %[[RES]]
   %0 = memref.load %arg0[] : memref<i1>
   return %0 : i1


        


More information about the Mlir-commits mailing list