[Mlir-commits] [mlir] 2b0c854 - [mlir][Linalg] Add pass to remove unit-extent dims from tensor

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 28 11:07:08 PDT 2020


Author: MaheshRavishankar
Date: 2020-05-28T11:06:47-07:00
New Revision: 2b0c8546ac9fb47e1bf9c5e54f1450420eadeab7

URL: https://github.com/llvm/llvm-project/commit/2b0c8546ac9fb47e1bf9c5e54f1450420eadeab7
DIFF: https://github.com/llvm/llvm-project/commit/2b0c8546ac9fb47e1bf9c5e54f1450420eadeab7.diff

LOG: [mlir][Linalg] Add pass to remove unit-extent dims from tensor
operands of Generic ops.

Unit-extent dimensions are typically used for achieving broadcasting
behavior. The pattern added (along with canonicalization patterns
added previously) removes the use of unit-extent dimensions, and
instead uses a more canonical representation of the computation.  This
new pattern is not added as a canonicalization for now since it
entails adding additional reshape operations. A pass is added to
exercise these patterns, along with an API entry to populate a
patterns list with these patterns.

Differential Revision: https://reviews.llvm.org/D79766

Added: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index beac1135a0bc..b03001c9b8e9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -123,6 +123,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       "Return the range over inputs (irrespective of type) and output buffers.",
       "Operation::operand_range", "getInputsAndOutputBuffers"
     >,
+    InterfaceMethod<
+      "Return the shaped types for all the inputs and outputs",
+      "SmallVector<ShapedType, 4>", "getInputOutputShapedTypes"
+    >,
 
     //===------------------------------------------------------------------===//
     // Other interface methods.
@@ -153,6 +157,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       "Return the indexing maps attribute within the current operation.",
       "ArrayAttr", "indexing_maps"
     >,
+    InterfaceMethod<
+      "Return the indexing maps within the current operation.",
+      "SmallVector<AffineMap, 4>", "getIndexingMaps"
+    >,
     InterfaceMethod<"Return the input or output indexing map at index `i`.",
       "AffineMap", "getIndexingMap", (ins "unsigned":$i)
     >,

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index b7bba5a31011..4ab547be2019 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -217,6 +217,18 @@ class StructuredOpTraits
     return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]
         .template cast<ShapedType>();
   }
+  /// Return the shaped types for all the inputs and outputs
+  SmallVector<ShapedType, 4> getInputOutputShapedTypes() {
+    SmallVector<Type, 4> inputOutputTypes(
+        this->getOperation()->operand_type_begin(),
+        this->getOperation()->operand_type_end());
+    inputOutputTypes.append(this->getOperation()->result_type_begin(),
+                            this->getOperation()->result_type_end());
+    return llvm::to_vector<4>(
+        llvm::map_range(inputOutputTypes, [](Type type) -> ShapedType {
+          return type.cast<ShapedType>();
+        }));
+  }
 
   //==========================================================================//
   // Other interface methods.
@@ -295,6 +307,13 @@ class StructuredOpTraits
     return attr;
   }
 
+  SmallVector<AffineMap, 4> getIndexingMaps() {
+    return llvm::to_vector<4>(
+        llvm::map_range(indexing_maps(), [](Attribute attr) -> AffineMap {
+          return attr.cast<AffineMapAttr>().getValue();
+        }));
+  }
+
   AffineMap getIndexingMap(unsigned i) {
     assert(i < getNumInputsAndOutputs());
     return indexing_maps()

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index d3bfa90e6bdb..8a274ed48dc5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -24,6 +24,8 @@ template <typename T> class OperationPass;
 class OwningRewritePatternList;
 class Pass;
 
+std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
+
 std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass();
 std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
 
@@ -59,6 +61,11 @@ createConvertLinalgOnTensorsToBuffersPass();
 void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
                                            OwningRewritePatternList &patterns);
 
+/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
+/// tensors.
+void populateLinalgFoldUnitExtentDimsPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_LINALG_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 850f381dd4ef..1fc7fa5bf729 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -11,6 +11,17 @@
 
 include "mlir/Pass/PassBase.td"
 
