[Mlir-commits] [mlir] 8f22998 - [mlir][Linalg] Add a linalg.tensor_reshape to operate on tensors

Nicolas Vasilache llvmlistbot at llvm.org
Mon Apr 6 08:23:07 PDT 2020


Author: Nicolas Vasilache
Date: 2020-04-06T11:19:17-04:00
New Revision: 8f229989d5394a36624f8ef1abf06f556e0664b7

URL: https://github.com/llvm/llvm-project/commit/8f229989d5394a36624f8ef1abf06f556e0664b7
DIFF: https://github.com/llvm/llvm-project/commit/8f229989d5394a36624f8ef1abf06f556e0664b7.diff

LOG: [mlir][Linalg] Add a linalg.tensor_reshape to operate on tensors

Summary:
This revision adds a tensor_reshape operation that operates on tensors.
In the tensor world the constraints are less stringent and we can allow more
arbitrary dynamic reshapes, as long as they are contractions.

The expansion of a dynamic dimension into multiple dynamic dimensions is under-specified and is punted on for now.

Differential Revision: https://reviews.llvm.org/D77360

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index bf0e1dd48770..3e667d98f822 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -60,9 +60,31 @@ def Linalg_RangeOp :
   let assemblyFormat = "$min `:` $max `:` $step attr-dict `:` type(results)";
 }
 
