[Mlir-commits] [mlir] e7cf723 - [mlir] Added strides check to rank reducing subview verification

Jakub Lichman llvmlistbot at llvm.org
Thu Oct 8 01:39:39 PDT 2020


Author: Jakub Lichman
Date: 2020-10-08T08:39:07Z
New Revision: e7cf723051cd4638cf5d2c407b756312292e7c18

URL: https://github.com/llvm/llvm-project/commit/e7cf723051cd4638cf5d2c407b756312292e7c18
DIFF: https://github.com/llvm/llvm-project/commit/e7cf723051cd4638cf5d2c407b756312292e7c18.diff

LOG: [mlir] Added strides check to rank reducing subview verification

Added missing strides check to verification method of rank reducing subview
which enforces strides specification for the resulting type.

Differential Revision: https://reviews.llvm.org/D88879

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index f2823c564cce..f445a0cce242 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2823,19 +2823,30 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
       }));
 }
 
+enum SubViewVerificationResult {
+  Success,
+  RankTooLarge,
+  SizeMismatch,
+  StrideMismatch,
+  ElemTypeMismatch,
+  MemSpaceMismatch,
+  AffineMapMismatch
+};
+
 /// Checks if `original` Type type can be rank reduced to `reduced` type.
 /// This function is slight variant of `is subsequence` algorithm where
 /// not matching dimension must be 1.
-static bool isRankReducedType(Type originalType, Type reducedType) {
+static SubViewVerificationResult isRankReducedType(Type originalType,
+                                                   Type reducedType) {
   if (originalType == reducedType)
-    return true;
+    return SubViewVerificationResult::Success;
   if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
-    return true;
+    return SubViewVerificationResult::Success;
   if (originalType.isa<RankedTensorType>() &&
       !reducedType.isa<RankedTensorType>())
-    return true;
+    return SubViewVerificationResult::Success;
   if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
-    return true;
+    return SubViewVerificationResult::Success;
 
   ShapedType originalShapedType = originalType.cast<ShapedType>();
   ShapedType reducedShapedType = reducedType.cast<ShapedType>();
@@ -2846,7 +2857,7 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
   unsigned originalRank = originalShape.size(),
            reducedRank = reducedShape.size();
   if (reducedRank > originalRank)
-    return false;
+    return SubViewVerificationResult::RankTooLarge;
 
   unsigned reducedIdx = 0;
   SmallVector<bool, 4> keepMask(originalRank);
@@ -2858,41 +2869,78 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
       reducedIdx++;
     // 1 is the only non-matching allowed.
     else if (originalShape[originalIdx] != 1)
-      return false;
+      return SubViewVerificationResult::SizeMismatch;
   }
   // Must match the reduced rank.
   if (reducedIdx != reducedRank)
-    return false;
+    return SubViewVerificationResult::SizeMismatch;
 
   // We are done for the tensor case.
   if (originalType.isa<RankedTensorType>())
-    return true;
+    return SubViewVerificationResult::Success;
 
   // Strided layout logic is relevant for MemRefType only.
   MemRefType original = originalType.cast<MemRefType>();
   MemRefType reduced = reducedType.cast<MemRefType>();
   MLIRContext *c = original.getContext();
-  int64_t originalOffset, symCounter = 0, dimCounter = 0;
-  SmallVector<int64_t, 4> originalStrides;
+  int64_t originalOffset, reducedOffset;
+  SmallVector<int64_t, 4> originalStrides, reducedStrides, keepStrides;
   getStridesAndOffset(original, originalStrides, originalOffset);
-  auto getSymbolOrConstant = [&](int64_t offset) {
-    return offset == ShapedType::kDynamicStrideOrOffset
-               ? getAffineSymbolExpr(symCounter++, c)
-               : getAffineConstantExpr(offset, c);
-  };
-
-  AffineExpr expr = getSymbolOrConstant(originalOffset);
-  for (unsigned i = 0, e = originalStrides.size(); i < e; i++) {
-    if (keepMask[i])
-      expr = expr + getSymbolOrConstant(originalStrides[i]) *
-                        getAffineDimExpr(dimCounter++, c);
+  getStridesAndOffset(reduced, reducedStrides, reducedOffset);
+
+  // Filter strides based on the mask and check that they are the same
+  // as reduced ones.
+  reducedIdx = 0;
+  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
+    if (keepMask[originalIdx]) {
+      if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])
+        return SubViewVerificationResult::StrideMismatch;
+      keepStrides.push_back(originalStrides[originalIdx]);
+    }
   }
 
