[Mlir-commits] [mlir] b270fbe - [mlir][gpu] Relax MMA load/store to allow vector memref

Lei Zhang llvmlistbot at llvm.org
Tue Nov 1 08:38:21 PDT 2022


Author: Lei Zhang
Date: 2022-11-01T11:38:14-04:00
New Revision: b270fbe0353275d3e5bc7453d4453daf17d9e053

URL: https://github.com/llvm/llvm-project/commit/b270fbe0353275d3e5bc7453d4453daf17d9e053
DIFF: https://github.com/llvm/llvm-project/commit/b270fbe0353275d3e5bc7453d4453daf17d9e053.diff

LOG: [mlir][gpu] Relax MMA load/store to allow vector memref

This is useful for converting to SPIR-V, where we'd like to have
memref of vector element types.

Reviewed By: ThomasRaoux, bondhugula

Differential Revision: https://reviews.llvm.org/D137143

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 28bc3190c1450..d5b36e495493d 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -68,6 +68,9 @@ def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">;
 def GPU_MMAMatrix : DialectType<
   GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
 
+// Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
+def GPU_MMAMemRef : MemRefOf<[F16, F32, VectorOfRankAndType<[1], [F16, F32]>]>;
+
 class MMAMatrixOf<list<Type> allowedTypes> :
   ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
   "$_self.cast<::mlir::gpu::MMAMatrixType>().getElementType()",

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 0cf2029a232df..aef31aff45954 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1114,7 +1114,7 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
     ```
   }];
 
-  let arguments = (ins Arg<MemRefOf<[F16, F32]>, "", [MemRead]>:$srcMemref,
+  let arguments = (ins Arg<GPU_MMAMemRef, "", [MemRead]>:$srcMemref,
                   Variadic<Index>:$indices,
                   IndexAttr:$leadDimension);
 
@@ -1153,7 +1153,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
   }];
 
   let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
-                  Arg<MemRefOf<[F16, F32]>, "",[MemWrite]>:$dstMemref,
+                  Arg<GPU_MMAMemRef, "",[MemWrite]>:$dstMemref,
                   Variadic<Index>:$indices,
                   IndexAttr:$leadDimension);
 

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 6390de3e3a10d..7a11acbc2d239 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -515,6 +515,14 @@ func.func @mmaLoadOp_identity_layout(){
 
 // -----
 
+func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) {
+    // expected-error @+1 {{operand #0 must be memref of 16-bit float or 32-bit float or vector of 16-bit float or 32-bit float values of ranks 1 values}}
+    %0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    return
+}
+
+// -----
+
 #layout_map_col_major = affine_map<(i, j) -> (j, i)>
 
 func.func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 9b31a326aa919..b68a109d34223 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -265,8 +265,8 @@ module attributes {gpu.container_module} {
     return
   }
 
-  func.func @mmamatrix_valid_element_type(%src : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>>){
-    // CHECK-LABEL: func @mmamatrix_valid_element_type
+  func.func @mmamatrix_valid_scalar_element_type(%src : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>>){
+    // CHECK-LABEL: func @mmamatrix_valid_scalar_element_type
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
     // CHECK: %[[wg:.*]] = memref.alloca()
     %i = arith.constant 16 : index
@@ -285,6 +285,15 @@ module attributes {gpu.container_module} {
     return
   }
 
+  // CHECK-LABEL: func @mmamatrix_valid_vector_element_type
+  func.func @mmamatrix_valid_vector_element_type(%src : memref<32x4xvector<4xf32>>, %i : index) {
+    // CHECK: gpu.subgroup_mma_load_matrix
+    %s = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4xf32>> -> !gpu.mma_matrix<16x16xf16, "COp">
+    // CHECK: gpu.subgroup_mma_store_matrix
+    gpu.subgroup_mma_store_matrix %s, %src[%i, %i] {leadDimension = 4 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x4xvector<4xf32>>
+    return
+  }
+
   // CHECK-LABEL: func @set_default_device
   func.func @set_default_device(%arg0: i32) {
     // CHECK: gpu.set_default_device


        


More information about the Mlir-commits mailing list