[llvm-branch-commits] [mlir] 774c9c6 - [mlir][Linalg] Add canonicalization of linalg op -> dim op.
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 14 16:21:40 PST 2021
Author: MaheshRavishankar
Date: 2021-01-14T16:17:08-08:00
New Revision: 774c9c6ef3addc735939a388965a0a694bbd4f57
URL: https://github.com/llvm/llvm-project/commit/774c9c6ef3addc735939a388965a0a694bbd4f57
DIFF: https://github.com/llvm/llvm-project/commit/774c9c6ef3addc735939a388965a0a694bbd4f57.diff
LOG: [mlir][Linalg] Add canonicalization of linalg op -> dim op.
Add canonicalization to replace use of the result of a linalg
operation on tensors in a dim operation, to use one of the operands of
the linalg operations instead. This allows the linalg op itself to be
deleted when all its non-dim uses are removed (say through tiling, etc.)
Differential Revision: https://reviews.llvm.org/D93076
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/include/mlir/IR/AffineExprVisitor.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 8ac82b768ad3..a706d67d2988 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -32,6 +32,9 @@ def Linalg_Dialect : Dialect {
the op semantics.
}];
let cppNamespace = "::mlir::linalg";
+ let dependentDialects = [
+ "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
+ ];
}
// Whether a type is a RangeType.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index f3b7181d71a5..85133604cda0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -946,6 +946,56 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return inversePermutation(getLoopsToShapesMap());
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the position in the results of the affine map computed
+ by getLoopsToShapesMap() that represents the shape of an
+ operand (input or output) at a dimension.
+ }],
+ /*retTy=*/"Optional<unsigned>",
+ /*methodName=*/"getOperandDimPositionInLoopsToShapeMap",
+ /*args=*/(ins "unsigned":$operandIdx, "unsigned":$dim),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ unsigned pos = 0;
+ for (auto type : llvm::enumerate(getShapedOperandTypes())) {
+ if (type.index() == operandIdx) return pos + dim;
+ pos += type.value().getRank();
+ }
+ return {};
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the position in the results of the affine map computed
+ by getLoopsToShapesMap() that represents the shape of an
+ input operand at a dimension.
+ }],
+ /*retTy=*/"Optional<unsigned>",
+ /*methodName=*/"getInputValueDimPositionInLoopsToShapeMap",
+ /*args=*/(ins "unsigned":$inputIdx, "unsigned":$dim),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (inputIdx >= getNumInputs()) return {};
+ return getOperandDimPositionInLoopsToShapeMap(inputIdx, dim);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the position in the results of the affine map computed
+ by getLoopsToShapesMap() that represents the shape of the
+ result value at a dimension.
+ }],
+ /*retTy=*/"Optional<unsigned>",
+ /*methodName=*/"getResultValueDimPositionInLoopsToShapeMap",
+ /*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (resultIdx >= getNumOutputs()) return {};
+ return getOperandDimPositionInLoopsToShapeMap(
+ getNumInputs() + resultIdx, dim);
+ }]
+ >,
//===------------------------------------------------------------------===//
// Other static interface methods.
@@ -1027,6 +1077,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
}
return res;
}
+
+ /// Returns the value that expresses the shape of the output in terms of
+ /// shape of the input operands where possible
+ Optional<Value> inferResultDimFromInputShapes
+ (OpBuilder &b, Location loc, unsigned resultIdx, unsigned im);
+
//========================================================================//
// Helper functions to mutate the `operand_segment_sizes` attribute.
// These are useful when cloning and changing operand types.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index a4e32b9263e8..71ac601977fa 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -9,6 +9,9 @@
#ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_
#define MLIR_DIALECT_LINALG_LINALGTYPES_H_
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 65019c8830f6..03bb4b24db54 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -159,29 +159,29 @@ template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
// Default visit methods. Note that the default op-specific binary op visit
// methods call the general visitAffineBinaryOpExpr visit method.
- void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {}
- void visitAddExpr(AffineBinaryOpExpr expr) {
- static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
+ RetTy visitAddExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitMulExpr(AffineBinaryOpExpr expr) {
- static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ RetTy visitMulExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitModExpr(AffineBinaryOpExpr expr) {
- static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ RetTy visitModExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitFloorDivExpr(AffineBinaryOpExpr expr) {
- static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitCeilDivExpr(AffineBinaryOpExpr expr) {
- static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
+ RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitConstantExpr(AffineConstantExpr expr) {}
- void visitDimExpr(AffineDimExpr expr) {}
- void visitSymbolExpr(AffineSymbolExpr expr) {}
+ RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
+ RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
+ RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
private:
// Walk the operands - each operand is itself walked in post order.
- void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
+ RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
walkPostOrder(expr.getLHS());
walkPostOrder(expr.getRHS());
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b74e44d91176..30a6b9c0c371 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -16,12 +16,14 @@
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
@@ -86,6 +88,82 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
return res;
}
+/// Visitor to check if any of the given set of positions from AffineDimExprs
+/// are used within an AffineExpr.
+struct HasAffineDimExprVisitor
+ : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
+ HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
+ : positions(positions) {}
+
+ bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
+ return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
+ }
+
+ bool visitDimExpr(AffineDimExpr dimExpr) {
+ return positions.count(dimExpr.getPosition());
+ }
+
+ bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
+
+ bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
+
+private:
+ llvm::SmallSet<unsigned, 4> positions;
+};
+
+Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
+ Location loc,
+ unsigned resultIdx,
+ unsigned dim) {
+ // An example that helps understand the logic below.
+ // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
+ // We want to express the shape of dim 0 of O in terms of shape of the inputs.
+ // This is achieved as follows.
+ // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
+ // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
+ // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
+ // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
+ // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
+ AffineMap loopsToShapesMap = getLoopsToShapesMap();
+
+ // Find the position in the above map that represents the shape of the
+ // result:dim being inferred.
+ Optional<unsigned> resultDimSubMapPos =
+ getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
+ if (!resultDimSubMapPos)
+ return {};
+
+ /// From loopsToShapesMap extract the submap that represents the shape of the
+ /// (resultIdx, dim) needed
+ AffineMap loopToResultDimShapeMap =
+ loopsToShapesMap.getSubMap(*resultDimSubMapPos);
+ AffineMap operandShapesToResultDimMap =
+ loopToResultDimShapeMap.compose(getShapesToLoopsMap());
+
+ // Check that the result dim map does not contain the positions corresponding
+ // to the outputs.
+ llvm::SmallSet<unsigned, 4> outputDims;
+ unsigned outputDimPosStart =
+ getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
+ unsigned outputDimPosEnd =
+ getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
+ getOutputOpOperands()
+ .back()
+ .get()
+ .getType()
+ .cast<ShapedType>()
+ .getRank() -
+ 1)
+ .getValue();
+ llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
+ [&outputDims](unsigned dim) { outputDims.insert(dim); });
+ HasAffineDimExprVisitor checkDimExpr(outputDims);
+ if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
+ return llvm::None;
+ return applyMapToValues(b, loc, operandShapesToResultDimMap,
+ createFlatListOfOperandDims(b, loc))[0];
+}
+
/// Forward declarations.
template <typename NamedStructuredOpType>
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
@@ -2022,6 +2100,49 @@ struct FoldTensorCastOp : public RewritePattern {
return success();
}
};
+
+/// Replaces std.dim operations that use the result of a LinalgOp (on tensors)
+/// with std.dim operations that use one of the arguments. For example,
+///
+/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
+/// %1 = dim %0, %c0
+///
+/// with
+///
+/// %1 = dim %arg0, %c0
+///
+/// where possible. With this the result of the `linalg.matmul` is not used in
+/// dim operations. If the value produced is replaced with another value (say by
+/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
+/// used in a dim op that would prevent the DCE of this op.
+struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ Value dimValue = dimOp.memrefOrTensor();
+ Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+ if (!dimIndex)
+ return failure();
+ auto linalgOp = dimValue.getDefiningOp<LinalgOp>();
+ if (!linalgOp)
+ return failure();
+
+ unsigned resultIndex = dimValue.cast<OpResult>().getResultNumber();
+ Optional<Value> operandDimValue = linalgOp.inferResultDimFromInputShapes(
+ rewriter, dimOp.getLoc(), resultIndex,
+ static_cast<unsigned>(*dimIndex));
+ if (!operandDimValue) {
+ // Its always possible to replace using the corresponding `outs`
+ // parameter.
+ operandDimValue = rewriter.create<DimOp>(
+ dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
+ }
+ rewriter.replaceOp(dimOp, *operandDimValue);
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -2166,26 +2287,6 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
return success();
}
};
-
-/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
-/// with the corresponding output tensor argument of the linalg op.
-struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dimOp,
- PatternRewriter &rewriter) const override {
- Value dimOpArg = dimOp.memrefOrTensor();
- auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>();
- if (!linalgOp)
- return failure();
-
- auto results = linalgOp.getOperation()->getResults();
- int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg));
- auto outputTensors = linalgOp.getOutputTensors();
- rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index());
- return success();
- }
-};
} // namespace
#define CANONICALIZERS_AND_FOLDERS(XXX) \
@@ -2193,7 +2294,7 @@ struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
MLIRContext *context) { \
results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
RemoveIdentityLinalgOps>(); \
- results.insert<ReplaceDimOfLinalgResult>(context); \
+ results.insert<ReplaceDimOfLinalgOpResult>(context); \
} \
\
LogicalResult XXX::fold(ArrayRef<Attribute>, \
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index ba31ca5a034b..9d39e4e8c75a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -58,9 +58,6 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===//
void mlir::linalg::LinalgDialect::initialize() {
- getContext()->getOrLoadDialect("std");
- getContext()->getOrLoadDialect("tensor");
-
addTypes<RangeType>();
addOperations<
#define GET_OP_LIST
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index b2de3fdc6c8e..ca7f82c1b254 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -390,10 +390,147 @@ func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
// -----
+func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+ %1 = dim %0, %c0 : tensor<?x?xf32>
+ %2 = dim %0, %c1 : tensor<?x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK: func @init_tensor_dynamic_dim2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG0]], %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses
+ (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> (index) {
+ %c0 = constant 0 : index
+ %0 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0 + d1, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %1 = mulf %arg3, %arg4 : f32
+ %2 = addf %1, %arg5 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ %3 = dim %0, %c0 : tensor<?x?xf32>
+ return %3 : index
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @remove_dim_result_uses
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK: %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]]
+// CHECK: return %[[T2]]
+
+// -----
+
+func @remove_dim_result_uses_outs
+ (%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = dim %arg0, %c0 : tensor<?xf32>
+ %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32) :
+ linalg.yield %arg2 : f32
+ } -> tensor<?x?xf32>
+ %2 = dim %1, %c1 : tensor<?x?xf32>
+ return %2 : index
+}
+// CHECK: func @remove_dim_result_uses_outs
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses_sequence
+ (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = dim %0, %c0 : tensor<?x?xf32>
+ %2 = dim %0, %c1 : tensor<?x?xf32>
+ %3 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %4 = mulf %arg3, %arg4 : f32
+ %5 = addf %4, %arg5 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x?xf32>
+ %6 = dim %3, %c0 : tensor<?x?xf32>
+ %7 = dim %3, %c1 : tensor<?x?xf32>
+ return %1, %2, %6, %7 : index, index, index, index
+}
+// CHECK-LABEL: func @remove_dim_result_uses_sequence
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[T2:.+]] = dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[T3:.+]] = dim %[[ARG1]], %[[C1]]
+// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
+
+// -----
+
+func @keep_result_dim_uses_sequence2
+ (%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = dim %arg0, %c0 : tensor<?xf32>
+ %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%arg2: f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ } -> tensor<?x?xf32>
+ %2 = dim %1, %c0 : tensor<?x?xf32>
+ %3 = dim %1, %c1 : tensor<?x?xf32>
+ return %2, %3 : index, index
+}
+// CHECK: func @keep_result_dim_uses_sequence2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
+// CHECK: return %[[T0]], %[[ARG1]]
+
+// -----
+
#map = affine_map<(d0) -> (d0)>
func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
- %arg_1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ %arg_1: tensor<?xf32>) -> (index, index) {
%0, %1 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel"]
@@ -405,16 +542,16 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
%c0 = constant 0 : index
%num_elem_0 = dim %0, %c0 : tensor<?xf32>
- %result_0 = linalg.init_tensor [%num_elem_0] : tensor<?xf32>
%num_elem_1 = dim %1, %c0 : tensor<?xf32>
- %result_1 = linalg.init_tensor [%num_elem_1] : tensor<?xf32>
- return %result_0, %result_1 : tensor<?xf32>, tensor<?xf32>
+ return %num_elem_0, %num_elem_1 : index, index
}
-// CHECK-LABEL: func @init_tensor_dim_of_linalg_result(
-// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
-// CHECK: dim [[ARG_0]]
-// CHECK: dim [[ARG_1]]
+// CHECK: func @init_tensor_dim_of_linalg_result(
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
+// CHECK: %[[R0:.+]] = dim %[[ARG_0]]
+// CHECK: %[[R1:.+]] = dim %[[ARG_0]]
+// CHECK: return %[[R0]], %[[R1]]
// -----
More information about the llvm-branch-commits
mailing list