[Mlir-commits] [mlir] 2b0c854 - [mlir][Linalg] Add pass to remove unit-extent dims from tensor
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 28 11:07:08 PDT 2020
Author: MaheshRavishankar
Date: 2020-05-28T11:06:47-07:00
New Revision: 2b0c8546ac9fb47e1bf9c5e54f1450420eadeab7
URL: https://github.com/llvm/llvm-project/commit/2b0c8546ac9fb47e1bf9c5e54f1450420eadeab7
DIFF: https://github.com/llvm/llvm-project/commit/2b0c8546ac9fb47e1bf9c5e54f1450420eadeab7.diff
LOG: [mlir][Linalg] Add pass to remove unit-extent dims from tensor
operands of Generic ops.
Unit-extent dimensions are typically used for achieving broadcasting
behavior. The pattern added (along with canonicalization patterns
added previously) removes the use of unit-extent dimensions, and
instead uses a more canonical representation of the computation. This
new pattern is not added as a canonicalization for now since it
entails adding additional reshape operations. A pass is added to
exercise these patterns, along with an API entry to populate a
patterns list with these patterns.
Differential Revision: https://reviews.llvm.org/D79766
Added:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index beac1135a0bc..b03001c9b8e9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -123,6 +123,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
"Return the range over inputs (irrespective of type) and output buffers.",
"Operation::operand_range", "getInputsAndOutputBuffers"
>,
+ InterfaceMethod<
+ "Return the shaped types for all the inputs and outputs",
+ "SmallVector<ShapedType, 4>", "getInputOutputShapedTypes"
+ >,
//===------------------------------------------------------------------===//
// Other interface methods.
@@ -153,6 +157,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
"Return the indexing maps attribute within the current operation.",
"ArrayAttr", "indexing_maps"
>,
+ InterfaceMethod<
+ "Return the indexing maps within the current operation.",
+ "SmallVector<AffineMap, 4>", "getIndexingMaps"
+ >,
InterfaceMethod<"Return the input or output indexing map at index `i`.",
"AffineMap", "getIndexingMap", (ins "unsigned":$i)
>,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index b7bba5a31011..4ab547be2019 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -217,6 +217,18 @@ class StructuredOpTraits
return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]
.template cast<ShapedType>();
}
+ /// Return the shaped types for all the inputs and outputs
+ SmallVector<ShapedType, 4> getInputOutputShapedTypes() {
+ SmallVector<Type, 4> inputOutputTypes(
+ this->getOperation()->operand_type_begin(),
+ this->getOperation()->operand_type_end());
+ inputOutputTypes.append(this->getOperation()->result_type_begin(),
+ this->getOperation()->result_type_end());
+ return llvm::to_vector<4>(
+ llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
+ return type.cast<ShapedType>();
+ }));
+ }
//==========================================================================//
// Other interface methods.
@@ -295,6 +307,13 @@ class StructuredOpTraits
return attr;
}
+ SmallVector<AffineMap, 4> getIndexingMaps() {
+ return llvm::to_vector<4>(
+ llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap {
+ return attr.cast<AffineMapAttr>().getValue();
+ }));
+ }
+
AffineMap getIndexingMap(unsigned i) {
assert(i < getNumInputsAndOutputs());
return indexing_maps()
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index d3bfa90e6bdb..8a274ed48dc5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -24,6 +24,8 @@ template <typename T> class OperationPass;
class OwningRewritePatternList;
class Pass;
+std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
+
std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass();
std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
@@ -59,6 +61,11 @@ createConvertLinalgOnTensorsToBuffersPass();
void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
+/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
+/// tensors.
+void populateLinalgFoldUnitExtentDimsPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns);
+
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 850f381dd4ef..1fc7fa5bf729 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -11,6 +11,17 @@
include "mlir/Pass/PassBase.td"
+def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
+ let summary = "Remove unit-extent dimension in Linalg ops on tensors";
+ let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
+ let options = [
+ Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool",
+ /*default=*/"false",
+ "Only folds the one-trip loops from Linalg ops on tensors "
+ "(for testing purposes only)">
+ ];
+}
+
def LinalgFusion : FunctionPass<"linalg-fusion"> {
let summary = "Fuse operations in the linalg dialect";
let constructor = "mlir::createLinalgFusionPass()";
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c7a0f9d3812d..db4587fce014 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -265,7 +265,7 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
ArrayRef<AffineMap> mapsConsumer,
MLIRContext *context) {
- if (mapsProducer.size() == 0 || mapsConsumer.size() == 0 ||
+ if (mapsProducer.empty() || mapsConsumer.empty() ||
mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
mapsProducer.size() != mapsConsumer[0].getNumDims())
return nullptr;
@@ -277,7 +277,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
for (AffineExpr rhsExpr : rhs.getResults()) {
AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
- i != e; ++i) {
+ i < e; ++i) {
reassociations.push_back(getAffineDimExpr(currDim++, context));
}
}
@@ -1129,8 +1129,6 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
return {};
}
OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute>) {
- if (succeeded(foldMemRefCast(*this)))
- return getResult();
return foldReshapeOp(*this);
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 097fa355a131..c87e3d4f15b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRLinalgTransforms
+ DropUnitDims.cpp
Fusion.cpp
Interchange.cpp
Loops.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
new file mode 100644
index 000000000000..e08c43d48ba0
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -0,0 +1,375 @@
+//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns/pass to remove usage of unit-extent dimensions
+// to specify broadcasting in favor of more canonical representation of the
+// computation
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-drop-unit-dims"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+
+/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
+/// broadcasting. For example,
+///
+/// ```mlir
+/// #accesses = [
+/// affine_map<(d0, d1) -> (0, d1)>,
+/// affine_map<(d0, d1) -> (d0, 0)>,
+/// affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+/// args_in = 2,
+/// args_out = 1,
+/// indexing_maps = #accesses,
+/// iterator_types = ["parallel", "parallel"],
+/// library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
+/// tensor<5xf32> into tensor<1x5xf32>
+/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
+/// tensor<5xf32> into tensor<5x1xf32>
+/// %2 = linalg.generic #trait %0, %1 {
+/// ^bb0(%arg2: f32, %arg3: f32):
+/// %3 = addf %arg2, %arg3 : f32
+/// linalg.yield %3 : f32
+/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
+/// return %2 : tensor<5x5xf32>
+/// }
+///
+/// would canonicalize to
+///
+/// ```mlir
+/// #accesses = [
+/// affine_map<(d0, d1) -> (d1)>,
+/// affine_map<(d0, d1) -> (d0)>,
+/// affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+/// args_in = 2,
+/// args_out = 1,
+/// indexing_maps = #accesses,
+/// iterator_types = ["parallel", "parallel"],
+/// library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+/// %0 = linalg.generic #trait %arg0, %arg1 {
+/// ^bb0(%arg2: f32, %arg3: f32):
+/// %3 = addf %arg2, %arg3 : f32
+/// linalg.yield %3 : f32
+/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
+/// return %0 : tensor<5x5xf32>
+/// }
+
+/// Given dims of the iteration space of a structured op that are known to be
+/// single trip count (`unitDims`), return the indexing maps to use in the
+/// canonicalized op with these dims removed, given the original `indexingMaps`.
+static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
+ ArrayRef<AffineMap> indexingMaps,
+ MLIRContext *context) {
+ if (indexingMaps.empty())
+ return nullptr;
+ unsigned numIterationDims = indexingMaps.front().getNumDims();
+ unsigned numSymbols = indexingMaps.front().getNumSymbols();
+
+ // Compute the replacement for each dim expr.
+ SmallVector<AffineExpr, 4> dimReplacements;
+ dimReplacements.reserve(numIterationDims);
+ unsigned numKeptDims = 0;
+ for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
+ if (unitDims.count(dim))
+ dimReplacements.push_back(getAffineConstantExpr(0, context));
+ else
+ dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
+ }
+
+ // Symbols remain the same.
+ SmallVector<AffineExpr, 4> symReplacements;
+ symReplacements.reserve(numSymbols);
+ for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
+ symReplacements.push_back(getAffineSymbolExpr(symbol, context));
+
+ SmallVector<AffineMap, 4> newIndexingMaps;
+ newIndexingMaps.reserve(indexingMaps.size());
+ for (AffineMap operandMap : indexingMaps) {
+ // Expected indexing maps to have no symbols.
+ if (operandMap.getNumSymbols())
+ return nullptr;
+ newIndexingMaps.push_back(simplifyAffineMap(
+ operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
+ numIterationDims - unitDims.size(),
+ numSymbols)));
+ }
+
+ // Check that the new index maps are invertible. If not, something went
+ // wrong, so abort.
+ if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
+ return nullptr;
+ return ArrayAttr::get(
+ llvm::to_vector<4>(llvm::map_range(
+ newIndexingMaps,
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
+ context);
+}
+
+namespace {
+/// Pattern to fold unit-trip count loops in GenericOps.
+// TODO: Generalize this to indexed-generic as well by modifying the region args
+// as well.
+struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.empty())
+ return failure();
+
+ // Check if any of the iteration dimensions are unit-trip count. They will
+ // end up being unit-trip count if they are used to index into a unit-dim
+ // tensor/memref.
+ AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
+ if (!invertedMap)
+ return failure();
+ SmallVector<int64_t, 4> dims;
+ for (ShapedType shapedType : genericOp.getInputOutputShapedTypes())
+ dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
+ DenseSet<unsigned> unitDims;
+ ArrayAttr iteratorTypes = genericOp.iterator_types();
+ for (auto expr : enumerate(invertedMap.getResults())) {
+ if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
+ if (dims[dimExpr.getPosition()] == 1 &&
+ iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() ==
+ getParallelIteratorTypeName())
+ unitDims.insert(expr.index());
+ }
+ if (unitDims.empty())
+ return failure();
+
+ // Compute the modified indexing maps.
+ MLIRContext *context = rewriter.getContext();
+ ArrayAttr newIndexingMapAttr =
+ replaceUnitDims(unitDims, indexingMaps, context);
+ if (!newIndexingMapAttr)
+ return genericOp.emitError("unable to compute modified indexing_maps");
+
+ // Compute the iterator types of the modified op by dropping the one-trip
+ // count loops.
+ SmallVector<Attribute, 4> newIteratorTypes;
+ for (auto attr : llvm::enumerate(iteratorTypes)) {
+ if (!unitDims.count(attr.index()))
+ newIteratorTypes.push_back(attr.value());
+ }
+
+ rewriter.startRootUpdate(genericOp);
+ genericOp.indexing_mapsAttr(newIndexingMapAttr);
+ genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
+ rewriter.finalizeRootUpdate(genericOp);
+ return success();
+ }
+};
+
+struct UnitExtentReplacementInfo {
+ RankedTensorType type;
+ AffineMap indexMap;
+ ArrayAttr reassociation;
+};
+} // namespace
+
+/// Utility function for replacing operands/results to a linalg generic
+/// operation on tensors with unit-extent dimensions. These can be replaced with
+/// an operand/result with the unit-extent dimension removed. This is only done
+/// if the indexing map used to access that didimensionmension has a
+/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
+/// Linalg op, and its `indexMap` the utility function returns:
+/// - the new type with dimensions of size 1 removed.
+/// - modified index map that can be used to access the replaced result/operand
+/// - the reassociation that converts from the original tensor type to the
+/// modified tensor type.
+static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
+ RankedTensorType type,
+ MLIRContext *context) {
+ ArrayRef<int64_t> shape = type.getShape();
+ ArrayRef<AffineExpr> exprs = indexMap.getResults();
+ SmallVector<AffineExpr, 2> reassociations;
+ SmallVector<Attribute, 4> reassociationMaps;
+ SmallVector<AffineExpr, 4> newIndexExprs;
+ SmallVector<int64_t, 4> newShape;
+
+ int64_t origRank = type.getRank();
+ AffineExpr zeroExpr = getAffineConstantExpr(0, context);
+ auto isUnitExtent = [&](int64_t dim) -> bool {
+ return shape[dim] == 1 && exprs[dim] == zeroExpr;
+ };
+
+ unsigned dim = 0;
+ // Fold dimensions that are unit-extent at the beginning of the tensor.
+ while (dim < origRank && isUnitExtent(dim))
+ reassociations.push_back(getAffineDimExpr(dim++, context));
+ while (dim < origRank) {
+ reassociations.push_back(getAffineDimExpr(dim, context));
+ newIndexExprs.push_back(exprs[dim]);
+ newShape.push_back(shape[dim]);
+ // Fold all following dimensions that are unit-extent.
+ while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
+ ++dim;
+ reassociations.push_back(getAffineDimExpr(dim, context));
+ }
+ reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
+ origRank, /*numSymbols = */ 0, reassociations, context)));
+ reassociations.clear();
+ ++dim;
+ }
+ UnitExtentReplacementInfo info = {
+ RankedTensorType::get(newShape, type.getElementType()),
+ AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
+ newIndexExprs, context),
+ ArrayAttr::get(reassociationMaps, context)};
+ return info;
+}
+
+namespace {
+/// Pattern to replace tensors operands/results that are unit extents.
+struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+
+ MLIRContext *context = rewriter.getContext();
+ Location loc = genericOp.getLoc();
+
+ SmallVector<AffineMap, 4> newIndexingMaps;
+ SmallVector<ArrayAttr, 4> reassociationMaps;
+ SmallVector<ShapedType, 4> newInputOutputTypes;
+ bool doCanonicalization = false;
+ for (auto it : llvm::zip(genericOp.getIndexingMaps(),
+ genericOp.getInputOutputShapedTypes())) {
+ auto replacementInfo = replaceUnitExtents(
+ std::get<0>(it), std::get<1>(it).cast<RankedTensorType>(), context);
+ reassociationMaps.push_back(replacementInfo.reassociation);
+ newIndexingMaps.push_back(replacementInfo.indexMap);
+ newInputOutputTypes.push_back(replacementInfo.type);
+ doCanonicalization =
+ doCanonicalization || replacementInfo.type != std::get<1>(it);
+ }
+
+ // If the indexing maps of the result operation are not invertible (i.e. not
+ // legal), abort.
+ if (!doCanonicalization ||
+ !inversePermutation(concatAffineMaps(newIndexingMaps)))
+ return failure();
+
+ // If any operand type change, insert a reshape to convert from the original
+ // type to the new type.
+ SmallVector<Value, 4> newOperands;
+ newOperands.reserve(genericOp.getNumOperands());
+ for (auto operand : llvm::enumerate(genericOp.getOperands())) {
+ if (operand.value().getType() == newInputOutputTypes[operand.index()]) {
+ newOperands.push_back(operand.value());
+ } else {
+ newOperands.push_back(rewriter.create<linalg::TensorReshapeOp>(
+ loc, newInputOutputTypes[operand.index()], operand.value(),
+ reassociationMaps[operand.index()]));
+ }
+ }
+
+ // If any result type change, insert a reshape to convert from the original
+ // type to the new type.
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.reserve(genericOp.getNumResults());
+ for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+ resultTypes.push_back(
+ newInputOutputTypes[i + genericOp.getNumOperands()]);
+ GenericOp replacementOp = rewriter.create<GenericOp>(
+ loc, resultTypes, newOperands, genericOp.args_in(),
+ genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps),
+ genericOp.iterator_types(),
+ /*doc = */ nullptr,
+ /*library_call = */ nullptr);
+ rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
+ replacementOp.region().begin());
+
+ // If any result tensor has a modified shape, then add reshape to recover
+ // the original shape.
+ SmallVector<Value, 4> resultReplacements;
+ for (auto result : llvm::enumerate(replacementOp.getResults())) {
+ unsigned index = result.index() + replacementOp.getNumOperands();
+ RankedTensorType origResultType = genericOp.getResult(result.index())
+ .getType()
+ .cast<RankedTensorType>();
+ if (origResultType != result.value().getType()) {
+ resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
+ loc, origResultType, result.value(), reassociationMaps[index]));
+ } else {
+ resultReplacements.push_back(result.value());
+ }
+ }
+ rewriter.replaceOp(genericOp, resultReplacements);
+ return success();
+ }
+};
+} // namespace
+
+/// Patterns that are used to canonicalize the use of unit-extent dims for
+/// broadcasting.
+void mlir::populateLinalgFoldUnitExtentDimsPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
+ TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
+}
+
+namespace {
+/// Pass that removes unit-extent dims within generic ops.
+struct LinalgFoldUnitExtentDimsPass
+ : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ FuncOp funcOp = getFunction();
+ MLIRContext *context = funcOp.getContext();
+ if (foldOneTripLoopsOnly)
+ patterns.insert<FoldUnitDimLoops>(context);
+ else
+ populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
+ applyPatternsAndFoldGreedily(funcOp.getBody(), patterns);
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgFoldUnitExtentDimsPass() {
+ return std::make_unique<LinalgFoldUnitExtentDimsPass>();
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 3123f95452fd..3f3c1c53fc3a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -575,8 +575,8 @@ struct FuseGenericOpsOnTensors {
if (auto yieldOp = dyn_cast<YieldOp>(op)) {
// Lookup the value the yield operation is mapped to.
Value yieldVal = yieldOp.getOperand(0);
- auto clonedVal = mapper.lookup(yieldVal);
- mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
+ if (Value clonedVal = mapper.lookupOrNull(yieldVal))
+ mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
continue;
}
rewriter.clone(op, mapper);
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
new file mode 100644
index 000000000000..a5169c35d18d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir-opt %s -linalg-fold-unit-extent-dims -split-input-file | FileCheck %s
+
+#accesses = [
+ affine_map<(i, j, k, l, m) -> (i, k, m)>,
+ affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+ args_in = 1,
+ args_out = 1,
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = #accesses,
+ library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
+{
+ %0 = linalg.generic #trait %arg0 {
+ ^bb0(%arg1 : f32) :
+ linalg.yield %arg1 : f32
+ } : tensor<?x1x?xf32> -> tensor<?x1x?x1x?xf32>
+ return %0 : tensor<?x1x?x1x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+// CHECK-DAG: #[[MAP6:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
+// CHECK-LABEL: func @drop_one_trip_loops
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP4]], #[[MAP5]], #[[MAP6]]]
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+ args_in = 1,
+ args_out = 1,
+ iterator_types = ["parallel", "parallel"],
+ indexing_maps = #access,
+ library_call = "some_external_func"
+}
+
+func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
+{
+ %0 = linalg.generic #trait %arg0 {
+ ^bb0(%arg1: f32) :
+ linalg.yield %arg1 : f32
+ } : tensor<1x1xf32> -> tensor<1x1xf32>
+ return %0 : tensor<1x1xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> ()>
+// CHECK-LABEL: func @drop_all_loops
+// CHECK: linalg.tensor_reshape %{{.*}} []
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = []
+
+// -----
+
+#accesses = [
+ affine_map<(d0) -> (0, d0)>,
+ affine_map<(d0) -> (d0)>
+]
+
+#trait = {
+ args_in = 1,
+ args_out = 1,
+ indexing_maps = #accesses,
+ iterator_types = ["parallel"],
+ library_call = "some_external_fn"
+}
+
+func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
+ %0 = linalg.generic #trait %arg0 {
+ ^bb0(%arg2: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } : tensor<1x5xf32> -> tensor<5xf32>
+ return %0 : tensor<5xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @leading_dim_1_canonicalization
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel"]
+
+// -----
+
+#accesses = [
+ affine_map<(d0, d1) -> (0, d1)>,
+ affine_map<(d0, d1) -> (d0, 0)>,
+ affine_map<(d0, d1) -> (d0, d1)>
+]
+
+#trait = {
+ args_in = 2,
+ args_out = 1,
+ indexing_maps = #accesses,
+ iterator_types = ["parallel", "parallel"],
+ library_call = "some_external_fn"
+}
+
+func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5xf32>
+{
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
+ tensor<5xf32> into tensor<1x5xf32>
+ %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
+ tensor<5xf32> into tensor<5x1xf32>
+ %2 = linalg.generic #trait %0, %1 {
+ ^bb0(%arg2: f32, %arg3: f32):
+ %3 = addf %arg2, %arg3 : f32
+ linalg.yield %3 : f32
+ } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
+ return %2 : tensor<5x5xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @broadcast_test
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+#accesses = [
+ affine_map<(d0, d1) -> (0, 0)>,
+ affine_map<(d0, d1) -> (d0, d1)>
+]
+
+#trait = {
+ args_in = 1,
+ args_out = 1,
+ indexing_maps = #accesses,
+ iterator_types = ["parallel", "parallel"],
+ library_call = "some_external_fn"
+}
+
+func @broadcast_scalar(%arg0 : tensor<1x1xf32>) -> tensor<?x?xf32>
+{
+ %0 = linalg.generic #trait %arg0 {
+ ^bb0(%arg1 : f32):
+ linalg.yield %arg1 : f32
+ } : tensor<1x1xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @broadcast_scalar
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1xf32>
+// CHECK: %[[A:.*]] = linalg.tensor_reshape %[[ARG0]] []
+// CHECK-SAME: tensor<1x1xf32> into tensor<f32>
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: %[[A]]
diff --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
new file mode 100644
index 000000000000..a977ab4cadd9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
@@ -0,0 +1,110 @@
+// RUN: mlir-opt %s -linalg-fold-unit-extent-dims="fold-one-trip-loops-only" -split-input-file | FileCheck %s
+
+#accesses = [
+ affine_map<(i, j, k, l, m) -> (i, k, m)>,
+ affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+ args_in = 1,
+ args_out = 1,
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = #accesses,
+ library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
+{
+ %0 = linalg.generic #trait %arg0 {
+ ^bb0(%arg1 : f32) :
+ linalg.yield %arg1 : f32
+ } : tensor<?x1x?xf32> -> tensor<?x1x?x1x?xf32>
+ return %0 : tensor<?x1x?x1x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d1, 0, d2)>
+// CHECK-LABEL: func @drop_one_trip_loops
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+ args_in = 1,
+ args_out = 1,
+ iterator_types = ["parallel", "parallel"],
+ indexing_maps = #access,
+ library_call = "some_external_func"
+}
+
+func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
+{
+ %0 = linalg.generic #trait %arg0 {
+ ^bb0(%arg1: f32) :
+ linalg.yield %arg1 : f32
+ } : tensor<1x1xf32> -> tensor<1x1xf32>
+ return %0 : tensor<1x1xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> (0, 0)>
+// CHECK-LABEL: func @drop_all_loops
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = []
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+ args_in = 1,
+ args_out = 1,
+ iterator_types = ["parallel", "parallel"],
+ indexing_maps = #access,
+ library_call = "some_external_func"
+}
+
+func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>)
+{
+ linalg.generic #trait %arg0, %arg1 {
+ ^bb0(%arg2: f32, %arg3 : f32) :
+ linalg.yield %arg2 : f32
+ } : memref<1x1xf32>, memref<1x1xf32>
+ return
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> (0, 0)>
+// CHECK-LABEL: func @drop_all_loops
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = []
+
+// -----
+
+#accesses = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>
+]
+
+#trait = {
+ args_in = 1,
+ args_out = 1,
+ indexing_maps = #accesses,
+ iterator_types = ["parallel", "parallel"],
+ library_call = "some_external_fn"
+}
+
+func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
+ %0 = linalg.generic #trait %arg0 {
+ ^bb0(%arg2: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } : tensor<1x5xf32> -> tensor<5xf32>
+ return %0 : tensor<5xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @leading_dim_1_canonicalization
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel"]
More information about the Mlir-commits
mailing list