[Mlir-commits] [mlir] e3de249 - [mlir] Add a subtensor operation
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 2 02:37:46 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-02T05:35:30-04:00
New Revision: e3de249a4c94d6962b36c2b4747c134d152bed37
URL: https://github.com/llvm/llvm-project/commit/e3de249a4c94d6962b36c2b4747c134d152bed37
DIFF: https://github.com/llvm/llvm-project/commit/e3de249a4c94d6962b36c2b4747c134d152bed37.diff
LOG: [mlir] Add a subtensor operation
This revision introduces a `subtensor` op, which is the counterpart of `subview` for a tensor operand. This also refactors the relevant pieces to allow reusing the `subview` implementation where appropriate.
This operation will be used to implement tiling for Linalg on tensors.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 76ce4eb30e7f..b4e5be58bad7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -185,7 +185,7 @@ struct ProcInfo {
Value nprocs;
};
using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>(
- OpBuilder &b, Location loc, ArrayRef<SubViewOp::Range> parallelLoopRanges)>;
+ OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>;
/// Options that allow distribution of loops generated in Linalg transforms to
/// processors while generating the loops.
@@ -216,7 +216,7 @@ struct GenerateLoopNest {
AffineIndexedValue, StdIndexedValue>::type;
static void
- doit(ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+ doit(ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Attribute> iteratorTypes,
function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
Optional<LinalgLoopDistributionOptions> = None);
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 2500343c0af3..fbe735e31cff 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -33,6 +33,17 @@ class Builder;
class FuncOp;
class OpBuilder;
+/// Auxiliary range data structure to unpack the offset, size and stride
+/// operands of the SubViewOp / SubTensorOp into a list of triples.
+/// Such a list of triple is sometimes more convenient to manipulate.
+struct Range {
+ Value offset;
+ Value size;
+ Value stride;
+};
+
+raw_ostream &operator<<(raw_ostream &os, Range &range);
+
#define GET_OP_CLASSES
#include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
@@ -300,8 +311,6 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
SmallVectorImpl<Value> &operands,
unsigned &numDims);
-raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
-
/// Determines whether MemRefCastOp casts to a more dynamic version of the
/// source memref. This is useful to to fold a memref_cast into a consuming op
/// and implement canonicalization patterns for ops in
diff erent dialects that
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index ff1a82c26561..dbc3e4ca521b 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2706,11 +2706,214 @@ def SubIOp : IntArithmeticOp<"subi"> {
// SubViewOp
//===----------------------------------------------------------------------===//
-def SubViewOp : Std_Op<"subview", [
- AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<ViewLikeOpInterface>,
- NoSideEffect,
- ]> {
+class BaseOpWithOffsetSizesAndStrides<string mnemonic, list<OpTrait> traits = []> :
+ Std_Op<mnemonic,
+ !listconcat(traits, [NoSideEffect, AttrSizedOperandSegments])> {
+ let builders = [
+ // Build a SubViewOp with mixed static and dynamic entries.
+ OpBuilder<
+ "Value source, ArrayRef<int64_t> staticOffsets, "
+ "ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
+ "ValueRange offsets, ValueRange sizes, ValueRange strides, "
+ "ArrayRef<NamedAttribute> attrs = {}">,
+ // Build a SubViewOp with all dynamic entries.
+ OpBuilder<
+ "Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, "
+ "ArrayRef<NamedAttribute> attrs = {}">
+ ];
+
+ code extraBaseClassDeclaration = [{
+ /// Returns the number of dynamic offset operands.
+ int64_t getNumOffsets() { return llvm::size(offsets()); }
+
+ /// Returns the number of dynamic size operands.
+ int64_t getNumSizes() { return llvm::size(sizes()); }
+
+ /// Returns the number of dynamic stride operands.
+ int64_t getNumStrides() { return llvm::size(strides()); }
+
+ /// Returns the dynamic sizes for this subview operation if specified.
+ operand_range getDynamicSizes() { return sizes(); }
+
+ /// Returns in `staticStrides` the static value of the stride
+ /// operands. Returns failure() if the static value of the stride
+ /// operands could not be retrieved.
+ LogicalResult getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
+ if (!strides().empty())
+ return failure();
+ staticStrides.reserve(static_strides().size());
+ for (auto s : static_strides().getAsValueRange<IntegerAttr>())
+ staticStrides.push_back(s.getZExtValue());
+ return success();
+ }
+
+ /// Return the list of Range (i.e. offset, size, stride). Each
+ /// Range entry contains either the dynamic value or a ConstantIndexOp
+ /// constructed with `b` at location `loc`.
+ SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
+
+ /// Return the offsets as Values. Each Value is either the dynamic
+ /// value specified in the op or a ConstantIndexOp constructed
+ /// with `b` at location `loc`
+ SmallVector<Value, 4> getOrCreateOffsets(OpBuilder &b, Location loc) {
+ unsigned dynamicIdx = 1;
+ return llvm::to_vector<4>(llvm::map_range(
+ static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
+ int64_t staticOffset = a.cast<IntegerAttr>().getInt();
+ if (ShapedType::isDynamicStrideOrOffset(staticOffset))
+ return getOperand(dynamicIdx++);
+ else
+ return b.create<ConstantOp>(
+ loc, b.getIndexType(), b.getIndexAttr(staticOffset));
+ }));
+ }
+
+ /// Return the sizes as Values. Each Value is either the dynamic
+ /// value specified in the op or a ConstantIndexOp constructed
+ /// with `b` at location `loc`
+ SmallVector<Value, 4> getOrCreateSizes(OpBuilder &b, Location loc) {
+ unsigned dynamicIdx = 1 + offsets().size();
+ return llvm::to_vector<4>(llvm::map_range(
+ static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
+ int64_t staticSize = a.cast<IntegerAttr>().getInt();
+ if (ShapedType::isDynamic(staticSize))
+ return getOperand(dynamicIdx++);
+ else
+ return b.create<ConstantOp>(
+ loc, b.getIndexType(), b.getIndexAttr(staticSize));
+ }));
+ }
+
+ /// Return the strides as Values. Each Value is either the dynamic
+ /// value specified in the op or a ConstantIndexOp constructed with
+ /// `b` at location `loc`
+ SmallVector<Value, 4> getOrCreateStrides(OpBuilder &b, Location loc) {
+ unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
+ return llvm::to_vector<4>(llvm::map_range(
+ static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
+ int64_t staticStride = a.cast<IntegerAttr>().getInt();
+ if (ShapedType::isDynamicStrideOrOffset(staticStride))
+ return getOperand(dynamicIdx++);
+ else
+ return b.create<ConstantOp>(
+ loc, b.getIndexType(), b.getIndexAttr(staticStride));
+ }));
+ }
+
+ /// Return the rank of the source ShapedType.
+ unsigned getSourceRank() {
+ return source().getType().cast<ShapedType>().getRank();
+ }
+
+ /// Return the rank of the result ShapedType.
+ unsigned getResultRank() { return getType().getRank(); }
+
+ /// Return true if the offset `idx` is a static constant.
+ bool isDynamicOffset(unsigned idx) {
+ APInt v = *(static_offsets().getAsValueRange<IntegerAttr>().begin() + idx);
+ return ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
+ }
+ /// Return true if the size `idx` is a static constant.
+ bool isDynamicSize(unsigned idx) {
+ APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
+ return ShapedType::isDynamic(v.getSExtValue());
+ }
+
+ /// Return true if the stride `idx` is a static constant.
+ bool isDynamicStride(unsigned idx) {
+ APInt v = *(static_strides().getAsValueRange<IntegerAttr>().begin() + idx);
+ return ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
+ }
+
+ /// Assert the offset `idx` is a static constant and return its value.
+ int64_t getStaticOffset(unsigned idx) {
+ assert(!isDynamicOffset(idx) && "expected static offset");
+ APInt v = *(static_offsets().getAsValueRange<IntegerAttr>().begin() + idx);
+ return v.getSExtValue();
+ }
+ /// Assert the size `idx` is a static constant and return its value.
+ int64_t getStaticSize(unsigned idx) {
+ assert(!isDynamicSize(idx) && "expected static size");
+ APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
+ return v.getSExtValue();
+ }
+ /// Assert the stride `idx` is a static constant and return its value.
+ int64_t getStaticStride(unsigned idx) {
+ assert(!isDynamicStride(idx) && "expected static stride");
+ APInt v = *(static_strides().getAsValueRange<IntegerAttr>().begin() + idx);
+ return v.getSExtValue();
+ }
+
+ unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr,
+ llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
+ return std::count_if(
+ attr.getValue().begin(), attr.getValue().begin() + idx,
+ [&](Attribute attr) {
+ return isDynamic(attr.cast<IntegerAttr>().getInt());
+ });
+ }
+ /// Assert the offset `idx` is dynamic and return the position of the
+ /// corresponding operand.
+ unsigned getIndexOfDynamicOffset(unsigned idx) {
+ assert(isDynamicOffset(idx) && "expected static offset");
+ auto numDynamic =
+ getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(),
+ ShapedType::isDynamicStrideOrOffset, idx);
+ return 1 + numDynamic;
+ }
+ /// Assert the size `idx` is dynamic and return the position of the
+ /// corresponding operand.
+ unsigned getIndexOfDynamicSize(unsigned idx) {
+ assert(isDynamicSize(idx) && "expected static size");
+ auto numDynamic = getNumDynamicEntriesUpToIdx(
+ static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx);
+ return 1 + offsets().size() + numDynamic;
+ }
+ /// Assert the stride `idx` is dynamic and return the position of the
+ /// corresponding operand.
+ unsigned getIndexOfDynamicStride(unsigned idx) {
+ assert(isDynamicStride(idx) && "expected static stride");
+ auto numDynamic =
+ getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(),
+ ShapedType::isDynamicStrideOrOffset, idx);
+ return 1 + offsets().size() + sizes().size() + numDynamic;
+ }
+
+ /// Assert the offset `idx` is dynamic and return its value.
+ Value getDynamicOffset(unsigned idx) {
+ return getOperand(getIndexOfDynamicOffset(idx));
+ }
+ /// Assert the size `idx` is dynamic and return its value.
+ Value getDynamicSize(unsigned idx) {
+ return getOperand(getIndexOfDynamicSize(idx));
+ }
+ /// Assert the stride `idx` is dynamic and return its value.
+ Value getDynamicStride(unsigned idx) {
+ return getOperand(getIndexOfDynamicStride(idx));
+ }
+
+ static StringRef getStaticOffsetsAttrName() {
+ return "static_offsets";
+ }
+ static StringRef getStaticSizesAttrName() {
+ return "static_sizes";
+ }
+ static StringRef getStaticStridesAttrName() {
+ return "static_strides";
+ }
+ static ArrayRef<StringRef> getSpecialAttrNames() {
+ static SmallVector<StringRef, 4> names{
+ getStaticOffsetsAttrName(),
+ getStaticSizesAttrName(),
+ getStaticStridesAttrName(),
+ getOperandSegmentSizeAttr()};
+ return names;
+ }
+ }];
+}
+
+def SubViewOp : BaseOpWithOffsetSizesAndStrides<
+ "subview", [DeclareOpInterfaceMethods<ViewLikeOpInterface>] > {
let summary = "memref subview operation";
let description = [{
The "subview" operation converts a memref type to another memref type
@@ -2726,8 +2929,11 @@ def SubViewOp : Std_Op<"subview", [
* Sizes: memref-rank number of dynamic sizes or static integer attributes
which specify the sizes of the result "view" memref type.
* Strides: memref-rank number of dynamic strides or static integer
- attributes multiplicatively to the base memref strides in each
- dimension.
+ attributes that compose multiplicatively with the base memref
+ strides in each dimension.
+
+ A subview operation may additionally reduce the rank of the resulting view
+ by removing dimensions that are statically known to be of size 1.
Example 1:
@@ -2817,6 +3023,15 @@ def SubViewOp : Std_Op<"subview", [
// memref is "inbounds" w.r.t to base memref. It is upto the client
// to ensure that the subview is accessed in a manner that is
// in-bounds.
+
+ Example 5:
+
+ ```
+ // Rank-reducing subview.
+ %1 = subview %0[0, 0, 0][1, 16, 4][1, 1, 1] :
+ memref<8x16x4xf32> to memref<16x4xf32>
+ %3 = subview %2[3, 4, 2][1, 6, 3][1, 1, 1] :
+ memref<8x16x4xf32> to memref<6x3xf32, offset: 210, strides: [4, 1]>
```
}
}];
@@ -2859,137 +3074,97 @@ def SubViewOp : Std_Op<"subview", [
"ArrayRef<NamedAttribute> attrs = {}">
];
- let extraClassDeclaration = [{
+ let extraClassDeclaration = extraBaseClassDeclaration # [{
/// Returns the type of the base memref operand.
- MemRefType getBaseMemRefType() {
+ MemRefType getSourceMemRefType() {
return source().getType().cast<MemRefType>();
}
/// The result of a subview is always a memref.
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
- /// Returns as integer value the number of offset operands.
- int64_t getNumOffsets() { return llvm::size(offsets()); }
+ /// A subview result type can be fully inferred from the source type and the
+ /// static representation of offsets, sizes and strides. Special sentinels
+ /// encode the dynamic case.
+ static Type inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides);
+ }];
- /// Returns as integer value the number of size operands.
- int64_t getNumSizes() { return llvm::size(sizes()); }
+ let hasCanonicalizer = 1;
+}
- /// Returns as integer value the number of stride operands.
- int64_t getNumStrides() { return llvm::size(strides()); }
+//===----------------------------------------------------------------------===//
+// SubTensorOp
+//===----------------------------------------------------------------------===//
- /// Returns the dynamic sizes for this subview operation if specified.
- operand_range getDynamicSizes() { return sizes(); }
+def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
+ let summary = "subtensor operation";
+ let description = [{
+ The "subtensor" operation extract a tensor from another tensor as
+ specified by the operation's offsets, sizes and strides arguments.
- /// Returns in `staticStrides` the static value of the stride
- /// operands. Returns failure() if the static value of the stride
- /// operands could not be retrieved.
- LogicalResult getStaticStrides(SmallVectorImpl<int64_t> &staticStrides);
-
- /// Auxiliary range data structure and helper function that unpacks the
- /// offset, size and stride operands of the SubViewOp into a list of triples.
- /// Such a list of triple is sometimes more convenient to manipulate.
- struct Range {
- Value offset, size, stride;
- };
- /// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each
- /// Range entry contains either the dynamic value or a ConstantIndexOp
- /// constructed with `b` at location `loc`.
- SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);
+ The subtensor operation supports the following arguments:
- /// Return the offsets as Values. Each Value is either the dynamic
- /// value specified in the op or a ConstantIndexOp constructed
- /// with `b` at location `loc`
- SmallVector<Value, 4> getOrCreateOffsets(OpBuilder &b, Location loc);
+ * tensor: the "base" tensor from which to extract a subtensor.
+ * offsets: tensor-rank number of dynamic offsets or static integer
+ attributes into the "base" tensor from which to extract the
+ subtensor.
+ * sizes: tensor-rank number of dynamic sizes or static integer attributes
+ which specify the sizes of the result tensor type.
+ * strides: tensor-rank number of dynamic strides or static integer
+ attributes specifying susampling in each dimension.
- /// Return the sizes as Values. Each Value is either the dynamic
- /// value specified in the op or a ConstantIndexOp constructed
- /// with `b` at location `loc`
- SmallVector<Value, 4> getOrCreateSizes(OpBuilder &b, Location loc);
+ After buffer-allocation, the "subtensor" op is expected to lower into a
+ "subview" op.
- /// Return the strides as Values. Each Value is either the dynamic
- /// value specified in the op or a ConstantIndexOp constructed with
- /// `b` at location `loc`
- SmallVector<Value, 4> getOrCreateStrides(OpBuilder &b, Location loc);
+ A subtensor operation may additionally reduce the rank of the resulting
+ tensor by removing dimensions that are statically known to be of size 1.
- /// A subview result type can be fully inferred from the source type and the
- /// static representation of offsets, sizes and strides. Special sentinels
- /// encode the dynamic case.
- static Type inferSubViewResultType(MemRefType sourceMemRefType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides);
+ Example:
- /// Return the rank of the result MemRefType.
- unsigned getRank() { return getType().getRank(); }
+ ```
+ // Rank-reducing subtensor.
+ %1 = subtensor %0[0, 0, 0][1, 16, 4][1, 1, 1] :
+ tensor<8x16x4xf32> to tensor<16x4xf32>
+ %3 = subtensor %2[3, 4, 2][1, 6, 3][1, 1, 1] :
+ tensor<8x16x4xf32> to tensor<6x3xf32>
+ ```
+ }];
- /// Return true if the offset `idx` is a static constant.
- bool isDynamicOffset(unsigned idx);
- /// Return true if the size `idx` is a static constant.
- bool isDynamicSize(unsigned idx);
- /// Return true if the stride `idx` is a static constant.
- bool isDynamicStride(unsigned idx);
+ let arguments = (ins
+ AnyRankedTensor:$source,
+ Variadic<Index>:$offsets,
+ Variadic<Index>:$sizes,
+ Variadic<Index>:$strides,
+ I64ArrayAttr:$static_offsets,
+ I64ArrayAttr:$static_sizes,
+ I64ArrayAttr:$static_strides
+ );
+ let results = (outs AnyRankedTensor:$result);
- /// Assert the offset `idx` is a static constant and return its value.
- int64_t getStaticOffset(unsigned idx) {
- assert(!isDynamicOffset(idx) && "expected static offset");
- return
- static_offsets().cast<ArrayAttr>()[idx].cast<IntegerAttr>().getInt();
- }
- /// Assert the size `idx` is a static constant and return its value.
- int64_t getStaticSize(unsigned idx) {
- assert(!isDynamicSize(idx) && "expected static size");
- return static_sizes().cast<ArrayAttr>()[idx].cast<IntegerAttr>().getInt();
- }
- /// Assert the stride `idx` is a static constant and return its value.
- int64_t getStaticStride(unsigned idx) {
- assert(!isDynamicStride(idx) && "expected static stride");
- return
- static_strides().cast<ArrayAttr>()[idx].cast<IntegerAttr>().getInt();
+ let extraClassDeclaration = extraBaseClassDeclaration # [{
+ /// Returns the type of the base tensor operand.
+ RankedTensorType getSourceRankedTensorType() {
+ return source().getType().cast<RankedTensorType>();
}
- /// Assert the offset `idx` is dynamic and return the position of the
- /// corresponding operand.
- unsigned getIndexOfDynamicOffset(unsigned idx);
- /// Assert the size `idx` is dynamic and return the position of the
- /// corresponding operand.
- unsigned getIndexOfDynamicSize(unsigned idx);
- /// Assert the stride `idx` is dynamic and return the position of the
- /// corresponding operand.
- unsigned getIndexOfDynamicStride(unsigned idx);
-
- /// Assert the offset `idx` is dynamic and return its value.
- Value getDynamicOffset(unsigned idx) {
- return getOperand(getIndexOfDynamicOffset(idx));
- }
- /// Assert the size `idx` is dynamic and return its value.
- Value getDynamicSize(unsigned idx) {
- return getOperand(getIndexOfDynamicSize(idx));
- }
- /// Assert the stride `idx` is dynamic and return its value.
- Value getDynamicStride(unsigned idx) {
- return getOperand(getIndexOfDynamicStride(idx));
+ /// The result of a subtensor is always a tensor.
+ RankedTensorType getType() {
+ return getResult().getType().cast<RankedTensorType>();
}
- static StringRef getStaticOffsetsAttrName() {
- return "static_offsets";
- }
- static StringRef getStaticSizesAttrName() {
- return "static_sizes";
- }
- static StringRef getStaticStridesAttrName() {
- return "static_strides";
- }
- static ArrayRef<StringRef> getSpecialAttrNames() {
- static SmallVector<StringRef, 4> names{
- getStaticOffsetsAttrName(),
- getStaticSizesAttrName(),
- getStaticStridesAttrName(),
- getOperandSegmentSizeAttr()};
- return names;
- }
+ /// A subview result type can be fully inferred from the source type and the
+ /// static representation of offsets, sizes and strides. Special sentinels
+ /// encode the dynamic case.
+ static Type inferResultType(RankedTensorType sourceRankedTensorType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides);
}];
- let hasCanonicalizer = 1;
+ // let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index c964c2466d5c..7b16a9197f11 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -60,7 +60,7 @@ using llvm::dbgs;
// This is achieved by applying the `loopToOperandRangesMaps` permutation maps
// to the `loopRanges` in order to obtain view ranges.
static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
- ArrayRef<SubViewOp::Range> loopRanges) {
+ ArrayRef<Range> loopRanges) {
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto maps = op.indexing_maps();
SmallVector<Value, 8> clonedViews;
@@ -73,7 +73,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
auto map = maps[idx].cast<AffineMapAttr>().getValue();
LLVM_DEBUG(dbgs() << "map: " << map << "\n");
Value view = en.value();
- SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
+ SmallVector<Range, 4> viewRanges(map.getNumResults());
for (auto en2 : llvm::enumerate(map.getResults())) {
unsigned d = en2.index();
// loopToOperandRangesMaps are permutations-only.
@@ -182,7 +182,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
unsigned nWin = producer.getNumWindowLoops();
- SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
+ SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
// Iterate over dimensions identified by the producer map for `producerIdx`.
// This defines a subset of the loop ranges that we need to complete later.
@@ -202,9 +202,9 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
<< "existing LoopRange: " << loopRanges[i] << "\n");
else {
auto viewDim = getViewDefiningLoopRange(producer, i);
- loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0),
- std_dim(viewDim.view, viewDim.dimension),
- folded_std_constant_index(folder, 1)};
+ loopRanges[i] = Range{folded_std_constant_index(folder, 0),
+ std_dim(viewDim.view, viewDim.dimension),
+ folded_std_constant_index(folder, 1)};
LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
}
}
@@ -300,8 +300,6 @@ static bool isSameSubView(Value a, Value b) {
return false;
if (sva.getType() != svb.getType())
return false;
- if (sva.getRank() != svb.getRank())
- return false;
if (sva.getNumOperands() != svb.getNumOperands())
return false;
if (sva.static_offsets() != svb.static_offsets())
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index eb452cc40305..a9e7a8660230 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -65,22 +65,21 @@ static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr).
/// It expects a non-inverted, concatenated map and last values in
/// allViewSizes will be applied to the symbols in the map if it contains any.
-static SmallVector<SubViewOp::Range, 4> emitLoopRanges(OpBuilder &b,
- Location loc,
- AffineMap map,
- ValueRange viewSizes) {
+static SmallVector<Range, 4> emitLoopRanges(OpBuilder &b, Location loc,
+ AffineMap map,
+ ValueRange viewSizes) {
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
unsigned numSym = map.getNumSymbols();
assert(viewSizes.size() == numRes + numSym &&
"viewSizes must contain sizes of all views and values for symbols");
- SmallVector<SubViewOp::Range, 4> res(numDims);
+ SmallVector<Range, 4> res(numDims);
for (unsigned idx = 0; idx < numRes; ++idx) {
auto result = map.getResult(idx);
if (auto d = result.dyn_cast<AffineDimExpr>()) {
if (res[d.getPosition()].offset)
continue;
- res[d.getPosition()] = SubViewOp::Range{
- std_constant_index(0), viewSizes[idx], std_constant_index(1)};
+ res[d.getPosition()] =
+ Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)};
}
// If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
@@ -124,7 +123,7 @@ static SmallVector<SubViewOp::Range, 4> emitLoopRanges(OpBuilder &b,
// Construction of the lower bound (s floordiv 2).
Value from = applyMapToValues(b, loc, fromMap, values).front();
Value to = applyMapToValues(b, loc, toMap, values).front();
- res[mPos] = SubViewOp::Range{from, to, std_constant_index(1)};
+ res[mPos] = Range{from, to, std_constant_index(1)};
}
}
return res;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 68d69549611c..3e8e0b74c145 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -54,7 +54,7 @@ using LoopIndexToRangeIndexMap = DenseMap<int, int>;
// are tiled and for which new loops will be created. Also the function returns
// a map from loop indices of the LinalgOp to the corresponding non-empty range
// indices of newly created loops.
-static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap>
+static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<Value> allViewSizes,
ArrayRef<Value> allTileSizes) {
@@ -76,10 +76,9 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
}
// Create a new range with the applied tile sizes.
- SmallVector<SubViewOp::Range, 4> res;
+ SmallVector<Range, 4> res;
for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
- res.push_back(SubViewOp::Range{std_constant_index(0), viewSizes[idx],
- tileSizes[idx]});
+ res.push_back(Range{std_constant_index(0), viewSizes[idx], tileSizes[idx]});
return std::make_tuple(res, loopIndexToRangeIndex);
}
@@ -346,7 +345,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
if (!viewSizesToLoopsMap)
return llvm::None;
- SmallVector<SubViewOp::Range, 4> loopRanges;
+ SmallVector<Range, 4> loopRanges;
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
b, op.getLoc(), viewSizesToLoopsMap, allViewSizes, tileSizes);
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 204716b40746..f9ea9092d55d 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -133,11 +133,10 @@ template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
/// Given a list of subview ranges, extract individual values for lower, upper
/// bounds and steps and put them into the corresponding vectors.
-static void unpackRanges(ArrayRef<SubViewOp::Range> ranges,
- SmallVectorImpl<Value> &lbs,
+static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
SmallVectorImpl<Value> &ubs,
SmallVectorImpl<Value> &steps) {
- for (SubViewOp::Range range : ranges) {
+ for (Range range : ranges) {
lbs.emplace_back(range.offset);
ubs.emplace_back(range.size);
steps.emplace_back(range.stride);
@@ -194,7 +193,7 @@ getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
/// Specialization to build an scf "for" nest.
template <>
void GenerateLoopNest<scf::ForOp>::doit(
- ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+ ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Attribute> iteratorTypes,
function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
Optional<LinalgLoopDistributionOptions>) {
@@ -206,7 +205,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
/// Specialization to build affine "for" nest.
template <>
void GenerateLoopNest<AffineForOp>::doit(
- ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+ ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Attribute> iteratorTypes,
function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
Optional<LinalgLoopDistributionOptions>) {
@@ -364,7 +363,7 @@ generateParallelLoopNest(ValueRange lbs, ValueRange ubs, ValueRange steps,
/// Specialization for generating a mix of parallel and sequential scf loops.
template <>
void GenerateLoopNest<scf::ParallelOp>::doit(
- ArrayRef<SubViewOp::Range> loopRanges, ValueRange iterArgInitValues,
+ ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
ArrayRef<Attribute> iteratorTypes,
function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
Optional<LinalgLoopDistributionOptions> distributionOptions) {
@@ -391,7 +390,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
Location loc = edsc::ScopedContext::getLocation();
distributionMethod.assign(distributionOptions->distributionMethod.begin(),
distributionOptions->distributionMethod.end());
- SmallVector<SubViewOp::Range, 2> parallelLoopRanges;
+ SmallVector<Range, 2> parallelLoopRanges;
for (auto iteratorType : enumerate(iteratorTypes)) {
if (isParallelIteratorType(iteratorType.value()))
parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1cabf172b7fc..d684a4b98e55 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2587,10 +2587,10 @@ Wrapper operator*(Wrapper a, int64_t b) {
/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
-Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides) {
+Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides) {
unsigned rank = sourceMemRefType.getRank();
(void)rank;
assert(staticOffsets.size() == rank &&
@@ -2638,7 +2638,8 @@ Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType,
/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` strided-memref-type `to` strided-memref-type
/// ```
-static void print(OpAsmPrinter &p, SubViewOp op) {
+template <typename OpType>
+static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
p << op.getOperand(0);
@@ -2649,16 +2650,22 @@ static void print(OpAsmPrinter &p, SubViewOp op) {
printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
ShapedType::isDynamicStrideOrOffset);
p.printOptionalAttrDict(op.getAttrs(),
- /*elidedAttrs=*/{SubViewOp::getSpecialAttrNames()});
+ /*elidedAttrs=*/{OpType::getSpecialAttrNames()});
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
}
+static void print(OpAsmPrinter &p, SubViewOp op) {
+ return printOpWithOffsetsSizesAndStrides<SubViewOp>(p, op);
+}
+
/// Parse SubViewOp of the form:
/// ```
-/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+/// `name` ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` strided-memref-type `to` strided-memref-type
/// ```
-static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
+template <typename OpType>
+static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
+ OperationState &result) {
OpAsmParser::OperandType srcInfo;
SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
auto indexType = parser.getBuilder().getIndexType();
@@ -2666,13 +2673,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
if (parser.parseOperand(srcInfo))
return failure();
if (parseListOfOperandsOrIntegers(
- parser, result, SubViewOp::getStaticOffsetsAttrName(),
+ parser, result, OpType::getStaticOffsetsAttrName(),
ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
parseListOfOperandsOrIntegers(parser, result,
- SubViewOp::getStaticSizesAttrName(),
+ OpType::getStaticSizesAttrName(),
ShapedType::kDynamicSize, sizesInfo) ||
parseListOfOperandsOrIntegers(
- parser, result, SubViewOp::getStaticStridesAttrName(),
+ parser, result, OpType::getStaticStridesAttrName(),
ShapedType::kDynamicStrideOrOffset, stridesInfo))
return failure();
@@ -2680,7 +2687,7 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
static_cast<int>(sizesInfo.size()),
static_cast<int>(stridesInfo.size())};
- result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(),
+ result.addAttribute(OpType::getOperandSegmentSizeAttr(),
b.getI32VectorAttr(segmentSizes));
return failure(
@@ -2694,6 +2701,10 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
parser.addTypeToList(dstType, result.types));
}
+static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
+ return parseOpWithOffsetsSizesAndStrides<SubViewOp>(parser, result);
+}
+
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
@@ -2701,8 +2712,8 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) {
auto sourceMemRefType = source.getType().cast<MemRefType>();
- auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets,
- staticSizes, staticStrides);
+ auto resultType = inferResultType(sourceMemRefType, staticOffsets,
+ staticSizes, staticStrides);
build(b, result, resultType, source, offsets, sizes, strides,
b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
b.getI64ArrayAttr(staticStrides));
@@ -2760,15 +2771,18 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
staticStridesVector, offsets, sizes, strides, attrs);
}
+/// For ViewLikeOpInterface.
+Value SubViewOp::getViewSource() { return source(); }
+
/// Verify that a particular offset/size/stride static attribute is well-formed.
-static LogicalResult
-verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName,
- ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
- ValueRange values) {
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
+ OpType op, StringRef name, StringRef attrName, ArrayAttr attr,
+ llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
/// Check static and dynamic offsets/sizes/strides breakdown.
- size_t inputRank = op.source().getType().cast<MemRefType>().getRank();
- if (attr.size() != inputRank)
- return op.emitError("expected ") << inputRank << " " << name << " values";
+ if (attr.size() != op.getSourceRank())
+ return op.emitError("expected ")
+ << op.getSourceRank() << " " << name << " values";
unsigned expectedNumDynamicEntries =
llvm::count_if(attr.getValue(), [&](Attribute attr) {
return isDynamic(attr.cast<IntegerAttr>().getInt());
@@ -2787,17 +2801,26 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
}));
}
-/// Checks if `original` MemRef type can be rank reduced to `reduced` type.
+/// Checks if `original` Type type can be rank reduced to `reduced` type.
/// This function is slight variant of `is subsequence` algorithm where
/// not matching dimension must be 1.
static bool isRankReducedType(Type originalType, Type reducedType) {
if (originalType == reducedType)
return true;
+ if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
+ return true;
+ if (originalType.isa<RankedTensorType>() &&
+ !reducedType.isa<RankedTensorType>())
+ return true;
+ if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
+ return true;
- MemRefType original = originalType.cast<MemRefType>();
- MemRefType reduced = reducedType.cast<MemRefType>();
- ArrayRef<int64_t> originalShape = original.getShape();
- ArrayRef<int64_t> reducedShape = reduced.getShape();
+ ShapedType originalShapedType = originalType.cast<ShapedType>();
+ ShapedType reducedShapedType = reducedType.cast<ShapedType>();
+
+ // Rank and size logic is valid for all ShapedTypes.
+ ArrayRef<int64_t> originalShape = originalShapedType.getShape();
+ ArrayRef<int64_t> reducedShape = reducedShapedType.getShape();
unsigned originalRank = originalShape.size(),
reducedRank = reducedShape.size();
if (reducedRank > originalRank)
@@ -2819,6 +2842,13 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
if (reducedIdx != reducedRank)
return false;
+ // We are done for the tensor case.
+ if (originalType.isa<RankedTensorType>())
+ return true;
+
+ // Strided layout logic is relevant for MemRefType only.
+ MemRefType original = originalType.cast<MemRefType>();
+ MemRefType reduced = reducedType.cast<MemRefType>();
MLIRContext *c = original.getContext();
int64_t originalOffset, symCounter = 0, dimCounter = 0;
SmallVector<int64_t, 4> originalStrides;
@@ -2843,10 +2873,29 @@ static bool isRankReducedType(Type originalType, Type reducedType) {
reducedMap == reduced.getAffineMaps().front());
}
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
+ // Verify static attributes offsets/sizes/strides.
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset, op.offsets())))
+ return failure();
+
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "size", op.getStaticSizesAttrName(), op.static_sizes(),
+ ShapedType::isDynamic, op.sizes())))
+ return failure();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
+ ShapedType::isDynamicStrideOrOffset, op.strides())))
+ return failure();
+ return success();
+}
+
/// Verifier for SubViewOp.
static LogicalResult verify(SubViewOp op) {
- auto baseType = op.getBaseMemRefType().cast<MemRefType>();
- auto subViewType = op.getType();
+ MemRefType baseType = op.getSourceMemRefType();
+ MemRefType subViewType = op.getType();
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != subViewType.getMemorySpace())
@@ -2858,24 +2907,12 @@ static LogicalResult verify(SubViewOp op) {
if (!isStrided(baseType))
return op.emitError("base type ") << baseType << " is not strided";
- // Verify static attributes offsets/sizes/strides.
- if (failed(verifySubViewOpPart(
- op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset, op.offsets())))
- return failure();
-
- if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(),
- op.static_sizes(), ShapedType::isDynamic,
- op.sizes())))
- return failure();
- if (failed(verifySubViewOpPart(
- op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset, op.strides())))
+ if (failed(verifyOpWithOffsetSizesAndStrides(op)))
return failure();
// Verify result type against inferred type.
- auto expectedType = SubViewOp::inferSubViewResultType(
- op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
+ auto expectedType = SubViewOp::inferResultType(
+ baseType, extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
if (!isRankReducedType(expectedType, subViewType))
@@ -2885,123 +2922,41 @@ static LogicalResult verify(SubViewOp op) {
return success();
}
-raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
+raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {
return os << "range " << range.offset << ":" << range.size << ":"
<< range.stride;
}
-static unsigned getNumDynamicEntriesUpToIdx(
- ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic, unsigned idx) {
- return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx,
- [&](Attribute attr) {
- return isDynamic(attr.cast<IntegerAttr>().getInt());
- });
-}
-
-bool SubViewOp::isDynamicOffset(unsigned idx) {
- return ShapedType::isDynamicStrideOrOffset(
- extractFromI64ArrayAttr(static_offsets())[idx]);
-}
-bool SubViewOp::isDynamicSize(unsigned idx) {
- return ShapedType::isDynamic(extractFromI64ArrayAttr(static_sizes())[idx]);
-}
-bool SubViewOp::isDynamicStride(unsigned idx) {
- return ShapedType::isDynamicStrideOrOffset(
- extractFromI64ArrayAttr(static_strides())[idx]);
-}
-
-unsigned SubViewOp::getIndexOfDynamicOffset(unsigned idx) {
- assert(isDynamicOffset(idx) && "expected static offset");
- auto numDynamic =
- getNumDynamicEntriesUpToIdx(static_offsets().cast<ArrayAttr>(),
- ShapedType::isDynamicStrideOrOffset, idx);
- return 1 + numDynamic;
-}
-unsigned SubViewOp::getIndexOfDynamicSize(unsigned idx) {
- assert(isDynamicSize(idx) && "expected static size");
- auto numDynamic = getNumDynamicEntriesUpToIdx(
- static_sizes().cast<ArrayAttr>(), ShapedType::isDynamic, idx);
- return 1 + offsets().size() + numDynamic;
-}
-unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) {
- assert(isDynamicStride(idx) && "expected static stride");
- auto numDynamic =
- getNumDynamicEntriesUpToIdx(static_strides().cast<ArrayAttr>(),
- ShapedType::isDynamicStrideOrOffset, idx);
- return 1 + offsets().size() + sizes().size() + numDynamic;
-}
-
-/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range
+/// Return the list of Range (i.e. offset, size, stride). Each Range
/// entry contains either the dynamic value or a ConstantIndexOp constructed
/// with `b` at location `loc`.
-SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
- Location loc) {
+template <typename OpType>
+static SmallVector<Range, 8> getOrCreateRangesImpl(OpType op, OpBuilder &b,
+ Location loc) {
SmallVector<Range, 8> res;
- unsigned rank = getType().getRank();
+ unsigned rank = op.getSourceRank();
res.reserve(rank);
for (unsigned idx = 0; idx < rank; ++idx) {
- auto offset = isDynamicOffset(idx)
- ? getDynamicOffset(idx)
- : b.create<ConstantIndexOp>(loc, getStaticOffset(idx));
- auto size = isDynamicSize(idx)
- ? getDynamicSize(idx)
- : b.create<ConstantIndexOp>(loc, getStaticSize(idx));
- auto stride = isDynamicStride(idx)
- ? getDynamicStride(idx)
- : b.create<ConstantIndexOp>(loc, getStaticStride(idx));
+ Value offset =
+ op.isDynamicOffset(idx)
+ ? op.getDynamicOffset(idx)
+ : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
+ Value size = op.isDynamicSize(idx)
+ ? op.getDynamicSize(idx)
+ : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
+ Value stride =
+ op.isDynamicStride(idx)
+ ? op.getDynamicStride(idx)
+ : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
res.emplace_back(Range{offset, size, stride});
}
return res;
}
-SmallVector<Value, 4> SubViewOp::getOrCreateOffsets(OpBuilder &b,
- Location loc) {
- unsigned dynamicIdx = 1;
- return llvm::to_vector<4>(llvm::map_range(
- static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
- int64_t staticOffset = a.cast<IntegerAttr>().getInt();
- if (ShapedType::isDynamicStrideOrOffset(staticOffset))
- return getOperand(dynamicIdx++);
- else
- return b.create<ConstantIndexOp>(loc, staticOffset);
- }));
-}
-
-SmallVector<Value, 4> SubViewOp::getOrCreateSizes(OpBuilder &b, Location loc) {
- unsigned dynamicIdx = 1 + offsets().size();
- return llvm::to_vector<4>(llvm::map_range(
- static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
- int64_t staticSize = a.cast<IntegerAttr>().getInt();
- if (ShapedType::isDynamic(staticSize))
- return getOperand(dynamicIdx++);
- else
- return b.create<ConstantIndexOp>(loc, staticSize);
- }));
-}
-
-SmallVector<Value, 4> SubViewOp::getOrCreateStrides(OpBuilder &b,
- Location loc) {
- unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
- return llvm::to_vector<4>(llvm::map_range(
- static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
- int64_t staticStride = a.cast<IntegerAttr>().getInt();
- if (ShapedType::isDynamicStrideOrOffset(staticStride))
- return getOperand(dynamicIdx++);
- else
- return b.create<ConstantIndexOp>(loc, staticStride);
- }));
+SmallVector<Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b, Location loc) {
+ return ::getOrCreateRangesImpl(*this, b, loc);
}
-LogicalResult
-SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
- if (!strides().empty())
- return failure();
- staticStrides = extractFromI64ArrayAttr(static_strides());
- return success();
-}
-
-Value SubViewOp::getViewSource() { return source(); }
-
namespace {
/// Take a list of `values` with potential new constant to extract and a list
@@ -3053,20 +3008,20 @@ class SubViewOpConstantArgumentFolder final
SmallVector<Value, 8> newOffsets(subViewOp.offsets());
SmallVector<int64_t, 8> newStaticOffsets =
extractFromI64ArrayAttr(subViewOp.static_offsets());
- assert(newStaticOffsets.size() == subViewOp.getRank());
+ assert(newStaticOffsets.size() == subViewOp.getSourceRank());
canonicalizeSubViewPart(newOffsets, newStaticOffsets,
ShapedType::isDynamicStrideOrOffset);
SmallVector<Value, 8> newSizes(subViewOp.sizes());
SmallVector<int64_t, 8> newStaticSizes =
extractFromI64ArrayAttr(subViewOp.static_sizes());
- assert(newStaticOffsets.size() == subViewOp.getRank());
+ assert(newStaticOffsets.size() == subViewOp.getSourceRank());
canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
SmallVector<Value, 8> newStrides(subViewOp.strides());
SmallVector<int64_t, 8> newStaticStrides =
extractFromI64ArrayAttr(subViewOp.static_strides());
- assert(newStaticOffsets.size() == subViewOp.getRank());
+ assert(newStaticOffsets.size() == subViewOp.getSourceRank());
canonicalizeSubViewPart(newStrides, newStaticStrides,
ShapedType::isDynamicStrideOrOffset);
@@ -3210,7 +3165,7 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
/// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
- Type resultType = SubViewOp::inferSubViewResultType(
+ Type resultType = SubViewOp::inferResultType(
castOp.source().getType().cast<MemRefType>(),
extractFromI64ArrayAttr(subViewOp.static_offsets()),
extractFromI64ArrayAttr(subViewOp.static_sizes()),
@@ -3232,6 +3187,94 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
context);
}
+//===----------------------------------------------------------------------===//
+// SubTensorOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, SubTensorOp op) {
+ return printOpWithOffsetsSizesAndStrides<SubTensorOp>(p, op);
+}
+
+static ParseResult parseSubTensorOp(OpAsmParser &parser,
+ OperationState &result) {
+ return parseOpWithOffsetsSizesAndStrides<SubTensorOp>(parser, result);
+}
+
+/// A subtensor result type can be fully inferred from the source type and the
+/// static representation of offsets, sizes and strides. Special sentinels
+/// encode the dynamic case.
+Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides) {
+ unsigned rank = sourceRankedTensorType.getRank();
+ (void)rank;
+ assert(staticOffsets.size() == rank &&
+ "unexpected staticOffsets size mismatch");
+ assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
+ assert(staticStrides.size() == rank &&
+ "unexpected staticStrides size mismatch");
+ return RankedTensorType::get(staticSizes,
+ sourceRankedTensorType.getElementType());
+}
+
+void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
+ Value source, ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides,
+ ValueRange offsets, ValueRange sizes,
+ ValueRange strides,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
+ auto resultType = inferResultType(sourceRankedTensorType, staticOffsets,
+ staticSizes, staticStrides);
+ build(b, result, resultType, source, offsets, sizes, strides,
+ b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
+ b.getI64ArrayAttr(staticStrides));
+ result.addAttributes(attrs);
+}
+
+/// Build a SubTensorOp with all dynamic entries: `staticOffsets`, `staticSizes`
+/// and `staticStrides` are automatically filled with sentinel values that
+/// encode dynamic entries.
+void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
+ Value source, ValueRange offsets,
+ ValueRange sizes, ValueRange strides,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
+ unsigned rank = sourceRankedTensorType.getRank();
+ SmallVector<int64_t, 4> staticOffsetsVector(
+ rank, ShapedType::kDynamicStrideOrOffset);
+ SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+ SmallVector<int64_t, 4> staticStridesVector(
+ rank, ShapedType::kDynamicStrideOrOffset);
+ build(b, result, source, staticOffsetsVector, staticSizesVector,
+ staticStridesVector, offsets, sizes, strides, attrs);
+}
+
+SmallVector<Range, 8> SubTensorOp::getOrCreateRanges(OpBuilder &b,
+ Location loc) {
+ return ::getOrCreateRangesImpl(*this, b, loc);
+}
+
+/// Verifier for SubTensorOp.
+static LogicalResult verify(SubTensorOp op) {
+ if (failed(verifyOpWithOffsetSizesAndStrides(op)))
+ return failure();
+
+ // Verify result type against inferred type.
+ auto expectedType = SubTensorOp::inferResultType(
+ op.getSourceRankedTensorType(),
+ extractFromI64ArrayAttr(op.static_offsets()),
+ extractFromI64ArrayAttr(op.static_sizes()),
+ extractFromI64ArrayAttr(op.static_strides()));
+ if (!isRankReducedType(expectedType, op.getType()))
+ return op.emitError("expected result type to be ")
+ << expectedType << " or a rank-reduced version.";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TensorCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 5e3959af29dd..72a063ff9d51 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -900,3 +900,27 @@ func @assume_alignment(%0: memref<4x4xf16>) {
assume_alignment %0, 16 : memref<4x4xf16>
return
}
+
+
+// CHECK-LABEL: func @subtensor({{.*}}) {
+func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+
+ // CHECK: subtensor
+ // CHECK-SAME: tensor<8x16x4xf32> to tensor<?x?x?xf32>
+ %1 = subtensor %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
+ : tensor<8x16x4xf32> to tensor<?x?x?xf32>
+
+ // CHECK: subtensor
+ // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32>
+ %2 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1]
+ : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+
+ // CHECK: subtensor
+ // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32>
+ %3 = subtensor %t[0, 2, 0][4, 1, 4][1, 1, 1]
+ : tensor<8x16x4xf32> to tensor<4x4xf32>
+
+ return
+}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index ab18845bdb53..7356c07577db 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1255,3 +1255,23 @@ func @imaginary_part_from_incompatible_complex_type(%cplx: complex<f64>) {
std.re %cplx : complex<f32>
return
}
+
+// -----
+
+func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
+ // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>'}}
+ %0 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1]
+ : tensor<8x16x4xf32> to tensor<?x4x4xf32>
+
+ return
+}
+
+// -----
+
+func @subtensor_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
+ // expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>'}}
+ %0 = subtensor %t[0, 0, 0][%idx, 3, %idx][1, 1, 1]
+ : tensor<8x16x4xf32> to tensor<4x4x4xf32>
+
+ return
+}
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index edcc66c9b6a6..ffb0f92dae99 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -301,8 +301,7 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx,
template <typename IdOp, typename NProcsOp>
static SmallVector<ProcInfo, 2>
-getGpuProcIds(OpBuilder &b, Location loc,
- ArrayRef<SubViewOp::Range> parallelLoopRanges) {
+getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) {
Type indexType = b.getIndexType();
SmallVector<ProcInfo, 2> procInfo(2);
procInfo[0] = {b.create<IdOp>(loc, indexType, b.getStringAttr("y")),
More information about the Mlir-commits
mailing list