[Mlir-commits] [mlir] b270fbe - [mlir][gpu] Relax MMA load/store to allow vector memref
Lei Zhang
llvmlistbot at llvm.org
Tue Nov 1 08:38:21 PDT 2022
Author: Lei Zhang
Date: 2022-11-01T11:38:14-04:00
New Revision: b270fbe0353275d3e5bc7453d4453daf17d9e053
URL: https://github.com/llvm/llvm-project/commit/b270fbe0353275d3e5bc7453d4453daf17d9e053
DIFF: https://github.com/llvm/llvm-project/commit/b270fbe0353275d3e5bc7453d4453daf17d9e053.diff
LOG: [mlir][gpu] Relax MMA load/store to allow vector memref
This is useful for converting to SPIR-V, where we'd like to have
memref of vector element types.
Reviewed By: ThomasRaoux, bondhugula
Differential Revision: https://reviews.llvm.org/D137143
Added:
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 28bc3190c1450..d5b36e495493d 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -68,6 +68,9 @@ def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">;
def GPU_MMAMatrix : DialectType<
GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
+// Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
+def GPU_MMAMemRef : MemRefOf<[F16, F32, VectorOfRankAndType<[1], [F16, F32]>]>;
+
class MMAMatrixOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
"$_self.cast<::mlir::gpu::MMAMatrixType>().getElementType()",
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 0cf2029a232df..aef31aff45954 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1114,7 +1114,7 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
```
}];
- let arguments = (ins Arg<MemRefOf<[F16, F32]>, "", [MemRead]>:$srcMemref,
+ let arguments = (ins Arg<GPU_MMAMemRef, "", [MemRead]>:$srcMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension);
@@ -1153,7 +1153,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
}];
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
- Arg<MemRefOf<[F16, F32]>, "",[MemWrite]>:$dstMemref,
+ Arg<GPU_MMAMemRef, "",[MemWrite]>:$dstMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension);
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 6390de3e3a10d..7a11acbc2d239 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -515,6 +515,14 @@ func.func @mmaLoadOp_identity_layout(){
// -----
+func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) {
+ // expected-error @+1 {{operand #0 must be memref of 16-bit float or 32-bit float or vector of 16-bit float or 32-bit float values of ranks 1 values}}
+ %0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp">
+ return
+}
+
+// -----
+
#layout_map_col_major = affine_map<(i, j) -> (j, i)>
func.func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 9b31a326aa919..b68a109d34223 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -265,8 +265,8 @@ module attributes {gpu.container_module} {
return
}
- func.func @mmamatrix_valid_element_type(%src : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>>){
- // CHECK-LABEL: func @mmamatrix_valid_element_type
+ func.func @mmamatrix_valid_scalar_element_type(%src : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>>){
+ // CHECK-LABEL: func @mmamatrix_valid_scalar_element_type
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
// CHECK: %[[wg:.*]] = memref.alloca()
%i = arith.constant 16 : index
@@ -285,6 +285,15 @@ module attributes {gpu.container_module} {
return
}
+ // CHECK-LABEL: func @mmamatrix_valid_vector_element_type
+ func.func @mmamatrix_valid_vector_element_type(%src : memref<32x4xvector<4xf32>>, %i : index) {
+ // CHECK: gpu.subgroup_mma_load_matrix
+ %s = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4xf32>> -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: gpu.subgroup_mma_store_matrix
+ gpu.subgroup_mma_store_matrix %s, %src[%i, %i] {leadDimension = 4 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x4xvector<4xf32>>
+ return
+ }
+
// CHECK-LABEL: func @set_default_device
func.func @set_default_device(%arg0: i32) {
// CHECK: gpu.set_default_device
More information about the Mlir-commits
mailing list