[Mlir-commits] [mlir] cfb955a - [mlir][spirv] Relax restriction on pointer type for CooperativeMatrix load/store

Thomas Raoux llvmlistbot at llvm.org
Fri Jul 31 08:16:30 PDT 2020


Author: Thomas Raoux
Date: 2020-07-31T08:02:21-07:00
New Revision: cfb955ac370cb724c51423a05694aaf5b70903a4

URL: https://github.com/llvm/llvm-project/commit/cfb955ac370cb724c51423a05694aaf5b70903a4
DIFF: https://github.com/llvm/llvm-project/commit/cfb955ac370cb724c51423a05694aaf5b70903a4.diff

LOG: [mlir][spirv] Relax restriction on pointer type for CooperativeMatrix load/store

This change allow CooperativeMatrix Load/Store operations to use pointer type
that may not match the matrix element type. This allow us to declare buffer
with a larger type size than the matrix element type. This follows SPIR-V spec
and this is needed to be able to use cooperative matrix in combination with
shared local memory efficiently.

Differential Revision: https://reviews.llvm.org/D84993

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
    mlir/test/Dialect/SPIRV/cooperative-matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
index 9c3462a2e5bf..720cfd697c24 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
@@ -101,16 +101,17 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
 
     ``` {.ebnf}
     cooperative-matrixload-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
-                              storage-class ssa-use `,` ssa-use `,` ssa-use
+                              ssa-use `,` ssa-use `,` ssa-use
                               (`[` memory-access `]`)? ` : `
+                              pointer-type `as`
                               cooperative-matrix-type
     ```
 
     For example:
 
     ```
-    %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor
-         : !spv.coopmatrix<i32, Workgroup, 16, 8>
+    %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %colMajor
+         : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<i32, Workgroup, 16, 8>
     ```
   }];
 
@@ -243,16 +244,17 @@ def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
 
     ``` {.ebnf}
     coop-matrix-store-op ::= `spv.CooperativeMatrixStoreNV `
-                              storage-class ssa-use `, ` ssa-use `, `
                               ssa-use `, ` ssa-use `, `
-                  (`[` memory-access `]`)? `:` spirv-element-type
+                              ssa-use `, ` ssa-use `, `
+                              (`[` memory-access `]`)? `:`
+                              pointer-type `,` spirv-element-type
     ```
 
     For example:
 
     ```
-      spv.CooperativeMatrixStoreNV "StorageBuffer" %arg0, %arg2, %arg1, %arg3 :
-        !spv.coopmatrix<Workgroup, i32, 16, 8>
+      spv.CooperativeMatrixStoreNV %arg0, %arg2, %arg1, %arg3 :
+        !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<Workgroup, i32, 16, 8>
     ```
   }];
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index b0235d419ebe..bac65a02f63d 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2793,21 +2793,16 @@ static LogicalResult verify(spirv::VariableOp varOp) {
 
 static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
                                                   OperationState &state) {
-  spirv::StorageClass storageClass;
   SmallVector<OpAsmParser::OperandType, 3> operandInfo;
   Type strideType = parser.getBuilder().getIntegerType(32);
   Type columnMajorType = parser.getBuilder().getIntegerType(1);
+  Type ptrType;
   Type elementType;
-  if (parseEnumStrAttr(storageClass, parser) ||
-      parser.parseOperandList(operandInfo, 3) ||
+  if (parser.parseOperandList(operandInfo, 3) ||
       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
-      parser.parseType(elementType)) {
+      parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
     return failure();
   }
-
-  auto ptrType = spirv::PointerType::get(
-      elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
-      storageClass);
   SmallVector<Type, 3> OperandType = {ptrType, strideType, columnMajorType};
   if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
                              state.operands)) {
@@ -2819,25 +2814,30 @@ static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
 }
 
 static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
