[Mlir-commits] [mlir] [mlir][spirv] Improve folding of MemRef to SPIRV Lowering (PR #85433)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 19 16:42:29 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir-gpu

Author: Finn Plummer (inbelic)

<details>
<summary>Changes</summary>

Investigate the lowering of MemRef Load/Store ops and implement additional folding of created ops

Aims to improve readability of generated lowered SPIR-V code.

Part of work llvm#<!-- -->70704

---

Patch is 42.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85433.diff


8 Files Affected:

- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+28-24) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+5-4) 
- (modified) mlir/test/Conversion/GPUToSPIRV/load-store.mlir (+2-6) 
- (modified) mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir (+38-120) 
- (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+8-26) 
- (modified) mlir/test/Conversion/SCFToSPIRV/for.mlir (+2-10) 
- (modified) mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir (+4-6) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+6-14) 


``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 0acb2142f3f68a..81b9f55cac80f7 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
   assert(targetBits % sourceBits == 0);
   Type type = srcIdx.getType();
   IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
-  auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
+  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
   IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
-  auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
-  auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
-  return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
+  auto srcBitsValue =
+      builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
+  auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
+  return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
 }
 
 /// Returns an adjusted spirv::AccessChainOp. Based on the
@@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
   Value lastDim = op->getOperand(op.getNumOperands() - 1);
   Type type = lastDim.getType();
   IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
-  auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
+  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
   auto indices = llvm::to_vector<4>(op.getIndices());
   // There are two elements if this is a 1-D tensor.
   assert(indices.size() == 2);
-  indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
+  indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
   Type t = typeConverter.convertType(op.getComponentPtr().getType());
   return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
 }
@@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
     return srcBool;
   Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
   Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
-  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
+  return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
+                                               zero);
 }
 
 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
@@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
           loc, builder.getIntegerType(targetBits), value);
     }
 
-    value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
+    value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
   }
-  return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
-                                                   offset);
+  return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
+                                                         value, offset);
 }
 
 /// Returns true if the allocations of memref `type` generated from `allocOp`
@@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
     return srcInt;
 
   auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
-  return builder.create<spirv::IEqualOp>(loc, srcInt, one);
+  return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
 }
 
 //===----------------------------------------------------------------------===//
@@ -597,13 +599,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // ____XXXX________ -> ____________XXXX
   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
-  Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
+  Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
       loc, spvLoadOp.getType(), spvLoadOp, offset);
 
   // Apply the mask to extract corresponding bits.
-  Value mask = rewriter.create<spirv::ConstantOp>(
+  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
-  result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+  result =
+      rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
 
   // Apply sign extension on the loading value unconditionally. The signedness
   // semantic is carried in the operator itself, we relies other pattern to
@@ -611,11 +614,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   IntegerAttr shiftValueAttr =
       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
   Value shiftValue =
-      rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
-  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
-                                                      shiftValue);
-  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
-                                                          shiftValue);
+      rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
+  result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
+                                                            result, shiftValue);
+  result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
+      loc, dstType, result, shiftValue);
 
   rewriter.replaceOp(loadOp, result);
 
@@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
 
   // Create a mask to clear the destination. E.g., if it is the second i8 in
   // i32, 0xFFFF00FF is created.
-  Value mask = rewriter.create<spirv::ConstantOp>(
+  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
-  Value clearBitsMask =
-      rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
-  clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
+  Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
+      loc, dstType, mask, offset);
+  clearBitsMask =
+      rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
 
   Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
@@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
 
     int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
     Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
-    return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
+    return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
   }();
 
   rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e85..4072608dc8f873 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
   // broken down into progressive small steps so we can have intermediate steps
   // using other dialects. At the moment SPIR-V is the final sink.
 
-  Value linearizedIndex = builder.create<spirv::ConstantOp>(
+  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
       loc, integerType, IntegerAttr::get(integerType, offset));
   for (const auto &index : llvm::enumerate(indices)) {
-    Value strideVal = builder.create<spirv::ConstantOp>(
+    Value strideVal = builder.createOrFold<spirv::ConstantOp>(
         loc, integerType,
         IntegerAttr::get(integerType, strides[index.index()]));
-    Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+    Value update =
+        builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
     linearizedIndex =
-        builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
+        builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
   }
   return linearizedIndex;
 }
diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index fa12da8ef9d4ec..4339799ccd5eaf 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -60,13 +60,9 @@ module attributes {
       // CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
       %13 = arith.addi %arg4, %3 : index
       // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
-      // CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32
       // CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32
-      // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
-      // CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
-      // CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32
-      // CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
-      // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
+      // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32
+      // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32
       // CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
       // CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]]
       %14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
diff --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
index 470c8531e2e0fb..52ed14e8cce233 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
@@ -12,16 +12,10 @@ module attributes {
 // CHECK-LABEL: @load_i1
 func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 {
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: %[[T4:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -37,32 +31,20 @@ func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1
 // INDEX64-LABEL: @load_i8
 func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   //     CHECK: builtin.unrealized_conversion_cast %[[SR]]
 
   //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
-  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
-  //   INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
   //   INDEX64: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]] : i32
-  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
-  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
-  //   INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
   //   INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //   INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
   //   INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //   INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -76,15 +58,12 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
 func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 {
   //     CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  //     CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32
-  //     CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32
   //     CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
   //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
   //     CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
   //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
   //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
@@ -110,20 +89,12 @@ func.func @load_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
 func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %value: i1) {
   //     CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //     CHECK: %[[MASK:.+]] = spirv.Constant -256 : i32
   //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
   //     CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32
-  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
   //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CASTED_ARG1]]
   memref.store %value, %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
   return
 }
@@ -136,36 +107,22 @@ func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %val
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //     CHECK: %[[MASK2:.+]] = spirv.Constant -256 : i32
   //     CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
-  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
-  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
+  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
 
   //   INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //   INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
-  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
-  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
-  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
   //   INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //   INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
-  //   INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //   INDEX64: %[[MASK2:.+]] = spirv.Constant -256 : i32
   //   INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
-  //   INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
-  //   INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
-  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
+  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
   memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
   return
 }
@@ -177,19 +...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/85433


More information about the Mlir-commits mailing list