[Mlir-commits] [mlir] fd15e2b - [mlir][Linalg] Use rank-reduced versions of subtensor and subtensor insert when possible.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 3 12:51:37 PDT 2021
Author: MaheshRavishankar
Date: 2021-05-03T12:51:24-07:00
New Revision: fd15e2b825f26dd7eac3b4a52aab36c88e52850a
URL: https://github.com/llvm/llvm-project/commit/fd15e2b825f26dd7eac3b4a52aab36c88e52850a
DIFF: https://github.com/llvm/llvm-project/commit/fd15e2b825f26dd7eac3b4a52aab36c88e52850a.diff
LOG: [mlir][Linalg] Use rank-reduced versions of subtensor and subtensor insert when possible.
Convert subtensor and subtensor_insert operations to use their
rank-reduced versions to drop unit dimensions.
Differential Revision: https://reviews.llvm.org/D101495
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/MemRef/fold-subview-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 18be136ac6d18..d98d510a134a2 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -18,7 +18,9 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
from/to the original memref.
}];
let constructor = "mlir::memref::createFoldSubViewOpsPass()";
- let dependentDialects = ["memref::MemRefDialect", "vector::VectorDialect"];
+ let dependentDialects = [
+ "AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
+ ];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 5d8a664ef9646..f9320f358ab54 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -544,77 +544,87 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
return success();
}
};
+} // namespace
-/// Pattern to fold subtensors that are just taking a slice of unit-dimension
-/// tensor. For example
-///
-/// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1]
-/// : tensor<1x?x1xf32> to tensor<1x?x1xf32>
-///
-/// can be replaced with
-///
-/// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-/// : tensor<1x?x1xf32> into tensor<?xf32>
-/// %1 = subtensor %0[%o1] [%s1] [1] : tensor<?xf32> to tensor<?xf32>
-/// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
-/// : tensor<?xf32> into tensor<1x?x1xf32>
-///
-/// The additional tensor_reshapes will hopefully get canonicalized away with
-/// other reshapes that drop unit dimensions. Three condiitions to fold a
-/// dimension
-/// - The offset must be 0
-/// - The size must be 1
-/// - The dimension of the source type must be 1.
-struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
+/// Get the reassociation maps to fold the result of a subtensor (or source of a
+/// subtensor_insert) operation with given offsets, and sizes to its
+/// rank-reduced version. This is only done for the cases where the size is 1
+/// and offset is 0. Strictly speaking the offset 0 is not required in general,
+/// but non-zero offsets are not handled by SPIR-V backend at this point (and
+/// potentially cannot be handled).
+static Optional<SmallVector<ReassociationIndices>>
+getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
+ SmallVector<ReassociationIndices> reassociation;
+ ReassociationIndices curr;
+ for (auto it : llvm::enumerate(mixedSizes)) {
+ auto dim = it.index();
+ auto size = it.value();
+ curr.push_back(dim);
+ auto attr = size.dyn_cast<Attribute>();
+ if (attr && attr.cast<IntegerAttr>().getInt() == 1)
+ continue;
+ reassociation.emplace_back(ReassociationIndices{});
+ std::swap(reassociation.back(), curr);
+ }
+ if (!curr.empty())
+ reassociation.back().append(curr.begin(), curr.end());
+ return reassociation;
+}
+
+namespace {
+/// Convert `subtensor` operations to rank-reduced versions.
+struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> {
using OpRewritePattern<SubTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
PatternRewriter &rewriter) const override {
- SmallVector<OpFoldResult> mixedOffsets = subTensorOp.getMixedOffsets();
- SmallVector<OpFoldResult> mixedSizes = subTensorOp.getMixedSizes();
- SmallVector<OpFoldResult> mixedStrides = subTensorOp.getMixedStrides();
- auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) {
- auto attr = valueOrAttr.dyn_cast<Attribute>();
- return attr && attr.cast<IntegerAttr>().getInt() == val;
- };
-
- if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) {
- return !hasValue(valueOrAttr, 1);
- }))
+ RankedTensorType resultType = subTensorOp.getType();
+ SmallVector<OpFoldResult> offsets = subTensorOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = subTensorOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = subTensorOp.getMixedStrides();
+ auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+ if (!reassociation ||
+ reassociation->size() == static_cast<size_t>(resultType.getRank()))
return failure();
+ auto rankReducedType =
+ SubTensorOp::inferRankReducedResultType(reassociation->size(),
+ subTensorOp.getSourceType(),
+ offsets, sizes, strides)
+ .cast<RankedTensorType>();
+
+ Location loc = subTensorOp.getLoc();
+ Value newSubTensor = rewriter.create<SubTensorOp>(
+ loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides);
+ rewriter.replaceOpWithNewOp<TensorReshapeOp>(subTensorOp, resultType,
+ newSubTensor, *reassociation);
+ return success();
+ }
+};
- // Find the expanded unit dimensions.
- SmallVector<ReassociationIndices> reassociation;
- SmallVector<OpFoldResult> newOffsets, newSizes;
- ArrayRef<int64_t> sourceShape = subTensorOp.getSourceType().getShape();
- ReassociationIndices curr;
- for (int64_t dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
- curr.push_back(dim);
- if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) &&
- hasValue(mixedSizes[dim], 1)) {
- continue;
- }
- newOffsets.push_back(mixedOffsets[dim]);
- newSizes.push_back(mixedSizes[dim]);
- reassociation.emplace_back(ReassociationIndices{});
- std::swap(reassociation.back(), curr);
- }
- if (newOffsets.size() == mixedOffsets.size())
+/// Convert `subtensor_insert` operations to rank-reduced versions.
+struct UseRankReducedSubTensorInsertOp
+ : public OpRewritePattern<SubTensorInsertOp> {
+ using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SubTensorInsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ RankedTensorType sourceType = insertOp.getSourceType();
+ SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
+ auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+ if (!reassociation ||
+ reassociation->size() == static_cast<size_t>(sourceType.getRank()))
return failure();
- reassociation.back().append(curr.begin(), curr.end());
- SmallVector<OpFoldResult> newStrides(newOffsets.size(),
- rewriter.getI64IntegerAttr(1));
- Location loc = subTensorOp->getLoc();
- auto srcReshape = rewriter.create<TensorReshapeOp>(
- loc, subTensorOp.source(), reassociation);
- auto newSubTensorOp = rewriter.create<SubTensorOp>(
- loc, srcReshape, newOffsets, newSizes, newStrides);
- rewriter.replaceOpWithNewOp<TensorReshapeOp>(
- subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation);
+ Location loc = insertOp.getLoc();
+ auto reshapedSource = rewriter.create<TensorReshapeOp>(
+ loc, insertOp.source(), *reassociation);
+ rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
+ insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
+ insertOp.getMixedSizes(), insertOp.getMixedStrides());
return success();
}
};
-
} // namespace
/// Patterns that are used to canonicalize the use of unit-extent dims for
@@ -623,8 +633,10 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
- FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors<GenericOp>,
- ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
+ ReplaceUnitExtentTensors<GenericOp>,
+ ReplaceUnitExtentTensors<IndexedGenericOp>,
+ UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
+ context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldReshapeOpWithUnitExtent>(context);
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index cb27354f36dfe..e795a86f69d74 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
MLIRMemRefPassIncGen
LINK_LIBS PUBLIC
+ MLIRAffine
MLIRMemRef
MLIRPass
MLIRStandard
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index ae76966ba25d6..4e1424083e96b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -41,27 +42,53 @@ static LogicalResult
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
memref::SubViewOp subViewOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- // TODO: Aborting when the offsets are static. There might be a way to fold
- // the subview op with load even if the offsets have been canonicalized
- // away.
- SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
- if (opRanges.size() != indices.size()) {
- // For the rank-reduced cases, we can only handle the folding when the
- // offset is zero, size is 1 and stride is 1.
- return failure();
+ SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
+
+ SmallVector<Value> useIndices;
+ // Check if this is rank-reducing case. Then for every unit-dim size add a
+ // zero to the indices.
+ ArrayRef<int64_t> resultShape = subViewOp.getType().getShape();
+ unsigned resultDim = 0;
+ for (auto size : llvm::enumerate(mixedSizes)) {
+ auto attr = size.value().dyn_cast<Attribute>();
+ // Check if this dimension has been dropped, i.e. the size is 1, but the
+ // associated dimension is not 1.
+ if (attr && attr.cast<IntegerAttr>().getInt() == 1 &&
+ (resultDim >= resultShape.size() || resultShape[resultDim] != 1))
+ useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
+ else if (resultDim < resultShape.size()) {
+ useIndices.push_back(indices[resultDim++]);
+ }
}
- auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
- auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
-
- // New indices for the load are the current indices * subview_stride +
- // subview_offset.
- sourceIndices.resize(indices.size());
- for (auto index : llvm::enumerate(indices)) {
- auto offset = *(opOffsets.begin() + index.index());
- auto stride = *(opStrides.begin() + index.index());
- auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
- sourceIndices[index.index()] =
- rewriter.create<AddIOp>(loc, offset, mul).getResult();
+ if (useIndices.size() != mixedOffsets.size())
+ return failure();
+ sourceIndices.resize(useIndices.size());
+ for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
+ SmallVector<Value> dynamicOperands;
+ AffineExpr expr = rewriter.getAffineDimExpr(0);
+ unsigned numSymbols = 0;
+ dynamicOperands.push_back(useIndices[index]);
+
+ // Multiply the stride;
+ if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
+ expr = expr * attr.cast<IntegerAttr>().getInt();
+ } else {
+ dynamicOperands.push_back(mixedStrides[index].get<Value>());
+ expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
+ }
+
+ // Add the offset.
+ if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
+ expr = expr + attr.cast<IntegerAttr>().getInt();
+ } else {
+ dynamicOperands.push_back(mixedOffsets[index].get<Value>());
+ expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
+ }
+ Location loc = subViewOp.getLoc();
+ sourceIndices[index] = rewriter.create<AffineApplyOp>(
+ loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
}
return success();
}
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index e9dd74faad64b..2c6ab57782dd2 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -476,67 +476,32 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
// -----
func @fold_subtensor(
- %arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : index, %arg2 : index,
- %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index)
- -> tensor<1x?x?x1x?x1x1xf32> {
- %0 = subtensor %arg0[0, %arg1, %arg2, 0, %arg3, 0, 0]
- [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
+ %arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : tensor<1x?x?x?x?x1x1xf32>,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index,
+ %arg6 : index, %arg7 : index) -> (tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32>) {
+ %0 = subtensor %arg0[0, %arg2, %arg3, 0, %arg4, 0, 0]
+ [1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
tensor<1x?x?x1x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
- return %0 : tensor<1x?x?x1x?x1x1xf32>
+ %1 = subtensor %arg1[%arg2, 0, %arg3, 0, 0, %arg4, 0]
+ [1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
+ tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
+ return %0, %1 : tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
// CHECK: func @fold_subtensor
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x?x1x?x1x1xf32>
-// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG3:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG4:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG5:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG6:[a-z0-9]+]]: index
-// CHECK: %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?x?x1x1xf32>
+// CHECK: %[[SUBTENSOR1:.+]] = subtensor %[[ARG0]]
+// CHECK-SAME: to tensor<?x?x?xf32>
+// CHECK: %[[RESULT1:.+]] = linalg.tensor_reshape %[[SUBTENSOR1]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]]
-// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[ARG3]]]
-// CHECK-SAME: [%[[ARG4]], %[[ARG5]], %[[ARG6]]]
-// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
+// CHECK: %[[SUBTENSOR2:.+]] = subtensor %[[ARG1]]
+// CHECK-SAME: to tensor<?x?x?xf32>
+// CHECK: %[[RESULT2:.+]] = linalg.tensor_reshape %[[SUBTENSOR2]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-// CHECK: return %[[RESULT_RESHAPE]]
-
-// -----
-
-func @no_fold_subtensor(
- %arg0 : tensor<1x?x?x?x?x1x1xf32>, %arg1 : index, %arg2 : index,
- %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index)
- -> tensor<1x?x?x1x?x1x1xf32> {
- %0 = subtensor %arg0[%arg1, 0, %arg2, 0, 0, %arg3, 0]
- [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
- tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
- return %0 : tensor<1x?x?x1x?x1x1xf32>
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6)>
-// CHECK: func @no_fold_subtensor
-// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x?x?x?x1x1xf32>
-// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG3:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG4:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG5:[a-z0-9]+]]: index
-// CHECK-SAME: %[[ARG6:[a-z0-9]+]]: index
-// CHECK: %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
-// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
-// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]]
-// CHECK-SAME: [%[[ARG1]], 0, %[[ARG2]], 0, 0, %[[ARG3]]]
-// CHECK-SAME: [1, %[[ARG4]], %[[ARG5]], 1, %[[ARG6]], 1]
-// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
-// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
-// CHECK: return %[[RESULT_RESHAPE]]
+// CHECK: return %[[RESULT1]], %[[RESULT2]]
// -----
diff --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
index 2cddeb93dc301..246c0b3552947 100644
--- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
@@ -1,99 +1,162 @@
-// RUN: mlir-opt -fold-memref-subview-ops -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -fold-memref-subview-ops -split-input-file %s -o - | FileCheck %s
-// CHECK-LABEL: @fold_static_stride_subview_with_load
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
- // CHECK-NOT: memref.subview
- // CHECK: [[C2:%.*]] = constant 2 : index
- // CHECK: [[C3:%.*]] = constant 3 : index
- // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
- // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
- // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
- // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
- // CHECK: memref.load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
%1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
return %1 : f32
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * 3 + s0)>
+// CHECK: func @fold_static_stride_subview_with_load
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP0]](%[[ARG3]])[%[[ARG1]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP1]](%[[ARG4]])[%[[ARG2]]]
+// CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
-// CHECK-LABEL: @fold_dynamic_stride_subview_with_load
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index
func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> f32 {
- // CHECK-NOT: memref.subview
- // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[ARG5]] : index
- // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
- // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
- // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
- // CHECK: memref.load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
%1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
return %1 : f32
}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @fold_dynamic_stride_subview_with_load
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG3]])[%[[ARG5]], %[[ARG1]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG4]])[%[[ARG6]], %[[ARG2]]]
+// CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
-// CHECK-LABEL: @fold_static_stride_subview_with_store
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: f32
func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
- // CHECK-NOT: memref.subview
- // CHECK: [[C2:%.*]] = constant 2 : index
- // CHECK: [[C3:%.*]] = constant 3 : index
- // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
- // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
- // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
- // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
- // CHECK: memref.store [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
memref.store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
return
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * 3 + s0)>
+// CHECK: func @fold_static_stride_subview_with_store
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP0]](%[[ARG3]])[%[[ARG1]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP1]](%[[ARG4]])[%[[ARG2]]]
+// CHECK: memref.store %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
-// CHECK-LABEL: @fold_dynamic_stride_subview_with_store
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: f32
func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : f32) {
- // CHECK-NOT: memref.subview
- // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[ARG5]] : index
- // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
- // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
- // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
- // CHECK: memref.store [[ARG7]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
memref.store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
return
}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @fold_dynamic_stride_subview_with_store
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG3]])[%[[ARG5]], %[[ARG1]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG4]])[%[[ARG6]], %[[ARG2]]]
+// CHECK: memref.store %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
-// CHECK-LABEL: @fold_static_stride_subview_with_transfer_read
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
-func @fold_static_stride_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> vector<4xf32> {
- // CHECK-NOT: memref.subview
- // CHECK-DAG: [[F1:%.*]] = constant 1.000000e+00 : f32
- // CHECK-DAG: [[C2:%.*]] = constant 2 : index
- // CHECK-DAG: [[C3:%.*]] = constant 3 : index
- // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
- // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
- // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
- // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
- // CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}, [[F1]] {in_bounds = [true]}
+func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> {
%f1 = constant 1.0 : f32
- %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
- %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32>
+ %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+ %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, offset:?, strides: [?, ?]>, vector<4xf32>
return %1 : vector<4xf32>
}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @fold_subview_with_transfer_read
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG3]])[%[[ARG5]], %[[ARG1]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG4]])[%[[ARG6]], %[[ARG2]]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[I1]], %[[I2]]]
-// CHECK-LABEL: @fold_static_stride_subview_with_transfer_write
-// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: vector<4xf32>
-func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : vector<4xf32>) {
- // CHECK-NOT: memref.subview
- // CHECK: [[C2:%.*]] = constant 2 : index
- // CHECK: [[C3:%.*]] = constant 3 : index
- // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
- // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
- // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
- // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
- // CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} {in_bounds = [true]}
- %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] :
- memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
- vector.transfer_write %arg5, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]>
+// -----
+
+func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5: index, %arg6 : index, %arg7 : vector<4xf32>) {
+ %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
+ memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+ vector.transfer_write %arg7, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, offset:?, strides: [?, ?]>
return
}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @fold_static_stride_subview_with_transfer_write
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG3]])[%[[ARG5]], %[[ARG1]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG4]])[%[[ARG6]], %[[ARG2]]]
+// CHECK: vector.transfer_write %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]]
+
+// -----
+
+func @fold_rank_reducing_subview_with_load
+ (%arg0 : memref<?x?x?x?x?x?xf32>, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index, %arg9 : index, %arg10: index,
+ %arg11 : index, %arg12 : index, %arg13 : index, %arg14: index,
+ %arg15 : index, %arg16 : index) -> f32 {
+ %0 = memref.subview %arg0[%arg1, %arg2, %arg3, %arg4, %arg5, %arg6][4, 1, 1, 4, 1, 1][%arg7, %arg8, %arg9, %arg10, %arg11, %arg12] : memref<?x?x?x?x?x?xf32> to memref<4x1x4x1xf32, offset:?, strides: [?, ?, ?, ?]>
+ %1 = memref.load %0[%arg13, %arg14, %arg15, %arg16] : memref<4x1x4x1xf32, offset:?, strides: [?, ?, ?, ?]>
+ return %1 : f32
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @fold_rank_reducing_subview_with_load
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?x?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG11:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG12:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG13:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG14:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG15:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG16:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG13]])[%[[ARG7]], %[[ARG1]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG14]])[%[[ARG8]], %[[ARG2]]]
+// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG9]], %[[ARG3]]]
+// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]](%[[ARG15]])[%[[ARG10]], %[[ARG4]]]
+// CHECK-DAG: %[[I5:.+]] = affine.apply #[[MAP]](%[[ARG16]])[%[[ARG11]], %[[ARG5]]]
+// CHECK-DAG: %[[I6:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG12]], %[[ARG6]]]
+// CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]]
More information about the Mlir-commits
mailing list