[Mlir-commits] [mlir] 74ed79f - [mlir][linalg] Add linalg.transpose constant folding (#92589)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 28 07:42:36 PDT 2024


Author: Ryan Holt
Date: 2024-05-28T10:42:32-04:00
New Revision: 74ed79f7f123788d95f1552800e1af9ceaee4a08

URL: https://github.com/llvm/llvm-project/commit/74ed79f7f123788d95f1552800e1af9ceaee4a08
DIFF: https://github.com/llvm/llvm-project/commit/74ed79f7f123788d95f1552800e1af9ceaee4a08.diff

LOG: [mlir][linalg] Add linalg.transpose constant folding (#92589)

There was existing support for constant folding a `linalg.generic` that
was actually a transpose. This commit adds support for the named op,
`linalg.transpose`, as well by making use of the `LinalgOp` interface.

Added: 
    mlir/test/Dialect/Linalg/constant-fold.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 8fffabf11f3fd..2e6079e1402e1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -23,21 +23,21 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 namespace {
-/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
-/// and permutation indexing maps.
+/// Base class for constant folding linalg structured ops with N inputs, 1
+/// output, and permutation indexing maps.
 ///
 /// `ConcreteType` should provide methods with signatures
 ///
 /// ```c++
-///   bool matchIndexingMaps(GenericOp genericOp) const;
-///   RegionComputationFn getRegionComputeFn(GenericOp) const;
+///   bool matchIndexingMaps(LinalgOp linalgOp) const;
+///   RegionComputationFn getRegionComputeFn(LinalgOp) const;
 /// ```
 ///
 /// The latter inspects the region and returns the computation inside as a
 /// functor. The functor will be invoked with constant elements for all inputs
 /// and should return the corresponding computed constant element for output.
 template <typename ConcreteType>
-class FoldConstantBase : public OpRewritePattern<GenericOp> {
+class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
 public:
   struct APIntOrFloat {
     std::optional<APInt> apInt;
@@ -52,25 +52,26 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
   FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
                    PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+        controlFn(controlFn) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
     // Mixed and buffer sematics aren't supported.
-    if (!genericOp.hasPureTensorSemantics())
+    if (!linalgOp.hasPureTensorSemantics())
       return failure();
 
     // Only support ops generating one output for now.
-    if (genericOp.getNumDpsInits() != 1)
+    if (linalgOp.getNumDpsInits() != 1)
       return failure();
 
-    auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
+    auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
     // Require the output types to be static given that we are generating
     // constants.
     if (!outputType || !outputType.hasStaticShape())
       return failure();
 
-    if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
+    if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
           return isa<ShapedType>(input.getType());
         }))
       return failure();
@@ -80,7 +81,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
       return cast<ShapedType>(value.getType()).getElementType();
     };
     if (!llvm::all_equal(
-            llvm::map_range(genericOp->getOperands(), getOperandElementType)))
+            llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
       return failure();
 
     // We can only handle the case where we have int/float elements.
@@ -93,30 +94,30 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
     // entirely in the compiler, without needing to turn all indices into
     // Values, and then do affine apply on them, and then match back the
     // constant again.
-    if (!llvm::all_of(genericOp.getIndexingMapsArray(),
+    if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
                       [](AffineMap map) { return map.isPermutation(); }))
       return failure();
 
-    for (OpOperand &operand : genericOp.getDpsInitsMutable()) {
-      if (genericOp.payloadUsesValueFromOperand(&operand))
+    for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
+      if (linalgOp.payloadUsesValueFromOperand(&operand))
         return failure();
     }
 
     // Further check the indexing maps are okay for the ConcreteType.
-    if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
+    if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
       return failure();
 
     // Defer to the concrete type to check the region and discover the
     // computation inside.
     RegionComputationFn computeFn =
-        static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
+        static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
     if (!computeFn)
       return failure();
 
     // All inputs should be constants.
-    int numInputs = genericOp.getNumDpsInputs();
+    int numInputs = linalgOp.getNumDpsInputs();
     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
-    for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
       if (!matchPattern(en.value()->get(),
                         m_Constant(&inputValues[en.index()])))
         return failure();
@@ -124,12 +125,11 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
     // Identified this as a potential candidate for folding. Now check the
     // policy to see whether we are allowed to proceed.
-    for (OpOperand *operand : genericOp.getDpsInputOperands()) {
+    for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
       if (!controlFn(operand))
         return failure();
     }
 
-    auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
     SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
     int64_t numElements = outputType.getNumElements();
 
@@ -155,8 +155,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
 
     SmallVector<SmallVector<unsigned>> inputDims;
     for (int i = 0; i < numInputs; ++i)
-      inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
-    auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
+      inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
+    auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
     auto outputShape = outputType.getShape();
 
     // Allocate small vectors for index delinearization. Initial values do not
@@ -173,7 +173,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
     APIntOrFloatArray computeFnInputs;
 
     auto inputShapes = llvm::to_vector<4>(
-        llvm::map_range(genericOp.getInputs(), [](Value value) {
+        llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
           return cast<ShapedType>(value.getType()).getShape();
         }));
 
@@ -254,7 +254,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
         isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
                 : DenseElementsAttr::get(outputType, intOutputValues);
 
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
     return success();
   }
 
