[Mlir-commits] [mlir] 542668d - [mlir][Linalg] Add support for fusing linalg.tensor_reshape with

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 23 13:42:19 PDT 2020


Author: MaheshRavishankar
Date: 2020-04-23T13:41:47-07:00
New Revision: 542668d1e2060693279462b67d07756fe93f3eb9

URL: https://github.com/llvm/llvm-project/commit/542668d1e2060693279462b67d07756fe93f3eb9
DIFF: https://github.com/llvm/llvm-project/commit/542668d1e2060693279462b67d07756fe93f3eb9.diff

LOG: [mlir][Linalg] Add support for fusing linalg.tensor_reshape with
linalg.generic operations.

Differential Revision: https://reviews.llvm.org/D78464

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/IR/StandardTypes.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 3e667d98f822..10883d03b38b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -172,6 +172,10 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
     RankedTensorType getResultType() {
       return result().getType().cast<RankedTensorType>();
     }
+    SmallVector<AffineMap, 4> getReassociationMaps() {
+      return llvm::to_vector<4>(llvm::map_range(reassociation(),
+        [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
+    }
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index d8886acc5992..5c4868c4c870 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -18,8 +18,10 @@
 
 namespace mlir {
 class FuncOp;
+class MLIRContext;
 class ModuleOp;
 template <typename T> class OperationPass;
+class OwningRewritePatternList;
 class Pass;
 
 std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass();
@@ -48,6 +50,10 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToParallelLoopsPass();
 /// Placeholder for now, this is NYI.
 std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
 
+/// Patterns for fusing linalg operation on tensors.
+void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
+                                           OwningRewritePatternList &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_LINALG_PASSES_H_

diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index cc94d27dedbb..4c5bbba0aa6a 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -672,9 +672,18 @@ MemRefType canonicalizeStridedLayout(MemRefType t);
 /// varying stride is always `1`.
 ///
 /// Examples:
-///   - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
-///   - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
-///   - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
+///   - memref<3x4x5xf32> has canonical stride expression
+///         `20*exprs[0] + 5*exprs[1] + exprs[2]`.
+///   - memref<3x?x5xf32> has canonical stride expression
+///         `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
+///   - memref<3x4x?xf32> has canonical stride expression
+///         `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
+AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+                                          ArrayRef<AffineExpr> exprs,
+                                          MLIRContext *context);
+
+/// Return the result of makeCanonicalStrudedLayoutExpr for the common case
+/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
                                           MLIRContext *context);
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8fa90f444f63..0aa149ef907f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -554,7 +554,7 @@ computeTensorReshapeCollapsedType(RankedTensorType type,
   unsigned currentDim = 0;
   for (AffineMap m : reassociation) {
     unsigned dim = m.getNumResults();
-    auto band = shape.drop_front(currentDim).take_front(dim);
+    auto band = shape.slice(currentDim, dim);
     int64_t size = 1;
     if (llvm::is_contained(band, ShapedType::kDynamicSize))
       size = ShapedType::kDynamicSize;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 1184b5f87ea6..cd6301ae249c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -559,6 +559,187 @@ struct FuseGenericOpsOnTensors {
 };
 } // namespace
 
+/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
+/// provided, given the shape of the source tensor that corresponds to the
+/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
+/// are "row-major" ordered logically.
+///
+/// For example:
+///
+/// %0 = op ... : tensor<?x?x4x5xf32>
+/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
+///
+/// and reshape:
+/// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+///                                affine_map<(i, j, k, l) -> (j, k, l)>] :
+///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
+///
+/// would be rewritten into:
+/// %0 = op ... : tensor<?x?x4x5xf32>
+/// with output index_map
+///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
+static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
+                                        ArrayRef<int64_t> sourceShape,
+                                        ArrayRef<AffineMap> reassociationMaps) {
+  SmallVector<AffineExpr, 4> resultExprs;
+  resultExprs.reserve(reassociationMaps.size());
+  ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
+  MLIRContext *context = sourceMap.getContext();
+
+  // Compute the result exprs based on the reassociation maps.
+  for (AffineMap map : reassociationMaps) {
+    ArrayRef<AffineExpr> collapsedDims = map.getResults();
+    // Assume that they are in-order and contiguous (already checked in
+    // verifier).
+    assert(!collapsedDims.empty());
+    unsigned startDim =
+        collapsedDims.front().cast<AffineDimExpr>().getPosition();
+    AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
+        sourceShape.slice(startDim, collapsedDims.size()),
+        sourceExprs.slice(startDim, collapsedDims.size()), context);
+    resultExprs.push_back(linearizedExpr);
+  }
+  return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
+                        resultExprs, context);
+}
+
+/// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
+/// true) or its producer (if `asProducer` is false) given the indexing map at
+/// its use.
+static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
+                                     AffineMap useIndexMap, bool asProducer) {
+  RankedTensorType returnType = reshapeOp.getResultType();
+  RankedTensorType operandType = reshapeOp.getSrcType();
+  // Reshape is fusible with its consumer (i.e. reshape as a producer) when its
+  // operand is of lesser rank than the result. Fusing when operand has higher
+  // rank will require use of mods and divs in the indexing maps of the fused op
+  // which would make it non-invertible. Similarly reshape is fused with its
+  // producer (i.e. reshape as consumer) only if the return type has lesser
+  // rank.
+  if ((asProducer && returnType.getRank() < operandType.getRank()) ||
+      (!asProducer && operandType.getRank() < returnType.getRank()))
+    return false;
+  return useIndexMap.isIdentity();
+}
+
+namespace {
+/// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
+template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer {
+  static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer,
+                        unsigned consumerIdx) {
+    return isTensorReshapeOpFusible(
+        producer, consumer.getInputIndexingMap(consumerIdx), true);
+  }
+
+  static Operation *fuse(TensorReshapeOp producer, LinalgOpTy consumer,
+                         unsigned consumerIdx, PatternRewriter &rewriter,
+                         OperationFolder *folder = nullptr) {
+    if (!isFusible(producer, consumer, consumerIdx))
+      return nullptr;
+
+    // Compute the fused operands list,
+    SmallVector<Value, 2> fusedOperands(consumer.operand_begin(),
+                                        consumer.operand_end());
+    fusedOperands[consumerIdx] = producer.src();
+
+    // Compute indexing_maps for the fused operation. The indexing_maps for the
+    // operands of the consumers that arent fused are the same.
+    SmallVector<AffineMap, 4> fusedIndexMaps =
+        llvm::to_vector<4>(llvm::map_range(
+            consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
+              return attr.cast<AffineMapAttr>().getValue();
+            }));
+
+    // Compute the indexing map to use for the operand of the producer.
+    AffineMap modifiedMap = linearizeCollapsedDims(
+        fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
+        producer.getReassociationMaps());
+    for (AffineExpr expr : modifiedMap.getResults()) {
+      if (!expr.isPureAffine())
+        return nullptr;
+    }
+    fusedIndexMaps[consumerIdx] = modifiedMap;
+
+    // Further check that the resulting index maps can be fused and
+    // inverted. Without this the resultant op is not legal.
+    if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
+      return nullptr;
+
+    SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
+        llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
+          return AffineMapAttr::get(map);
+        }));
+    auto fusedOp = rewriter.create<LinalgOpTy>(
+        rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
+        rewriter.getI64IntegerAttr(fusedOperands.size()),
+        rewriter.getI64IntegerAttr(consumer.getNumResults()),
+        rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
+        /*doc=*/nullptr,
+        /*library_call=*/nullptr);
+    auto &fusedRegion = fusedOp.region();
+    rewriter.cloneRegionBefore(consumer.region(), fusedRegion,
+                               fusedRegion.begin());
+    return fusedOp;
+  }
+};
+
+/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
+template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
+  static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer,
+                        unsigned consumerIdx) {
+    return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
+                                    false);
+  }
+
+  static Operation *fuse(LinalgOpTy producer, TensorReshapeOp consumer,
+                         unsigned consumerIdx, PatternRewriter &rewriter,
+                         OperationFolder *folder = nullptr) {
+    if (!isFusible(producer, consumer, consumerIdx))
+      return nullptr;
+
+    // The indexing_maps for the operands of the fused operation are same as
+    // those for the operands of the producer.
+    SmallVector<AffineMap, 4> fusedIndexMaps =
+        llvm::to_vector<4>(llvm::map_range(
+            producer.indexing_maps(), [](Attribute attr) -> AffineMap {
+              return attr.cast<AffineMapAttr>().getValue();
+            }));
+    // Compute the indexing map to use for the operand of the producer.
+    AffineMap modifiedMap = linearizeCollapsedDims(
+        producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
+        consumer.getReassociationMaps());
+    for (AffineExpr expr : modifiedMap.getResults()) {
+      if (!expr.isPureAffine())
+        return nullptr;
+    }
+    fusedIndexMaps.back() = modifiedMap;
+
+    // Further check that the resulting index maps can be fused and
+    // inverted. Without this the resultant op is not legal.
+    if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
+      return nullptr;
+
+    SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
+        llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
+          return AffineMapAttr::get(map);
+        }));
+
+    auto fusedOp = rewriter.create<LinalgOpTy>(
+        rewriter.getUnknownLoc(), consumer.getResultType(),
+        producer.getOperands(),
+        rewriter.getI64IntegerAttr(producer.getNumOperands()),
+        rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
+        producer.iterator_types(),
+        /*doc=*/nullptr,
+        /*library_call=*/nullptr);
+    auto &fusedRegion = fusedOp.region();
+    rewriter.cloneRegionBefore(producer.region(), fusedRegion,
+                               fusedRegion.begin());
+    return fusedOp;
+  }
+};
+} // namespace
+
 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
                                        Operation *consumer,
                                        unsigned consumerIdx,