+def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
+  let summary = "Remove unit-extent dimension in Linalg ops on tensors";
+  let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
+  let options = [
+    Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool",
+            /*default=*/"false",
+	    "Only folds the one-trip loops from Linalg ops on tensors "
+	    "(for testing purposes only)">
+  ];
+}
+
 def LinalgFusion : FunctionPass<"linalg-fusion"> {
   let summary = "Fuse operations in the linalg dialect";
   let constructor = "mlir::createLinalgFusionPass()";

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c7a0f9d3812d..db4587fce014 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -265,7 +265,7 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
 static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
                                            ArrayRef<AffineMap> mapsConsumer,
                                            MLIRContext *context) {
-  if (mapsProducer.size() == 0 || mapsConsumer.size() == 0 ||
+  if (mapsProducer.empty() || mapsConsumer.empty() ||
       mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
       mapsProducer.size() != mapsConsumer[0].getNumDims())
     return nullptr;
@@ -277,7 +277,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
     for (AffineExpr rhsExpr : rhs.getResults()) {
       AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
       for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
-           i != e; ++i) {
+           i < e; ++i) {
         reassociations.push_back(getAffineDimExpr(currDim++, context));
       }
     }
@@ -1129,8 +1129,6 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
   return {};
 }
 OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute>) {
-  if (succeeded(foldMemRefCast(*this)))
-    return getResult();
   return foldReshapeOp(*this);
 }
 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 097fa355a131..c87e3d4f15b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRLinalgTransforms
+  DropUnitDims.cpp
   Fusion.cpp
   Interchange.cpp
   Loops.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
