[Mlir-commits] [mlir] 14088a6 - [mlir] Added support for rank reducing subviews

Jakub Lichman llvmlistbot at llvm.org
Wed Sep 30 04:16:19 PDT 2020


Author: Jakub Lichman
Date: 2020-09-30T11:15:18Z
New Revision: 14088a6f5d1ae597960833a366beb9acee8d65cb

URL: https://github.com/llvm/llvm-project/commit/14088a6f5d1ae597960833a366beb9acee8d65cb
DIFF: https://github.com/llvm/llvm-project/commit/14088a6f5d1ae597960833a366beb9acee8d65cb.diff

LOG: [mlir] Added support for rank reducing subviews

This commit adds support for subviews which enable to reduce resulting rank
by dropping static dimensions of size 1.

Differential Revision: https://reviews.llvm.org/D88534

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 352b7d8fd3d6..ff1a82c26561 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2841,6 +2841,20 @@ def SubViewOp : Std_Op<"subview", [
       "ArrayRef<NamedAttribute> attrs = {}">,
     // Build a SubViewOp with all dynamic entries.
     OpBuilder<
+      "OpBuilder &b, OperationState &result, Value source, "
+      "ValueRange offsets, ValueRange sizes, ValueRange strides, "
+      "ArrayRef<NamedAttribute> attrs = {}">,
+    // Build a SubViewOp with mixed static and dynamic entries
+    // and custom result type.
+    OpBuilder<
+      "OpBuilder &b, OperationState &result, MemRefType resultType, "
+      "Value source, ArrayRef<int64_t> staticOffsets, "
+      "ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
+      "ValueRange offsets, ValueRange sizes, "
+      "ValueRange strides, ArrayRef<NamedAttribute> attrs = {}">,
+    // Build a SubViewOp with all dynamic entries and custom result type.
+    OpBuilder<
+      "OpBuilder &b, OperationState &result, MemRefType resultType, "
       "Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, "
       "ArrayRef<NamedAttribute> attrs = {}">
   ];

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index c0dc87210a3f..1cabf172b7fc 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2728,15 +2728,47 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
         staticStridesVector, offsets, sizes, strides, attrs);
 }
 
+/// Build a SubViewOp as above but with custom result type.
+void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
+                            MemRefType resultType, Value source,
+                            ArrayRef<int64_t> staticOffsets,
+                            ArrayRef<int64_t> staticSizes,
+                            ArrayRef<int64_t> staticStrides, ValueRange offsets,
+                            ValueRange sizes, ValueRange strides,
+                            ArrayRef<NamedAttribute> attrs) {
+  build(b, result, resultType, source, offsets, sizes, strides,
+        b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
+        b.getI64ArrayAttr(staticStrides));
+  result.addAttributes(attrs);
+}
+
+/// Build a SubViewOp as above but with custom result type.
+void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
+                            MemRefType resultType, Value source,
+                            ValueRange offsets, ValueRange sizes,
+                            ValueRange strides,
+                            ArrayRef<NamedAttribute> attrs) {
+  auto sourceMemRefType = source.getType().cast<MemRefType>();
+  unsigned rank = sourceMemRefType.getRank();
+  SmallVector<int64_t, 4> staticOffsetsVector;
+  staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
+  SmallVector<int64_t, 4> staticSizesVector;
+  staticSizesVector.assign(rank, ShapedType::kDynamicSize);
+  SmallVector<int64_t, 4> staticStridesVector;
+  staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
+  build(b, result, resultType, source, staticOffsetsVector, staticSizesVector,
+        staticStridesVector, offsets, sizes, strides, attrs);
+}
+
 /// Verify that a particular offset/size/stride static attribute is well-formed.
 static LogicalResult
 verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName,
                     ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
                     ValueRange values) {
   /// Check static and dynamic offsets/sizes/strides breakdown.
-  if (attr.size() != op.getRank())
-    return op.emitError("expected ")
-           << op.getRank() << " " << name << " values";
+  size_t inputRank = op.source().getType().cast<MemRefType>().getRank();
+  if (attr.size() != inputRank)
+    return op.emitError("expected ") << inputRank << " " << name << " values";
   unsigned expectedNumDynamicEntries =
       llvm::count_if(attr.getValue(), [&](Attribute attr) {
         return isDynamic(attr.cast<IntegerAttr>().getInt());
@@ -2755,6 +2787,62 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
       }));
 }
 