-def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
-    Arguments<(ins AnyStridedMemRef:$view, AffineMapArrayAttr:$reassociation)>,
-    Results<(outs AnyStridedMemRef)> {
+class Linalg_ReshapeLikeOp<string mnemonic> :
+    Linalg_Op<mnemonic, [NoSideEffect]> {
+  let builders = [
+    // Builder for a contracting reshape whose result type is computed from
+    // `src` and `reassociation`.
+    OpBuilder<"Builder *b, OperationState &result, Value src, "
+    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+    "ArrayRef<NamedAttribute> attrs = {}">,
+    // Builder for a reshape whose result type is passed explicitly. This may be
+    // either a contracting or expanding reshape.
+    OpBuilder<"Builder *b, OperationState &result, Type resultType, Value src,"
+    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
+    "ArrayRef<NamedAttribute> attrs = {}">];
+
+  code commonExtraClassDeclaration = [{
+    static StringRef getReassociationAttrName() { return "reassociation"; }
+  }];
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type(results)
+  }];
+}
+
+def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">,
+    Arguments<(ins AnyStridedMemRef:$src, AffineMapArrayAttr:$reassociation)>,
+    Results<(outs AnyStridedMemRef:$result)> {
   let summary = "linalg.reshape produces a new view into the operand view";
   let description = [{
     The `linalg.reshape` op produces a new view whose sizes are a reassociation
@@ -102,27 +124,55 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
       memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
     ```
   }];
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
+    MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
+  }];
+  let hasFolder = 1;
+}
 
-  let builders = [
-    // Builder for a contracting reshape whose result type is computed from
-    // `view` and `reassociation`.
-    OpBuilder<"Builder *b, OperationState &result, Value view, "
-    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
-    "ArrayRef<NamedAttribute> attrs = {}">,
-    // Builder for a reshape whose result type is passed explicitly. This may be
-    // either a contracting or expanding reshape.
-    OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view,"
-    "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
-    "ArrayRef<NamedAttribute> attrs = {}">];
+def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
+    Arguments<(ins AnyTensor:$src,
+                   AffineMapArrayAttr:$reassociation)>,
+    Results<(outs AnyTensor:$result)> {
+  let summary = "linalg.tensor_reshape produces a new reshaped tensor.";
+  let description = [{
+    The `linalg.reshape` op produces a new tensor whose sizes are a
+    reassociation of the original `src`.
 
-  let extraClassDeclaration = [{
-    static StringRef getReassociationAttrName() { return "reassociation"; }
-    MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an affine map array attribute. In the future,
+    non-continuous groupings may be allowed (i.e. permutations, reindexings
+    etc).
+
+    A reshape may either collapse or expand dimensions, depending on the
+    relationship between source and target tensor ranks. The verification rule
+    is that the reassociation maps are applied to the tensor with the larger
+    rank to obtain the tensor with the smaller rank. In the case of a dimension
+    expansion, the reassociation maps can be interpreted as inverse maps.
+
+    Examples:
+
+    ```mlir
+    // Dimension collapse (i, j) -> i' and k -> k'
+    %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
+      tensor<?x?x?xf32> into tensor<?x?xf32>
+    ```
+
+    ```mlir
+    // Dimension expansion i -> (i', j') and (k) -> (k')
+    %b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
+      tensor<?x?xf32> into tensor<?x?x?xf32>
+    ```
   }];
-  let assemblyFormat = [{
-    $view $reassociation attr-dict `:` type($view) `into` type(results)
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    RankedTensorType getSrcType() {
+      return src().getType().cast<RankedTensorType>();
+    }
+    RankedTensorType getResultType() {
+      return result().getType().cast<RankedTensorType>();
+    }
   }];
-  let hasFolder = 1;
 }
 
 def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index cb66ae9f5013..07c8111941e4 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -164,7 +164,7 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto reshapeOp = cast<ReshapeOp>(op);
-    MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
+    MemRefType dstType = reshapeOp.getResultType();
 
     if (!dstType.hasStaticShape())
       return failure();
@@ -179,7 +179,7 @@ class ReshapeOpConversion : public ConvertToLLVMPattern {
 
     edsc::ScopedContext context(rewriter, op->getLoc());
     ReshapeOpOperandAdaptor adaptor(operands);
-    BaseViewConversionHelper baseDesc(adaptor.view());
+    BaseViewConversionHelper baseDesc(adaptor.src());
     BaseViewConversionHelper desc(typeConverter.convertType(dstType));
     desc.setAllocatedPtr(baseDesc.allocatedPtr());
     desc.setAlignedPtr(baseDesc.alignedPtr());

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 24dcf7370943..3d81cce0e883 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -531,30 +531,33 @@ getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) {
 }
 
 void mlir::linalg::ReshapeOp::build(
-    Builder *b, OperationState &result, Value view,
+    Builder *b, OperationState &result, Value src,
     ArrayRef<ArrayRef<AffineExpr>> reassociation,
     ArrayRef<NamedAttribute> attrs) {
   auto maps = getSymbolLessAffineMaps(reassociation);
-  auto memRefType = view.getType().cast<MemRefType>();
+  auto memRefType = src.getType().cast<MemRefType>();
   auto resultType = computeReshapeCollapsedType(memRefType, maps);
-  build(b, result, resultType, view, attrs);
+  build(b, result, resultType, src, attrs);
   result.addAttribute(ReshapeOp::getReassociationAttrName(),
                       b->getAffineMapArrayAttr(maps));
 }
 
 void mlir::linalg::ReshapeOp::build(
-    Builder *b, OperationState &result, Type resultType, Value view,
+    Builder *b, OperationState &result, Type resultType, Value src,
     ArrayRef<ArrayRef<AffineExpr>> reassociation,
     ArrayRef<NamedAttribute> attrs) {
   auto maps = getSymbolLessAffineMaps(reassociation);
-  build(b, result, resultType, view, attrs);
+  build(b, result, resultType, src, attrs);
   result.addAttribute(ReshapeOp::getReassociationAttrName(),
                       b->getAffineMapArrayAttr(maps));
 }
 
-static LogicalResult verify(ReshapeOp op) {
-  MemRefType expandedType = op.getViewType();
-  MemRefType collapsedType = op.getResult().getType().cast<MemRefType>();
+// Common verifier for reshape-like types. Fills `expandedType` and
+// `collapsedType` with the proper `src` or `result` type.
+template <typename Op, typename T>
+LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType, T &collapsedType) {
+  expandedType = op.getSrcType();
+  collapsedType = op.getResultType();
   unsigned expandedRank = expandedType.getRank();
   unsigned collapsedRank = collapsedType.getRank();
   bool isCollapse = expandedRank > collapsedRank;
@@ -568,7 +571,7 @@ static LogicalResult verify(ReshapeOp op) {
     return op.emitOpError("expected to collapse or expand dims");
 
   if (collapsedRank != op.reassociation().size())
-    return op.emitOpError("expected rank of the collapsed view(")
+    return op.emitOpError("expected rank of the collapsed type(")
            << collapsedRank << ") to be the number of reassociation maps("
            << op.reassociation().size() << ")";
   auto maps = getAffineMaps(op.reassociation());
@@ -581,6 +584,14 @@ static LogicalResult verify(ReshapeOp op) {
   if (!isReassociationValid(maps, &invalidIdx))
     return op.emitOpError("expected reassociation map #")
            << invalidIdx << " to be valid and contiguous";
+  return success();
+}
+
+static LogicalResult verify(ReshapeOp op) {
+  MemRefType expandedType, collapsedType;
+  if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+    return failure();
+  auto maps = getAffineMaps(op.reassociation());
   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
   if (collapsedType != expectedType)
     return op.emitOpError("expected collapsed type to be ")
@@ -588,6 +599,75 @@ static LogicalResult verify(ReshapeOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TensorReshapeOp
+//===----------------------------------------------------------------------===//
+
+/// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
+static RankedTensorType
+computeTensorReshapeCollapsedType(RankedTensorType type,
+                                  ArrayRef<AffineMap> reassociation) {
+  auto shape = type.getShape();
+  SmallVector<int64_t, 4> newShape;
+  newShape.reserve(reassociation.size());
+
+  // Use the fact that reassociation is valid to simplify the logic: only use
+  // each map's rank.
+  assert(isReassociationValid(reassociation) && "invalid reassociation");
+  unsigned currentDim = 0;
+  for (AffineMap m : reassociation) {
+    unsigned dim = m.getNumResults();
+    auto band = shape.drop_front(currentDim).take_front(dim);
+    int64_t size = 1;
+    if (llvm::is_contained(band, ShapedType::kDynamicSize))
+      size = ShapedType::kDynamicSize;
+    else
+      for (unsigned d = 0; d < dim; ++d)
+        size *= shape[currentDim + d];
+    newShape.push_back(size);
+    currentDim += dim;
+  }
+
+  return RankedTensorType::get(newShape, type.getElementType());
+}
+
+void mlir::linalg::TensorReshapeOp::build(
+    Builder *b, OperationState &result, Value src,
+    ArrayRef<ArrayRef<AffineExpr>> reassociation,
+    ArrayRef<NamedAttribute> attrs) {
+  auto maps = getSymbolLessAffineMaps(reassociation);
+  auto resultType = computeTensorReshapeCollapsedType(
+      src.getType().cast<RankedTensorType>(), maps);
+  build(b, result, resultType, src, attrs);
+  result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
+                      b->getAffineMapArrayAttr(maps));
+}
+
+void mlir::linalg::TensorReshapeOp::build(
+    Builder *b, OperationState &result, Type resultType, Value src,
+    ArrayRef<ArrayRef<AffineExpr>> reassociation,
+    ArrayRef<NamedAttribute> attrs) {
+  auto maps = getSymbolLessAffineMaps(reassociation);
+  build(b, result, resultType, src, attrs);
+  result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
+                      b->getAffineMapArrayAttr(maps));
+}
+
+static LogicalResult verify(TensorReshapeOp op) {
+  RankedTensorType expandedType, collapsedType;
+  if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
+    return failure();
+  auto maps = getAffineMaps(op.reassociation());
+  // TODO(ntv): expanding a ? with a non-constant is under-specified. Error
+  // out.
+  RankedTensorType expectedType =
+      computeTensorReshapeCollapsedType(expandedType, maps);
+  if (collapsedType != expectedType)
+    return op.emitOpError("expected collapsed type to be ")
+           << expectedType << ", but got " << collapsedType;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SliceOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 7a8291504ae6..0041f97d7eea 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -485,7 +485,7 @@ func @reshape(%arg0: memref<?xf32>) {
 // -----
 
 func @reshape(%arg0: memref<?x?x?xf32>) {
-  // expected-error @+1 {{expected rank of the collapsed view(2) to be the number of reassociation maps(1)}}
+  // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}}
   %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>] :
     memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
 }

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 05d35f8f43e4..c28c671d2885 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -505,8 +505,8 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
 // CHECK-DAG: #[[reshape5D2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)>
 // CHECK-DAG: #[[reshape5D34:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
 
-func @reshape_static(%arg0: memref<3x4x5xf32>) {
-  // Reshapes that collapse and expand back a contiguous tensor.
+func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, %arg2: tensor<3x?x5xf32>) {
+  // Reshapes that collapse and expand back a contiguous buffer.
   %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>,
                              affine_map<(i, j, k) -> (k)>] :
     memref<3x4x5xf32> into memref<12x5xf32>
@@ -523,7 +523,7 @@ func @reshape_static(%arg0: memref<3x4x5xf32>) {
     memref<3x4x5xf32> into memref<60xf32>
   %r2 = linalg.reshape %2 [affine_map<(i, j, k) -> (i, j, k)>] :
     memref<60xf32> into memref<3x4x5xf32>
-  // Reshapes that expand and collapse back a contiguous tensor with some 1's.
+  // Reshapes that expand and collapse back a contiguous buffer with some 1's.
   %3 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
                              affine_map<(i, j, k, l, m) -> (k)>,
                              affine_map<(i, j, k, l, m) -> (l, m)>] :
@@ -532,6 +532,23 @@ func @reshape_static(%arg0: memref<3x4x5xf32>) {
                            affine_map<(i, j, k, l, m) -> (k)>,
                            affine_map<(i, j, k, l, m) -> (l, m)>] :
     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+  // Reshapes on tensors.
+  %t0 = linalg.tensor_reshape %arg1 [affine_map<(i, j, k, l, m) -> (i, j)>,
+                                     affine_map<(i, j, k, l, m) -> (k)>,
+                                     affine_map<(i, j, k, l, m) -> (l, m)>] :
+    tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+  %rt0 = linalg.tensor_reshape %t0 [affine_map<(i, j, k, l, m) -> (i, j)>,
+                                   affine_map<(i, j, k, l, m) -> (k)>,
+                                   affine_map<(i, j, k, l, m) -> (l, m)>] :
+    tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+  %t1 = linalg.tensor_reshape %arg2 [affine_map<(i, j, k, l, m) -> (i, j)>,
+                                     affine_map<(i, j, k, l, m) -> (k)>,
+                                     affine_map<(i, j, k, l, m) -> (l, m)>] :
+    tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+  %rt1 = linalg.tensor_reshape %t1 [affine_map<(i, j, k, l, m) -> (i)>,
+                                    affine_map<(i, j, k, l, m) -> (j, k)>,
+                                    affine_map<(i, j, k, l, m) -> (l, m)>] :
+    tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
   return
 }
 // CHECK-LABEL: func @reshape_static
@@ -551,6 +568,11 @@ func @reshape_static(%arg0: memref<3x4x5xf32>) {
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
 //       CHECK:   linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]]
 //  CHECK-SAME:     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+//
+//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
 
 // -----
 


        


More information about the Mlir-commits mailing list