[Mlir-commits] [mlir] 774c9c6 - [mlir][Linalg] Add canonicalization of linalg op -> dim op.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 14 16:17:29 PST 2021


Author: MaheshRavishankar
Date: 2021-01-14T16:17:08-08:00
New Revision: 774c9c6ef3addc735939a388965a0a694bbd4f57

URL: https://github.com/llvm/llvm-project/commit/774c9c6ef3addc735939a388965a0a694bbd4f57
DIFF: https://github.com/llvm/llvm-project/commit/774c9c6ef3addc735939a388965a0a694bbd4f57.diff

LOG: [mlir][Linalg] Add canonicalization of linalg op -> dim op.

Add canonicalization to replace use of the result of a linalg
operation on tensors in a dim operation, to use one of the operands of
the linalg operations instead. This allows the linalg op itself to be
deleted when all its non-dim uses are removed (say through tiling, etc.)

Differential Revision: https://reviews.llvm.org/D93076

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
    mlir/include/mlir/IR/AffineExprVisitor.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 8ac82b768ad3..a706d67d2988 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -32,6 +32,9 @@ def Linalg_Dialect : Dialect {
     the op semantics.
   }];
   let cppNamespace = "::mlir::linalg";
+  let dependentDialects = [
+    "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
+  ];
 }
 
 // Whether a type is a RangeType.

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index f3b7181d71a5..85133604cda0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -946,6 +946,56 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return inversePermutation(getLoopsToShapesMap());
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the position in the results of the affine map computed
+        by getLoopsToShapesMap() that represents the shape of an
+        operand (input or output) at a dimension.
+      }],
+      /*retTy=*/"Optional<unsigned>",
+      /*methodName=*/"getOperandDimPositionInLoopsToShapeMap",
+      /*args=*/(ins "unsigned":$operandIdx, "unsigned":$dim),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        unsigned pos = 0;
+        for (auto type : llvm::enumerate(getShapedOperandTypes())) {
+          if (type.index() == operandIdx) return pos + dim;
+          pos += type.value().getRank();
+        }
+        return {};
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the position in the results of the affine map computed
+        by getLoopsToShapesMap() that represents the shape of an
+        input operand at a dimension.
+      }],
+      /*retTy=*/"Optional<unsigned>",
+      /*methodName=*/"getInputValueDimPositionInLoopsToShapeMap",
+      /*args=*/(ins "unsigned":$inputIdx, "unsigned":$dim),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        if (inputIdx >= getNumInputs()) return {};
+        return getOperandDimPositionInLoopsToShapeMap(inputIdx, dim);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the position in the results of the affine map computed
+        by getLoopsToShapesMap() that represents the shape of the
+        result value at a dimension.
+      }],
+      /*retTy=*/"Optional<unsigned>",
+      /*methodName=*/"getResultValueDimPositionInLoopsToShapeMap",
+      /*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        if (resultIdx >= getNumOutputs()) return {};
+        return getOperandDimPositionInLoopsToShapeMap(
+            getNumInputs() + resultIdx, dim);
+      }]
+    >,
 
     //===------------------------------------------------------------------===//
     // Other static interface methods.
@@ -1027,6 +1077,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       }
       return res;
     }
+
+    /// Returns the value that expresses the shape of the output in terms of
+    /// shape of the input operands where possible
+    Optional<Value> inferResultDimFromInputShapes
+      (OpBuilder &b, Location loc, unsigned resultIdx, unsigned im);
+
     //========================================================================//
     // Helper functions to mutate the `operand_segment_sizes` attribute.
     // These are useful when cloning and changing operand types.

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index a4e32b9263e8..71ac601977fa 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -9,6 +9,9 @@
 #ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_
 #define MLIR_DIALECT_LINALG_LINALGTYPES_H_
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Types.h"
 

diff  --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 65019c8830f6..03bb4b24db54 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -159,29 +159,29 @@ template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
 
   // Default visit methods. Note that the default op-specific binary op visit
   // methods call the general visitAffineBinaryOpExpr visit method.
-  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {}
-  void visitAddExpr(AffineBinaryOpExpr expr) {
-    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
+  RetTy visitAddExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitMulExpr(AffineBinaryOpExpr expr) {
-    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  RetTy visitMulExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitModExpr(AffineBinaryOpExpr expr) {
-    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  RetTy visitModExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitFloorDivExpr(AffineBinaryOpExpr expr) {
-    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitCeilDivExpr(AffineBinaryOpExpr expr) {
-    static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+  RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
+    return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
   }
-  void visitConstantExpr(AffineConstantExpr expr) {}
-  void visitDimExpr(AffineDimExpr expr) {}
-  void visitSymbolExpr(AffineSymbolExpr expr) {}
+  RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
+  RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
+  RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
 
 private:
   // Walk the operands - each operand is itself walked in post order.
-  void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+  RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
     walkPostOrder(expr.getLHS());
     walkPostOrder(expr.getRHS());
   }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b74e44d91176..30a6b9c0c371 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -16,12 +16,14 @@
 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
@@ -86,6 +88,82 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
   return res;
 }
 
+/// Visitor to check if any of the given set of positions from AffineDimExprs
+/// are used within an AffineExpr.
+struct HasAffineDimExprVisitor
+    : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
+  HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
+      : positions(positions) {}
+
+  bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
+    return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
+  }
+
+  bool visitDimExpr(AffineDimExpr dimExpr) {
+    return positions.count(dimExpr.getPosition());
+  }
+
+  bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
+
+  bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
+
+private:
+  llvm::SmallSet<unsigned, 4> positions;
+};
+
+Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
+                                                        Location loc,
+                                                        unsigned resultIdx,
+                                                        unsigned dim) {
+  // An example that helps understand the logic below.
+  // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
+  // We want to express the shape of dim 0 of O in terms of shape of the inputs.
+  // This is achieved as follows.
+  //   loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
+  //   subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
+  //   shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
+  //   resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
+  //     = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
+  AffineMap loopsToShapesMap = getLoopsToShapesMap();
+
+  // Find the position in the above map that represents the shape of the
+  // result:dim being inferred.
+  Optional<unsigned> resultDimSubMapPos =
+      getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
+  if (!resultDimSubMapPos)
+    return {};
+
+  /// From loopsToShapesMap extract the submap that represents the shape of the
+  /// (resultIdx, dim) needed
+  AffineMap loopToResultDimShapeMap =
+      loopsToShapesMap.getSubMap(*resultDimSubMapPos);
+  AffineMap operandShapesToResultDimMap =
+      loopToResultDimShapeMap.compose(getShapesToLoopsMap());
+
+  // Check that the result dim map does not contain the positions corresponding
+  // to the outputs.
+  llvm::SmallSet<unsigned, 4> outputDims;
+  unsigned outputDimPosStart =
+      getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
+  unsigned outputDimPosEnd =
+      getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
+                                                 getOutputOpOperands()
+                                                         .back()
+                                                         .get()
+                                                         .getType()
+                                                         .cast<ShapedType>()
+                                                         .getRank() -
+                                                     1)
+          .getValue();
+  llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
+                 [&outputDims](unsigned dim) { outputDims.insert(dim); });
+  HasAffineDimExprVisitor checkDimExpr(outputDims);
+  if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
+    return llvm::None;
+  return applyMapToValues(b, loc, operandShapesToResultDimMap,
+                          createFlatListOfOperandDims(b, loc))[0];
+}
+
 /// Forward declarations.
 template <typename NamedStructuredOpType>
 static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
@@ -2022,6 +2100,49 @@ struct FoldTensorCastOp : public RewritePattern {
     return success();
   }
 };
+
+/// Replaces std.dim operations that use the result of a LinalgOp (on tensors)
+/// with std.dim operations that use one of the arguments. For example,
+///
+/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
+/// %1 = dim %0, %c0
+///
+/// with
+///
+/// %1 = dim %arg0, %c0
+///
+/// where possible. With this the result of the `linalg.matmul` is not used in
+/// dim operations. If the value produced is replaced with another value (say by
+/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
+/// used in a dim op that would prevent the DCE of this op.
+struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    Value dimValue = dimOp.memrefOrTensor();
+    Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+    if (!dimIndex)
+      return failure();
+    auto linalgOp = dimValue.getDefiningOp<LinalgOp>();
+    if (!linalgOp)
+      return failure();
+
+    unsigned resultIndex = dimValue.cast<OpResult>().getResultNumber();
+    Optional<Value> operandDimValue = linalgOp.inferResultDimFromInputShapes(
+        rewriter, dimOp.getLoc(), resultIndex,
+        static_cast<unsigned>(*dimIndex));
+    if (!operandDimValue) {
+      // Its always possible to replace using the corresponding `outs`
+      // parameter.
+      operandDimValue = rewriter.create<DimOp>(
+          dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
+    }
+    rewriter.replaceOp(dimOp, *operandDimValue);
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -2166,26 +2287,6 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
     return success();
   }
 };
-
-/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
-/// with the corresponding output tensor argument of the linalg op.
-struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
-  using OpRewritePattern<DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DimOp dimOp,
-                                PatternRewriter &rewriter) const override {
-    Value dimOpArg = dimOp.memrefOrTensor();
-    auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>();
-    if (!linalgOp)
-      return failure();
-
-    auto results = linalgOp.getOperation()->getResults();
-    int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg));
-    auto outputTensors = linalgOp.getOutputTensors();
-    rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index());
-    return success();
-  }
-};
 } // namespace
 
 #define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
@@ -2193,7 +2294,7 @@ struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
                                         MLIRContext *context) {                \
     results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,     \
                    RemoveIdentityLinalgOps>();                                 \
-    results.insert<ReplaceDimOfLinalgResult>(context);                         \
+    results.insert<ReplaceDimOfLinalgOpResult>(context);                       \
   }                                                                            \
                                                                                \
   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index ba31ca5a034b..9d39e4e8c75a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -58,9 +58,6 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
 //===----------------------------------------------------------------------===//
 
 void mlir::linalg::LinalgDialect::initialize() {
-  getContext()->getOrLoadDialect("std");
-  getContext()->getOrLoadDialect("tensor");
-
   addTypes<RangeType>();
   addOperations<
 #define GET_OP_LIST

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index b2de3fdc6c8e..ca7f82c1b254 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -390,10 +390,147 @@ func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
 
 // -----
 
+func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+  %1 = dim %0, %c0 : tensor<?x?xf32>
+  %2 = dim %0, %c1 : tensor<?x?xf32>
+  return %1, %2 : index, index
+}
+//      CHECK: func @init_tensor_dynamic_dim2
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//      CHECK:   return %[[ARG0]], %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses
+  (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+   %arg2 : tensor<?x?xf32>) -> (index) {
+  %c0 = constant 0 : index
+  %0 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                      affine_map<(d0, d1, d2) -> (d2, d1)>,
+                      affine_map<(d0, d1, d2) -> (d0 + d1, d1)>],
+     iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+      %1 = mulf %arg3, %arg4 : f32
+      %2 = addf %1, %arg5 : f32
+      linalg.yield %2 : f32
+    } -> tensor<?x?xf32>
+  %3 = dim %0, %c0 : tensor<?x?xf32>
+  return %3 : index
+}
+//       CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @remove_dim_result_uses
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[T1:.+]] = dim %[[ARG1]], %[[C1]]
+//       CHECK:   %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]]
+//       CHECK:   return %[[T2]]
+
+// -----
+
+func @remove_dim_result_uses_outs
+  (%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = dim %arg0, %c0 : tensor<?xf32>
+  %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+  %1 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+                      affine_map<(d0, d1) -> (d0, d1)>],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+    ^bb0(%arg2: f32, %arg3: f32) :
+      linalg.yield %arg2 : f32
+    } -> tensor<?x?xf32>
+  %2 = dim %1, %c1 : tensor<?x?xf32>
+  return %2 : index
+}
+//      CHECK: func @remove_dim_result_uses_outs
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//      CHECK:   return %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses_sequence
+  (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+   %arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = dim %0, %c0 : tensor<?x?xf32>
+  %2 = dim %0, %c1 : tensor<?x?xf32>
+  %3 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
+                      affine_map<(d0, d1, d2) -> (d0, d2)>,
+                      affine_map<(d0, d1, d2) -> (d0, d2)>],
+     iterator_types = ["parallel", "reduction", "parallel"]}
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%0 : tensor<?x?xf32>) {
+    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+      %4 = mulf %arg3, %arg4 : f32
+      %5 = addf %4, %arg5 : f32
+      linalg.yield %5 : f32
+    } -> tensor<?x?xf32>
+  %6 = dim %3, %c0 : tensor<?x?xf32>
+  %7 = dim %3, %c1 : tensor<?x?xf32>
+  return %1, %2, %6, %7 : index, index, index, index
+}
+// CHECK-LABEL: func @remove_dim_result_uses_sequence
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[T1:.+]] = dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[T2:.+]] = dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[T3:.+]] = dim %[[ARG1]], %[[C1]]
+//       CHECK:   return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
+
+// -----
+
+func @keep_result_dim_uses_sequence2
+  (%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = dim %arg0, %c0 : tensor<?xf32>
+  %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+  %1 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+                      affine_map<(d0, d1) -> (d0, d1)>],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+    ^bb0(%arg2: f32, %arg3 : f32):
+      linalg.yield %arg2 : f32
+    } -> tensor<?x?xf32>
+  %2 = dim %1, %c0 : tensor<?x?xf32>
+  %3 = dim %1, %c1 : tensor<?x?xf32>
+  return %2, %3 : index, index
+}
+//       CHECK: func @keep_result_dim_uses_sequence2
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
+//       CHECK:   return %[[T0]], %[[ARG1]]
+
+// -----
+
 #map = affine_map<(d0) -> (d0)>
 
 func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
-    %arg_1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+    %arg_1: tensor<?xf32>) -> (index, index) {
   %0, %1 = linalg.generic {
     indexing_maps = [#map, #map, #map],
     iterator_types = ["parallel"]
@@ -405,16 +542,16 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
 
   %c0 = constant 0 : index
   %num_elem_0 = dim %0, %c0 : tensor<?xf32>
-  %result_0 = linalg.init_tensor [%num_elem_0] : tensor<?xf32>
 
   %num_elem_1 = dim %1, %c0 : tensor<?xf32>
-  %result_1 = linalg.init_tensor [%num_elem_1] : tensor<?xf32>
-  return %result_0, %result_1 : tensor<?xf32>, tensor<?xf32>
+  return %num_elem_0, %num_elem_1 : index, index
 }
-// CHECK-LABEL: func @init_tensor_dim_of_linalg_result(
-// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
-// CHECK: dim [[ARG_0]]
-// CHECK: dim [[ARG_1]]
+//      CHECK: func @init_tensor_dim_of_linalg_result(
+// CHECK-SAME:   %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-SAME:   %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
+//      CHECK:   %[[R0:.+]] = dim %[[ARG_0]]
+//      CHECK:   %[[R1:.+]] = dim %[[ARG_0]]
+//      CHECK:   return %[[R0]], %[[R1]]
 
 // -----
 


        


More information about the Mlir-commits mailing list