+/// Checks if `original` MemRef type can be rank reduced to `reduced` type.
+/// This function is slight variant of `is subsequence` algorithm where
+/// not matching dimension must be 1.
+static bool isRankReducedType(Type originalType, Type reducedType) {
+  if (originalType == reducedType)
+    return true;
+
+  MemRefType original = originalType.cast<MemRefType>();
+  MemRefType reduced = reducedType.cast<MemRefType>();
+  ArrayRef<int64_t> originalShape = original.getShape();
+  ArrayRef<int64_t> reducedShape = reduced.getShape();
+  unsigned originalRank = originalShape.size(),
+           reducedRank = reducedShape.size();
+  if (reducedRank > originalRank)
+    return false;
+
+  unsigned reducedIdx = 0;
+  SmallVector<bool, 4> keepMask(originalRank);
+  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
+    // -2 is never used as a dim size so it will never match.
+    int reducedVal = reducedIdx < reducedRank ? reducedShape[reducedIdx] : -2;
+    // Skip matching dims greedily.
+    if ((keepMask[originalIdx] = originalShape[originalIdx] == reducedVal))
+      reducedIdx++;
+    // 1 is the only non-matching allowed.
+    else if (originalShape[originalIdx] != 1)
+      return false;
+  }
+  // Must match the reduced rank.
+  if (reducedIdx != reducedRank)
+    return false;
+
+  MLIRContext *c = original.getContext();
+  int64_t originalOffset, symCounter = 0, dimCounter = 0;
+  SmallVector<int64_t, 4> originalStrides;
+  getStridesAndOffset(original, originalStrides, originalOffset);
+  auto getSymbolOrConstant = [&](int64_t offset) {
+    return offset == ShapedType::kDynamicStrideOrOffset
+               ? getAffineSymbolExpr(symCounter++, c)
+               : getAffineConstantExpr(offset, c);
+  };
+
+  AffineExpr expr = getSymbolOrConstant(originalOffset);
+  for (unsigned i = 0, e = originalStrides.size(); i < e; i++) {
+    if (keepMask[i])
+      expr = expr + getSymbolOrConstant(originalStrides[i]) *
+                        getAffineDimExpr(dimCounter++, c);
+  }
+
+  auto reducedMap = AffineMap::get(dimCounter, symCounter, expr, c);
+  return original.getElementType() == reduced.getElementType() &&
+         original.getMemorySpace() == reduced.getMemorySpace() &&
+         (reduced.getAffineMaps().empty() ||
+          reducedMap == reduced.getAffineMaps().front());
+}
+
 /// Verifier for SubViewOp.
 static LogicalResult verify(SubViewOp op) {
   auto baseType = op.getBaseMemRefType().cast<MemRefType>();
@@ -2790,8 +2878,9 @@ static LogicalResult verify(SubViewOp op) {
       op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
-  if (op.getType() != expectedType)
-    return op.emitError("expected result type to be ") << expectedType;
+  if (!isRankReducedType(expectedType, subViewType))
+    return op.emitError("expected result type to be ")
+           << expectedType << " or a rank-reduced version.";
 
   return success();
 }

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 332bfbe2f457..5bf7857a66e8 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2107,9 +2107,6 @@ LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
   // TODO: expand support to these 2 cases.
   if (!xferOp.permutation_map().isMinorIdentity())
     return failure();
-  // TODO: relax this precondition. This will require rank-reducing subviews.
-  if (xferOp.getMemRefType().getRank() != xferOp.getTransferRank())
-    return failure();
   // Must have some masked dimension to be a candidate for splitting.
   if (!xferOp.hasMaskedDim())
     return failure();

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index f182936c8703..5e3959af29dd 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -19,6 +19,8 @@
 // CHECK-DAG: #[[$SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>
 // CHECK-DAG: #[[$SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[$SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)>
+// CHECK-DAG: #[[$SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 36 + d1 * 36 + d2 * 4 + d3 * 4 + d4)>
+// CHECK-DAG: #[[$SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5 * s6)>
 
 // CHECK-LABEL: func @func_with_ops
 // CHECK-SAME: %[[ARG:.*]]: f32
@@ -797,6 +799,33 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
   %11 = subview %9[%arg1, %arg2][4, 4][2, 2]
     : memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[8, 2]>
 
+  %12 = alloc() : memref<1x9x1x4x1xf32, affine_map<(d0, d1, d2, d3, d4) -> (36 * d0 + 36 * d1 + 4 * d2 + 4 * d3 + d4)>>
+  // CHECK: subview %12[%arg1, %arg1, %arg1, %arg1, %arg1]
+  // CHECK-SAME: [1, 9, 1, 4, 1] [%arg2, %arg2, %arg2, %arg2, %arg2] :
+  // CHECK-SAME: memref<1x9x1x4x1xf32, #[[$SUBVIEW_MAP6]]> to memref<9x4xf32, #[[$SUBVIEW_MAP2]]>
+  %13 = subview %12[%arg1, %arg1, %arg1, %arg1, %arg1][1, 9, 1, 4, 1][%arg2, %arg2, %arg2, %arg2, %arg2] : memref<1x9x1x4x1xf32, offset: 0, strides: [36, 36, 4, 4, 1]> to memref<9x4xf32, offset: ?, strides: [?, ?]>
+  // CHECK: subview %12[%arg1, %arg1, %arg1, %arg1, %arg1]
+  // CHECK-SAME: [1, 9, 1, 4, 1] [%arg2, %arg2, %arg2, %arg2, %arg2] :
+  // CHECK-SAME: memref<1x9x1x4x1xf32, #[[$SUBVIEW_MAP6]]> to memref<1x9x4xf32, #[[$BASE_MAP3]]>
+  %14 = subview %12[%arg1, %arg1, %arg1, %arg1, %arg1][1, 9, 1, 4, 1][%arg2, %arg2, %arg2, %arg2, %arg2] : memref<1x9x1x4x1xf32, offset: 0, strides: [36, 36, 4, 4, 1]> to memref<1x9x4xf32, offset: ?, strides: [?, ?, ?]>
+
+  %15 = alloc(%arg1, %arg2)[%c0, %c1, %arg1, %arg0, %arg0, %arg2, %arg2] : memref<1x?x5x1x?x1xf32, affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * d0 + s2 * d1 + s3 * d2 + s4 * d3 + s5 * d4 + s6 * d5)>>
+  // CHECK: subview %15[0, 0, 0, 0, 0, 0] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1]  :
+  // CHECK-SAME: memref<1x?x5x1x?x1xf32,  #[[$SUBVIEW_MAP7]]> to memref<?x5x?xf32>
+  %16 = subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?xf32>
+  // CHECK: subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1]  :
+  // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref<?x5x?x1xf32>
+  %17 = subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] :  memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref<?x5x?x1xf32>
+
+  %18 = alloc() : memref<1x8xf32>
+  // CHECK: subview %18[0, 0] [1, 8] [1, 1]  : memref<1x8xf32> to memref<8xf32>
+  %19 = subview %18[0, 0][1, 8][1, 1] : memref<1x8xf32> to memref<8xf32>
+
+  %20 = alloc() : memref<8x16x4xf32>
+  // CHECK: subview %20[0, 0, 0] [1, 16, 4] [1, 1, 1]  : memref<8x16x4xf32> to memref<16x4xf32>
+  %21 = subview %20[0, 0, 0][1, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32>
+
+  %22 = subview %20[3, 4, 2][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, offset: 210, strides: [4, 1]>
   return
 }
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index e02dbca494df..ab18845bdb53 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1020,6 +1020,16 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 // -----
 
+func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %0 = alloc() : memref<8x16x4xf32>
+  // expected-error at +1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>'}}
+  %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1]
+    : memref<8x16x4xf32> to memref<16x4xf32>
+  return
+}
+
+// -----
+
 func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
   // expected-error at +1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}}
   %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>


        


More information about the Mlir-commits mailing list