-  auto reducedMap = AffineMap::get(dimCounter, symCounter, expr, c);
-  return original.getElementType() == reduced.getElementType() &&
-         original.getMemorySpace() == reduced.getMemorySpace() &&
-         (reduced.getAffineMaps().empty() ||
-          reducedMap == reduced.getAffineMaps().front());
+  if (original.getElementType() != reduced.getElementType())
+    return SubViewVerificationResult::ElemTypeMismatch;
+
+  if (original.getMemorySpace() != reduced.getMemorySpace())
+    return SubViewVerificationResult::MemSpaceMismatch;
+
+  auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c);
+  if (!reduced.getAffineMaps().empty() &&
+      reducedMap != reduced.getAffineMaps().front())
+    return SubViewVerificationResult::AffineMapMismatch;
+
+  return SubViewVerificationResult::Success;
+}
+
+template <typename OpTy>
+static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
+                                            OpTy op, Type expectedType) {
+  auto memrefType = expectedType.cast<ShapedType>();
+  switch (result) {
+  case SubViewVerificationResult::Success:
+    return success();
+  case SubViewVerificationResult::RankTooLarge:
+    return op.emitError("expected result rank to be smaller or equal to ")
+           << "the source rank.";
+  case SubViewVerificationResult::SizeMismatch:
+    return op.emitError("expected result type to be ")
+           << expectedType
+           << " or a rank-reduced version. (mismatch of result sizes)";
+  case SubViewVerificationResult::StrideMismatch:
+    return op.emitError("expected result type to be ")
+           << expectedType
+           << " or a rank-reduced version. (mismatch of result strides)";
+  case SubViewVerificationResult::ElemTypeMismatch:
+    return op.emitError("expected result element type to be ")
+           << memrefType.getElementType();
+  case SubViewVerificationResult::MemSpaceMismatch:
+    return op.emitError("expected result and source memory spaces to match.");
+  case SubViewVerificationResult::AffineMapMismatch:
+    return op.emitError("expected result type to be ")
+           << expectedType
+           << " or a rank-reduced version. (mismatch of result affine map)";
+  }
 }
 
 template <typename OpType>
@@ -2937,11 +2985,9 @@ static LogicalResult verify(SubViewOp op) {
       baseType, extractFromI64ArrayAttr(op.static_offsets()),
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
-  if (!isRankReducedType(expectedType, subViewType))
-    return op.emitError("expected result type to be ")
-           << expectedType << " or a rank-reduced version.";
 
-  return success();
+  auto result = isRankReducedType(expectedType, subViewType);
+  return produceSubViewErrorMsg(result, op, expectedType);
 }
 
 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
@@ -3352,11 +3398,8 @@ static LogicalResult verify(SubTensorOp op) {
       op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
-  if (!isRankReducedType(expectedType, op.getType()))
-    return op.emitError("expected result type to be ")
-           << expectedType << " or a rank-reduced version.";
-
-  return success();
+  auto result = isRankReducedType(expectedType, op.getType());
+  return produceSubViewErrorMsg(result, op, expectedType);
 }
 
 void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 2590dc0105c4..219c3bc84d57 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -21,6 +21,7 @@
 // CHECK-DAG: #[[$SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)>
 // CHECK-DAG: #[[$SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 36 + d1 * 36 + d2 * 4 + d3 * 4 + d4)>
 // CHECK-DAG: #[[$SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5 * s6)>
