[Mlir-commits] [mlir] aa37318 - [mlir][Linalg] Rewrite DownscaleSizeOneWindowed2DConvolution to use rank-reducing insert/extract slices.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Nov 12 03:57:16 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-12T11:57:12Z
New Revision: aa3731806723a2a12914aecda2af6e40e1903702

URL: https://github.com/llvm/llvm-project/commit/aa3731806723a2a12914aecda2af6e40e1903702
DIFF: https://github.com/llvm/llvm-project/commit/aa3731806723a2a12914aecda2af6e40e1903702.diff

LOG: [mlir][Linalg] Rewrite DownscaleSizeOneWindowed2DConvolution to use rank-reducing insert/extract slices.

This rewriting enables better bufferization and canonicalizations.

Differential Revision: https://reviews.llvm.org/D113745

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Linalg/decompose-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 72e33d08dea33..f9e3014838a1f 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -85,6 +85,18 @@ bool canFoldIntoConsumerOp(CastOp castOp);
 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
 /// that can be folded.
 LogicalResult foldTensorCast(Operation *op);
+
+/// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
+/// appropriate sizes to reduce the rank of `tensor` to `targetType`.
+Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
+                                                Value tensor,
+                                                RankedTensorType targetType);
+
+/// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
+/// appropriate sizes to increase the rank of `tensor` to `dest`.
+Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
+                                               Value tensor, Value dest);
+
 } // namespace tensor
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 0daabe25ea1cd..fb8dbd8a6c5be 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -202,6 +202,57 @@ class BaseMemRefType : public ShapedType {
 #include "mlir/IR/BuiltinTypeInterfaces.h.inc"
 
 namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// RankedTensorType
+//===----------------------------------------------------------------------===//
+
+/// This is a builder type that keeps local references to arguments. Arguments
+/// that are passed into the builder must out-live the builder.
+class RankedTensorType::Builder {
+public:
+  /// Build from another RankedTensorType.
+  explicit Builder(RankedTensorType other)
+      : shape(other.getShape()), elementType(other.getElementType()),
+        encoding(other.getEncoding()) {}
+
+  /// Build from scratch.
+  Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
+      : shape(shape), elementType(elementType), encoding(encoding) {}
+
+  Builder &setShape(ArrayRef<int64_t> newShape) {
+    shape = newShape;
+    return *this;
+  }
+
+  Builder &setElementType(Type newElementType) {
+    elementType = newElementType;
+    return *this;
+  }
+
+  Builder &setEncoding(Attribute newEncoding) {
+    encoding = newEncoding;
+    return *this;
+  }
+
+  /// Create a new RankedTensor by erasing a dim from shape.
+  // Note: the newly created type has ownership of a new shape vector.
+  RankedTensorType dropDim(unsigned dim) {
+    SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
+    newShape.erase(newShape.begin() + dim);
+    return setShape(newShape);
+  }
+
+  operator RankedTensorType() {
+    return RankedTensorType::get(shape, elementType, encoding);
+  }
+
+private:
+  ArrayRef<int64_t> shape;
+  Type elementType;
+  Attribute encoding;
+};
+
 //===----------------------------------------------------------------------===//
 // MemRefType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b27086144761e..c05ff64805921 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -701,6 +701,11 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
       return $_get(elementType.getContext(), shape, elementType, encoding);
     }]>
   ];
+  let extraClassDeclaration = [{
+    /// This is a builder type that keeps local references to arguments.
+    /// Arguments that are passed into the builder must out-live the builder.
+    class Builder;
+  }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index bcfca706900ed..f15337f0f70e6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -852,7 +852,6 @@ struct DownscaleSizeOneWindowed2DConvolution final
     auto filterType = filter.getType().dyn_cast<RankedTensorType>();
     auto outputType = output.getType().dyn_cast<RankedTensorType>();
 
-    auto inputShape = inputType.getShape();
     auto filterShape = filterType.getShape();
     auto outputShape = outputType.getShape();
 
@@ -860,52 +859,47 @@ struct DownscaleSizeOneWindowed2DConvolution final
     // of size 1. Other cases can rely on tiling to reduce to such cases.
     int64_t fhSize = filterShape[0], fwSize = filterShape[1];
     int64_t ohSize = outputShape[1], owSize = outputShape[2];
-    if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1))
+    bool removeH = (fhSize == 1 && ohSize == 1);
+    bool removeW = (fwSize == 1 && owSize == 1);
+    if (!removeH && !removeW)
       return failure();
-    bool removeH = ohSize == 1;
 
     // Get new shapes and types for all operands by removing the size-1
     // dimension.
