[Mlir-commits] [mlir] [mlir][linalg] Add linalg.transpose constant folding (PR #92589)
Ryan Holt
llvmlistbot at llvm.org
Wed May 22 14:38:28 PDT 2024
https://github.com/ryan-holt-1 updated https://github.com/llvm/llvm-project/pull/92589
>From 4114e8aae6699dbba0a2803adec75338e1fe3d7e Mon Sep 17 00:00:00 2001
From: ryan-holt-1 <ryanholt at mathworks.com>
Date: Fri, 17 May 2024 14:35:15 -0400
Subject: [PATCH] [mlir][linalg] Add linalg.transpose constant folding
There was existing support for constant folding a `linalg.generic` that
was actually a transpose. This commit adds support for the named op,
`linalg.transpose`, as well by making use of the `LinalgOp` interface.
---
.../Linalg/Transforms/ConstantFold.cpp | 62 ++++----
mlir/test/Dialect/Linalg/constant-fold.mlir | 148 ++++++++++++++++++
.../Linalg/fusion-elementwise-ops.mlir | 133 ----------------
3 files changed, 180 insertions(+), 163 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/constant-fold.mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 8fffabf11f3fd..2e6079e1402e1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -23,21 +23,21 @@ using namespace mlir;
using namespace mlir::linalg;
namespace {
-/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
-/// and permutation indexing maps.
+/// Base class for constant folding linalg structured ops with N inputs, 1
+/// output, and permutation indexing maps.
///
/// `ConcreteType` should provide methods with signatures
///
/// ```c++
-/// bool matchIndexingMaps(GenericOp genericOp) const;
-/// RegionComputationFn getRegionComputeFn(GenericOp) const;
+/// bool matchIndexingMaps(LinalgOp linalgOp) const;
+/// RegionComputationFn getRegionComputeFn(LinalgOp) const;
/// ```
///
/// The latter inspects the region and returns the computation inside as a
/// functor. The functor will be invoked with constant elements for all inputs
/// and should return the corresponding computed constant element for output.
template <typename ConcreteType>
-class FoldConstantBase : public OpRewritePattern<GenericOp> {
+class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
public:
struct APIntOrFloat {
std::optional<APInt> apInt;
@@ -52,25 +52,26 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ controlFn(controlFn) {}
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
// Mixed and buffer sematics aren't supported.
- if (!genericOp.hasPureTensorSemantics())
+ if (!linalgOp.hasPureTensorSemantics())
return failure();
// Only support ops generating one output for now.
- if (genericOp.getNumDpsInits() != 1)
+ if (linalgOp.getNumDpsInits() != 1)
return failure();
- auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
+ auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
// Require the output types to be static given that we are generating
// constants.
if (!outputType || !outputType.hasStaticShape())
return failure();
- if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
+ if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
return isa<ShapedType>(input.getType());
}))
return failure();
@@ -80,7 +81,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
return cast<ShapedType>(value.getType()).getElementType();
};
if (!llvm::all_equal(
- llvm::map_range(genericOp->getOperands(), getOperandElementType)))
+ llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
return failure();
// We can only handle the case where we have int/float elements.
@@ -93,30 +94,30 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
// entirely in the compiler, without needing to turn all indices into
// Values, and then do affine apply on them, and then match back the
// constant again.
- if (!llvm::all_of(genericOp.getIndexingMapsArray(),
+ if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
[](AffineMap map) { return map.isPermutation(); }))
return failure();
- for (OpOperand &operand : genericOp.getDpsInitsMutable()) {
- if (genericOp.payloadUsesValueFromOperand(&operand))
+ for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
+ if (linalgOp.payloadUsesValueFromOperand(&operand))
return failure();
}
// Further check the indexing maps are okay for the ConcreteType.
- if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
+ if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
return failure();
// Defer to the concrete type to check the region and discover the
// computation inside.
RegionComputationFn computeFn =
- static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
+ static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
if (!computeFn)
return failure();
// All inputs should be constants.
- int numInputs = genericOp.getNumDpsInputs();
+ int numInputs = linalgOp.getNumDpsInputs();
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
- for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+ for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
if (!matchPattern(en.value()->get(),
m_Constant(&inputValues[en.index()])))
return failure();
@@ -124,12 +125,11 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
// Identified this as a potential candidate for folding. Now check the
// policy to see whether we are allowed to proceed.
- for (OpOperand *operand : genericOp.getDpsInputOperands()) {
+ for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
if (!controlFn(operand))
return failure();
}
- auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
int64_t numElements = outputType.getNumElements();
@@ -155,8 +155,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
SmallVector<SmallVector<unsigned>> inputDims;
for (int i = 0; i < numInputs; ++i)
- inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
- auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
+ inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
+ auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
auto outputShape = outputType.getShape();
// Allocate small vectors for index delinearization. Initial values do not
@@ -173,7 +173,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
APIntOrFloatArray computeFnInputs;
auto inputShapes = llvm::to_vector<4>(
- llvm::map_range(genericOp.getInputs(), [](Value value) {
+ llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
return cast<ShapedType>(value.getType()).getShape();
}));
@@ -254,7 +254,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
: DenseElementsAttr::get(outputType, intOutputValues);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
return success();
}
@@ -262,18 +262,20 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
ControlFusionFn controlFn;
};
-// Folds linalg.generic ops that are actually transposes on constant values.
+// Folds linalg.transpose (and linalg.generic ops that are actually transposes)
+// on constant values.
struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
+
using FoldConstantBase::FoldConstantBase;
- bool matchIndexingMaps(GenericOp genericOp) const {
+ bool matchIndexingMaps(LinalgOp linalgOp) const {
// We should have one input and one output.
- return genericOp.getIndexingMapsArray().size() == 2;
+ return linalgOp.getIndexingMapsArray().size() == 2;
}
- RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
+ RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
// Make sure the region only contains a yield op.
- Block &body = genericOp.getRegion().front();
+ Block &body = linalgOp->getRegion(0).front();
if (!llvm::hasSingleElement(body))
return nullptr;
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
diff --git a/mlir/test/Dialect/Linalg/constant-fold.mlir b/mlir/test/Dialect/Linalg/constant-fold.mlir
new file mode 100644
index 0000000000000..3929c26a3382f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/constant-fold.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt %s -linalg-fuse-elementwise-ops -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @transpose_fold_2d_fp32
+func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+ // CHECK: %[[CST:.+]] = arith.constant
+ // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> tensor<3x2xf32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_2d_fp64
+func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
+ %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
+ // CHECK: %[[CST:.+]] = arith.constant
+ // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
+ ^bb0(%arg1: f64, %arg2: f64):
+ linalg.yield %arg1 : f64
+ } -> tensor<3x2xf64>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_i32
+func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
+ %input = arith.constant dense<[[
+ [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
+ [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+ ]]> : tensor<1x2x3x4xi32>
+ // CHECK: %[[CST:.+]] = arith.constant dense<[
+ // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+ // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+ // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+ // CHECK-SAME{LITERAL}: ]>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
+ ^bb0(%arg1: i32, %arg2: i32):
+ linalg.yield %arg1 : i32
+ } -> tensor<3x1x4x2xi32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x1x4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_i16
+func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
+ %input = arith.constant dense<[[
+ [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
+ [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+ ]]> : tensor<1x2x3x4xi16>
+ // CHECK: %[[CST:.+]] = arith.constant dense<[
+ // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+ // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+ // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+ // CHECK-SAME{LITERAL}: ]>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
+ ^bb0(%arg1: i16, %arg2: i16):
+ linalg.yield %arg1 : i16
+ } -> tensor<3x1x4x2xi16>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x1x4x2xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_input
+func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ // CHECK: linalg.generic
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_yield_const
+func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+ %cst = arith.constant 8.0 : f32
+ // CHECK: linalg.generic
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %cst : f32
+ } -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
+func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+ // CHECK: linalg.generic
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %add = arith.addf %arg1, %arg1 : f32
+ linalg.yield %add : f32
+ } -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @named_transpose_fold_2d_fp32
+func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+ // CHECK: %[[CST:.+]] = arith.constant
+ // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+ %1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0]
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 15a4f6cdd3bbe..e45a9fbb1052c 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -777,139 +777,6 @@ func.func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, te
// -----
-// CHECK-LABEL: @transpose_fold_2d_fp32
-func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
- %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
- // CHECK: %[[CST:.+]] = arith.constant
- // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
- %1 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]
- } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
- linalg.yield %arg1 : f32
- } -> tensor<3x2xf32>
- // CHECK: return %[[CST]]
- return %1 : tensor<3x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_fold_2d_fp64
-func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
- %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
- // CHECK: %[[CST:.+]] = arith.constant
- // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
- %1 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]
- } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
- ^bb0(%arg1: f64, %arg2: f64):
- linalg.yield %arg1 : f64
- } -> tensor<3x2xf64>
- // CHECK: return %[[CST]]
- return %1 : tensor<3x2xf64>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_fold_4d_i32
-func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
- %input = arith.constant dense<[[
- [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
- [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
- ]]> : tensor<1x2x3x4xi32>
- // CHECK: %[[CST:.+]] = arith.constant dense<[
- // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
- // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
- // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
- // CHECK-SAME{LITERAL}: ]>
- %1 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]
- } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
- ^bb0(%arg1: i32, %arg2: i32):
- linalg.yield %arg1 : i32
- } -> tensor<3x1x4x2xi32>
- // CHECK: return %[[CST]]
- return %1 : tensor<3x1x4x2xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_fold_4d_i16
-func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
- %input = arith.constant dense<[[
- [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
- [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
- ]]> : tensor<1x2x3x4xi16>
- // CHECK: %[[CST:.+]] = arith.constant dense<[
- // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
- // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
- // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
- // CHECK-SAME{LITERAL}: ]>
- %1 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]
- } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
- ^bb0(%arg1: i16, %arg2: i16):
- linalg.yield %arg1 : i16
- } -> tensor<3x1x4x2xi16>
- // CHECK: return %[[CST]]
- return %1 : tensor<3x1x4x2xi16>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_nofold_non_cst_input
-func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
- // CHECK: linalg.generic
- %1 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]
- } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
- linalg.yield %arg1 : f32
- } -> tensor<3x2xf32>
- return %1 : tensor<3x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_nofold_yield_const
-func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
- %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
- %cst = arith.constant 8.0 : f32
- // CHECK: linalg.generic
- %1 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]
- } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
- linalg.yield %cst : f32
- } -> tensor<3x2xf32>
- return %1 : tensor<3x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
-func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
- %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
- // CHECK: linalg.generic
- %1 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]
- } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
- ^bb0(%arg1: f32, %arg2: f32):
- %add = arith.addf %arg1, %arg1 : f32
- linalg.yield %add : f32
- } -> tensor<3x2xf32>
- return %1 : tensor<3x2xf32>
-}
-
-// -----
-
// Fusing the broadcast into a reduction would require to insert extra knowledge
// about the size of the reduction dimension. As long, as this is not
// implemented, we check that two linalg operations remain.
More information about the Mlir-commits
mailing list