new file mode 100644
index 000000000000..e08c43d48ba0
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -0,0 +1,375 @@
+//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns/pass to remove usage of unit-extent dimensions
+// to specify broadcasting in favor of more canonical representation of the
+// computation
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-drop-unit-dims"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace mlir::linalg;
+
+/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
+/// broadcasting. For example,
+///
+/// ```mlir
+/// #accesses = [
+///   affine_map<(d0, d1) -> (0, d1)>,
+///   affine_map<(d0, d1) -> (d0, 0)>,
+///   affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+///   args_in = 2,
+///   args_out = 1,
+///   indexing_maps = #accesses,
+///   iterator_types = ["parallel", "parallel"],
+///   library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
+///        tensor<5xf32> into tensor<1x5xf32>
+///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
+///        tensor<5xf32> into tensor<5x1xf32>
+///   %2 = linalg.generic #trait %0, %1 {
+///        ^bb0(%arg2: f32, %arg3: f32):
+///          %3 = addf %arg2, %arg3 : f32
+///          linalg.yield %3 : f32
+///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
+///   return %2 : tensor<5x5xf32>
+/// }
+///
+/// would canonicalize to
+///
+/// ```mlir
+/// #accesses = [
+///   affine_map<(d0, d1) -> (d1)>,
+///   affine_map<(d0, d1) -> (d0)>,
+///   affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+///   args_in = 2,
+///   args_out = 1,
+///   indexing_maps = #accesses,
+///   iterator_types = ["parallel", "parallel"],
+///   library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+///   %0 = linalg.generic #trait %arg0, %arg1 {
+///        ^bb0(%arg2: f32, %arg3: f32):
+///          %3 = addf %arg2, %arg3 : f32
+///          linalg.yield %3 : f32
+///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
+///   return %0 : tensor<5x5xf32>
+/// }
+
+/// Given dims of the iteration space of a structured op that are known to be
+/// single trip count (`unitDims`), return the indexing maps to use in the
+/// canonicalized op with these dims removed, given the original `indexingMaps`.
+static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
+                                 ArrayRef<AffineMap> indexingMaps,
+                                 MLIRContext *context) {
+  if (indexingMaps.empty())
+    return nullptr;
+  unsigned numIterationDims = indexingMaps.front().getNumDims();
+  unsigned numSymbols = indexingMaps.front().getNumSymbols();
+
+  // Compute the replacement for each dim expr.
+  SmallVector<AffineExpr, 4> dimReplacements;
+  dimReplacements.reserve(numIterationDims);
+  unsigned numKeptDims = 0;
+  for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
+    if (unitDims.count(dim))
+      dimReplacements.push_back(getAffineConstantExpr(0, context));
+    else
+      dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
+  }
+
+  // Symbols remain the same.
+  SmallVector<AffineExpr, 4> symReplacements;
+  symReplacements.reserve(numSymbols);
+  for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
+    symReplacements.push_back(getAffineSymbolExpr(symbol, context));
+
+  SmallVector<AffineMap, 4> newIndexingMaps;
+  newIndexingMaps.reserve(indexingMaps.size());
+  for (AffineMap operandMap : indexingMaps) {
+    // Expected indexing maps to have no symbols.
+    if (operandMap.getNumSymbols())
+      return nullptr;
+    newIndexingMaps.push_back(simplifyAffineMap(
+        operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
+                                         numIterationDims - unitDims.size(),
+                                         numSymbols)));
+  }
+
+  // Check that the new index maps are invertible. If not, something went
+  // wrong, so abort.
+  if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
+    return nullptr;
+  return ArrayAttr::get(
+      llvm::to_vector<4>(llvm::map_range(
+          newIndexingMaps,
+          [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
+      context);
+}
+
+namespace {
+/// Pattern to fold unit-trip count loops in GenericOps.
+// TODO: Generalize this to indexed-generic as well by modifying the region args
+// as well.
+struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
+    if (indexingMaps.empty())
+      return failure();
+
+    // Check if any of the iteration dimensions are unit-trip count. They will
+    // end up being unit-trip count if they are used to index into a unit-dim
+    // tensor/memref.
+    AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
+    if (!invertedMap)
+      return failure();
+    SmallVector<int64_t, 4> dims;
+    for (ShapedType shapedType : genericOp.getInputOutputShapedTypes())
+      dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
+    DenseSet<unsigned> unitDims;
+    ArrayAttr iteratorTypes = genericOp.iterator_types();
+    for (auto expr : enumerate(invertedMap.getResults())) {
+      if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
+        if (dims[dimExpr.getPosition()] == 1 &&
+            iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() ==
+                getParallelIteratorTypeName())
+          unitDims.insert(expr.index());
+    }
+    if (unitDims.empty())
+      return failure();
+
+    // Compute the modified indexing maps.
+    MLIRContext *context = rewriter.getContext();
+    ArrayAttr newIndexingMapAttr =
+        replaceUnitDims(unitDims, indexingMaps, context);
+    if (!newIndexingMapAttr)
+      return genericOp.emitError("unable to compute modified indexing_maps");
+
+    // Compute the iterator types of the modified op by dropping the one-trip
+    // count loops.
+    SmallVector<Attribute, 4> newIteratorTypes;
+    for (auto attr : llvm::enumerate(iteratorTypes)) {
+      if (!unitDims.count(attr.index()))
+        newIteratorTypes.push_back(attr.value());
+    }
+
+    rewriter.startRootUpdate(genericOp);
+    genericOp.indexing_mapsAttr(newIndexingMapAttr);
+    genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
+    rewriter.finalizeRootUpdate(genericOp);
+    return success();
+  }
+};
+
+struct UnitExtentReplacementInfo {
+  RankedTensorType type;
+  AffineMap indexMap;
+  ArrayAttr reassociation;
+};
+} // namespace
+
+/// Utility function for replacing operands/results to a linalg generic
+/// operation on tensors with unit-extent dimensions. These can be replaced with
+/// an operand/result with the unit-extent dimension removed. This is only done
+/// if the indexing map used to access that didimensionmension has a
+/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
+/// Linalg op, and its `indexMap` the utility function returns:
+/// - the new type with dimensions of size 1 removed.
+/// - modified index map that can be used to access the replaced result/operand
+/// - the reassociation that converts from the original tensor type to the
+///   modified tensor type.
+static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
+                                                    RankedTensorType type,
+                                                    MLIRContext *context) {
+  ArrayRef<int64_t> shape = type.getShape();
+  ArrayRef<AffineExpr> exprs = indexMap.getResults();
+  SmallVector<AffineExpr, 2> reassociations;
+  SmallVector<Attribute, 4> reassociationMaps;
+  SmallVector<AffineExpr, 4> newIndexExprs;
+  SmallVector<int64_t, 4> newShape;
+
+  int64_t origRank = type.getRank();
+  AffineExpr zeroExpr = getAffineConstantExpr(0, context);
+  auto isUnitExtent = [&](int64_t dim) -> bool {
+    return shape[dim] == 1 && exprs[dim] == zeroExpr;
+  };
+
+  unsigned dim = 0;
+  // Fold dimensions that are unit-extent at the beginning of the tensor.
+  while (dim < origRank && isUnitExtent(dim))
+    reassociations.push_back(getAffineDimExpr(dim++, context));
+  while (dim < origRank) {
+    reassociations.push_back(getAffineDimExpr(dim, context));
+    newIndexExprs.push_back(exprs[dim]);
+    newShape.push_back(shape[dim]);
+    // Fold all following dimensions that are unit-extent.
+    while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
+      ++dim;
+      reassociations.push_back(getAffineDimExpr(dim, context));
+    }
+    reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
+        origRank, /*numSymbols = */ 0, reassociations, context)));
+    reassociations.clear();
+    ++dim;
+  }
+  UnitExtentReplacementInfo info = {
+      RankedTensorType::get(newShape, type.getElementType()),
+      AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
+                     newIndexExprs, context),
+      ArrayAttr::get(reassociationMaps, context)};
+  return info;
+}
+
+namespace {
+/// Pattern to replace tensors operands/results that are unit extents.
+struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+
+    MLIRContext *context = rewriter.getContext();
+    Location loc = genericOp.getLoc();
+
+    SmallVector<AffineMap, 4> newIndexingMaps;
+    SmallVector<ArrayAttr, 4> reassociationMaps;
+    SmallVector<ShapedType, 4> newInputOutputTypes;
+    bool doCanonicalization = false;
+    for (auto it : llvm::zip(genericOp.getIndexingMaps(),
+                             genericOp.getInputOutputShapedTypes())) {
+      auto replacementInfo = replaceUnitExtents(
+          std::get<0>(it), std::get<1>(it).cast<RankedTensorType>(), context);
+      reassociationMaps.push_back(replacementInfo.reassociation);
+      newIndexingMaps.push_back(replacementInfo.indexMap);
+      newInputOutputTypes.push_back(replacementInfo.type);
+      doCanonicalization =
+          doCanonicalization || replacementInfo.type != std::get<1>(it);
+    }
+
+    // If the indexing maps of the result operation are not invertible (i.e. not
+    // legal), abort.
+    if (!doCanonicalization ||
+        !inversePermutation(concatAffineMaps(newIndexingMaps)))
+      return failure();
+
+    // If any operand type change, insert a reshape to convert from the original
+    // type to the new type.
+    SmallVector<Value, 4> newOperands;
+    newOperands.reserve(genericOp.getNumOperands());
+    for (auto operand : llvm::enumerate(genericOp.getOperands())) {
+      if (operand.value().getType() == newInputOutputTypes[operand.index()]) {
+        newOperands.push_back(operand.value());
+      } else {
+        newOperands.push_back(rewriter.create<linalg::TensorReshapeOp>(
+            loc, newInputOutputTypes[operand.index()], operand.value(),
+            reassociationMaps[operand.index()]));
+      }
+    }
+
+    // If any result type change, insert a reshape to convert from the original
+    // type to the new type.
+    SmallVector<Type, 4> resultTypes;
+    resultTypes.reserve(genericOp.getNumResults());
+    for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+      resultTypes.push_back(
+          newInputOutputTypes[i + genericOp.getNumOperands()]);
+    GenericOp replacementOp = rewriter.create<GenericOp>(
+        loc, resultTypes, newOperands, genericOp.args_in(),
+        genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps),
+        genericOp.iterator_types(),
+        /*doc = */ nullptr,
+        /*library_call = */ nullptr);
+    rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
+                                replacementOp.region().begin());
+
+    // If any result tensor has a modified shape, then add reshape to recover
+    // the original shape.
+    SmallVector<Value, 4> resultReplacements;
+    for (auto result : llvm::enumerate(replacementOp.getResults())) {
+      unsigned index = result.index() + replacementOp.getNumOperands();
+      RankedTensorType origResultType = genericOp.getResult(result.index())
+                                            .getType()
+                                            .cast<RankedTensorType>();
+      if (origResultType != result.value().getType()) {
+        resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
+            loc, origResultType, result.value(), reassociationMaps[index]));
+      } else {
+        resultReplacements.push_back(result.value());
+      }
+    }
+    rewriter.replaceOp(genericOp, resultReplacements);
+    return success();
+  }
+};
+} // namespace
+
+/// Patterns that are used to canonicalize the use of unit-extent dims for
+/// broadcasting.
+void mlir::populateLinalgFoldUnitExtentDimsPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
+  patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
+  TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
+}
+
+namespace {
+/// Pass that removes unit-extent dims within generic ops.
+struct LinalgFoldUnitExtentDimsPass
+    : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    FuncOp funcOp = getFunction();
+    MLIRContext *context = funcOp.getContext();
+    if (foldOneTripLoopsOnly)
+      patterns.insert<FoldUnitDimLoops>(context);
+    else
+      populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
+    applyPatternsAndFoldGreedily(funcOp.getBody(), patterns);
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createLinalgFoldUnitExtentDimsPass() {
+  return std::make_unique<LinalgFoldUnitExtentDimsPass>();
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 3123f95452fd..3f3c1c53fc3a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -575,8 +575,8 @@ struct FuseGenericOpsOnTensors {
       if (auto yieldOp = dyn_cast<YieldOp>(op)) {
         // Lookup the value the yield operation is mapped to.
         Value yieldVal = yieldOp.getOperand(0);
-        auto clonedVal = mapper.lookup(yieldVal);
-        mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
+        if (Value clonedVal = mapper.lookupOrNull(yieldVal))
+          mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
         continue;
       }
       rewriter.clone(op, mapper);

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
new file mode 100644
index 000000000000..a5169c35d18d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir-opt %s -linalg-fold-unit-extent-dims -split-input-file | FileCheck %s
+
+#accesses = [
+  affine_map<(i, j, k, l, m) -> (i, k, m)>,
+  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+  indexing_maps = #accesses,
+  library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
+{
+  %0 = linalg.generic #trait %arg0 {
+       ^bb0(%arg1 : f32) :
+         linalg.yield %arg1 : f32
+       } : tensor<?x1x?xf32> -> tensor<?x1x?x1x?xf32>
+  return %0 : tensor<?x1x?x1x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
+//   CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//   CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+//   CHECK-DAG: #[[MAP6:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
+// CHECK-LABEL: func @drop_one_trip_loops
+//       CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[MAP2]], #[[MAP3]]]
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
+//       CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP4]], #[[MAP5]], #[[MAP6]]]
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["parallel", "parallel"],
+  indexing_maps = #access,
+  library_call = "some_external_func"
+}
+
+func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
+{
+  %0 = linalg.generic #trait %arg0 {
+       ^bb0(%arg1: f32) :
+         linalg.yield %arg1 : f32
+       } : tensor<1x1xf32> -> tensor<1x1xf32>
+  return %0 : tensor<1x1xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> ()>
+// CHECK-LABEL: func @drop_all_loops
+//       CHECK:   linalg.tensor_reshape %{{.*}} []
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
+//  CHECK-SAME:     iterator_types = []
+
+// -----
+
+#accesses = [
+  affine_map<(d0) -> (0, d0)>,
+  affine_map<(d0) -> (d0)>
+]
+
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #accesses,
+  iterator_types = ["parallel"],
+  library_call = "some_external_fn"
+}
+
+func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
+  %0 = linalg.generic #trait %arg0 {
+  ^bb0(%arg2: f32):     // no predecessors
+    linalg.yield %arg2 : f32
+  }  : tensor<1x5xf32> -> tensor<5xf32>
+  return %0 : tensor<5xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @leading_dim_1_canonicalization
+//       CHECK:   linalg.tensor_reshape %{{.*}} [#[[MAP0]]]
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP1]]]
+//  CHECK-SAME:     iterator_types = ["parallel"]
+
+// -----
+
+#accesses = [
+  affine_map<(d0, d1) -> (0, d1)>,
+  affine_map<(d0, d1) -> (d0, 0)>,
+  affine_map<(d0, d1) -> (d0, d1)>
+]
+
+#trait = {
+  args_in = 2,
+  args_out = 1,
+  indexing_maps = #accesses,
+  iterator_types = ["parallel", "parallel"],
+  library_call = "some_external_fn"
+}
+
+func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5xf32>
+{
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
+       tensor<5xf32> into tensor<1x5xf32>
+  %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
+       tensor<5xf32> into tensor<5x1xf32>
+  %2 = linalg.generic #trait %0, %1 {
+       ^bb0(%arg2: f32, %arg3: f32):
+         %3 = addf %arg2, %arg3 : f32
+         linalg.yield %3 : f32
+       } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
+  return %2 : tensor<5x5xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
+//   CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @broadcast_test
+//   CHECK-NOT:   linalg.tensor_reshape
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+//  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
+//   CHECK-NOT:   linalg.tensor_reshape
+
+// -----
+
+#accesses = [
+  affine_map<(d0, d1) -> (0, 0)>,
+  affine_map<(d0, d1) -> (d0, d1)>
+]
+
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #accesses,
+  iterator_types = ["parallel", "parallel"],
+  library_call = "some_external_fn"
+}
+
+func @broadcast_scalar(%arg0 : tensor<1x1xf32>) -> tensor<?x?xf32>
+{
+   %0 = linalg.generic #trait %arg0 {
+        ^bb0(%arg1 : f32):
+	  linalg.yield %arg1 : f32
+	} : tensor<1x1xf32> -> tensor<?x?xf32>
+   return %0 : tensor<?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> ()>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @broadcast_scalar
+//  CHECK-SAME:   %[[ARG0:.*]]: tensor<1x1xf32>
+//       CHECK:   %[[A:.*]] = linalg.tensor_reshape %[[ARG0]] []
+//  CHECK-SAME:     tensor<1x1xf32> into tensor<f32>
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
+//  CHECK-SAME:     %[[A]]

