[Mlir-commits] [mlir] 87c0260 - [AMDGPU] Add parameterization for optimized shared memory variables (#82508)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 27 20:28:15 PST 2024
Author: erman-gurses
Date: 2024-02-27T23:28:12-05:00
New Revision: 87c0260f45e5a02cb07722d089dae3f0f84c7b3d
URL: https://github.com/llvm/llvm-project/commit/87c0260f45e5a02cb07722d089dae3f0f84c7b3d
DIFF: https://github.com/llvm/llvm-project/commit/87c0260f45e5a02cb07722d089dae3f0f84c7b3d.diff
LOG: [AMDGPU] Add parameterization for optimized shared memory variables (#82508)
- This PR adds parameterization for shared memory variables that are
used for optimization: `sharedMemoryLineSizeBytes` and
`defaultVectorSizeBits.`
- The default values are set to 128 for both variables since it gives
zero bank conflicts.
Added:
Modified:
mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
index 23873d86b495c6..0eb67050608630 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
@@ -13,8 +13,8 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
// ApplyOptimizeSharedMemoryReadsAndWritesOp
//===----------------------------------------------------------------------===//
@@ -28,7 +28,9 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
reads/writes with the goal of avoiding bank conflicts.
}];
- let arguments = (ins TransformHandleTypeInterface:$target);
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ DefaultValuedOptionalAttr<I64Attr, "128">:$sharedMemoryLineSizeBytes,
+ DefaultValuedOptionalAttr<I64Attr, "128">:$defaultVectorSizeBits);
let results = (outs);
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index c8059e6d316e8a..67f951fd19d172 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -37,10 +37,17 @@ def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts.
}];
-
let dependentDialects = [
"memref::MemRefDialect", "vector::VectorDialect"
];
+ let options = [
+ Option<"sharedMemoryLineSizeBytes", "shared-memory-line-size-bytes", "int64_t",
+ /*default=*/"128",
+ "Shared memory line size in bytes">,
+ Option<"defaultVectorSizeBits", "default-vector-size-bits", "int64_t",
+ /*default=*/"128",
+ "Default vector size in bits">,
+ ];
}
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index 79f9ab71a2b430..843cea2c503b9a 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -45,11 +45,15 @@ namespace amdgpu {
/// function that depends on the row Index. The permutation function is chosen
/// to ensure that sequential distributed+vectorized reads/writes down a single
/// dimension of the memref have minimal conflicts.
-LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
- Value memrefValue);
+LogicalResult
+optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue,
+ int64_t sharedMemoryLineSizeBytes,
+ int64_t defaultVectorSizeBits);
std::optional<LogicalResult>
-optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp);
+optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
+ int64_t sharedMemoryLineSizeBytes,
+ int64_t defaultVectorSizeBits);
} // namespace amdgpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
index ff29f9f6938535..b7e17a92897389 100644
--- a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
@@ -27,7 +27,8 @@ DiagnosedSilenceableFailure
ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
TransformRewriter &rewriter, FuncOp funcOp, ApplyToEachResultList &results,
TransformState &state) {
- optimizeSharedMemoryReadsAndWritesOp(funcOp);
+ optimizeSharedMemoryReadsAndWritesOp(funcOp, getSharedMemoryLineSizeBytes(),
+ getDefaultVectorSizeBits());
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 6bd03ed833898d..32fab265e03cc0 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -35,13 +35,6 @@ namespace amdgpu {
using namespace mlir;
using namespace mlir::amdgpu;
-/// The size of a shared memory line according to AMD documentation.
-/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
-constexpr int64_t kSharedMemoryLineSizeBytes = 64;
-/// We optimize for 64bit accesses, but this can be made an argument in the
-/// future.
-constexpr int64_t kDefaultVectorSizeBits = 64;
-
/// Uses `srcIndexValue` to permute `tgtIndexValue` via
/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
/// floordiv(tgtIdxVal,vectorSize)))
@@ -49,7 +42,9 @@ constexpr int64_t kDefaultVectorSizeBits = 64;
/// This is done using an optimized sequence of `arith` operations.
static Value permuteVectorOffset(OpBuilder &b, Location loc,
ArrayRef<Value> indices, MemRefType memrefTy,
- int64_t srcDim, int64_t tgtDim) {
+ int64_t srcDim, int64_t tgtDim,
+ int64_t sharedMemoryLineSizeBytes,
+ int64_t defaultVectorSizeBits) {
// Adjust the src index to change how often the permutation changes
// if necessary.
Value src = indices[srcDim];
@@ -57,9 +52,9 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
// We only want to permute every N iterations of the target dim where N is
// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
const int64_t permuteEveryN = std::max<int64_t>(
- 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
- memrefTy.getElementTypeBitWidth()) /
- 8));
+ 1, sharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
+ memrefTy.getElementTypeBitWidth()) /
+ 8));
// clang-format off
// Index bit representation (b0 = least significant bit) for dim(1)
@@ -71,7 +66,7 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
// bits[N:M] = vector index
// clang-format on
int64_t n =
- llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
+ llvm::Log2_64(defaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
// Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
@@ -105,9 +100,11 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
static void transformIndices(OpBuilder &builder, Location loc,
SmallVector<Value, 4> &indices,
MemRefType memrefTy, int64_t srcDim,
- int64_t tgtDim) {
+ int64_t tgtDim, int64_t sharedMemoryLineSizeBytes,
+ int64_t defaultVectorSizeBits) {
indices[tgtDim] =
- permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
+ permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim,
+ sharedMemoryLineSizeBytes, defaultVectorSizeBits);
}
// Return all operations within `parentOp` that read from or write to
@@ -149,8 +146,9 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
return success();
}
-LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
- Value memrefValue) {
+LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
+ Operation *parentOp, Value memrefValue, int64_t sharedMemoryLineSizeBytes,
+ int64_t defaultVectorSizeBits) {
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
if (!memRefType ||
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
@@ -167,10 +165,10 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
const int64_t rowsPerLine =
- (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
+ (8 * sharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
rowSize;
const int64_t threadGroupSize =
- 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
+ 1LL << (7 - llvm::Log2_64(defaultVectorSizeBits / 8));
if (rowsPerLine >= threadGroupSize)
return failure();
@@ -198,7 +196,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
auto indices = amdgpu::getIndices(shmWriteOp);
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
- memRefType, srcDim, tgtDim);
+ memRefType, srcDim, tgtDim, sharedMemoryLineSizeBytes,
+ defaultVectorSizeBits);
amdgpu::setIndices(shmWriteOp, transformedIndices);
}
@@ -210,7 +209,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
auto indices = amdgpu::getIndices(shmReadOp);
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
- memRefType, srcDim, tgtDim);
+ memRefType, srcDim, tgtDim, sharedMemoryLineSizeBytes,
+ defaultVectorSizeBits);
amdgpu::setIndices(shmReadOp, transformedIndices);
}
@@ -218,7 +218,9 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
}
std::optional<LogicalResult>
-amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
+ int64_t sharedMemoryLineSizeBytes,
+ int64_t defaultVectorSizeBits) {
SmallVector<memref::AllocOp> shmAllocOps;
funcOp.walk([&](memref::AllocOp allocOp) {
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -226,8 +228,9 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
shmAllocOps.push_back(allocOp);
});
for (auto allocOp : shmAllocOps) {
- if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
- allocOp.getMemref())))
+ if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(
+ funcOp, allocOp.getMemref(), sharedMemoryLineSizeBytes,
+ defaultVectorSizeBits)))
return failure();
}
return success();
@@ -237,7 +240,8 @@ struct OptimizeSharedMemoryPass
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
public:
OptimizeSharedMemoryPass() = default;
-
+ OptimizeSharedMemoryPass(const OptimizeSharedMemoryOptions &options)
+ : OptimizeSharedMemoryBase(options) {}
void runOnOperation() override {
Operation *op = getOperation();
SmallVector<memref::AllocOp> shmAllocOps;
@@ -248,8 +252,9 @@ struct OptimizeSharedMemoryPass
shmAllocOps.push_back(allocOp);
});
for (auto allocOp : shmAllocOps) {
- if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
- allocOp.getMemref())))
+ if (failed(optimizeSharedMemoryReadsAndWrites(op, allocOp.getMemref(),
+ sharedMemoryLineSizeBytes,
+ defaultVectorSizeBits)))
return;
}
}
diff --git a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
index a1de1ff87c229f..983eee732e2afe 100644
--- a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
@@ -1,13 +1,13 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
- func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
+ func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
%readRow: index, %readCol: index,
%writeRow: index, %writeCol: index,
- %fragRow: index, %fragCol: index,
+ %fragRow: index, %fragCol: index,
%fragColPerm: index,
%stRow: index, %stCol: index) {
- // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
+ // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
%cst = arith.constant 0.000000e+00 : f16
// CHECK: [[shmA:%.+]] = memref.alloc
@@ -15,42 +15,36 @@
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
- // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
- // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
gpu.barrier
gpu.barrier
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
- // CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
- // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
- // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
gpu.barrier
gpu.barrier
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
- // CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
return
}
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index 143e7c2d270952..b1bb91ffc29721 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
- func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
+ func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
%readRow: index, %readCol: index,
%writeRow: index, %writeCol: index,
- %fragRow: index, %fragCol: index,
+ %fragRow: index, %fragCol: index,
%fragColPerm: index,
%stRow: index, %stCol: index) {
%cst = arith.constant 0.000000e+00 : f16
@@ -13,33 +13,33 @@
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
gpu.barrier
gpu.barrier
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
gpu.barrier
gpu.barrier
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
return
@@ -48,7 +48,7 @@
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
- transform.amdgpu.optimize_shared_memory_reads_and_writes %0 : (!transform.any_op) -> ()
+ transform.amdgpu.optimize_shared_memory_reads_and_writes %0 {sharedMemoryLineSizeBytes = 128, defaultVectorSizeBits = 128}: (!transform.any_op) -> ()
transform.yield
} // @__transform_main
} // module
More information about the Mlir-commits
mailing list