@@ -569,6 +750,7 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
   if (!producer || producer->getNumResults() != 1)
     return nullptr;
 
+  // Fuse when consumer is GenericOp.
   if (GenericOp genericOp = dyn_cast<GenericOp>(consumer)) {
     if (!genericOp.hasTensorSemantics())
       return nullptr;
@@ -576,7 +758,21 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
       if (genericOpProducer.hasTensorSemantics())
         return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp,
                                              consumerIdx, rewriter, folder);
+    } else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
+      return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
+          reshapeOpProducer, genericOp, consumerIdx, rewriter, folder);
     }
+    return nullptr;
+  }
+
+  // Fuse when consumer is a TensorReshapeOp.
+  if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
+    if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
+      if (genericOpProducer.hasTensorSemantics())
+        return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse(
+            genericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
+    }
+    return nullptr;
   }
   return nullptr;
 }
@@ -612,7 +808,7 @@ struct FusionOfTensorOpsPass
   void runOnOperation() override {
     OwningRewritePatternList patterns;
     Operation *op = getOperation();
-    patterns.insert<FuseTensorOps<GenericOp>>(op->getContext());
+    populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
   };
 };
@@ -622,6 +818,12 @@ struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
 };
 } // namespace
 
+void mlir::populateLinalgTensorOpsFusionPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
+  patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<TensorReshapeOp>>(
+      context);
+}
+
 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
   return std::make_unique<LinalgFusionPass>();
 }

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 903ae92e6baf..94156c358eb0 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -728,35 +728,47 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
 }
 
 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+                                                ArrayRef<AffineExpr> exprs,
                                                 MLIRContext *context) {
   AffineExpr expr;
   bool dynamicPoisonBit = false;
+  unsigned numDims = 0;
   unsigned nSymbols = 0;
+  // Compute the number of symbols and dimensions of the passed exprs.
+  for (AffineExpr expr : exprs) {
+    expr.walk([&numDims, &nSymbols](AffineExpr d) {
+      if (AffineDimExpr dim = d.dyn_cast<AffineDimExpr>())
+        numDims = std::max(numDims, dim.getPosition() + 1);
+      else if (AffineSymbolExpr symbol = d.dyn_cast<AffineSymbolExpr>())
+        nSymbols = std::max(nSymbols, symbol.getPosition() + 1);
+    });
+  }
   int64_t runningSize = 1;
-  unsigned rank = sizes.size();
-  for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
-    auto size = en.value();
-    auto position = rank - 1 - en.index();
+  for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
+    int64_t size = std::get<1>(en);
     // Degenerate case, no size =-> no stride
     if (size == 0)
       continue;
-    auto d = getAffineDimExpr(position, context);
-    // Static case: stride = runningSize and runningSize *= size.
-    if (!dynamicPoisonBit) {
-      auto cst = getAffineConstantExpr(runningSize, context);
-      expr = expr ? expr + cst * d : cst * d;
-      if (size > 0)
-        runningSize *= size;
-      else
-        // From now on bail into dynamic mode.
-        dynamicPoisonBit = true;
-      continue;
-    }
-    // Dynamic case, new symbol for each new stride.
-    auto sym = getAffineSymbolExpr(nSymbols++, context);
-    expr = expr ? expr + d * sym : d * sym;
+    AffineExpr dimExpr = std::get<0>(en);
+    AffineExpr stride = dynamicPoisonBit
+                            ? getAffineSymbolExpr(nSymbols++, context)
+                            : getAffineConstantExpr(runningSize, context);
+    expr = expr ? expr + dimExpr * stride : dimExpr * stride;
+    if (size > 0)
+      runningSize *= size;
+    else
+      dynamicPoisonBit = true;
   }