diff  --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
new file mode 100644
index 000000000000..a977ab4cadd9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
@@ -0,0 +1,110 @@
+// RUN: mlir-opt %s -linalg-fold-unit-extent-dims="fold-one-trip-loops-only" -split-input-file | FileCheck %s
+
+#accesses = [
+  affine_map<(i, j, k, l, m) -> (i, k, m)>,
+  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+  indexing_maps = #accesses,
+  library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
+{
+  %0 = linalg.generic #trait %arg0 {
+       ^bb0(%arg1 : f32) :
+         linalg.yield %arg1 : f32
+       } : tensor<?x1x?xf32> -> tensor<?x1x?x1x?xf32>
+  return %0 : tensor<?x1x?x1x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d1, 0, d2)>
+// CHECK-LABEL: func @drop_one_trip_loops
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"]
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["parallel", "parallel"],
+  indexing_maps = #access,
+  library_call = "some_external_func"
+}
+
+func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
+{
+  %0 = linalg.generic #trait %arg0 {
+       ^bb0(%arg1: f32) :
+         linalg.yield %arg1 : f32
+       } : tensor<1x1xf32> -> tensor<1x1xf32>
+  return %0 : tensor<1x1xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> (0, 0)>
+// CHECK-LABEL: func @drop_all_loops
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
+//  CHECK-SAME:     iterator_types = []
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  iterator_types = ["parallel", "parallel"],
+  indexing_maps = #access,
+  library_call = "some_external_func"
+}
+
+func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>)
+{
+  linalg.generic #trait %arg0, %arg1 {
+    ^bb0(%arg2: f32, %arg3 : f32) :
+      linalg.yield %arg2 : f32
+    } : memref<1x1xf32>, memref<1x1xf32>
+  return
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<() -> (0, 0)>
+// CHECK-LABEL: func @drop_all_loops
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
+//  CHECK-SAME:     iterator_types = []
+
+// -----
+
+#accesses = [
+  affine_map<(d0, d1) -> (d0, d1)>,
+  affine_map<(d0, d1) -> (d1)>
+]
+
+#trait = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #accesses,
+  iterator_types = ["parallel", "parallel"],
+  library_call = "some_external_fn"
+}
+
+func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
+  %0 = linalg.generic #trait %arg0 {
+  ^bb0(%arg2: f32):     // no predecessors
+    linalg.yield %arg2 : f32
+  }  : tensor<1x5xf32> -> tensor<5xf32>
+  return %0 : tensor<5xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (0, d0)>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @leading_dim_1_canonicalization
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:     iterator_types = ["parallel"]


        


More information about the Mlir-commits mailing list