[Mlir-commits] [mlir] [mlir][spirv] Improve folding of MemRef to SPIRV Lowering (PR #85433)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 19 16:42:29 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-gpu
Author: Finn Plummer (inbelic)
<details>
<summary>Changes</summary>
Investigate the lowering of MemRef Load/Store ops and implement additional folding of created ops
Aims to improve readability of generated lowered SPIR-V code.
Part of work llvm#<!-- -->70704
---
Patch is 42.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85433.diff
8 Files Affected:
- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+28-24)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+5-4)
- (modified) mlir/test/Conversion/GPUToSPIRV/load-store.mlir (+2-6)
- (modified) mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir (+38-120)
- (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+8-26)
- (modified) mlir/test/Conversion/SCFToSPIRV/for.mlir (+2-10)
- (modified) mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir (+4-6)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+6-14)
``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 0acb2142f3f68a..81b9f55cac80f7 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
assert(targetBits % sourceBits == 0);
Type type = srcIdx.getType();
IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
- auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
+ auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
- auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
- auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
- return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
+ auto srcBitsValue =
+ builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
+ auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
+ return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
}
/// Returns an adjusted spirv::AccessChainOp. Based on the
@@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
Value lastDim = op->getOperand(op.getNumOperands() - 1);
Type type = lastDim.getType();
IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
- auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
+ auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
auto indices = llvm::to_vector<4>(op.getIndices());
// There are two elements if this is a 1-D tensor.
assert(indices.size() == 2);
- indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
+ indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
}
@@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
return srcBool;
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
- return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
+ return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
+ zero);
}
/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
@@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
loc, builder.getIntegerType(targetBits), value);
}
- value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
+ value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
}
- return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
- offset);
+ return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
+ value, offset);
}
/// Returns true if the allocations of memref `type` generated from `allocOp`
@@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
return srcInt;
auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
- return builder.create<spirv::IEqualOp>(loc, srcInt, one);
+ return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
}
//===----------------------------------------------------------------------===//
@@ -597,13 +599,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// ____XXXX________ -> ____________XXXX
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
- Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
+ Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
loc, spvLoadOp.getType(), spvLoadOp, offset);
// Apply the mask to extract corresponding bits.
- Value mask = rewriter.create<spirv::ConstantOp>(
+ Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
- result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+ result =
+ rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
// Apply sign extension on the loading value unconditionally. The signedness
// semantic is carried in the operator itself, we relies other pattern to
@@ -611,11 +614,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
IntegerAttr shiftValueAttr =
rewriter.getIntegerAttr(dstType, dstBits - srcBits);
Value shiftValue =
- rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
- result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
- shiftValue);
- result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
- shiftValue);
+ rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
+ result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
+ result, shiftValue);
+ result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
+ loc, dstType, result, shiftValue);
rewriter.replaceOp(loadOp, result);
@@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
// Create a mask to clear the destination. E.g., if it is the second i8 in
// i32, 0xFFFF00FF is created.
- Value mask = rewriter.create<spirv::ConstantOp>(
+ Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
- Value clearBitsMask =
- rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
- clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
+ Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
+ loc, dstType, mask, offset);
+ clearBitsMask =
+ rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
@@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
- return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
+ return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
}();
rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e85..4072608dc8f873 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
// broken down into progressive small steps so we can have intermediate steps
// using other dialects. At the moment SPIR-V is the final sink.
- Value linearizedIndex = builder.create<spirv::ConstantOp>(
+ Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
loc, integerType, IntegerAttr::get(integerType, offset));
for (const auto &index : llvm::enumerate(indices)) {
- Value strideVal = builder.create<spirv::ConstantOp>(
+ Value strideVal = builder.createOrFold<spirv::ConstantOp>(
loc, integerType,
IntegerAttr::get(integerType, strides[index.index()]));
- Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+ Value update =
+ builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
linearizedIndex =
- builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
+ builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
}
return linearizedIndex;
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index fa12da8ef9d4ec..4339799ccd5eaf 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -60,13 +60,9 @@ module attributes {
// CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
%13 = arith.addi %arg4, %3 : index
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
- // CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32
// CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32
- // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
- // CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
- // CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32
- // CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
- // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
+ // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32
+ // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32
// CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
// CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]]
%14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
diff --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
index 470c8531e2e0fb..52ed14e8cce233 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
@@ -12,16 +12,10 @@ module attributes {
// CHECK-LABEL: @load_i1
func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 {
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
// CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// CHECK: %[[T4:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -37,32 +31,20 @@ func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1
// INDEX64-LABEL: @load_i8
func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
// CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
// CHECK: builtin.unrealized_conversion_cast %[[SR]]
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
- // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
- // INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+ // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
// INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
- // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
- // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
- // INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
// INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
- // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
// INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -76,15 +58,12 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 {
// CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32
- // CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32
// CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32
+ // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
// CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
+ // CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
// CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
// CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
@@ -110,20 +89,12 @@ func.func @load_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %value: i1) {
// CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // CHECK: %[[MASK:.+]] = spirv.Constant -256 : i32
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
// CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32
- // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
// CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CASTED_ARG1]]
memref.store %value, %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
return
}
@@ -136,36 +107,22 @@ func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %val
// CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
// CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // CHECK: %[[MASK2:.+]] = spirv.Constant -256 : i32
// CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
- // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
- // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
- // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
+ // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+ // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
// INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
// INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
- // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
- // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
- // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
// INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
- // INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // INDEX64: %[[MASK2:.+]] = spirv.Constant -256 : i32
// INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
- // INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
- // INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
- // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
+ // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+ // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
return
}
@@ -177,19 +...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/85433
More information about the Mlir-commits
mailing list