[Mlir-commits] [mlir] 9d273c0 - [mlir] Harden verifiers for DMA ops

Alex Zinenko llvmlistbot at llvm.org
Tue May 5 11:40:52 PDT 2020


Author: Alex Zinenko
Date: 2020-05-05T20:40:41+02:00
New Revision: 9d273c0ef032445402c332ff7c896d661fca5747

URL: https://github.com/llvm/llvm-project/commit/9d273c0ef032445402c332ff7c896d661fca5747
DIFF: https://github.com/llvm/llvm-project/commit/9d273c0ef032445402c332ff7c896d661fca5747.diff

LOG: [mlir] Harden verifiers for DMA ops

DMA operation classes in the Standard dialect (`DmaStartOp` and `DmaWaitOp`)
provide helper functions that make numerous assumptions about the number and
order of operands, and about their types. However, these assumptions were not
checked in the verifier, leading to assertion failures or crashes when helper
functions were used on ill-formed ops. Some of the assuptions were checked in
the custom parser (and thus could not check assumption violations in ops
constructed programmatically, e.g., during rewrites) and others were not
checked at all. Introduce the verifiers for all these assumptions and drop
unnecessary checks in the parser that are now covered by the verifier.

Addresses PR45560.

Differential Revision: https://reviews.llvm.org/D79408

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 51228d3e8437..573f9b7c988f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -286,6 +286,7 @@ class DmaWaitOp
   void print(OpAsmPrinter &p);
   LogicalResult fold(ArrayRef<Attribute> cstOperands,
                      SmallVectorImpl<OpFoldResult> &results);
+  LogicalResult verify();
 };
 
 /// Prints dimension and symbol list.

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 8ef24e239152..972a37d20f97 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1444,49 +1444,82 @@ ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
     return failure();
 
-  auto memrefType0 = types[0].dyn_cast<MemRefType>();
-  if (!memrefType0)
-    return parser.emitError(parser.getNameLoc(),
-                            "expected source to be of memref type");
-
-  auto memrefType1 = types[1].dyn_cast<MemRefType>();
-  if (!memrefType1)
-    return parser.emitError(parser.getNameLoc(),
-                            "expected destination to be of memref type");
-
-  auto memrefType2 = types[2].dyn_cast<MemRefType>();
-  if (!memrefType2)
-    return parser.emitError(parser.getNameLoc(),
-                            "expected tag to be of memref type");
-
   if (isStrided) {
     if (parser.resolveOperands(strideInfo, indexType, result.operands))
       return failure();
   }
 
-  // Check that source/destination index list size matches associated rank.
-  if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() ||
-      static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank())
-    return parser.emitError(parser.getNameLoc(),
-                            "memref rank not equal to indices count");
-  if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank())
-    return parser.emitError(parser.getNameLoc(),
-                            "tag memref rank not equal to indices count");
 
   return success();
 }
 
 LogicalResult DmaStartOp::verify() {
+  unsigned numOperands = getNumOperands();
+
+  // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
+  // the number of elements.
+  if (numOperands < 4)
+    return emitOpError("expected at least 4 operands");
+
+  // Check types of operands. The order of these calls is important: the later
+  // calls rely on some type properties to compute the operand position.
+  // 1. Source memref.
+  if (!getSrcMemRef().getType().isa<MemRefType>())
+    return emitOpError("expected source to be of memref type");
+  if (numOperands < getSrcMemRefRank() + 4)
+    return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
+                         << " operands";
+  if (!getSrcIndices().empty() &&
+      !llvm::all_of(getSrcIndices().getTypes(),
+                    [](Type t) { return t.isIndex(); }))
+    return emitOpError("expected source indices to be of index type");
+
+  // 2. Destination memref.
+  if (!getDstMemRef().getType().isa<MemRefType>())
+    return emitOpError("expected destination to be of memref type");
+  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
+  if (numOperands < numExpectedOperands)
+    return emitOpError() << "expected at least " << numExpectedOperands
+                         << " operands";
+  if (!getDstIndices().empty() &&
+      !llvm::all_of(getDstIndices().getTypes(),
+                    [](Type t) { return t.isIndex(); }))
+    return emitOpError("expected destination indices to be of index type");
+
+  // 3. Number of elements.
+  if (!getNumElements().getType().isIndex())
+    return emitOpError("expected num elements to be of index type");
+
+  // 4. Tag memref.
+  if (!getTagMemRef().getType().isa<MemRefType>())
+    return emitOpError("expected tag to be of memref type");
+  numExpectedOperands += getTagMemRefRank();
+  if (numOperands < numExpectedOperands)
+    return emitOpError() << "expected at least " << numExpectedOperands
+                         << " operands";
+  if (!getTagIndices().empty() &&
+      !llvm::all_of(getTagIndices().getTypes(),
+                    [](Type t) { return t.isIndex(); }))
+    return emitOpError("expected tag indices to be of index type");
+
   // DMAs from 
diff erent memory spaces supported.
   if (getSrcMemorySpace() == getDstMemorySpace())
     return emitOpError("DMA should be between 
diff erent memory spaces");
 
-  if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
-                              getDstMemRefRank() + 3 + 1 &&
-      getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
-                              getDstMemRefRank() + 3 + 1 + 2) {
+  // Optional stride-related operands must be either both present or both
+  // absent.
+  if (numOperands != numExpectedOperands &&
+      numOperands != numExpectedOperands + 2)
     return emitOpError("incorrect number of operands");
+
+  // 5. Strides.
+  if (isStrided()) {
+    if (!getStride().getType().isIndex() ||
+        !getNumElementsPerStride().getType().isIndex())
+      return emitOpError(
+          "expected stride and num elements per stride to be of type index");
   }
+
   return success();
 }
 
@@ -1536,15 +1569,6 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.resolveOperand(numElementsInfo, indexType, result.operands))
     return failure();
 