+// CHECK-DAG: #[[$SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
 
 // CHECK-LABEL: func @func_with_ops
 // CHECK-SAME: %[[ARG:.*]]: f32
@@ -811,11 +812,11 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
 
   %15 = alloc(%arg1, %arg2)[%c0, %c1, %arg1, %arg0, %arg0, %arg2, %arg2] : memref<1x?x5x1x?x1xf32, affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * d0 + s2 * d1 + s3 * d2 + s4 * d3 + s5 * d4 + s6 * d5)>>
   // CHECK: subview %15[0, 0, 0, 0, 0, 0] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1]  :
-  // CHECK-SAME: memref<1x?x5x1x?x1xf32,  #[[$SUBVIEW_MAP7]]> to memref<?x5x?xf32>
-  %16 = subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?xf32>
+  // CHECK-SAME: memref<1x?x5x1x?x1xf32,  #[[$SUBVIEW_MAP7]]> to memref<?x5x?xf32, #[[$BASE_MAP3]]>
+  %16 = subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?xf32, offset: ?, strides: [?, ?, ?]>
   // CHECK: subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1]  :
-  // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref<?x5x?x1xf32>
-  %17 = subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] :  memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?x1xf32>
+  // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref<?x5x?x1xf32, #[[$SUBVIEW_MAP8]]>
+  %17 = subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] :  memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?x1xf32, offset: ?, strides: [?, ?, ?, ?]>
 
   %18 = alloc() : memref<1x8xf32>
   // CHECK: subview %18[0, 0] [1, 8] [1, 1]  : memref<1x8xf32> to memref<8xf32>

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 7356c07577db..b59353aa2f7c 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1011,7 +1011,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
   %0 = alloc() : memref<8x16x4xf32>
-  // expected-error at +1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>'}}
+  // expected-error at +1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result strides)}}
   %1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2]
     : memref<8x16x4xf32> to
       memref<?x?x?xf32, offset: ?, strides: [64, 4, 1]>
@@ -1020,9 +1020,31 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 // -----
 
+func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %0 = alloc() : memref<8x16x4xf32>
+  // expected-error at +1 {{expected result element type to be 'f32'}}
+  %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1]
+    : memref<8x16x4xf32> to
+      memref<8x16x4xi32>
+  return
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %0 = alloc() : memref<8x16x4xf32>
+  // expected-error at +1 {{expected result rank to be smaller or equal to the source rank.}}
+  %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1]
+    : memref<8x16x4xf32> to
+      memref<8x16x4x3xi32>
+  return
+}
+
+// -----
+
 func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
   %0 = alloc() : memref<8x16x4xf32>
-  // expected-error at +1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>'}}
+  // expected-error at +1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}}
   %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1]
     : memref<8x16x4xf32> to memref<16x4xf32>
   return
@@ -1030,6 +1052,14 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
 
 // -----
 
+func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
+  // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}}
+  %0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
+  return
+}
+
+// -----
+
 func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
   // expected-error at +1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}}
   %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>
@@ -1259,7 +1289,7 @@ func @imaginary_part_from_incompatible_complex_type(%cplx: complex<f64>) {
 // -----
 
 func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
-      // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>'}}
+      // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}}
   %0 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1]
     : tensor<8x16x4xf32> to tensor<?x4x4xf32>
 
@@ -1269,7 +1299,7 @@ func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
 // -----
 
 func @subtensor_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
-      // expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>'}}
+      // expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>' or a rank-reduced version. (mismatch of result sizes)}}
   %0 = subtensor %t[0, 0, 0][%idx, 3, %idx][1, 1, 1]
     : tensor<8x16x4xf32> to tensor<4x4x4xf32>
 


        


More information about the Mlir-commits mailing list