+    using RTTBuilder = RankedTensorType::Builder;
+    auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+    auto newFilterType = RTTBuilder(filterType).dropDim((removeH ? 0 : 1));
+    auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
 
-    SmallVector<int64_t, 3> newInputShape{
-        inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]};
-    auto newInputType = RankedTensorType::get(
-        newInputShape, inputType.getElementType(), inputType.getEncoding());
-
-    SmallVector<int64_t, 3> newFilterShape{filterShape[removeH ? 1 : 0],
-                                           filterShape[2], filterShape[3]};
-    auto newFilterType = RankedTensorType::get(
-        newFilterShape, filterType.getElementType(), filterType.getEncoding());
-
-    SmallVector<int64_t, 3> newOutputShape{
-        outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]};
-    auto newOutputType = RankedTensorType::get(
-        newOutputShape, outputType.getElementType(), outputType.getEncoding());
-
-    SmallVector<ReassociationIndices, 3> ioReshapeIndices = {{0}, {1, 2}, {3}};
-    SmallVector<ReassociationIndices, 3> fReshapeIndices = {{0, 1}, {2}, {3}};
-
-    // Reshape all operands for 1-D convolution.
+    // Rank-reduce operands.
     Location loc = convOp.getLoc();
-    Value newInput = rewriter.create<linalg::TensorCollapseShapeOp>(
-        loc, newInputType, input, ioReshapeIndices);
-    Value newFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
-        loc, newFilterType, filter, fReshapeIndices);
-    Value newOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
-        loc, newOutputType, output, ioReshapeIndices);
-
-    // We need to shrink the strides and dilations too.
-    auto stride = convOp.strides().getValues<int64_t>()[removeH ? 1 : 0];
-    auto stridesAttr = rewriter.getI64VectorAttr(stride);
-    auto dilation = convOp.dilations().getValues<int64_t>()[removeH ? 1 : 0];
-    auto dilationsAttr = rewriter.getI64VectorAttr(dilation);
+    Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+        rewriter, loc, input, newInputType);
+    Value newFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+        rewriter, loc, filter, newFilterType);
+    Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+        rewriter, loc, output, newOutputType);
+
+    // Rank-reduce strides and dilations too.
+    // TODO: dropDim 1-liner helper.
+    auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
+    strides.erase(strides.begin() + (removeH ? 0 : 1));
+    auto stridesAttr = rewriter.getI64VectorAttr(strides);
+
+    auto dilations =
+        llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
+    dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+    auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
 
     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
         loc, newOutputType, ValueRange{newInput, newFilter},
         ValueRange{newOutput}, stridesAttr, dilationsAttr);
 
-    rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
-        convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices);
+    // Insert back.
+    Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+        rewriter, loc, conv1DOp.getResult(0), output);
+    rewriter.replaceOp(convOp, inserted);
+
     return success();
   };
 };

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1d8d8e2d50688..0eff26bc86de4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1070,6 +1070,27 @@ OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
   return OpFoldResult();
 }
 
+Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
+    OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
+  auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
+  unsigned rank = rankedTensorType.getRank();
+  auto shape = rankedTensorType.getShape();
+  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
+  SmallVector<OpFoldResult> sizes;
+  for (unsigned i = 0, e = rank; i < e; ++i) {
+    OpFoldResult dim;
+    if (rankedTensorType.isDynamicDim(i))
+      dim = b.createOrFold<tensor::DimOp>(
+          loc, tensor, b.create<arith::ConstantIndexOp>(loc, i));
+    else
+      dim = b.getIndexAttr(shape[i]);
+    sizes.push_back(dim);
+  }
+  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
+  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
+                                                offsets, sizes, strides);
+}
+
 //===----------------------------------------------------------------------===//
 // InsertSliceOp
 //===----------------------------------------------------------------------===//
@@ -1309,6 +1330,29 @@ void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
               InsertSliceOpSourceCastInserter>(context);
 }
 
+Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
+                                                             Location loc,
+                                                             Value tensor,
+                                                             Value dest) {
+  auto rankedTensorType = dest.getType().cast<RankedTensorType>();
+  unsigned rank = rankedTensorType.getRank();
+  auto shape = rankedTensorType.getShape();
+  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
+  SmallVector<OpFoldResult> sizes;
+  for (unsigned i = 0, e = rank; i < e; ++i) {
+    OpFoldResult dim;
+    if (rankedTensorType.isDynamicDim(i))
+      dim = b.createOrFold<tensor::DimOp>(
+          loc, dest, b.create<arith::ConstantIndexOp>(loc, i));
+    else
+      dim = b.getIndexAttr(shape[i]);
+    sizes.push_back(dim);
+  }
+  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
+  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
+                                               sizes, strides);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
index ebd7dd6d4a2af..381ed1bd6080e 100644
--- a/mlir/test/Dialect/Linalg/decompose-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
@@ -10,21 +10,20 @@ func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x1x6x3xf32>, %filter: tensor<1x
   return %0 : tensor<4x1x2x8xf32>
 }
 