@@ -262,18 +262,20 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
   ControlFusionFn controlFn;
 };
 
-// Folds linalg.generic ops that are actually transposes on constant values.
+// Folds linalg.transpose (and linalg.generic ops that are actually transposes)
+// on constant values.
 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
+
   using FoldConstantBase::FoldConstantBase;
 
-  bool matchIndexingMaps(GenericOp genericOp) const {
+  bool matchIndexingMaps(LinalgOp linalgOp) const {
     // We should have one input and one output.
-    return genericOp.getIndexingMapsArray().size() == 2;
+    return linalgOp.getIndexingMapsArray().size() == 2;
   }
 
-  RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
+  RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
     // Make sure the region only contains a yield op.
-    Block &body = genericOp.getRegion().front();
+    Block &body = linalgOp->getRegion(0).front();
     if (!llvm::hasSingleElement(body))
       return nullptr;
     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());

diff  --git a/mlir/test/Dialect/Linalg/constant-fold.mlir b/mlir/test/Dialect/Linalg/constant-fold.mlir
new file mode 100644
index 0000000000000..3929c26a3382f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/constant-fold.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt %s -linalg-fuse-elementwise-ops -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @transpose_fold_2d_fp32
+func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+  //               CHECK: %[[CST:.+]] = arith.constant
+  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    linalg.yield %arg1 : f32
+  } -> tensor<3x2xf32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_2d_fp64
+func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
+  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
+  //               CHECK: %[[CST:.+]] = arith.constant
+  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
+  ^bb0(%arg1: f64, %arg2: f64):
+    linalg.yield %arg1 : f64
+  } -> tensor<3x2xf64>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_i32
+func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
+  %input = arith.constant dense<[[
+    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
+    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+  ]]> : tensor<1x2x3x4xi32>
+  //               CHECK: %[[CST:.+]] = arith.constant dense<[
+  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+  // CHECK-SAME{LITERAL}: ]>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
+  ^bb0(%arg1: i32, %arg2: i32):
+    linalg.yield %arg1 : i32
+  } -> tensor<3x1x4x2xi32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x1x4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_i16
+func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
+  %input = arith.constant dense<[[
+    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
+    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+  ]]> : tensor<1x2x3x4xi16>
+  //               CHECK: %[[CST:.+]] = arith.constant dense<[
+  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+  // CHECK-SAME{LITERAL}: ]>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
+  ^bb0(%arg1: i16, %arg2: i16):
+    linalg.yield %arg1 : i16
+  } -> tensor<3x1x4x2xi16>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x1x4x2xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_input
+func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  // CHECK: linalg.generic
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    linalg.yield %arg1 : f32
+  } -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_yield_const
+func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+  %cst = arith.constant 8.0 : f32
+  // CHECK: linalg.generic
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    linalg.yield %cst : f32
+  } -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
+func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+  // CHECK: linalg.generic
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    %add = arith.addf %arg1, %arg1 : f32
+    linalg.yield %add : f32
+  } -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @named_transpose_fold_2d_fp32
+func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+  //               CHECK: %[[CST:.+]] = arith.constant
+  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+  %1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0]
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 15a4f6cdd3bbe..e45a9fbb1052c 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -777,139 +777,6 @@ func.func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, te
 
 // -----
 
