[Mlir-commits] [mlir] 887e1aa - [mlir][spirv] Fix sub-word `memref.store` conversion

Jakub Kuderski llvmlistbot at llvm.org
Tue Sep 5 11:37:00 PDT 2023


Author: Jakub Kuderski
Date: 2023-09-05T14:35:27-04:00
New Revision: 887e1aa330ddf8f03b6908f3dd57a0d1c18c20ec

URL: https://github.com/llvm/llvm-project/commit/887e1aa330ddf8f03b6908f3dd57a0d1c18c20ec
DIFF: https://github.com/llvm/llvm-project/commit/887e1aa330ddf8f03b6908f3dd57a0d1c18c20ec.diff

LOG: [mlir][spirv] Fix sub-word `memref.store` conversion

Support environments where logical types do not necessarily correspond to allowed storage access types.

Also make pattern match failures more descriptive.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D159386

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 8a03e01d0ccb098..acddb3c4da461f3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -77,11 +77,37 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
   return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
 }
 
-/// Returns the shifted `targetBits`-bit value with the given offset.
+/// Casts the given `srcBool` into an integer of `dstType`.
+static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
+                            OpBuilder &builder) {
+  assert(srcBool.getType().isInteger(1));
+  if (dstType.isInteger(1))
+    return srcBool;
+  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
+  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
+  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
+}
+
+/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
+/// to the type destination type, and masked.
 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
-                        int targetBits, OpBuilder &builder) {
-  Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
-  return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), result,
+                        OpBuilder &builder) {
+  IntegerType dstType = cast<IntegerType>(mask.getType());
+  int targetBits = static_cast<int>(dstType.getWidth());
+  int valueBits = value.getType().getIntOrFloatBitWidth();
+  assert(valueBits <= targetBits);
+
+  if (valueBits == 1) {
+    value = castBoolToIntN(loc, value, dstType, builder);
+  } else {
+    if (valueBits < targetBits) {
+      value = builder.create<spirv::UConvertOp>(
+          loc, builder.getIntegerType(targetBits), value);
+    }
+
+    value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
+  }
+  return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
                                                    offset);
 }
 
@@ -136,17 +162,6 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
   return builder.create<spirv::IEqualOp>(loc, srcInt, one);
 }
 
-/// Casts the given `srcBool` into an integer of `dstType`.
-static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
-                            OpBuilder &builder) {
-  assert(srcBool.getType().isInteger(1));
-  if (dstType.isInteger(1))
-    return srcBool;
-  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
-  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
-  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
-}
-
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -553,7 +568,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
   auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
   if (!memrefType.getElementType().isSignlessInteger())
-    return failure();
+    return rewriter.notifyMatchFailure(storeOp,
+                                       "element type is not a signless int");
 
   auto loc = storeOp.getLoc();
   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
@@ -562,7 +578,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                            adaptor.getIndices(), loc, rewriter);
 
   if (!accessChain)
-    return failure();
+    return rewriter.notifyMatchFailure(
+        storeOp, "failed to convert element pointer type");
 
   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
 
@@ -576,23 +593,28 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                        "failed to convert memref type");
 
   Type pointeeType = pointerType.getPointeeType();
-  Type dstType;
+  IntegerType dstType;
   if (typeConverter.allows(spirv::Capability::Kernel)) {
     if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
-      dstType = arrayType.getElementType();
+      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
     else
-      dstType = pointeeType;
+      dstType = dyn_cast<IntegerType>(pointeeType);
   } else {
     // For Vulkan we need to extract element from wrapping struct and array.
     Type structElemType =
         cast<spirv::StructType>(pointeeType).getElementType(0);
     if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
-      dstType = arrayType.getElementType();
+      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
     else
-      dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
+      dstType = dyn_cast<IntegerType>(
+          cast<spirv::RuntimeArrayType>(structElemType).getElementType());
   }
 
-  int dstBits = dstType.getIntOrFloatBitWidth();
+  if (!dstType)
+    return rewriter.notifyMatchFailure(
+        storeOp, "failed to determine destination element type");
+
+  int dstBits = static_cast<int>(dstType.getWidth());
   assert(dstBits % srcBits == 0);
 
   if (srcBits == dstBits) {
@@ -612,17 +634,17 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   if (!accessChainOp)
     return failure();
 
-  // Since there are multi threads in the processing, the emulation will be done
-  // with atomic operations. E.g., if the storing value is i8, rewrite the
-  // StoreOp to
+  // Since there are multiple threads in the processing, the emulation will be
+  // done with atomic operations. E.g., if the stored value is i8, rewrite the
+  // StoreOp to:
   // 1) load a 32-bit integer
-  // 2) clear 8 bits in the loading value
-  // 3) store 32-bit value back
-  // 4) load a 32-bit integer
-  // 5) modify 8 bits in the loading value
-  // 6) store 32-bit value back
-  // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
-  // 4 to step 6 are done by AtomicOr as another atomic step.
+  // 2) clear 8 bits in the loaded value
+  // 3) set 8 bits in the loaded value
+  // 4) store 32-bit value back
+  //
+  // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
+  // loaded 32-bit value and the shifted 8-bit store value) as another atomic
+  // step.
   assert(accessChainOp.getIndices().size() == 2);
   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
@@ -635,15 +657,13 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
       rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
   clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
 
-  Value storeVal = adaptor.getValue();
-  if (isBool)
-    storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
-  storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
+  Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
   std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
   if (!scope)
-    return failure();
+    return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
+
   Value result = rewriter.create<spirv::AtomicAndOp>(
       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
       clearBitsMask);
@@ -740,13 +760,13 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
   auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
   if (memrefType.getElementType().isSignlessInteger())
-    return failure();
+    return rewriter.notifyMatchFailure(storeOp, "signless int");
   auto storePtr = spirv::getElementPtr(
       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
       adaptor.getIndices(), storeOp.getLoc(), rewriter);
 
   if (!storePtr)
-    return failure();
+    return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
 
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
                                               adaptor.getValue());

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
index d4d535080d6b356..928bd82c2cb2d81 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
@@ -119,8 +119,7 @@ func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %val
   //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
   //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
   //     CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32
-  //     CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32
-  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32
   //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
   //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
@@ -270,3 +269,96 @@ func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %v
 }
 
 } // end module
+
+// -----
+
+// Check that we can access i8 storage with i8 types available but without
+// 8-bit storage capabilities.
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader, Int64, Int8], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_i8
+// INDEX64-LABEL: @load_i8
+func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
+  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
+  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
+  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
+  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
+  //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
+  //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
+  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
+  //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
+  //     CHECK: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8
+  //     CHECK: return %[[CAST]] : i8
+
+  //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
+  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
+  //   INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+  //   INDEX64: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]] : i32
+  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
+  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
+  //   INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
+  //   INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
+  //   INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
+  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //   INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
+  //   INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //   INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
+  //   INDEX64: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8
+  //   INDEX64: return %[[CAST]] : i8
+  %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
+  return %0 : i8
+}
+
+// CHECK-LABEL: @store_i8
+//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
+// INDEX64-LABEL: @store_i8
+//       INDEX64: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
+func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) {
+  //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
+  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
+  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
+  //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
+  //     CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
+  //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //     CHECK: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
+  //     CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
+  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
+  //     CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+  //     CHECK: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+
+  //   INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
+  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
+  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
+  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
+  //   INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
+  //   INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
+  //   INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
+  //   INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //   INDEX64: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
+  //   INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
+  //   INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
+  //   INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
+  //   INDEX64: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+  //   INDEX64: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+  memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
+  return
+}
+
+} // end module


        


More information about the Mlir-commits mailing list