[Mlir-commits] [mlir] 7db25f7 - [WIP] Add support for MMA conversion for 1-D vector.transfer followed by a broadcast to 2-D
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Dec 1 02:50:00 PST 2022
Author: Nicolas Vasilache
Date: 2022-12-01T02:49:47-08:00
New Revision: 7db25f78db807da171f23bcbaff258c5677901d1
URL: https://github.com/llvm/llvm-project/commit/7db25f78db807da171f23bcbaff258c5677901d1
DIFF: https://github.com/llvm/llvm-project/commit/7db25f78db807da171f23bcbaff258c5677901d1.diff
LOG: [WIP] Add support for MMA conversion for 1-D vector.transfer followed by a broadcast to 2-D
Differential Revision: https://reviews.llvm.org/D139040
Added:
Modified:
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 2734b5f1660b..1da8dc4a872c 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -150,6 +150,26 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
return true;
}
+// Return true if the transfer op can be converted to a MMA matrix load.
+static bool transferReadFollowedByBroadcastSupportsMMAMatrixType(
+ vector::TransferReadOp readOp, bool useNvGpu) {
+ bool res = true;
+ if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
+ readOp.getVectorType().getRank() != 1)
+ res = false;
+ if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
+ res = false;
+ AffineMap map = readOp.getPermutationMap();
+ OpBuilder b(readOp.getContext());
+
+ if (res && !useNvGpu)
+ return map.isMinorIdentity() || isTransposeMatrixLoadMap(b, map);
+
+ llvm::errs() << "RES transferReadFollowedByBroadcastSupportsMMAMatrixType: "
+ << res << "\n";
+ return res;
+}
+
// Return true if the transfer op can be converted to a MMA matrix store.
static bool
transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
@@ -179,8 +199,27 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
/// Return true if this is a broadcast from scalar to a 2D vector.
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
- return broadcastOp.getVectorType().getRank() == 2 &&
- broadcastOp.getSource().getType().isa<FloatType>();
+ auto res = broadcastOp.getVectorType().getRank() == 2 &&
+ broadcastOp.getSource().getType().isa<FloatType>();
+ llvm::errs() << "RES broadcastSupportsMMAMatrixType: " << res << "\n";
+ return res;
+}
+
+/// Return true if this is a broadcast from 1-D to a 2-D vector and the 1-D
+/// vector comes from a TransferReadOp.
+static bool
+broadcastFromTransferReadSupportsMMAMatrixType(vector::BroadcastOp broadcastOp,
+ bool useNvGpu) {
+ auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
+ auto sourceVectorType =
+ broadcastOp.getSource().getType().dyn_cast<VectorType>();
+ auto res =
+ !broadcastSupportsMMAMatrixType(broadcastOp) && sourceVectorType &&
+ sourceVectorType.getRank() == 1 &&
+ transferReadFollowedByBroadcastSupportsMMAMatrixType(readOp, useNvGpu);
+ llvm::errs() << "RES broadcastFromTransferReadSupportsMMAMatrixType: " << res
+ << "\n";
+ return res;
}
/// Return the MMA elementwise enum associated with `op` if it is supported.
@@ -219,9 +258,10 @@ extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
if (failed(contractOp))
return false;
- // Handle vector.extract_strided_slice on registers containing
- // matrixB and matrixC operands. vector.extract_strided_slice op
- // is not supported on registers containing matrixA operands.
+ // Handle vector.extract_strided_slice on registers
+ // containing matrixB and matrixC operands.
+ // vector.extract_strided_slice op is not supported on
+ // registers containing matrixA operands.
if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
return (op->getResult(0).getType().cast<VectorType>() ==
(*contractOp).getRhs().getType().cast<VectorType>());
@@ -236,7 +276,9 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
if (isa<scf::ForOp, scf::YieldOp>(op))
return true;
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
- return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
+ return transferReadSupportsMMAMatrixType(transferRead, useNvGpu) ||
+ transferReadFollowedByBroadcastSupportsMMAMatrixType(transferRead,
+ useNvGpu);
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteSupportsMMAMatrixType(transferWrite);
if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
@@ -246,8 +288,10 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
return contractSupportsMMAMatrixType(contract, useNvGpu);
if (auto constant = dyn_cast<arith::ConstantOp>(op))
return constantSupportsMMAMatrixType(constant);
- if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
- return broadcastSupportsMMAMatrixType(broadcast);
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
+ return broadcastSupportsMMAMatrixType(broadcast) ||
+ broadcastFromTransferReadSupportsMMAMatrixType(broadcast, useNvGpu);
+ }
return elementwiseSupportsMMAMatrixType(op);
}
@@ -264,17 +308,20 @@ static SetVector<Operation *> getSliceContract(Operation *op,
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentOp = (slice)[currentIndex];
- // Compute and insert the backwardSlice starting from currentOp.
+ // Compute and insert the backwardSlice starting from
+ // currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
slice.insert(backwardSlice.begin(), backwardSlice.end());
- // Compute and insert the forwardSlice starting from currentOp.
+ // Compute and insert the forwardSlice starting from
+ // currentOp.
forwardSlice.clear();
- // Special case for ForOp, we don't want to include the whole region but
- // only the value using the region arguments.
- // TODO: We should refine this to only care about the region arguments being
- // converted to matrix type.
+ // Special case for ForOp, we don't want to include the
+ // whole region but only the value using the region
+ // arguments.
+ // TODO: We should refine this to only care about the
+ // region arguments being converted to matrix type.
if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
for (Value forOpResult : forOp.getResults())
getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
@@ -307,16 +354,20 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
return;
SetVector<Operation *> dependentOps =
getSliceContract(contract, hasVectorDest, hasVectorSrc);
- // If any instruction cannot use MMA matrix type drop the whole
- // chain. MMA matrix are stored in an opaque type so they cannot be used
- // by all operations.
+ // If any instruction cannot use MMA matrix type drop the
+ // whole chain. MMA matrix are stored in an opaque type so
+ // they cannot be used by all operations.
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
- return !supportsMMaMatrixType(op, useNvGpu);
+ auto res = !supportsMMaMatrixType(op, useNvGpu);
+ if (res)
+ llvm::errs() << "DOES NOT SUPPORT: " << *op << "\n";
+ return res;
}))
return;
opToConvert.insert(dependentOps.begin(), dependentOps.end());
});
- // Sort the operations so that we can convert them in topological order.
+ // Sort the operations so that we can convert them in
+ // topological order.
return topologicalSort(opToConvert);
}
@@ -443,7 +494,12 @@ static const char *inferFragType(OpTy op) {
static void convertTransferReadOp(vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
- assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
+ if (!transferReadSupportsMMAMatrixType(op,
+ /*useNvGpu=*/false))
+ return;
+ // Only transfers that return 2-D vectors are supported.
+ if (op.getVectorType().getRank() != 2)
+ return;
std::optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
AffineMap map = op.getPermutationMap();
@@ -535,10 +591,11 @@ creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
*warpMatrixInfo,
/*transpose=*/!op.getPermutationMap().isMinorIdentity());
if (failed(params)) {
- return op->emitError()
- << "failed to convert vector.transfer_read to ldmatrix; this op "
- "likely "
- "should not be converted to a nvgpu.ldmatrix call.";
+ return op->emitError() << "failed to convert vector.transfer_read to "
+ "ldmatrix; this op "
+ "likely "
+ "should not be converted to a nvgpu.ldmatrix "
+ "call.";
}
// Adjust the load offset.
@@ -572,7 +629,8 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
op->emitError() << "Failed to deduce register fragment type during "
- "conversion to distributed non-ldmatrix compatible load";
+ "conversion to distributed non-ldmatrix compatible "
+ "load";
return failure();
}
@@ -590,8 +648,8 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
- // If we are not transposing, then we can use vectorized loads. Otherwise, we
- // must load each element individually.
+ // If we are not transposing, then we can use vectorized
+ // loads. Otherwise, we must load each element individually.
if (!isTransposeLoad) {
if (!loadedElType.isa<VectorType>()) {
loadedElType = VectorType::get({1}, loadedElType);
@@ -665,9 +723,9 @@ convertTransferReadToLoads(vector::TransferReadOp op,
VectorType vecTy = op.getVectorType();
int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
- // When we are transposing the B operand, ldmatrix will only work if we have
- // at least 8 rows to read and the width to read for the transpose is 128
- // bits.
+ // When we are transposing the B operand, ldmatrix will only
+ // work if we have at least 8 rows to read and the width to
+ // read for the transpose is 128 bits.
if (!op.getPermutationMap().isMinorIdentity() &&
(bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
vecTy.getDimSize(0) * bitWidth < 128))
@@ -740,7 +798,8 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
if (failed(mmaSyncFragmentInfo))
return failure();
- // Find the vector.transer_read whose result vector is being sliced.
+ // Find the vector.transer_read whose result vector is being
+ // sliced.
auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return failure();
@@ -754,12 +813,13 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
if (failed(ldFragmentInfo))
return failure();
- assert(
- (mmaSyncFragmentInfo->elementsPerRegister ==
- ldFragmentInfo->elementsPerRegister) &&
- "Number of elements per register should be same for load and mma.sync");
+ assert((mmaSyncFragmentInfo->elementsPerRegister ==
+ ldFragmentInfo->elementsPerRegister) &&
+ "Number of elements per register should be same for "
+ "load and mma.sync");
- // Create vector.extract_strided_slice op for thread-owned fragments.
+ // Create vector.extract_strided_slice op for thread-owned
+ // fragments.
std::array<int64_t, 2> strides = {1,
1}; // stride for extract slice is always 1.
std::array<int64_t, 2> sliceShape = {
@@ -775,9 +835,11 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
populateFromInt64AttrArray(op.getSizes(), sizes);
ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
- // Compute offset in vector registers. Note that the mma.sync vector registers
- // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
- // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
+ // Compute offset in vector registers. Note that the mma.sync
+ // vector registers are shaped as numberOfFragments x
+ // numberOfRegistersPerfFragment. The vector registers can
+ // only be sliced along numberOfFragments, i.e.,
+ // sliceOffset[0].
std::array<int64_t, 2> sliceOffset = {0, 0};
if (offsets[0] && offsets[1])
@@ -842,7 +904,10 @@ static void convertConstantOp(arith::ConstantOp op,
/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static void convertBroadcastOp(vector::BroadcastOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
- assert(broadcastSupportsMMAMatrixType(op));
+ // This op only catches the broadcasts that can directly
+ // convert to an MMA op.
+ if (!broadcastSupportsMMAMatrixType(op))
+ return;
OpBuilder b(op);
const char *fragType = inferFragType(op);
auto vecType = op.getVectorType();
@@ -853,11 +918,39 @@ static void convertBroadcastOp(vector::BroadcastOp op,
valueMapping[op.getResult()] = matrix;
}
+/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
+static void
+convertBroadcastFromTransferReadOp(vector::BroadcastOp broadcastOp,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ // This op catches the broadcasts that cannot directly convert to an MMA
+ // op.
+ if (broadcastSupportsMMAMatrixType(broadcastOp))
+ return;
+ if (!broadcastFromTransferReadSupportsMMAMatrixType(broadcastOp,
+ /*useNvGpu=*/false))
+ return;
+ auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
+ assert(readOp && readOp.getVectorType().getRank() == 1);
+ // Handle broadcast by setting the stride to 0, unconditionally.
+ int64_t stride = 0;
+ const char *fragType = inferFragType(readOp);
+ gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
+ broadcastOp.getVectorType().getShape(),
+ broadcastOp.getVectorType().getElementType(), fragType);
+ OpBuilder b(readOp);
+ bool isTranspose = false;
+ Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
+ readOp.getLoc(), type, readOp.getSource(), readOp.getIndices(),
+ b.getIndexAttr(stride), isTranspose ? b.getUnitAttr() : UnitAttr());
+ valueMapping[broadcastOp.getResult()] = load;
+}
+
// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
// updated and needs to be updated separatly for the loop to be correct.
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
ValueRange newIterOperands) {
- // Create a new loop before the existing one, with the extra operands.
+ // Create a new loop before the existing one, with the extra
+ // operands.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getIterOperands());
@@ -912,8 +1005,8 @@ static void convertYieldOp(scf::YieldOp op,
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end())
continue;
- // Replace the yield of old value with the for op argument to make it easier
- // to remove the dead code.
+ // Replace the yield of old value with the for op argument
+ // to make it easier to remove the dead code.
yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
yieldOperands.push_back(it->second);
}
@@ -959,6 +1052,7 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) {
convertConstantOp(constantOp, valueMapping);
} else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
convertBroadcastOp(broadcastOp, valueMapping);
+ convertBroadcastFromTransferReadOp(broadcastOp, valueMapping);
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
convertForOp(forOp, valueMapping);
} else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
@@ -1027,6 +1121,8 @@ struct ConvertVectorToGPUPass
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
+ getOperation()->dump();
+
if (useNvGpu.getValue()) {
if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
return signalPassFailure();
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index fa2a40f7334d..9a0f4c9f0837 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -4,7 +4,6 @@
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-#map4 = affine_map<(d0) -> (d0, 0)>
#map5 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @matmul
@@ -118,6 +117,21 @@ func.func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x1
// CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise divf %[[D]], %[[E]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+// func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
+// %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
+// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+// %c0 = arith.constant 0 : index
+// %cst = arith.constant 0.000000e+00 : f16
+// %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+// %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+// {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
+// : memref<16x16x16x16xf16>, vector<16x16xf16>
+// %F = arith.divf %D, %E : vector<16x16xf16>
+// vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+// return
+// }
func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
%arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
@@ -126,9 +140,10 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
- %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
- {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
- : memref<16x16x16x16xf16>, vector<16x16xf16>
+ %Eread = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+ {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3)->(d3)>}
+ : memref<16x16x16x16xf16>, vector<16xf16>
+ %E = vector.broadcast %Eread: vector<16xf16> to vector<16x16xf16>
%F = arith.divf %D, %E : vector<16x16xf16>
vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
return
@@ -141,12 +156,24 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
+// func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
+// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+// %c0 = arith.constant 0 : index
+// %cst = arith.constant 0.000000e+00 : f16
+// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
+// return
+// }
func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
- %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+ %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
+ %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
%C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
@@ -160,12 +187,24 @@ func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %a
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
+// func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
+// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+// %c0 = arith.constant 0 : index
+// %cst = arith.constant 0.000000e+00 : f16
+// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
+// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
+// return
+// }
func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
- %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+ %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
+ %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
%C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
More information about the Mlir-commits
mailing list