[Mlir-commits] [mlir] 4ffc63a - [mlir][spirv] Fix 64-bit index for MemRef bitwidth emulation

Lei Zhang llvmlistbot at llvm.org
Tue Aug 22 08:33:39 PDT 2023


Author: Lei Zhang
Date: 2023-08-22T08:32:27-07:00
New Revision: 4ffc63ab71e501eb06664af3e3e0c410560f7ebe

URL: https://github.com/llvm/llvm-project/commit/4ffc63ab71e501eb06664af3e3e0c410560f7ebe
DIFF: https://github.com/llvm/llvm-project/commit/4ffc63ab71e501eb06664af3e3e0c410560f7ebe.diff

LOG: [mlir][spirv] Fix 64-bit index for MemRef bitwidth emulation

We need to use the converted index type for index offset calculation
logic; not the target bitwidth, which is typically 32-bit.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D158482

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
    mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d239cb12e03a5a..3c2962ab86f631 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -745,7 +745,10 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
   let options = [
     Option<"boolNumBits", "bool-num-bits",
            "int", /*default=*/"8",
-           "The number of bits to store a boolean value">
+           "The number of bits to store a boolean value">,
+    Option<"use64bitIndex", "use-64bit-index",
+           "bool", /*default=*/"false",
+           "Use 64-bit integers to convert index types">
   ];
 }
 

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index f024bdfda93888..8a03e01d0ccb09 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -42,15 +42,13 @@ using namespace mlir;
 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
                                   int targetBits, OpBuilder &builder) {
   assert(targetBits % sourceBits == 0);
-  IntegerType targetType = builder.getIntegerType(targetBits);
-  IntegerAttr idxAttr =
-      builder.getIntegerAttr(targetType, targetBits / sourceBits);
-  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
-  IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
-  auto srcBitsValue =
-      builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
+  Type type = srcIdx.getType();
+  IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
+  auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
+  IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
+  auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
   auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
-  return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
+  return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
 }
 
 /// Returns an adjusted spirv::AccessChainOp. Based on the
@@ -58,7 +56,7 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
 /// supported. During conversion if a memref of an unsupported type is used,
 /// load/stores to this memref need to be modified to use a supported higher
 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
-/// 1D array (spirv.array or spirv.rt_array), the last index is modified to load
+/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
 /// the bits needed. The extraction of the actual bits needed are handled
 /// separately. Note that this only works for a 1-D tensor.
 static Value
@@ -67,11 +65,10 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
                              int targetBits, OpBuilder &builder) {
   assert(targetBits % sourceBits == 0);
   const auto loc = op.getLoc();
-  IntegerType targetType = builder.getIntegerType(targetBits);
-  IntegerAttr attr =
-      builder.getIntegerAttr(targetType, targetBits / sourceBits);
-  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
-  auto lastDim = op->getOperand(op.getNumOperands() - 1);
+  Value lastDim = op->getOperand(op.getNumOperands() - 1);
+  Type type = lastDim.getType();
+  IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
+  auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
   auto indices = llvm::to_vector<4>(op.getIndices());
   // There are two elements if this is a 1-D tensor.
   assert(indices.size() == 2);
@@ -83,9 +80,8 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
 /// Returns the shifted `targetBits`-bit value with the given offset.
 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
                         int targetBits, OpBuilder &builder) {
-  Type targetType = builder.getIntegerType(targetBits);
   Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
-  return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
+  return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), result,
                                                    offset);
 }
 

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
index e2ce927cc8fdfd..2e383b916970be 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
@@ -41,6 +41,7 @@ void ConvertMemRefToSPIRVPass::runOnOperation() {
 
   SPIRVConversionOptions options;
   options.boolNumBits = this->boolNumBits;
+  options.use64bitIndex = this->use64bitIndex;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
   // Use UnrealizedConversionCast as the bridge so that we don't need to pull in

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
index 84ec3e594ac44d..d4d535080d6b35 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
@@ -1,11 +1,12 @@
 // RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8 use-64bit-index" -cse %s -o - | FileCheck %s --check-prefix=INDEX64
 
 // Check that access chain indices are properly adjusted if non-32-bit types are
 // emulated via 32-bit types.
 // TODO: Test i64 types.
 module attributes {
   spirv.target_env = #spirv.target_env<
-    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+    #spirv.vce<v1.0, [Shader, Int64], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
 } {
 
 // CHECK-LABEL: @load_i1
@@ -33,6 +34,7 @@ func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1
 }
 
 // CHECK-LABEL: @load_i8
+// INDEX64-LABEL: @load_i8
 func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
   //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
@@ -49,6 +51,22 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   //     CHECK: builtin.unrealized_conversion_cast %[[SR]]
+
+  //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
+  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
+  //   INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+  //   INDEX64: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]] : i32
+  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
+  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
+  //   INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
+  //   INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
+  //   INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
+  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //   INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
+  //   INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //   INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
+  //   INDEX64: builtin.unrealized_conversion_cast %[[SR]]
   %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
   return %0 : i8
 }
@@ -113,6 +131,8 @@ func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %val
 
 // CHECK-LABEL: @store_i8
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
+// INDEX64-LABEL: @store_i8
+//       INDEX64: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
 func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) {
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
@@ -130,6 +150,23 @@ func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %val
   //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
   //     CHECK: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+
+  //   INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
+  //   INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
+  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
+  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
+  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
+  //   INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
+  //   INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
+  //   INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
+  //   INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //   INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
+  //   INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
+  //   INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
+  //   INDEX64: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+  //   INDEX64: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
   memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
   return
 }
@@ -179,7 +216,7 @@ func.func @store_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>, %v
 // emulated via 32-bit types.
 module attributes {
   spirv.target_env = #spirv.target_env<
-    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+    #spirv.vce<v1.0, [Shader, Int64], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
 } {
 
 // CHECK-LABEL: @load_i4


        


More information about the Mlir-commits mailing list