-  return simplifyAffineExpr(expr, rank, nSymbols);
+  return simplifyAffineExpr(expr, numDims, nSymbols);
+}
+
+AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+                                                MLIRContext *context) {
+  SmallVector<AffineExpr, 4> exprs;
+  exprs.reserve(sizes.size());
+  for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
+    exprs.push_back(getAffineDimExpr(dim, context));
+  return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
 }
 
 /// Return true if the layout for `t` is compatible with strided semantics.

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 11c38fcb7601..2c00f77edd3f 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -129,3 +129,93 @@ func @add_mul_scalar_fusion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tenso
 
   return %1 : tensor<f32>
 }
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
+                                         %arg1 : tensor<?x?x4x?xf32>) ->
+                                         tensor<?x?x4x?xf32>
+{
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+                                    affine_map<(i, j, k, l) -> (j, k)>,
+                                    affine_map<(i, j, k, l) -> (l)>] :
+    tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
+  %1 = linalg.generic
+    {args_in = 2 : i64, args_out = 1 : i64,
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+    %0, %arg1 {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  }: tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32> -> tensor<?x?x4x?xf32>
+  return %1 : tensor<?x?x4x?xf32>
+}
+
+// CHECK-LABEL: func @generic_op_reshape_producer_fusion
+//       CHECK: linalg.generic
+//  CHECK-SAME:   args_in = 2
+//  CHECK-SAME:   args_out = 1
+//  CHECK-SAME:   indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP1]]]
+//   CHECK-NOT: linalg.generic
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
+                                         %arg1 : tensor<?x?x4x5xf32>) ->
+                                         tensor<?x?xf32>
+{
+  %0 = linalg.generic
+    {args_in = 2 : i64, args_out = 1 : i64,
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+    %arg0, %arg1 {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  }: tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32> -> tensor<?x?x4x5xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?x4x5xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @generic_op_reshape_consumer_fusion
+//       CHECK: linalg.generic
+//  CHECK-SAME:   args_in = 2
+//  CHECK-SAME:   args_out = 1
+//  CHECK-SAME:   indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]]]
+//   CHECK-NOT: linalg.generic
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
+                                           %arg1 : tensor<?x?x?x5xf32>) ->
+                                           tensor<?x?xf32>
+{
+  %0 = linalg.generic
+    {args_in = 2 : i64, args_out = 1 : i64,
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+    %arg0, %arg1 {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  }: tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32> -> tensor<?x?x?x5xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?x?x5xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
+//       CHECK: linalg.tensor_reshape


        


More information about the Mlir-commits mailing list