-  StringRef sc = stringifyStorageClass(
-      M.pointer().getType().cast<spirv::PointerType>().getStorageClass());
-  printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc
-          << "\" " << M.pointer() << ", " << M.stride() << ", "
-          << M.columnmajor();
+  printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " "
+          << M.pointer() << ", " << M.stride() << ", " << M.columnmajor();
   // Print optional memory access attribute.
   if (auto memAccess = M.memory_access())
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
-  printer << " : " << M.getType();
+  printer << " : " << M.pointer().getType() << " as " << M.getType();
 }
 
 static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
                                                     Type coopMatrix) {
-  if (pointer.cast<spirv::PointerType>().getPointeeType() !=
-      coopMatrix.cast<spirv::CooperativeMatrixNVType>().getElementType())
+  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
+  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
     return op->emitError(
-               "expected the same type for pointer and the cooperative matrix"
-               "element, bu provided ")
-           << pointer << " and " << coopMatrix;
+               "Pointer must point to a scalar or vector type but provided ")
+           << pointeeType;
+  spirv::StorageClass storage =
+      pointer.cast<spirv::PointerType>().getStorageClass();
+  if (storage != spirv::StorageClass::Workgroup &&
+      storage != spirv::StorageClass::StorageBuffer &&
+      storage != spirv::StorageClass::PhysicalStorageBuffer)
+    return op->emitError(
+               "Pointer storage class must be Workgroup, StorageBuffer or "
+               "PhysicalStorageBufferEXT but provided ")
+           << stringifyStorageClass(storage);
   return success();
 }
 
@@ -2847,21 +2847,17 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
 
 static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
                                                    OperationState &state) {
-  spirv::StorageClass storageClass;
   SmallVector<OpAsmParser::OperandType, 4> operandInfo;
   Type strideType = parser.getBuilder().getIntegerType(32);
   Type columnMajorType = parser.getBuilder().getIntegerType(1);
+  Type ptrType;
   Type elementType;
-  if (parseEnumStrAttr(storageClass, parser) ||
-      parser.parseOperandList(operandInfo, 4) ||
+  if (parser.parseOperandList(operandInfo, 4) ||
       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
+      parser.parseType(ptrType) || parser.parseComma() ||
       parser.parseType(elementType)) {
     return failure();
   }
-
-  auto ptrType = spirv::PointerType::get(
-      elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
-      storageClass);
   SmallVector<Type, 4> OperandType = {ptrType, elementType, strideType,
                                       columnMajorType};
   if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
@@ -2874,17 +2870,14 @@ static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
 
 static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
                   OpAsmPrinter &printer) {
-  StringRef sc = stringifyStorageClass(coopMatrix.pointer()
-                                           .getType()
-                                           .cast<spirv::PointerType>()
-                                           .getStorageClass());
-  printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " \""
-          << sc << "\" " << coopMatrix.pointer() << ", " << coopMatrix.object()
-          << ", " << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
+  printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " "
+          << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
+          << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
   // Print optional memory access attribute.
   if (auto memAccess = coopMatrix.memory_access())
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
-  printer << " : " << coopMatrix.getOperand(1).getType();
+  printer << " : " << coopMatrix.pointer().getType() << ", "
+          << coopMatrix.getOperand(1).getType();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
index ad913dfb1624..3f8c4bff4738 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
@@ -3,29 +3,29 @@
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
   // CHECK-LABEL: @cooperative_matrix_load
   spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
-    %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
+    %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
     spv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_load_memaccess
   spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
-    %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+    %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
     spv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_store
   spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
-    // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
-    spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+    // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
+    spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
     spv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_store_memaccess
   spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-    // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
-    spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+    // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
     spv.Return
   }
 

diff  --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
index 523bd6bfb030..f0bb50d10f58 100644
--- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -2,30 +2,37 @@
 
 // CHECK-LABEL: @cooperative_matrix_load
 spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
-  %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
   spv.Return
 }
 
 // -----
 // CHECK-LABEL: @cooperative_matrix_load_memaccess
 spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
-  %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_
diff _ptr_type
+spv.func @cooperative_matrix_load_
diff _ptr_type(%ptr : !spv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
   spv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_store
 spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
-  // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Workgroup>
-  spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<8x16xi32, Workgroup>
+  // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
+  spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
   spv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_store_memaccess
 spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-  // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
-  spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+  // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
   spv.Return
 }
 
@@ -134,3 +141,18 @@ spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b
   spv.Return
 }
 
+// -----
+
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
+  // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, Function> as !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}


        


More information about the Mlir-commits mailing list