[Mlir-commits] [mlir] aa37318 - [mlir][Linalg] Rewrite DownscaleSizeOneWindowed2DConvolution to use rank-reducing insert/extract slices.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Nov 12 03:57:16 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-12T11:57:12Z
New Revision: aa3731806723a2a12914aecda2af6e40e1903702
URL: https://github.com/llvm/llvm-project/commit/aa3731806723a2a12914aecda2af6e40e1903702
DIFF: https://github.com/llvm/llvm-project/commit/aa3731806723a2a12914aecda2af6e40e1903702.diff
LOG: [mlir][Linalg] Rewrite DownscaleSizeOneWindowed2DConvolution to use rank-reducing insert/extract slices.
This rewriting enables better bufferization and canonicalizations.
Differential Revision: https://reviews.llvm.org/D113745
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/decompose-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 72e33d08dea33..f9e3014838a1f 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -85,6 +85,18 @@ bool canFoldIntoConsumerOp(CastOp castOp);
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult foldTensorCast(Operation *op);
+
+/// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
+/// appropriate sizes to reduce the rank of `tensor` to `targetType`.
+Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
+ Value tensor,
+ RankedTensorType targetType);
+
+/// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
+/// appropriate sizes to increase the rank of `tensor` to `dest`.
+Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
+ Value tensor, Value dest);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 0daabe25ea1cd..fb8dbd8a6c5be 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -202,6 +202,57 @@ class BaseMemRefType : public ShapedType {
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// RankedTensorType
+//===----------------------------------------------------------------------===//
+
+/// This is a builder type that keeps local references to arguments. Arguments
+/// that are passed into the builder must out-live the builder.
+class RankedTensorType::Builder {
+public:
+ /// Build from another RankedTensorType.
+ explicit Builder(RankedTensorType other)
+ : shape(other.getShape()), elementType(other.getElementType()),
+ encoding(other.getEncoding()) {}
+
+ /// Build from scratch.
+ Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
+ : shape(shape), elementType(elementType), encoding(encoding) {}
+
+ Builder &setShape(ArrayRef<int64_t> newShape) {
+ shape = newShape;
+ return *this;
+ }
+
+ Builder &setElementType(Type newElementType) {
+ elementType = newElementType;
+ return *this;
+ }
+
+ Builder &setEncoding(Attribute newEncoding) {
+ encoding = newEncoding;
+ return *this;
+ }
+
+ /// Create a new RankedTensor by erasing a dim from shape.
+ // Note: the newly created type has ownership of a new shape vector.
+ RankedTensorType dropDim(unsigned dim) {
+ SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
+ newShape.erase(newShape.begin() + dim);
+ return setShape(newShape);
+ }
+
+ operator RankedTensorType() {
+ return RankedTensorType::get(shape, elementType, encoding);
+ }
+
+private:
+ ArrayRef<int64_t> shape;
+ Type elementType;
+ Attribute encoding;
+};
+
//===----------------------------------------------------------------------===//
// MemRefType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b27086144761e..c05ff64805921 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -701,6 +701,11 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
return $_get(elementType.getContext(), shape, elementType, encoding);
}]>
];
+ let extraClassDeclaration = [{
+ /// This is a builder type that keeps local references to arguments.
+ /// Arguments that are passed into the builder must out-live the builder.
+ class Builder;
+ }];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index bcfca706900ed..f15337f0f70e6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -852,7 +852,6 @@ struct DownscaleSizeOneWindowed2DConvolution final
auto filterType = filter.getType().dyn_cast<RankedTensorType>();
auto outputType = output.getType().dyn_cast<RankedTensorType>();
- auto inputShape = inputType.getShape();
auto filterShape = filterType.getShape();
auto outputShape = outputType.getShape();
@@ -860,52 +859,47 @@ struct DownscaleSizeOneWindowed2DConvolution final
// of size 1. Other cases can rely on tiling to reduce to such cases.
int64_t fhSize = filterShape[0], fwSize = filterShape[1];
int64_t ohSize = outputShape[1], owSize = outputShape[2];
- if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1))
+ bool removeH = (fhSize == 1 && ohSize == 1);
+ bool removeW = (fwSize == 1 && owSize == 1);
+ if (!removeH && !removeW)
return failure();
- bool removeH = ohSize == 1;
// Get new shapes and types for all operands by removing the size-1
// dimension.
+ using RTTBuilder = RankedTensorType::Builder;
+ auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+ auto newFilterType = RTTBuilder(filterType).dropDim((removeH ? 0 : 1));
+ auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
- SmallVector<int64_t, 3> newInputShape{
- inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]};
- auto newInputType = RankedTensorType::get(
- newInputShape, inputType.getElementType(), inputType.getEncoding());
-
- SmallVector<int64_t, 3> newFilterShape{filterShape[removeH ? 1 : 0],
- filterShape[2], filterShape[3]};
- auto newFilterType = RankedTensorType::get(
- newFilterShape, filterType.getElementType(), filterType.getEncoding());
-
- SmallVector<int64_t, 3> newOutputShape{
- outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]};
- auto newOutputType = RankedTensorType::get(
- newOutputShape, outputType.getElementType(), outputType.getEncoding());
-
- SmallVector<ReassociationIndices, 3> ioReshapeIndices = {{0}, {1, 2}, {3}};
- SmallVector<ReassociationIndices, 3> fReshapeIndices = {{0, 1}, {2}, {3}};
-
- // Reshape all operands for 1-D convolution.
+ // Rank-reduce operands.
Location loc = convOp.getLoc();
- Value newInput = rewriter.create<linalg::TensorCollapseShapeOp>(
- loc, newInputType, input, ioReshapeIndices);
- Value newFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
- loc, newFilterType, filter, fReshapeIndices);
- Value newOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
- loc, newOutputType, output, ioReshapeIndices);
-
- // We need to shrink the strides and dilations too.
- auto stride = convOp.strides().getValues<int64_t>()[removeH ? 1 : 0];
- auto stridesAttr = rewriter.getI64VectorAttr(stride);
- auto dilation = convOp.dilations().getValues<int64_t>()[removeH ? 1 : 0];
- auto dilationsAttr = rewriter.getI64VectorAttr(dilation);
+ Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, input, newInputType);
+ Value newFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, filter, newFilterType);
+ Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, output, newOutputType);
+
+ // Rank-reduce strides and dilations too.
+ // TODO: dropDim 1-liner helper.
+ auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
+ strides.erase(strides.begin() + (removeH ? 0 : 1));
+ auto stridesAttr = rewriter.getI64VectorAttr(strides);
+
+ auto dilations =
+ llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
+ dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+ auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
loc, newOutputType, ValueRange{newInput, newFilter},
ValueRange{newOutput}, stridesAttr, dilationsAttr);
- rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
- convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices);
+ // Insert back.
+ Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+ rewriter, loc, conv1DOp.getResult(0), output);
+ rewriter.replaceOp(convOp, inserted);
+
return success();
};
};
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1d8d8e2d50688..0eff26bc86de4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1070,6 +1070,27 @@ OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
return OpFoldResult();
}
+Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
+ OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
+ auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
+ unsigned rank = rankedTensorType.getRank();
+ auto shape = rankedTensorType.getShape();
+ SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes;
+ for (unsigned i = 0, e = rank; i < e; ++i) {
+ OpFoldResult dim;
+ if (rankedTensorType.isDynamicDim(i))
+ dim = b.createOrFold<tensor::DimOp>(
+ loc, tensor, b.create<arith::ConstantIndexOp>(loc, i));
+ else
+ dim = b.getIndexAttr(shape[i]);
+ sizes.push_back(dim);
+ }
+ SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
+ return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
+ offsets, sizes, strides);
+}
+
//===----------------------------------------------------------------------===//
// InsertSliceOp
//===----------------------------------------------------------------------===//
@@ -1309,6 +1330,29 @@ void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
InsertSliceOpSourceCastInserter>(context);
}
+Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
+ Location loc,
+ Value tensor,
+ Value dest) {
+ auto rankedTensorType = dest.getType().cast<RankedTensorType>();
+ unsigned rank = rankedTensorType.getRank();
+ auto shape = rankedTensorType.getShape();
+ SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes;
+ for (unsigned i = 0, e = rank; i < e; ++i) {
+ OpFoldResult dim;
+ if (rankedTensorType.isDynamicDim(i))
+ dim = b.createOrFold<tensor::DimOp>(
+ loc, dest, b.create<arith::ConstantIndexOp>(loc, i));
+ else
+ dim = b.getIndexAttr(shape[i]);
+ sizes.push_back(dim);
+ }
+ SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
+ return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
+ sizes, strides);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
index ebd7dd6d4a2af..381ed1bd6080e 100644
--- a/mlir/test/Dialect/Linalg/decompose-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
@@ -10,21 +10,20 @@ func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x1x6x3xf32>, %filter: tensor<1x
return %0 : tensor<4x1x2x8xf32>
}
-// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x6x3xf32> into tensor<4x6x3xf32>
-// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
-// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor<1x2x3x8xf32> into tensor<2x3x8xf32>
-// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x2x8xf32> into tensor<4x2x8xf32>
+// CHECK: %[[INPUT_1D:.+]] = tensor.extract_slice %[[INPUT]]
+// CHECK-SAME{LITERAL}: [0, 0, 0, 0] [4, 1, 6, 3] [1, 1, 1, 1] : tensor<4x1x6x3xf32> to tensor<4x6x3xf32>
+// CHECK: %[[FILTER_1D:.+]] = tensor.extract_slice %[[FILTER]]
+// CHECK-SAME{LITERAL}: [0, 0, 0, 0] [1, 2, 3, 8] [1, 1, 1, 1] : tensor<1x2x3x8xf32> to tensor<2x3x8xf32>
+// CHECK: %[[INIT_1D:.+]] = tensor.extract_slice %[[INIT]]
+// CHECK-SAME{LITERAL}: [0, 0, 0, 0] [4, 1, 2, 8] [1, 1, 1, 1] : tensor<4x1x2x8xf32> to tensor<4x2x8xf32>
// CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
// CHECK-SAME: dilations = dense<3> : vector<1xi64>
// CHECK-SAME: strides = dense<2> : vector<1xi64>
// CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<4x6x3xf32>, tensor<2x3x8xf32>)
// CHECK-SAME: outs(%[[INIT_1D]] : tensor<4x2x8xf32>)
-// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32>
+// CHECK: %[[CONV_2D:.+]] = tensor.insert_slice %[[CONV_1D]] into %[[INIT]]
+// CHECK-SAME{LITERAL}: [0, 0, 0, 0] [4, 1, 2, 8] [1, 1, 1, 1] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32>
// CHECK: return %[[CONV_2D]]
-
// -----
// CHECK-LABEL: func @conv2d_nhwc_qxqx1xq_tensor
@@ -37,19 +36,23 @@ func @conv2d_nhwc_qxqx1xq_tensor(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x
return %0 : tensor<?x?x1x?xf32>
}
-// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
-// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
-// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor<?x1x?x?xf32> into tensor<?x?x?xf32>
-// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<?x?x1x?xf32> into tensor<?x?x?xf32>
+// CHECK: %[[INPUT_1D:.+]] = tensor.extract_slice %[[INPUT]]
+// CHECK-SAME: [0, 0, 0, 0] [%{{.*}}, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] :
+// CHECK-SAME: tensor<?x?x1x?xf32> to tensor<?x?x?xf32>
+// CHECK: %[[FILTER_1D:.+]] = tensor.extract_slice %[[FILTER]]
+// CHECK-SAME: [0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] :
+// CHECK-SAME: tensor<?x1x?x?xf32> to tensor<?x?x?xf32>
+// CHECK: %[[INIT_1D:.+]] = tensor.extract_slice %[[INIT]]
+// CHECK-SAME: [0, 0, 0, 0] [%{{.*}}, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] :
+// CHECK-SAME: tensor<?x?x1x?xf32> to tensor<?x?x?xf32>
// CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf
// CHECK-SAME: dilations = dense<2> : vector<1xi64>
// CHECK-SAME: strides = dense<3> : vector<1xi64>
// CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SAME: outs(%[[INIT_1D]] : tensor<?x?x?xf32>)
-// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
+// CHECK: %[[CONV_2D:.+]] = tensor.insert_slice %[[CONV_1D]] into %[[INIT]]
+// CHECK-SAME: [0, 0, 0, 0] [%{{.*}}, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] :
+// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
// CHECK: return %[[CONV_2D]]
// -----
More information about the Mlir-commits
mailing list