[Mlir-commits] [mlir] 3af6438 - Revert "[WIP] Add support for MMA conversion for 1-D vector.transfer followed by a broadcast to 2-D"
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Dec 1 02:57:29 PST 2022
Author: Nicolas Vasilache
Date: 2022-12-01T02:57:03-08:00
New Revision: 3af6438372ad28c3c2c632a67b15fb68f9c3d52b
URL: https://github.com/llvm/llvm-project/commit/3af6438372ad28c3c2c632a67b15fb68f9c3d52b
DIFF: https://github.com/llvm/llvm-project/commit/3af6438372ad28c3c2c632a67b15fb68f9c3d52b.diff
LOG: Revert "[WIP] Add support for MMA conversion for 1-D vector.transfer followed by a broadcast to 2-D"
This reverts commit 7db25f78db807da171f23bcbaff258c5677901d1.
This was mistakently stacked below (and committed) along with an NFC change.
Added:
Modified:
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1da8dc4a872c..2734b5f1660b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -150,26 +150,6 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
return true;
}
-// Return true if the transfer op can be converted to a MMA matrix load.
-static bool transferReadFollowedByBroadcastSupportsMMAMatrixType(
- vector::TransferReadOp readOp, bool useNvGpu) {
- bool res = true;
- if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
- readOp.getVectorType().getRank() != 1)
- res = false;
- if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
- res = false;
- AffineMap map = readOp.getPermutationMap();
- OpBuilder b(readOp.getContext());
-
- if (res && !useNvGpu)
- return map.isMinorIdentity() || isTransposeMatrixLoadMap(b, map);
-
- llvm::errs() << "RES transferReadFollowedByBroadcastSupportsMMAMatrixType: "
- << res << "\n";
- return res;
-}
-
// Return true if the transfer op can be converted to a MMA matrix store.
static bool
transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
@@ -199,27 +179,8 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
/// Return true if this is a broadcast from scalar to a 2D vector.
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
- auto res = broadcastOp.getVectorType().getRank() == 2 &&
- broadcastOp.getSource().getType().isa<FloatType>();
- llvm::errs() << "RES broadcastSupportsMMAMatrixType: " << res << "\n";
- return res;
-}
-
-/// Return true if this is a broadcast from 1-D to a 2-D vector and the 1-D
-/// vector comes from a TransferReadOp.
-static bool
-broadcastFromTransferReadSupportsMMAMatrixType(vector::BroadcastOp broadcastOp,
- bool useNvGpu) {
- auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
- auto sourceVectorType =
- broadcastOp.getSource().getType().dyn_cast<VectorType>();
- auto res =
- !broadcastSupportsMMAMatrixType(broadcastOp) && sourceVectorType &&
- sourceVectorType.getRank() == 1 &&
- transferReadFollowedByBroadcastSupportsMMAMatrixType(readOp, useNvGpu);
- llvm::errs() << "RES broadcastFromTransferReadSupportsMMAMatrixType: " << res
- << "\n";
- return res;
+ return broadcastOp.getVectorType().getRank() == 2 &&
+ broadcastOp.getSource().getType().isa<FloatType>();
}
/// Return the MMA elementwise enum associated with `op` if it is supported.
@@ -258,10 +219,9 @@ extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
if (failed(contractOp))
return false;
- // Handle vector.extract_strided_slice on registers
- // containing matrixB and matrixC operands.
- // vector.extract_strided_slice op is not supported on
- // registers containing matrixA operands.
+ // Handle vector.extract_strided_slice on registers containing
+ // matrixB and matrixC operands. vector.extract_strided_slice op
+ // is not supported on registers containing matrixA operands.
if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
return (op->getResult(0).getType().cast<VectorType>() ==
(*contractOp).getRhs().getType().cast<VectorType>());
@@ -276,9 +236,7 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
if (isa<scf::ForOp, scf::YieldOp>(op))
return true;
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
- return transferReadSupportsMMAMatrixType(transferRead, useNvGpu) ||
- transferReadFollowedByBroadcastSupportsMMAMatrixType(transferRead,
- useNvGpu);
+ return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteSupportsMMAMatrixType(transferWrite);
if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
@@ -288,10 +246,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
return contractSupportsMMAMatrixType(contract, useNvGpu);
if (auto constant = dyn_cast<arith::ConstantOp>(op))
return constantSupportsMMAMatrixType(constant);
- if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
- return broadcastSupportsMMAMatrixType(broadcast) ||
- broadcastFromTransferReadSupportsMMAMatrixType(broadcast, useNvGpu);
- }
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+ return broadcastSupportsMMAMatrixType(broadcast);
return elementwiseSupportsMMAMatrixType(op);
}
@@ -308,20 +264,17 @@ static SetVector<Operation *> getSliceContract(Operation *op,
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentOp = (slice)[currentIndex];
- // Compute and insert the backwardSlice starting from
- // currentOp.
+ // Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
slice.insert(backwardSlice.begin(), backwardSlice.end());
- // Compute and insert the forwardSlice starting from
- // currentOp.
+ // Compute and insert the forwardSlice starting from currentOp.
forwardSlice.clear();
- // Special case for ForOp, we don't want to include the
- // whole region but only the value using the region
- // arguments.
- // TODO: We should refine this to only care about the
- // region arguments being converted to matrix type.
+ // Special case for ForOp, we don't want to include the whole region but
+ // only the value using the region arguments.
+ // TODO: We should refine this to only care about the region arguments being
+ // converted to matrix type.
if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
for (Value forOpResult : forOp.getResults())
getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
@@ -354,20 +307,16 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
return;
SetVector<Operation *> dependentOps =
getSliceContract(contract, hasVectorDest, hasVectorSrc);
- // If any instruction cannot use MMA matrix type drop the
- // whole chain. MMA matrix are stored in an opaque type so
- // they cannot be used by all operations.
+ // If any instruction cannot use MMA matrix type drop the whole
+ // chain. MMA matrix are stored in an opaque type so they cannot be used
+ // by all operations.
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
- auto res = !supportsMMaMatrixType(op, useNvGpu);
- if (res)
- llvm::errs() << "DOES NOT SUPPORT: " << *op << "\n";
- return res;
+ return !supportsMMaMatrixType(op, useNvGpu);
}))
return;
opToConvert.insert(dependentOps.begin(), dependentOps.end());
});
- // Sort the operations so that we can convert them in
- // topological order.
+ // Sort the operations so that we can convert them in topological order.
return topologicalSort(opToConvert);
}
@@ -494,12 +443,7 @@ static const char *inferFragType(OpTy op) {
static void convertTransferReadOp(vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
- if (!transferReadSupportsMMAMatrixType(op,
- /*useNvGpu=*/false))
- return;
- // Only transfers that return 2-D vectors are supported.
- if (op.getVectorType().getRank() != 2)
- return;
+ assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
std::optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
AffineMap map = op.getPermutationMap();
@@ -591,11 +535,10 @@ creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
*warpMatrixInfo,
/*transpose=*/!op.getPermutationMap().isMinorIdentity());
if (failed(params)) {
- return op->emitError() << "failed to convert vector.transfer_read to "
- "ldmatrix; this op "
- "likely "
- "should not be converted to a nvgpu.ldmatrix "
- "call.";
+ return op->emitError()
+ << "failed to convert vector.transfer_read to ldmatrix; this op "
+ "likely "
+ "should not be converted to a nvgpu.ldmatrix call.";
}
// Adjust the load offset.
@@ -629,8 +572,7 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
op->emitError() << "Failed to deduce register fragment type during "
- "conversion to distributed non-ldmatrix compatible "
- "load";
+ "conversion to distributed non-ldmatrix compatible load";
return failure();
}
@@ -648,8 +590,8 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
- // If we are not transposing, then we can use vectorized
- // loads. Otherwise, we must load each element individually.
+ // If we are not transposing, then we can use vectorized loads. Otherwise, we
+ // must load each element individually.
if (!isTransposeLoad) {
if (!loadedElType.isa<VectorType>()) {
loadedElType = VectorType::get({1}, loadedElType);
@@ -723,9 +665,9 @@ convertTransferReadToLoads(vector::TransferReadOp op,
VectorType vecTy = op.getVectorType();
int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
- // When we are transposing the B operand, ldmatrix will only
- // work if we have at least 8 rows to read and the width to
- // read for the transpose is 128 bits.
+ // When we are transposing the B operand, ldmatrix will only work if we have
+ // at least 8 rows to read and the width to read for the transpose is 128
+ // bits.
if (!op.getPermutationMap().isMinorIdentity() &&
(bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
vecTy.getDimSize(0) * bitWidth < 128))
@@ -798,8 +740,7 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
if (failed(mmaSyncFragmentInfo))
return failure();
- // Find the vector.transer_read whose result vector is being
- // sliced.
+ // Find the vector.transer_read whose result vector is being sliced.
auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return failure();
@@ -813,13 +754,12 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
if (failed(ldFragmentInfo))
return failure();
- assert((mmaSyncFragmentInfo->elementsPerRegister ==
- ldFragmentInfo->elementsPerRegister) &&
- "Number of elements per register should be same for "
- "load and mma.sync");
+ assert(
+ (mmaSyncFragmentInfo->elementsPerRegister ==
+ ldFragmentInfo->elementsPerRegister) &&
+ "Number of elements per register should be same for load and mma.sync");
- // Create vector.extract_strided_slice op for thread-owned
- // fragments.
+ // Create vector.extract_strided_slice op for thread-owned fragments.
std::array<int64_t, 2> strides = {1,
1}; // stride for extract slice is always 1.
std::array<int64_t, 2> sliceShape = {
@@ -835,11 +775,9 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
populateFromInt64AttrArray(op.getSizes(), sizes);
ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
- // Compute offset in vector registers. Note that the mma.sync
- // vector registers are shaped as numberOfFragments x
- // numberOfRegistersPerfFragment. The vector registers can
- // only be sliced along numberOfFragments, i.e.,
- // sliceOffset[0].
+ // Compute offset in vector registers. Note that the mma.sync vector registers
+ // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
+ // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
std::array<int64_t, 2> sliceOffset = {0, 0};
if (offsets[0] && offsets[1])
@@ -904,10 +842,7 @@ static void convertConstantOp(arith::ConstantOp op,
/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static void convertBroadcastOp(vector::BroadcastOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
- // This op only catches the broadcasts that can directly
- // convert to an MMA op.
- if (!broadcastSupportsMMAMatrixType(op))
- return;
+ assert(broadcastSupportsMMAMatrixType(op));
OpBuilder b(op);
const char *fragType = inferFragType(op);
auto vecType = op.getVectorType();
@@ -918,39 +853,11 @@ static void convertBroadcastOp(vector::BroadcastOp op,
valueMapping[op.getResult()] = matrix;
}
-/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
-static void
-convertBroadcastFromTransferReadOp(vector::BroadcastOp broadcastOp,
- llvm::DenseMap<Value, Value> &valueMapping) {
- // This op catches the broadcasts that cannot directly convert to an MMA
- // op.
- if (broadcastSupportsMMAMatrixType(broadcastOp))
- return;
- if (!broadcastFromTransferReadSupportsMMAMatrixType(broadcastOp,
- /*useNvGpu=*/false))
- return;
- auto readOp = broadcastOp.getSource().getDefiningOp<vector::TransferReadOp>();
- assert(readOp && readOp.getVectorType().getRank() == 1);
- // Handle broadcast by setting the stride to 0, unconditionally.
- int64_t stride = 0;
- const char *fragType = inferFragType(readOp);
- gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
- broadcastOp.getVectorType().getShape(),
- broadcastOp.getVectorType().getElementType(), fragType);
- OpBuilder b(readOp);
- bool isTranspose = false;
- Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
- readOp.getLoc(), type, readOp.getSource(), readOp.getIndices(),
- b.getIndexAttr(stride), isTranspose ? b.getUnitAttr() : UnitAttr());
- valueMapping[broadcastOp.getResult()] = load;
-}
-
// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
// updated and needs to be updated separatly for the loop to be correct.
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
ValueRange newIterOperands) {
- // Create a new loop before the existing one, with the extra
- // operands.
+ // Create a new loop before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getIterOperands());
@@ -1005,8 +912,8 @@ static void convertYieldOp(scf::YieldOp op,
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end())
continue;
- // Replace the yield of old value with the for op argument
- // to make it easier to remove the dead code.
+ // Replace the yield of old value with the for op argument to make it easier
+ // to remove the dead code.
yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
yieldOperands.push_back(it->second);
}
@@ -1052,7 +959,6 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) {
convertConstantOp(constantOp, valueMapping);
} else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
convertBroadcastOp(broadcastOp, valueMapping);
- convertBroadcastFromTransferReadOp(broadcastOp, valueMapping);
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
convertForOp(forOp, valueMapping);
} else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
@@ -1121,8 +1027,6 @@ struct ConvertVectorToGPUPass
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
- getOperation()->dump();
-
if (useNvGpu.getValue()) {
if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
return signalPassFailure();
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 9a0f4c9f0837..fa2a40f7334d 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -4,6 +4,7 @@
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
#map5 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @matmul
@@ -117,21 +118,6 @@ func.func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x1
// CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise divf %[[D]], %[[E]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
-// func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
-// %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
-// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
-// %c0 = arith.constant 0 : index
-// %cst = arith.constant 0.000000e+00 : f16
-// %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
-// %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
-// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-// %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
-// {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
-// : memref<16x16x16x16xf16>, vector<16x16xf16>
-// %F = arith.divf %D, %E : vector<16x16xf16>
-// vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
-// return
-// }
func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
%arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
@@ -140,10 +126,9 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
- %Eread = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
- {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3)->(d3)>}
- : memref<16x16x16x16xf16>, vector<16xf16>
- %E = vector.broadcast %Eread: vector<16xf16> to vector<16x16xf16>
+ %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
+ {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
+ : memref<16x16x16x16xf16>, vector<16x16xf16>
%F = arith.divf %D, %E : vector<16x16xf16>
vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
return
@@ -156,24 +141,12 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
-// func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
-// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
-// %c0 = arith.constant 0 : index
-// %cst = arith.constant 0.000000e+00 : f16
-// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
-// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
-// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
-// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
-// return
-// }
func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
- %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
- %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
%C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
@@ -187,24 +160,12 @@ func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %a
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
-// func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
-// %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
-// %c0 = arith.constant 0 : index
-// %cst = arith.constant 0.000000e+00 : f16
-// %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
-// %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0, 0)>, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
-// %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
-// %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
-// vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
-// return
-// }
func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
- %Bread = vector.transfer_read %arg1[%c0], %cst {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : memref<16xf16>, vector<16xf16>
- %B = vector.broadcast %Bread: vector<16xf16> to vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
%C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
More information about the Mlir-commits
mailing list