[Mlir-commits] [mlir] 6c7be41 - Support buffers in LinalgFoldUnitExtentDims
Tres Popp
llvmlistbot at llvm.org
Mon Jun 14 23:22:32 PDT 2021
Author: Tres Popp
Date: 2021-06-15T08:22:22+02:00
New Revision: 6c7be4176703fff69d20acc466a879e080346f30
URL: https://github.com/llvm/llvm-project/commit/6c7be4176703fff69d20acc466a879e080346f30
DIFF: https://github.com/llvm/llvm-project/commit/6c7be4176703fff69d20acc466a879e080346f30.diff
LOG: Support buffers in LinalgFoldUnitExtentDims
This doesn't add any canonicalizations, but executes the same
simplification on bufferSemantic linalg.generic ops by using
linalg::ReshapeOp instead of linalg::TensorReshapeOp.
Differential Revision: https://reviews.llvm.org/D103513
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 1deea9476674..68102cb0f480 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/CommandLine.h"
@@ -256,7 +257,7 @@ struct UnitExtentReplacementInfo {
} // namespace
/// Utility function for replacing operands/results to a linalg generic
-/// operation on tensors with unit-extent dimensions. These can be replaced with
+/// operation with unit-extent dimensions. These can be replaced with
/// an operand/result with the unit-extent dimension removed. This is only done
/// if the indexing map used to access that didimensionmension has a
/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
@@ -301,10 +302,19 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
++dim;
}
// Compute the tensor or scalar replacement type.
+ Type actualType = opOperand->get().getType();
Type elementType = getElementTypeOrSelf(opOperand->get());
- Type replacementType = elementType == opOperand->get().getType()
- ? elementType
- : RankedTensorType::get(newShape, elementType);
+ Type replacementType;
+ if (elementType == opOperand->get().getType()) {
+ replacementType = elementType;
+ } else if (actualType.isa<RankedTensorType>()) {
+ replacementType = RankedTensorType::get(newShape, elementType);
+ } else if (actualType.isa<MemRefType>()) {
+ assert(actualType.cast<MemRefType>().getAffineMaps().empty() &&
+ "unsupported strided memrefs");
+ replacementType = MemRefType::get(newShape, elementType);
+ }
+ assert(replacementType && "unsupported shaped type");
UnitExtentReplacementInfo info = {replacementType,
AffineMap::get(indexingMap.getNumDims(),
indexingMap.getNumSymbols(),
@@ -324,14 +334,53 @@ convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
return reassociationExprs;
}
-/// Pattern to replace tensors operands/results that are unit extents.
-struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
+/// Pattern to replace tensor/buffer operands/results that are unit extents.
+struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ // Return the original value if the type is unchanged, or reshape it. Return a
+ // nullptr if this is an unsupported type.
+ Value maybeExpand(Value result, Type origResultType,
+ ArrayAttr reassociationMap, Location loc,
+ PatternRewriter &rewriter) const {
+ if (origResultType == result.getType())
+ return result;
+ if (origResultType.isa<RankedTensorType>()) {
+ return rewriter.create<linalg::TensorExpandShapeOp>(
+ loc, origResultType, result,
+ convertAffineMapArrayToExprs(reassociationMap));
+ }
+ if (origResultType.isa<MemRefType>()) {
+ return rewriter.create<linalg::ExpandShapeOp>(
+ loc, origResultType, result,
+ convertAffineMapArrayToExprs(reassociationMap));
+ }
+ return nullptr;
+ };
+
+ // Return the original value if the type is unchanged, or reshape it. Return a
+ // nullptr if this is an unsupported type.
+ Value maybeCollapse(Value operand, Type newInputOutputType,
+ ArrayAttr reassociationMap, Location loc,
+ PatternRewriter &rewriter) const {
+ auto operandType = operand.getType();
+ if (operandType == newInputOutputType)
+ return operand;
+ if (operandType.isa<MemRefType>()) {
+ return rewriter.create<linalg::CollapseShapeOp>(
+ loc, newInputOutputType, operand,
+ convertAffineMapArrayToExprs(reassociationMap));
+ }
+ if (operandType.isa<RankedTensorType>()) {
+ return rewriter.create<linalg::TensorCollapseShapeOp>(
+ loc, newInputOutputType, operand,
+ convertAffineMapArrayToExprs(reassociationMap));
+ }
+ return nullptr;
+ };
+
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!genericOp.hasTensorSemantics())
- return failure();
-
MLIRContext *context = rewriter.getContext();
Location loc = genericOp.getLoc();
@@ -339,7 +388,6 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
SmallVector<ArrayAttr> reassociationMaps;
SmallVector<Type> newInputOutputTypes;
bool doCanonicalization = false;
-
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
UnitExtentReplacementInfo replacementInfo =
replaceUnitExtents(genericOp, opOperand, context);
@@ -362,14 +410,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
auto insertReshapes = [&](ValueRange values) {
SmallVector<Value, 4> res;
res.reserve(values.size());
- for (auto operand : llvm::enumerate(values)) {
- if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
- res.push_back(operand.value());
- else {
- res.push_back(rewriter.create<TensorCollapseShapeOp>(
- loc, newInputOutputTypes[flattenedIdx], operand.value(),
- convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])));
- }
+ for (auto operand : values) {
+ auto reshapedValue =
+ maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
+ reassociationMaps[flattenedIdx], loc, rewriter);
+ assert(reshapedValue &&
+ "expected ranked MemRef or Tensor operand type");
+ res.push_back(reshapedValue);
++flattenedIdx;
}
return res;
@@ -396,15 +443,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
SmallVector<Value, 4> resultReplacements;
for (auto result : llvm::enumerate(replacementOp.getResults())) {
unsigned index = result.index() + replacementOp.getNumInputs();
- RankedTensorType origResultType = genericOp.getResult(result.index())
- .getType()
- .template cast<RankedTensorType>();
- if (origResultType != result.value().getType()) {
- resultReplacements.push_back(rewriter.create<TensorExpandShapeOp>(
- loc, origResultType, result.value(),
- convertAffineMapArrayToExprs(reassociationMaps[index])));
- } else
- resultReplacements.push_back(result.value());
+ auto origResultType = genericOp.getResult(result.index()).getType();
+
+ auto newResult = maybeExpand(result.value(), origResultType,
+ reassociationMaps[index], loc, rewriter);
+ assert(newResult &&
+ "unexpected output type other than ranked MemRef or Tensor");
+ resultReplacements.push_back(newResult);
}
rewriter.replaceOp(genericOp, resultReplacements);
return success();
@@ -501,9 +546,8 @@ struct UseRankReducedSubTensorInsertOp
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
- patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors,
- UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
- context);
+ patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, UseRankReducedSubTensorOp,
+ UseRankReducedSubTensorInsertOp>(context);
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
}
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5a53c228bea5..f5359e54a7d5 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -451,3 +451,303 @@ func @subtensor_insert_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>)
// CHECK: %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]]
// CHECK-SAME: tensor<f32> into tensor<1x3xf32>
// CHECK: return %[[RESULT]]
+
+// -----
+
+#accesses = [
+ affine_map<(i, j, k, l, m) -> (i, k, m)>,
+ affine_map<(i, j, k, l, m) -> ()>,
+ affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = #accesses,
+ library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops(%arg0 : memref<?x1x?xf32>, %arg1 : f32, %shape: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32> {
+ linalg.generic #trait
+ ins(%arg0, %arg1 : memref<?x1x?xf32>, f32)
+ outs(%shape : memref<?x1x?x1x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
+ linalg.yield %arg3 : f32
+ }
+ return %shape : memref<?x1x?x1x?xf32>
+}
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @drop_one_trip_loops
+// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+
+// -----
+
+#accesses = [
+ affine_map<(i, j, k, l, m) -> (i, k, m)>,
+ affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = #accesses,
+ library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops_indexed
+ (%arg0 : memref<?x1x?xi32>, %shape: memref<?x1x?x1x?xi32>) -> memref<?x1x?x1x?xi32>
+{
+ linalg.generic #trait
+ ins(%arg0 : memref<?x1x?xi32>)
+ outs(%shape: memref<?x1x?x1x?xi32>) {
+ ^bb0(%arg6 : i32, %arg7 : i32) :
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %idx2 = linalg.index 2 : index
+ %idx3 = linalg.index 3 : index
+ %idx4 = linalg.index 4 : index
+ %1 = addi %idx0, %idx1 : index
+ %2 = subi %1, %idx2 : index
+ %3 = subi %2, %idx3 : index
+ %4 = addi %3, %idx4 : index
+ %5 = index_cast %4 : index to i32
+ %6 = addi %5, %arg6 : i32
+ linalg.yield %6 : i32
+ }
+ return %shape : memref<?x1x?x1x?xi32>
+}
+// The subtractions disappear the access map of the output memref maps its unit
+// dimensions 1 and 3 to the index dimensions 2 and 3.
+// CHECK-LABEL: func @drop_one_trip_loops_indexed
+// CHECK: linalg.generic
+// CHECK: ^{{.+}}(
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32)
+// CHECK: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK: %[[IDX2:.+]] = linalg.index 2 : index
+// CHECK: %[[T3:.+]] = addi %[[IDX0]], %[[IDX1]]
+// CHECK: %[[T4:.+]] = addi %[[T3]], %[[IDX2]]
+// CHECK: %[[T5:.+]] = index_cast %[[T4]] : index to i32
+// CHECK: %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32
+// CHECK: linalg.yield %[[T6]] : i32
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+ iterator_types = ["parallel", "parallel"],
+ indexing_maps = #access,
+ library_call = "some_external_func"
+}
+
+func @drop_all_loops(%arg0 : memref<1x1xf32>) -> memref<1x1xf32>
+{
+ linalg.generic #trait
+ ins(%arg0 : memref<1x1xf32>)
+ outs(%arg0 : memref<1x1xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32) :
+ linalg.yield %arg1 : f32
+ }
+ return %arg0 : memref<1x1xf32>
+}
+// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
+// CHECK-LABEL: func @drop_all_loops
+// CHECK: linalg.collapse_shape %{{.*}} []
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
+// CHECK-SAME: iterator_types = []
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+ iterator_types = ["parallel", "parallel"],
+ indexing_maps = #access,
+ library_call = "some_external_func"
+}
+
+func @drop_all_loops_indexed
+ (%arg0 : memref<1x1xi32>) -> memref<1x1xi32>{
+ linalg.generic #trait
+ ins(%arg0 : memref<1x1xi32>)
+ outs(%arg0 : memref<1x1xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32) :
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %1 = addi %idx0, %idx1 : index
+ %2 = index_cast %1 : index to i32
+ %3 = addi %2, %arg3 : i32
+ linalg.yield %3 : i32
+ }
+ return %arg0 : memref<1x1xi32>
+}
+
+// CHECK-LABEL: func @drop_all_loops_indexed
+// CHECK: linalg.generic
+// CHECK: ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+// CHECK: linalg.yield %[[ARG1]] : i32
+
+// -----
+
+#accesses = [
+ affine_map<(d0) -> (0, d0)>,
+ affine_map<(d0) -> (d0)>
+]
+
+#trait = {
+ indexing_maps = #accesses,
+ iterator_types = ["parallel"],
+ library_call = "some_external_fn"
+}
+
+func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref<5xf32>) -> memref<5xf32> {
+ linalg.generic #trait
+ ins(%arg0 : memref<1x5xf32>)
+ outs(%shape : memref<5xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ }
+ return %shape : memref<5xf32>
+}
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: func @leading_dim_1_canonicalization
+// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1]]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel"]
+
+// -----
+
+#accesses = [
+ affine_map<(d0, d1) -> (0, d1)>,
+ affine_map<(d0, d1) -> (d0, 0)>,
+ affine_map<(d0, d1) -> (d0, d1)>
+]
+
+#trait = {
+ indexing_maps = #accesses,
+ iterator_types = ["parallel", "parallel"],
+ library_call = "some_external_fn"
+}
+
+func @broadcast_test(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %shape : memref<5x5xf32>) -> memref<5x5xf32>
+{
+ %0 = linalg.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32>
+ %1 = linalg.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32>
+ linalg.generic #trait
+ ins(%0, %1 : memref<1x5xf32>, memref<5x1xf32>)
+ outs(%shape : memref<5x5xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %3 = addf %arg3, %arg4 : f32
+ linalg.yield %3 : f32
+ }
+ return %shape : memref<5x5xf32>
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @broadcast_test
+// CHECK-NOT: linalg.memref_{{.*}}shape
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-NOT: linalg.memref_{{.*}}shape
+
+// -----
+
+#accesses = [
+ affine_map<(d0, d1) -> (0, 0)>,
+ affine_map<(d0, d1) -> (d0, d1)>
+]
+
+#trait = {
+ indexing_maps = #accesses,
+ iterator_types = ["parallel", "parallel"],
+ library_call = "some_external_fn"
+}
+
+func @broadcast_scalar(%arg0 : memref<1x1xf32>, %shape : memref<?x?xf32>) -> memref<?x?xf32>
+{
+ linalg.generic #trait
+ ins(%arg0 : memref<1x1xf32>)
+ outs(%shape : memref<?x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ }
+ return %shape : memref<?x?xf32>
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @broadcast_scalar
+// CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32>
+// CHECK: %[[A:.*]] = linalg.collapse_shape %[[ARG0]] []
+// CHECK-SAME: memref<1x1xf32> into memref<f32>
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: %[[A]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2)>
+func @fold_unit_dim_memref_reshape_op(%arg0 : memref<5xf32>) -> memref<2x5xf32>
+{
+ %1 = memref.alloc() : memref<1x2x5xf32>
+ linalg.generic {i64, indexing_maps = [#map1, #map0],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0 : memref<5xf32>) outs(%1 : memref<1x2x5xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32): // no predecessors
+ linalg.yield %arg1 : f32
+ }
+ %3 = linalg.collapse_shape %1 [[0, 1], [2]]
+ : memref<1x2x5xf32> into memref<2x5xf32>
+ return %3 : memref<2x5xf32>
+}
+// CHECK-LABEL: func @fold_unit_dim_memref_reshape_op
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x2x5xf32>
+// CHECK: %[[OUT:.*]] = linalg.collapse_shape %[[ALLOC]]
+// CHECK: linalg.generic
+// CHECK-SAME: outs(%[[OUT:.*]] :
+// CHECK: %[[RESULT:.*]] = linalg.collapse_shape %[[ALLOC]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32> {
+ %cst = constant 0.0 : f32
+ %init = memref.alloc() : memref<1xf32>
+ linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%input : memref<1x1000xf32>)outs(%init : memref<1xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1823 = addf %arg1, %arg2 : f32
+ linalg.yield %1823 : f32
+ }
+ return %init : memref<1xf32>
+}
+
+
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ()>
+
+// CHECK: func @fold_unit_dim_for_init_memref
+// CHECK: %[[INIT:.+]] = memref.alloc() : memref<1xf32>
+// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32>
+// CHECK: %[[INIT_RESHAPE:.+]] = linalg.collapse_shape %[[INIT]] [] : memref<1xf32> into memref<f32>
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["reduction"]
+// CHECK-SAME: ins(%[[INPUT_RESHAPE]] : memref<1000xf32>)
+// CHECK-SAME: outs(%[[INIT_RESHAPE]] : memref<f32>)
+// CHECK: return %[[INIT:.+]] : memref<1xf32>
+
+
+
More information about the Mlir-commits
mailing list