-//               CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
-// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<4x1x6x3xf32> into tensor<4x6x3xf32>
-//               CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
-// CHECK-SAME{LITERAL}:   [[0, 1], [2], [3]] : tensor<1x2x3x8xf32> into tensor<2x3x8xf32>
-//               CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
-// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<4x1x2x8xf32> into tensor<4x2x8xf32>
+//               CHECK: %[[INPUT_1D:.+]] = tensor.extract_slice %[[INPUT]]
+// CHECK-SAME{LITERAL}:   [0, 0, 0, 0] [4, 1, 6, 3] [1, 1, 1, 1] : tensor<4x1x6x3xf32> to tensor<4x6x3xf32>
+//               CHECK: %[[FILTER_1D:.+]] = tensor.extract_slice %[[FILTER]]
+// CHECK-SAME{LITERAL}:   [0, 0, 0, 0] [1, 2, 3, 8] [1, 1, 1, 1] : tensor<1x2x3x8xf32> to tensor<2x3x8xf32>
+//               CHECK: %[[INIT_1D:.+]] = tensor.extract_slice %[[INIT]]
+// CHECK-SAME{LITERAL}:   [0, 0, 0, 0] [4, 1, 2, 8] [1, 1, 1, 1] : tensor<4x1x2x8xf32> to tensor<4x2x8xf32>
 //               CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
 //          CHECK-SAME:     dilations = dense<3> : vector<1xi64>
 //          CHECK-SAME:     strides = dense<2> : vector<1xi64>
 //          CHECK-SAME:   ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<4x6x3xf32>, tensor<2x3x8xf32>)
 //          CHECK-SAME:   outs(%[[INIT_1D]] : tensor<4x2x8xf32>)
-//               CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
-// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32>
+//               CHECK: %[[CONV_2D:.+]] = tensor.insert_slice %[[CONV_1D]] into %[[INIT]]
+// CHECK-SAME{LITERAL}:   [0, 0, 0, 0] [4, 1, 2, 8] [1, 1, 1, 1] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32>
 //               CHECK: return %[[CONV_2D]]
-
 // -----
 
 // CHECK-LABEL: func @conv2d_nhwc_qxqx1xq_tensor
@@ -37,19 +36,23 @@ func @conv2d_nhwc_qxqx1xq_tensor(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x
   return %0 : tensor<?x?x1x?xf32>
 }
 
-//               CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
-// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
-//               CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
-// CHECK-SAME{LITERAL}:   [[0, 1], [2], [3]] : tensor<?x1x?x?xf32> into tensor<?x?x?xf32>
-//               CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
-// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
+//               CHECK: %[[INPUT_1D:.+]] = tensor.extract_slice %[[INPUT]]
+//          CHECK-SAME:   [0, 0, 0, 0] [%{{.*}}, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] :
+//          CHECK-SAME:     tensor<?x?x1x?xf32> to tensor<?x?x?xf32>
+//               CHECK: %[[FILTER_1D:.+]] = tensor.extract_slice %[[FILTER]]
+//          CHECK-SAME:   [0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] :
+//          CHECK-SAME:     tensor<?x1x?x?xf32> to tensor<?x?x?xf32>
+//               CHECK: %[[INIT_1D:.+]] = tensor.extract_slice %[[INIT]]
+//          CHECK-SAME:   [0, 0, 0, 0] [%{{.*}}, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] :
+//          CHECK-SAME:     tensor<?x?x1x?xf32> to tensor<?x?x?xf32>
 //               CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
 //          CHECK-SAME:     dilations = dense<2> : vector<1xi64>
 //          CHECK-SAME:     strides = dense<3> : vector<1xi64>
 //          CHECK-SAME:   ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
 //          CHECK-SAME:   outs(%[[INIT_1D]] : tensor<?x?x?xf32>)
-//               CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
-// CHECK-SAME{LITERAL}:   [[0], [1, 2], [3]] : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
+//               CHECK: %[[CONV_2D:.+]] = tensor.insert_slice %[[CONV_1D]] into %[[INIT]]
+//          CHECK-SAME:   [0, 0, 0, 0] [%{{.*}}, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] :
+//          CHECK-SAME:     tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
 //               CHECK: return %[[CONV_2D]]
 
 // -----


        


More information about the Mlir-commits mailing list