-  auto memrefType = type.dyn_cast<MemRefType>();
-  if (!memrefType)
-    return parser.emitError(parser.getNameLoc(),
-                            "expected tag to be of memref type");
-
-  if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank())
-    return parser.emitError(parser.getNameLoc(),
-                            "tag memref rank not equal to indices count");
-
   return success();
 }
 
@@ -1554,6 +1578,32 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
   return foldMemRefCast(*this);
 }
 
+LogicalResult DmaWaitOp::verify() {
+  // Mandatory non-variadic operands are tag and the number of elements.
+  if (getNumOperands() < 2)
+    return emitOpError() << "expected at least 2 operands";
+
+  // Check types of operands. The order of these calls is important: the later
+  // calls rely on some type properties to compute the operand position.
+  if (!getTagMemRef().getType().isa<MemRefType>())
+    return emitOpError() << "expected tag to be of memref type";
+
+  if (getNumOperands() != 2 + getTagMemRefRank())
+    return emitOpError() << "expected " << 2 + getTagMemRefRank()
+                         << " operands";
+
+  if (!getTagIndices().empty() &&
+      !llvm::all_of(getTagIndices().getTypes(),
+                    [](Type t) { return t.isIndex(); }))
+    return emitOpError() << "expected tag indices to be of index type";
+
+  if (!getNumElements().getType().isIndex())
+    return emitOpError()
+           << "expected the number of elements to be of index type";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractElementOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 2145c1bbc172..2a14c3ae6c41 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -303,6 +303,13 @@ func @invalid_cmp_shape(%idx : () -> ()) {
 
 // -----
 
+func @dma_start_not_enough_operands() {
+  // expected-error at +1 {{expected at least 4 operands}}
+  "std.dma_start"() : () -> ()
+}
+
+// -----
+
 func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
   // expected-error at +1 {{expected source to be of memref type}}
   dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
@@ -310,6 +317,24 @@ func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
 
 // -----
 
+func @dma_start_not_enough_operands_for_src(
+    %src: memref<2x2x2xf32>, %idx: index) {
+  // expected-error at +1 {{expected at least 7 operands}}
+  "std.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
+}
+
+// -----
+
+func @dma_start_src_index_wrong_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>, %flt: f32) {
+  // expected-error at +1 {{expected source indices to be of index type}}
+  "std.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
+      : (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
+}
+
+// -----
+
 func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
   %mref = alloc() : memref<8 x f32>
   // expected-error at +1 {{expected destination to be of memref type}}
@@ -318,6 +343,36 @@ func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
 
 // -----
 
+func @dma_start_not_enough_operands_for_dst(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>) {
+  // expected-error at +1 {{expected at least 7 operands}}
+  "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
+}
+
+// -----
+
+func @dma_start_dst_index_wrong_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>, %flt: f32) {
+  // expected-error at +1 {{expected destination indices to be of index type}}
+  "std.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
+}
+
+// -----
+
+func @dma_start_dst_index_wrong_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>, %flt: f32) {
+  // expected-error at +1 {{expected num elements to be of index type}}
+  "std.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
+}
+
+// -----
+
 func @dma_no_tag_memref(%tag : f32, %c0 : index) {
   %mref = alloc() : memref<8 x f32>
   // expected-error at +1 {{expected tag to be of memref type}}
@@ -326,9 +381,80 @@ func @dma_no_tag_memref(%tag : f32, %c0 : index) {
 
 // -----
 
+func @dma_start_not_enough_operands_for_tag(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<2xi32,2>) {
+  // expected-error at +1 {{expected at least 8 operands}}
+  "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
+}
+
+// -----
+
+func @dma_start_dst_index_wrong_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<2xi32,2>, %flt: f32) {
+  // expected-error at +1 {{expected tag indices to be of index type}}
+  "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
+}
+
+// -----
+
+func @dma_start_same_space(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32>,
+    %tag: memref<i32,2>) {
+  // expected-error at +1 {{DMA should be between 
diff erent memory spaces}}
+  dma_start %src[%idx, %idx], %dst[%idx], %idx, %tag[] : memref<2x2xf32>, memref<2xf32>, memref<i32,2>
+}
+
+// -----
+
+func @dma_start_too_many_operands(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>) {
+  // expected-error at +1 {{incorrect number of operands}}
+  "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
+}
+
+
+// -----
+
+func @dma_start_wrong_stride_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>, %flt: f32) {
+  // expected-error at +1 {{expected stride and num elements per stride to be of type index}}
+  "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
+}
+
+// -----
+
+func @dma_wait_not_enough_operands() {
+  // expected-error at +1 {{expected at least 2 operands}}
+  "std.dma_wait"() : () -> ()
+}
+
+// -----
+
 func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
   // expected-error at +1 {{expected tag to be of memref type}}
-  dma_wait %tag[%c0], %arg0 : f32
+  "std.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> ()
+}
+
+// -----
+
+func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
+  // expected-error at +1 {{expected tag indices to be of index type}}
+  "std.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> ()
+}
+
+// -----
+
+func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
+  // expected-error at +1 {{expected the number of elements to be of index type}}
+  "std.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> ()
 }
 
 // -----


        


More information about the Mlir-commits mailing list