[Mlir-commits] [mlir] 38f8a3c - [mlir][spirv] Improve folding of MemRef to SPIRV Lowering (#85433)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 21 08:49:31 PDT 2024
Author: Finn Plummer
Date: 2024-03-21T08:49:27-07:00
New Revision: 38f8a3cf0d75cd25e13d3757027f7356e4466cb9
URL: https://github.com/llvm/llvm-project/commit/38f8a3cf0d75cd25e13d3757027f7356e4466cb9
DIFF: https://github.com/llvm/llvm-project/commit/38f8a3cf0d75cd25e13d3757027f7356e4466cb9.diff
LOG: [mlir][spirv] Improve folding of MemRef to SPIRV Lowering (#85433)
Investigate the lowering of MemRef Load/Store ops and implement
additional folding of created ops
Aims to improve readability of generated lowered SPIR-V code.
Part of work llvm#70704
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/GPUToSPIRV/load-store.mlir
mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
mlir/test/Conversion/SCFToSPIRV/for.mlir
mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 0acb2142f3f68a..81b9f55cac80f7 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
assert(targetBits % sourceBits == 0);
Type type = srcIdx.getType();
IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
- auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
+ auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
- auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
- auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
- return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
+ auto srcBitsValue =
+ builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
+ auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
+ return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
}
/// Returns an adjusted spirv::AccessChainOp. Based on the
@@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
Value lastDim = op->getOperand(op.getNumOperands() - 1);
Type type = lastDim.getType();
IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
- auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
+ auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
auto indices = llvm::to_vector<4>(op.getIndices());
// There are two elements if this is a 1-D tensor.
assert(indices.size() == 2);
- indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
+ indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
}
@@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
return srcBool;
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
- return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
+ return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
+ zero);
}
/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
@@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
loc, builder.getIntegerType(targetBits), value);
}
- value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
+ value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
}
- return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
- offset);
+ return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
+ value, offset);
}
/// Returns true if the allocations of memref `type` generated from `allocOp`
@@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
return srcInt;
auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
- return builder.create<spirv::IEqualOp>(loc, srcInt, one);
+ return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
}
//===----------------------------------------------------------------------===//
@@ -597,13 +599,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// ____XXXX________ -> ____________XXXX
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
- Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
+ Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
loc, spvLoadOp.getType(), spvLoadOp, offset);
// Apply the mask to extract corresponding bits.
- Value mask = rewriter.create<spirv::ConstantOp>(
+ Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
- result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+ result =
+ rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
// Apply sign extension on the loading value unconditionally. The signedness
// semantic is carried in the operator itself, we relies other pattern to
@@ -611,11 +614,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
IntegerAttr shiftValueAttr =
rewriter.getIntegerAttr(dstType, dstBits - srcBits);
Value shiftValue =
- rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
- result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
- shiftValue);
- result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
- shiftValue);
+ rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
+ result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
+ result, shiftValue);
+ result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
+ loc, dstType, result, shiftValue);
rewriter.replaceOp(loadOp, result);
@@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
// Create a mask to clear the destination. E.g., if it is the second i8 in
// i32, 0xFFFF00FF is created.
- Value mask = rewriter.create<spirv::ConstantOp>(
+ Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
- Value clearBitsMask =
- rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
- clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
+ Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
+ loc, dstType, mask, offset);
+ clearBitsMask =
+ rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
@@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
- return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
+ return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
}();
rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e85..4072608dc8f873 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
// broken down into progressive small steps so we can have intermediate steps
// using other dialects. At the moment SPIR-V is the final sink.
- Value linearizedIndex = builder.create<spirv::ConstantOp>(
+ Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
loc, integerType, IntegerAttr::get(integerType, offset));
for (const auto &index : llvm::enumerate(indices)) {
- Value strideVal = builder.create<spirv::ConstantOp>(
+ Value strideVal = builder.createOrFold<spirv::ConstantOp>(
loc, integerType,
IntegerAttr::get(integerType, strides[index.index()]));
- Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+ Value update =
+ builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
linearizedIndex =
- builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
+ builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
}
return linearizedIndex;
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index fa12da8ef9d4ec..4339799ccd5eaf 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -60,13 +60,9 @@ module attributes {
// CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
%13 = arith.addi %arg4, %3 : index
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
- // CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32
// CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32
- // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
- // CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
- // CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32
- // CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
- // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
+ // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32
+ // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32
// CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
// CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]]
%14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
diff --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
index 470c8531e2e0fb..52ed14e8cce233 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
@@ -12,16 +12,10 @@ module attributes {
// CHECK-LABEL: @load_i1
func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 {
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
// CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// CHECK: %[[T4:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -37,32 +31,20 @@ func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1
// INDEX64-LABEL: @load_i8
func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
// CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
// CHECK: builtin.unrealized_conversion_cast %[[SR]]
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
- // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
- // INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+ // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
// INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
- // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
- // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
- // INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
// INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
- // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
// INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -76,15 +58,12 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 {
// CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32
- // CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32
// CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32
+ // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
// CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
+ // CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
// CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
// CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
@@ -110,20 +89,12 @@ func.func @load_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %value: i1) {
// CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // CHECK: %[[MASK:.+]] = spirv.Constant -256 : i32
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
// CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32
- // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
// CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CASTED_ARG1]]
memref.store %value, %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
return
}
@@ -136,36 +107,22 @@ func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %val
// CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
// CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // CHECK: %[[MASK2:.+]] = spirv.Constant -256 : i32
// CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
- // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
- // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
- // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
+ // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+ // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
// INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
// INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
- // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
- // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
- // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
// INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
- // INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // INDEX64: %[[MASK2:.+]] = spirv.Constant -256 : i32
// INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
- // INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
- // INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
- // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
+ // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+ // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
return
}
@@ -177,19 +134,16 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
// CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32
- // CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32
// CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
// CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
+ // CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
// CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
// CHECK: %[[MASK1:.+]] = spirv.Constant 65535 : i32
// CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
// CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG2_CAST]], %[[MASK1]] : i32
// CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
- // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32
+ // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
// CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
// CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
// CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
@@ -222,15 +176,12 @@ module attributes {
func.func @load_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %i: index) -> i4 {
// CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32
- // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
// CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32
+ // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32
// CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
// CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32
+ // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32
// CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
// CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
@@ -248,19 +199,16 @@ func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %v
// CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32
// CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32
- // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
// CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant [[OFFSET]] : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32
+ // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
+ // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32
// CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
// CHECK: %[[MASK1:.+]] = spirv.Constant 15 : i32
// CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[BITS]] : i32, i32
// CHECK: %[[MASK2:.+]] = spirv.Not %[[SL]] : i32
// CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[VAL]], %[[MASK1]] : i32
// CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[BITS]] : i32, i32
- // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32
+ // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32
// CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ACCESS_INDEX]]]
// CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
// CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
@@ -283,16 +231,10 @@ module attributes {
// INDEX64-LABEL: @load_i8
func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]]
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
// CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -300,16 +242,10 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
// CHECK: return %[[CAST]] : i8
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
- // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
- // INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+ // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
// INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
- // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
- // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
- // INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
// INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
- // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
// INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
// INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
// INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -326,37 +262,19 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) {
// CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
- // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
- // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
- // CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
- // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // CHECK: %[[MASK1:.+]] = spirv.Constant -256 : i32
// CHECK: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
- // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
- // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
- // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
- // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
- // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
+ // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]]
+ // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]]
// INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
- // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
- // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
- // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
- // INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
- // INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
- // INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+ // INDEX64: %[[MASK1:.+]] = spirv.Constant -256 : i32
// INDEX64: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
- // INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
- // INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
- // INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
- // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
- // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
- // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+ // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
+ // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]]
+ // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]]
memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
return
}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index feb6d4e924015f..10c03a270005f1 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -70,11 +70,8 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora
func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
// CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
- // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
- // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
- // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[ADD]]]
+ // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
// CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
// CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
@@ -90,15 +87,10 @@ func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i:
%true = arith.constant true
// CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
- // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
- // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
- // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ZERO]], %[[ADD]]]
- // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+ // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ZERO]], %[[IDX_CAST]]]
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
- // CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
- // CHECK: spirv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8
+ // CHECK: spirv.Store "StorageBuffer" %[[ADDR]], %[[ONE_I8]] : i8
memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<StorageBuffer>>
return
}
@@ -234,11 +226,7 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora
func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i : index) -> i1 {
// CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
- // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
- // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
- // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ADD]]]
+ // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]]
// CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
// CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
@@ -254,15 +242,9 @@ func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
%true = arith.constant true
// CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
- // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
- // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
- // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ADD]]]
- // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+ // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[IDX_CAST]]]
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
- // CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
- // CHECK: spirv.Store "CrossWorkgroup" %[[ADDR]], %[[RES]] : i8
+ // CHECK: spirv.Store "CrossWorkgroup" %[[ADDR]], %[[ONE_I8]] : i8
memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>
return
}
diff --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir
index 02558463b8662d..81661ec7a3a060 100644
--- a/mlir/test/Conversion/SCFToSPIRV/for.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir
@@ -19,17 +19,9 @@ func.func @loop_kernel(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer
// CHECK: spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
// CHECK: ^[[BODY]]:
// CHECK: %[[ZERO1:.*]] = spirv.Constant 0 : i32
- // CHECK: %[[OFFSET1:.*]] = spirv.Constant 0 : i32
- // CHECK: %[[STRIDE1:.*]] = spirv.Constant 1 : i32
- // CHECK: %[[UPDATE1:.*]] = spirv.IMul %[[STRIDE1]], %[[INDVAR]] : i32
- // CHECK: %[[INDEX1:.*]] = spirv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32
- // CHECK: spirv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}}
+ // CHECK: spirv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDVAR]]{{\]}}
// CHECK: %[[ZERO2:.*]] = spirv.Constant 0 : i32
- // CHECK: %[[OFFSET2:.*]] = spirv.Constant 0 : i32
- // CHECK: %[[STRIDE2:.*]] = spirv.Constant 1 : i32
- // CHECK: %[[UPDATE2:.*]] = spirv.IMul %[[STRIDE2]], %[[INDVAR]] : i32
- // CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32
- // CHECK: spirv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]]
+ // CHECK: spirv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDVAR]]]
// CHECK: %[[INCREMENT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32
// CHECK: spirv.Branch ^[[HEADER]](%[[INCREMENT]] : i32)
// CHECK: ^[[MERGE]]
diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index 19de613bf5b073..32d0fbea65b164 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -14,14 +14,12 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
// CHECK: spirv.Store "Function" %[[VAR]], %[[CST]] : !spirv.array<12 x i32>
// CHECK: %[[C0:.+]] = spirv.Constant 0 : i32
// CHECK: %[[C6:.+]] = spirv.Constant 6 : i32
- // CHECK: %[[MUL0:.+]] = spirv.IMul %[[C6]], %[[A]] : i32
- // CHECK: %[[ADD0:.+]] = spirv.IAdd %[[C0]], %[[MUL0]] : i32
+ // CHECK: %[[MUL0:.+]] = spirv.IMul %[[A]], %[[C6]] : i32
// CHECK: %[[C3:.+]] = spirv.Constant 3 : i32
- // CHECK: %[[MUL1:.+]] = spirv.IMul %[[C3]], %[[B]] : i32
- // CHECK: %[[ADD1:.+]] = spirv.IAdd %[[ADD0]], %[[MUL1]] : i32
+ // CHECK: %[[MUL1:.+]] = spirv.IMul %[[B]], %[[C3]] : i32
+ // CHECK: %[[ADD1:.+]] = spirv.IAdd %[[MUL1]], %[[MUL0]] : i32
// CHECK: %[[C1:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[MUL2:.+]] = spirv.IMul %[[C1]], %[[C]] : i32
- // CHECK: %[[ADD2:.+]] = spirv.IAdd %[[ADD1]], %[[MUL2]] : i32
+ // CHECK: %[[ADD2:.+]] = spirv.IAdd %[[C]], %[[ADD1]] : i32
// CHECK: %[[AC:.+]] = spirv.AccessChain %[[VAR]][%[[ADD2]]]
// CHECK: %[[VAL:.+]] = spirv.Load "Function" %[[AC]] : i32
%extract = tensor.extract %cst[%a, %b, %c] : tensor<2x2x3xi32>
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c9984091d5acc6..cddc4ee385357d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -720,9 +720,7 @@ module attributes {
// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
-// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
-// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
-// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
// CHECK: return %[[R0]] : vector<4xf32>
@@ -743,11 +741,9 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
-// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
-// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+// CHECK: %[[S3:.+]] = spirv.IMul %[[S1]], %[[CST4]] : i32
// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
-// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
-// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+// CHECK: %[[S6:.+]] = spirv.IAdd %[[S2]], %[[S3]] : i32
// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S8]] : vector<4xf32>
@@ -768,9 +764,7 @@ func.func @vector_load_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBu
// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
-// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
-// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
-// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
// CHECK: spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
@@ -790,11 +784,9 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
-// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
-// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+// CHECK: %[[S3:.+]] = spirv.IMul %[[S1]], %[[CST4]] : i32
// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
-// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
-// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+// CHECK: %[[S6:.+]] = spirv.IAdd %[[S2]], %[[S3]] : i32
// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
// CHECK: spirv.Store "StorageBuffer" %[[S8]], %[[ARG1]] : vector<4xf32>
More information about the Mlir-commits
mailing list