-// CHECK-LABEL: @transpose_fold_2d_fp32
-func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
-  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
-  //               CHECK: %[[CST:.+]] = arith.constant
-  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
-  %1 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
-    iterator_types = ["parallel", "parallel"]
-  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32):
-    linalg.yield %arg1 : f32
-  } -> tensor<3x2xf32>
-  // CHECK: return %[[CST]]
-  return %1 : tensor<3x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_fold_2d_fp64
-func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
-  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
-  //               CHECK: %[[CST:.+]] = arith.constant
-  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
-  %1 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
-    iterator_types = ["parallel", "parallel"]
-  } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
-  ^bb0(%arg1: f64, %arg2: f64):
-    linalg.yield %arg1 : f64
-  } -> tensor<3x2xf64>
-  // CHECK: return %[[CST]]
-  return %1 : tensor<3x2xf64>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_fold_4d_i32
-func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
-  %input = arith.constant dense<[[
-    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
-    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
-  ]]> : tensor<1x2x3x4xi32>
-  //               CHECK: %[[CST:.+]] = arith.constant dense<[
-  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
-  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
-  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
-  // CHECK-SAME{LITERAL}: ]>
-  %1 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-  } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
-  ^bb0(%arg1: i32, %arg2: i32):
-    linalg.yield %arg1 : i32
-  } -> tensor<3x1x4x2xi32>
-  // CHECK: return %[[CST]]
-  return %1 : tensor<3x1x4x2xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_fold_4d_i16
-func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
-  %input = arith.constant dense<[[
-    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
-    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
-  ]]> : tensor<1x2x3x4xi16>
-  //               CHECK: %[[CST:.+]] = arith.constant dense<[
-  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
-  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
-  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
-  // CHECK-SAME{LITERAL}: ]>
-  %1 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-  } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
-  ^bb0(%arg1: i16, %arg2: i16):
-    linalg.yield %arg1 : i16
-  } -> tensor<3x1x4x2xi16>
-  // CHECK: return %[[CST]]
-  return %1 : tensor<3x1x4x2xi16>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_nofold_non_cst_input
-func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
-  // CHECK: linalg.generic
-  %1 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
-    iterator_types = ["parallel", "parallel"]
-  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32):
-    linalg.yield %arg1 : f32
-  } -> tensor<3x2xf32>
-  return %1 : tensor<3x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_nofold_yield_const
-func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
-  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
-  %cst = arith.constant 8.0 : f32
-  // CHECK: linalg.generic
-  %1 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
-    iterator_types = ["parallel", "parallel"]
-  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32):
-    linalg.yield %cst : f32
-  } -> tensor<3x2xf32>
-  return %1 : tensor<3x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
-func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
-  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
-  // CHECK: linalg.generic
-  %1 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
-    iterator_types = ["parallel", "parallel"]
-  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
-  ^bb0(%arg1: f32, %arg2: f32):
-    %add = arith.addf %arg1, %arg1 : f32
-    linalg.yield %add : f32
-  } -> tensor<3x2xf32>
-  return %1 : tensor<3x2xf32>
-}
-
-// -----
-
 // Fusing the broadcast into a reduction would require to insert extra knowledge
 // about the size of the reduction dimension. As long, as this is not
 // implemented, we check that two linalg operations remain.


        


More information about the Mlir-commits mailing list