[Mlir-commits] [mlir] d77f483 - [mlir][gpu] Relax restriction on mma load/store op

Thomas Raoux llvmlistbot at llvm.org
Thu Mar 24 21:04:09 PDT 2022


Author: Thomas Raoux
Date: 2022-03-25T04:03:40Z
New Revision: d77f4836401e492fa4aee549d6f677b6ab336a81

URL: https://github.com/llvm/llvm-project/commit/d77f4836401e492fa4aee549d6f677b6ab336a81
DIFF: https://github.com/llvm/llvm-project/commit/d77f4836401e492fa4aee549d6f677b6ab336a81.diff

LOG: [mlir][gpu] Relax restriction on mma load/store op

Those ops can support more complex layout as long as the most inner
dimension is contiguous.

Differential Revision: https://reviews.llvm.org/D122452

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index e3dbfc9b86edc..196e4f6c2a952 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -65,7 +65,8 @@ getMemrefConstantHorizontalStride(ShapedType type) {
     return 0;
   int64_t offset = 0;
   SmallVector<int64_t, 2> strides;
-  if (failed(getStridesAndOffset(memrefType, strides, offset)))
+  if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
+      strides.back() != 1)
     return llvm::None;
   int64_t stride = strides[strides.size() - 2];
   if (stride == ShapedType::kDynamicStrideOrOffset)

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 990a2dc40de04..d351abc883d97 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1068,6 +1068,17 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
 // GPU_SubgroupMmaLoadMatrixOp
 //===----------------------------------------------------------------------===//
 
+/// Return true if the last dimension of the MemRefType has unit stride. Also
+/// return true for memrefs with no strides.
+static bool isLastMemrefDimUnitStride(MemRefType type) {
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  if (failed(getStridesAndOffset(type, strides, offset))) {
+    return false;
+  }
+  return strides.back() == 1;
+}
+
 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
   auto srcType = srcMemref().getType();
   auto resType = res().getType();
@@ -1076,8 +1087,9 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
   auto srcMemrefType = srcType.cast<MemRefType>();
   auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
 
-  if (!srcMemrefType.getLayout().isIdentity())
-    return emitError("expected identity layout map for source memref");
+  if (!isLastMemrefDimUnitStride(srcMemrefType))
+    return emitError(
+        "expected source memref most minor dim must have unit stride");
 
   if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
       srcMemSpace != kGlobalMemorySpace)
@@ -1102,8 +1114,10 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
   auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
   auto dstMemrefType = dstType.cast<MemRefType>();
   auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
-  if (!dstMemrefType.getLayout().isIdentity())
-    return emitError("expected identity layout map for destination memref");
+
+  if (!isLastMemrefDimUnitStride(dstMemrefType))
+    return emitError(
+        "expected destination memref most minor dim must have unit stride");
 
   if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
       dstMemSpace != kGlobalMemorySpace)
@@ -1232,15 +1246,6 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // GPU_DeviceAsyncCopyOp
 //===----------------------------------------------------------------------===//
 
-/// Return true if the last dimension of the MemRefType has unit stride. Also
-/// return true for memrefs with no strides.
-static bool isLastMemrefDimUnitStride(MemRefType type) {
-  int64_t offset;
-  SmallVector<int64_t> strides;
-  auto successStrides = getStridesAndOffset(type, strides, offset);
-  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
-}
-
 LogicalResult DeviceAsyncCopyOp::verify() {
   auto srcMemref = src().getType().cast<MemRefType>();
   auto dstMemref = dst().getType().cast<MemRefType>();

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index bb309c5363421..4fe710b796963 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -151,3 +151,22 @@ func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2:
   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
   return
 }
+
+// CHECK-LABEL: func @matmul_memref_strided
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 32 : index} : memref<2x16x16xf16, #{{.*}}> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]]] {leadDimension = 0 : index} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16>
+func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
+  return
+}

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index b30efdecef4cb..a6fd6dbb775d9 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -491,7 +491,7 @@ func @mmamatrix_invalid_element_type(){
 func @mmaLoadOp_identity_layout(){
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
     %i = arith.constant 16 : index
-    // expected-error @+1 {{expected identity layout map for source memref}}
+    // expected-error @+1 {{expected source memref most minor dim must have unit stride}}
     %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, #layout_map_col_major, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
     return
 }
@@ -514,7 +514,7 @@ func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
     %i = arith.constant 16 : index
     %j = arith.constant 16 : index
-    // expected-error @+1 {{expected identity layout map for destination memref}}
+    // expected-error @+1 {{expected destination memref most minor dim must have unit stride}}
     gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,#layout_map_col_major, 3>
     return
 }

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 794f3c18f918c..140a806cfc92e 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -227,7 +227,7 @@ module attributes {gpu.container_module} {
     return
   }
 
-  func @mmamatrix_valid_element_type(){
+  func @mmamatrix_valid_element_type(%src : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>>){
     // CHECK-LABEL: func @mmamatrix_valid_element_type
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
     // CHECK: %[[wg:.*]] = memref.alloca()
@@ -237,6 +237,8 @@ module attributes {gpu.container_module} {
     // CHECK: %[[cst:.*]] = arith.constant 1.000000e+00 : f32
     %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
     // CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %s = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 64 : index} : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    // CHECK: gpu.subgroup_mma_load_matrix %{{.*}}[%[[i]], %[[i]]] {leadDimension = 64 : index} : memref<32x32xf16, #{{.*}}> -> !gpu.mma_matrix<16x16xf16, "AOp">
     %1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp">
     // CHECK: gpu.subgroup_mma_elementwise addf %{{.*}}, %{{.*}} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
     %2 = gpu.subgroup_mma_elementwise addf %1, %1 : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">


        


More information about the Mlir-commits mailing list