[Mlir-commits] [mlir] 3ed3e43 - [mlir] Move `memref.dim` canonicalization using `InferShapedTypeOpInterface` to a separate pass.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 16 22:13:24 PDT 2021
Author: MaheshRavishankar
Date: 2021-06-16T22:13:11-07:00
New Revision: 3ed3e438a75d9cf756f6004b60dd5b3feec96b0b
URL: https://github.com/llvm/llvm-project/commit/3ed3e438a75d9cf756f6004b60dd5b3feec96b0b
DIFF: https://github.com/llvm/llvm-project/commit/3ed3e438a75d9cf756f6004b60dd5b3feec96b0b.diff
LOG: [mlir] Move `memref.dim` canonicalization using `InferShapedTypeOpInterface` to a separate pass.
Based on dicussion in
[this](https://llvm.discourse.group/t/remove-canonicalizer-for-memref-dim-via-shapedtypeopinterface/3641)
thread the pattern to resolve the `memref.dim` of a value that is a
result of an operation that implements the
`InferShapedTypeOpInterface` is moved to a separate pass instead of
running it as a canonicalization pass. This allows shape resolution to
happen when explicitly required, instead of automatically through a
canonicalization.
Differential Revision: https://reviews.llvm.org/D104321
Added:
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/fusion-sequence.mlir
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
mlir/test/Transforms/test-canonicalize.mlir
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 1eae023a6df19..153991b5696d1 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -16,6 +16,15 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+
+class AffineDialect;
+namespace tensor {
+class TensorDialect;
+} // namespace tensor
+namespace vector {
+class VectorDialect;
+} // namespace vector
+
namespace memref {
//===----------------------------------------------------------------------===//
@@ -26,6 +35,11 @@ namespace memref {
/// into `patterns`.
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
+/// Appends patterns that resolve `memref.dim` operations with values that are
+/// defined by operations that implement the `InferShapedTypeOpInterface`, in
+/// terms of shapes of its input operands.
+void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
@@ -34,6 +48,11 @@ void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
/// load/store ops into `patterns`.
std::unique_ptr<Pass> createFoldSubViewOpsPass();
+/// Creates an operation pass to resolve `memref.dim` operations with values
+/// that are defined by operations that implement the
+/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands.
+std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index d98d510a134a2..d7a7ddccf39a5 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -23,6 +23,18 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
];
}
+def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
+ let summary = "Resolve memref.dim of result values";
+ let description = [{
+ The pass resolves memref.dim of result of operations that
+ implement the `InferShapedTypeOpInterface` in terms of shapes of
+ its operands.
+ }];
+ let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
+ let dependentDialects = [
+ "memref::MemRefDialect", "tensor::TensorDialect"
+ ];
+}
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b9f4dc91634bc..9f597382fdc5d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -794,84 +794,12 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
return success();
}
};
-
-/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
-/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
-/// TODO(ravishankarm): This is better put as a interface utility method
-/// somewhere, but that would imply the interface will depend on the `tensor`
-/// dialect. Ideally maybe a utility method in the `tensor` dialect.
-static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
- int64_t dimIndex) {
- unsigned resultNumber = result.getResultNumber();
- auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
- Location loc = result.getOwner()->getLoc();
- if (!shapedTypeOp)
- return nullptr;
-
- // The interface exposes two methods, one that returns the shape of all the
- // results as `Value` and other that returns the shape as a list of
- // `SmallVector<Value>`. The former takes precedence over the latter. So first
- // check if the op implements the first interface method or the second, and
- // get the value to use appropriately.
- SmallVector<Value> reifiedResultShapes;
- if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
- builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
- if (reifiedResultShapes.size() <= resultNumber)
- return nullptr;
- Value resultShape = reifiedResultShapes[resultNumber];
- auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
- if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
- return nullptr;
- return builder.create<tensor::ExtractOp>(
- loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
- }
-
- SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
- if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
- builder, reifiedResultShapesPerDim)))
- return nullptr;
- if (reifiedResultShapesPerDim.size() <= resultNumber ||
- reifiedResultShapesPerDim[resultNumber].size() !=
- static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
- return nullptr;
- OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
- if (auto attr = valueOrAttr.dyn_cast<Attribute>())
- return builder.createOrFold<ConstantIndexOp>(
- loc, attr.cast<IntegerAttr>().getInt());
- return valueOrAttr.get<Value>();
-}
-
-/// Fold dim of an operation that implements the InferShapedTypeOpInterface
-struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dimOp,
- PatternRewriter &rewriter) const override {
- OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
- if (!dimValue)
- return failure();
- auto shapedTypeOp =
- dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
- if (!shapedTypeOp)
- return failure();
-
- Optional<int64_t> dimIndex = dimOp.getConstantIndex();
- if (!dimIndex)
- return failure();
- Value replacement =
- getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
- if (!replacement)
- return failure();
- rewriter.replaceOp(dimOp, replacement);
- return success();
- }
-};
} // end anonymous namespace.
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
- DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
+ DimOfCastOp<tensor::CastOp>>(context);
}
// ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index e795a86f69d74..672d89772499f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRMemRefTransforms
FoldSubViewOps.cpp
+ ResolveShapedTypeResultDims.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
@@ -9,9 +10,11 @@ add_mlir_dialect_library(MLIRMemRefTransforms
LINK_LIBS PUBLIC
MLIRAffine
+ MLIRInferTypeOpInterface
MLIRMemRef
MLIRPass
MLIRStandard
+ MLIRTensor
MLIRTransforms
MLIRVector
)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
new file mode 100644
index 0000000000000..1b16efe5f2e18
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -0,0 +1,127 @@
+//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values
+//-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass resolves `memref.dim` operations of result values in terms of
+// shapes of their operands using the `InferShapedTypeOpInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
+/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
+/// TODO(ravishankarm): This is better put as a interface utility method
+/// somewhere, but that would imply the interface will depend on the `tensor`
+/// dialect. Ideally maybe a utility method in the `tensor` dialect.
+static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
+ int64_t dimIndex) {
+ unsigned resultNumber = result.getResultNumber();
+ auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
+ Location loc = result.getOwner()->getLoc();
+ if (!shapedTypeOp)
+ return nullptr;
+
+ // The interface exposes two methods, one that returns the shape of all the
+ // results as `Value` and other that returns the shape as a list of
+ // `SmallVector<Value>`. The former takes precedence over the latter. So first
+ // check if the op implements the first interface method or the second, and
+ // get the value to use appropriately.
+ SmallVector<Value> reifiedResultShapes;
+ if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
+ builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
+ if (reifiedResultShapes.size() <= resultNumber)
+ return nullptr;
+ Value resultShape = reifiedResultShapes[resultNumber];
+ auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
+ if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
+ return nullptr;
+ return builder.create<tensor::ExtractOp>(
+ loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
+ }
+
+ SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
+ if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
+ builder, reifiedResultShapesPerDim)))
+ return nullptr;
+ if (reifiedResultShapesPerDim.size() <= resultNumber ||
+ reifiedResultShapesPerDim[resultNumber].size() !=
+ static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
+ return nullptr;
+ OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>())
+ return builder.createOrFold<ConstantIndexOp>(
+ loc, attr.cast<IntegerAttr>().getInt());
+ return valueOrAttr.get<Value>();
+}
+
+namespace {
+/// Fold dim of an operation that implements the InferShapedTypeOpInterface
+struct DimOfShapedTypeOpInterface : public OpRewritePattern<memref::DimOp> {
+ using OpRewritePattern<memref::DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
+ if (!dimValue)
+ return failure();
+ auto shapedTypeOp =
+ dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
+ if (!shapedTypeOp)
+ return failure();
+
+ Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+ if (!dimIndex)
+ return failure();
+ Value replacement =
+ getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
+ if (!replacement)
+ return failure();
+ rewriter.replaceOp(dimOp, replacement);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+struct ResolveShapedTypeResultDimsPass final
+ : public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void memref::populateResolveShapedTypeResultDimsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<DimOfShapedTypeOpInterface>(patterns.getContext());
+}
+
+void ResolveShapedTypeResultDimsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ memref::populateResolveShapedTypeResultDimsPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+ std::move(patterns))))
+ return signalPassFailure();
+}
+
+std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
+ return std::make_unique<ResolveShapedTypeResultDimsPass>();
+}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 029ac621ca4b5..16895590a55fd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -532,205 +532,6 @@ func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
// -----
-func @init_tensor_static_dim() -> (index, index) {
- %c0 = constant 0 : index
- %c2 = constant 2 : index
- %c6 = constant 6 : index
- %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
- %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
- %2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
- return %1, %2 : index, index
-}
-// CHECK: func @init_tensor_static_dim
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C6:.+]] = constant 6 : index
-// CHECK: return %[[C6]], %[[C4]]
-
-// -----
-
-func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
- %c2 = constant 2 : index
- %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
- %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
- return %1 : index
-}
-// CHECK: func @init_tensor_dynamic_dim
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK: return %[[ARG0]]
-
-// -----
-
-func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
- %1 = memref.dim %0, %c0 : tensor<?x?xf32>
- %2 = memref.dim %0, %c1 : tensor<?x?xf32>
- return %1, %2 : index, index
-}
-// CHECK: func @init_tensor_dynamic_dim2
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK: return %[[ARG0]], %[[ARG1]]
-
-// -----
-
-func @remove_dim_result_uses
- (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
- %arg2 : tensor<?x?xf32>) -> (index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
- iterator_types = ["parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
- %1 = mulf %arg3, %arg4 : f32
- %2 = addf %1, %arg5 : f32
- linalg.yield %2 : f32
- } -> tensor<?x?xf32>
- %3 = memref.dim %0, %c0 : tensor<?x?xf32>
- %4 = memref.dim %0, %c1 : tensor<?x?xf32>
- return %3, %4 : index, index
-}
-// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
-// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)>
-// CHECK: func @remove_dim_result_uses
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
-// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
-// CHECK: return %[[T2]], %[[T5]]
-
-// -----
-
-func @remove_dim_result_uses_outs
- (%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
- %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
- %1 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
- ^bb0(%arg2: f32, %arg3: f32) :
- linalg.yield %arg2 : f32
- } -> tensor<?x?xf32>
- %2 = memref.dim %1, %c1 : tensor<?x?xf32>
- return %2 : index
-}
-// CHECK: func @remove_dim_result_uses_outs
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK: return %[[ARG1]]
-
-// -----
-
-func @remove_dim_result_uses_sequence
- (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
- %arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = memref.dim %0, %c0 : tensor<?x?xf32>
- %2 = memref.dim %0, %c1 : tensor<?x?xf32>
- %3 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d2)>],
- iterator_types = ["parallel", "reduction", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%0 : tensor<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
- %4 = mulf %arg3, %arg4 : f32
- %5 = addf %4, %arg5 : f32
- linalg.yield %5 : f32
- } -> tensor<?x?xf32>
- %6 = memref.dim %3, %c0 : tensor<?x?xf32>
- %7 = memref.dim %3, %c1 : tensor<?x?xf32>
- return %1, %2, %6, %7 : index, index, index, index
-}
-// CHECK-LABEL: func @remove_dim_result_uses_sequence
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
-
-// -----
-
-func @keep_result_dim_uses_sequence2
- (%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
- %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
- %1 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
- ^bb0(%arg2: f32, %arg3 : f32):
- linalg.yield %arg2 : f32
- } -> tensor<?x?xf32>
- %2 = memref.dim %1, %c0 : tensor<?x?xf32>
- %3 = memref.dim %1, %c1 : tensor<?x?xf32>
- return %2, %3 : index, index
-}
-// CHECK: func @keep_result_dim_uses_sequence2
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK: return %[[T0]], %[[ARG1]]
-
-// -----
-
-#map = affine_map<(d0) -> (d0)>
-
-func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
- %arg_1: tensor<?xf32>) -> (index, index) {
- %0, %1 = linalg.generic {
- indexing_maps = [#map, #map, #map],
- iterator_types = ["parallel"]
- } ins(%arg_0 : tensor<?xf32>)
- outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
- ^bb0(%in: f32, %out_0: f32, %out_1: f32):
- linalg.yield %in, %in : f32, f32
- } -> (tensor<?xf32>, tensor<?xf32>)
-
- %c0 = constant 0 : index
- %num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
-
- %num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
- return %num_elem_0, %num_elem_1 : index, index
-}
-// CHECK: func @init_tensor_dim_of_linalg_result(
-// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
-// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
-// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]]
-// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]]
-// CHECK: return %[[R0]], %[[R1]]
-
-// -----
-
func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
%0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
%1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]]
@@ -740,9 +541,12 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @init_tensor_reshape_expansion
// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
-// CHECK: return %[[T1]]
+// CHECK: %[[C2:.+]] = constant 2
+// CHECK: %[[INIT1:.+]] = linalg.init_tensor [6, 5, %[[ARG0]]]
+// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C2]]
+// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: %[[INIT2:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
+// CHECK: return %[[INIT2]]
// -----
@@ -755,9 +559,12 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
// CHECK: func @init_tensor_reshape_collapse
// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
-// CHECK: return %[[T1]]
+// CHECK: %[[C4:.+]] = constant 4
+// CHECK: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[ARG0]], 7]
+// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C4]]
+// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: %[[INIT2:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
+// CHECK: return %[[INIT2]]
// -----
@@ -906,54 +713,6 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
} : tensor<5x6xf32> to tensor<5x6xf32>
return %0 : tensor<5x6xf32>
}
-
-// -----
-
-func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
-{
- %c1 = constant 1 : index
- %c3 = constant 3 : index
- %c4 = constant 4 : index
- %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
- : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
- %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
- %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
- %3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
- return %1, %2, %3 : index, index, index
-}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
-// CHECK: func @dim_reshape_expansion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
-// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: return %[[C3]], %[[C4]], %[[D1]]
-
-// -----
-
-func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
-{
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
- : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
- %1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
- %2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
- return %1, %2 : index, index
-}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
-// CHECK: func @dim_reshape_collapse
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C5:.+]] = constant 5 : index
-// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
-// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: return %[[C5]], %[[D1]]
-
-// -----
-
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
%arg3 : index) -> tensor<?x?xf32> {
%c0 = constant 0 : index
@@ -1083,41 +842,6 @@ func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
// -----
-func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
- %arg3: f32) -> (index, index, index)
-{
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %c3 = constant 3 : index
- %c4 = constant 4 : index
- %c5 = constant 5 : index
- %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
- ^bb0(%arg4: index, %arg5: index, %arg6: index):
- linalg.yield %arg3 : f32
- } : tensor<2x?x?xf32> to tensor<?x?x?xf32>
- %1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
- %2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
- %3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
- return %1, %2, %3 : index, index, index
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)>
-// CHECK: func @dim_of_pad_op
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
-// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C12:.+]] = constant 12 : index
-// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
-// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
-// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
-// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
-
-// -----
-
#map = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {
diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
index 3321595cf9a4b..8455991e8f5e9 100644
--- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s
module {
func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index 730482d957578..4ea25dffdcef1 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
+// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
module {
func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
new file mode 100644
index 0000000000000..0bcf60158416c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -0,0 +1,278 @@
+// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
+
+func @init_tensor_static_dim() -> (index, index) {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c6 = constant 6 : index
+ %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
+ %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
+ %2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK: func @init_tensor_static_dim
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C6:.+]] = constant 6 : index
+// CHECK: return %[[C6]], %[[C4]]
+
+// -----
+
+func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
+ %c2 = constant 2 : index
+ %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
+ %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
+ return %1 : index
+}
+// CHECK: func @init_tensor_dynamic_dim
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG0]]
+
+// -----
+
+func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+ %1 = memref.dim %0, %c0 : tensor<?x?xf32>
+ %2 = memref.dim %0, %c1 : tensor<?x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK: func @init_tensor_dynamic_dim2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG0]], %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses
+ (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %1 = mulf %arg3, %arg4 : f32
+ %2 = addf %1, %arg5 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ %3 = memref.dim %0, %c0 : tensor<?x?xf32>
+ %4 = memref.dim %0, %c1 : tensor<?x?xf32>
+ return %3, %4 : index, index
+}
+// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 - s0)>
+// CHECK: func @remove_dim_result_uses
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
+// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
+// CHECK: return %[[T2]], %[[T5]]
+
+// -----
+
+func @remove_dim_result_uses_outs
+ (%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
+ %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32) :
+ linalg.yield %arg2 : f32
+ } -> tensor<?x?xf32>
+ %2 = memref.dim %1, %c1 : tensor<?x?xf32>
+ return %2 : index
+}
+// CHECK: func @remove_dim_result_uses_outs
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses_sequence
+ (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = memref.dim %0, %c0 : tensor<?x?xf32>
+ %2 = memref.dim %0, %c1 : tensor<?x?xf32>
+ %3 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %4 = mulf %arg3, %arg4 : f32
+ %5 = addf %4, %arg5 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x?xf32>
+ %6 = memref.dim %3, %c0 : tensor<?x?xf32>
+ %7 = memref.dim %3, %c1 : tensor<?x?xf32>
+ return %1, %2, %6, %7 : index, index, index, index
+}
+// CHECK-LABEL: func @remove_dim_result_uses_sequence
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
+
+// -----
+
+func @keep_result_dim_uses_sequence2
+ (%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
+ %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%arg2: f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ } -> tensor<?x?xf32>
+ %2 = memref.dim %1, %c0 : tensor<?x?xf32>
+ %3 = memref.dim %1, %c1 : tensor<?x?xf32>
+ return %2, %3 : index, index
+}
+// CHECK: func @keep_result_dim_uses_sequence2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: return %[[T0]], %[[ARG1]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
+ %arg_1: tensor<?xf32>) -> (index, index) {
+ %0, %1 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel"]
+ } ins(%arg_0 : tensor<?xf32>)
+ outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%in: f32, %out_0: f32, %out_1: f32):
+ linalg.yield %in, %in : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ %c0 = constant 0 : index
+ %num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
+
+ %num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
+ return %num_elem_0, %num_elem_1 : index, index
+}
+// CHECK: func @init_tensor_dim_of_linalg_result(
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
+// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]]
+// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]]
+// CHECK: return %[[R0]], %[[R1]]
+
+// -----
+
+func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
+{
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
+ : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+ %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
+ %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
+ %3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
+ return %1, %2, %3 : index, index, index
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
+// CHECK: func @dim_reshape_expansion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
+// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: return %[[C3]], %[[C4]], %[[D1]]
+
+// -----
+
+func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
+{
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
+ : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
+ %1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
+ %2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
+// CHECK: func @dim_reshape_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
+// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: return %[[C5]], %[[D1]]
+
+// -----
+
+func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
+ %arg3: f32) -> (index, index, index)
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %c5 = constant 5 : index
+ %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
+ ^bb0(%arg4: index, %arg5: index, %arg6: index):
+ linalg.yield %arg3 : f32
+ } : tensor<2x?x?xf32> to tensor<?x?x?xf32>
+ %1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
+ %2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
+ %3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
+ return %1, %2, %3 : index, index, index
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 4)>
+// CHECK: func @dim_of_pad_op
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
+// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C12:.+]] = constant 12 : index
+// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
+// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
+// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
+// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 952652170356e..bffbca1df75ab 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -205,16 +205,14 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
// CHECK: #[[BOUND8_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 8, -d0 + s1)>
-// CHECK: #[[BOUND8_MAP_3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 8)>
// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)>
-// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
+// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 16, -d0 + s1)>
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 4, -d0 + s1)>
// CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, 2, -d1 + s1)>
-// CHECK: #[[BOUND2_MAP_3:.+]] = affine_map<(d0, d1)[s0] -> (-d0 + s0, 2, -d1 + s0)>
// CHECK: func @conv_tensors_dynamic
// CHECK-SAME: (%[[INPUT]]: tensor<?x?x?x?xf32>, %[[FILTER]]: tensor<?x?x?x?xf32>, %[[ELEM]]: tensor<?x?x?x?xf32>)
@@ -240,16 +238,20 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK-DAG: %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILL_N:.+]] = memref.dim %[[FILL]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILL_H:.+]] = memref.dim %[[FILL]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILL_W:.+]] = memref.dim %[[FILL]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILL_C:.+]] = memref.dim %[[FILL]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[INPUT_N]], %[[ELEM_N]]]
-// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_3]](%[[IV0]])[%[[ELEM_N]]]
+// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[FILL_N]], %[[ELEM_N]]]
// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]]
// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]]
-// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[ELEM_OH]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]]
// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]]
// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
@@ -257,7 +259,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]]
// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
-// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[ELEM_OW]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
@@ -266,7 +268,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILTER_OC]], %[[ELEM_OC]]]
// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]]
// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
-// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_3]](%[[IV3]], %[[IV2]])[%[[ELEM_OC]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILL_C]], %[[ELEM_OC]]]
// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]]
// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
new file mode 100644
index 0000000000000..2568e23ca8bb1
--- /dev/null
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
+
+func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0:2 = "test.op_with_result_shape_interface"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @result_shape(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
+// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
+// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
+// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
+
+// -----
+
+func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @result_shape_per_dim(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+
+// -----
+
+func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @result_shape_and_per_dim(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
+// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
+// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
+// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index d245a560bcfaa..8c919da3f0643 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -82,30 +82,6 @@ func @typemismatch() -> i32 {
return %0 : i32
}
-// CHECK-LABEL: func @result_shape_per_dim
-// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
-func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
- -> (index, index, index, index, index) {
- // CHECK-DAG: %[[C0:.+]] = constant 0 : index
- // CHECK-DAG: %[[C2:.+]] = constant 2 : index
- // CHECK-DAG: %[[C3:.+]] = constant 3 : index
- // CHECK-DAG: %[[C5:.+]] = constant 5 : index
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
- : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
- %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
- %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
- %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
- %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
- %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
- // CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
- // CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
- // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
- return %1, %2, %3, %4, %5 : index, index, index, index, index
-}
-
// CHECK-LABEL: test_dialect_canonicalizer
func @test_dialect_canonicalizer() -> (i32) {
%0 = "test.dialect_canonicalizable"() : () -> (i32)
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index c5bda7ba72c8c..a591ab5dd543a 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -65,6 +65,7 @@ add_mlir_library(MLIRTestDialect
MLIRReduce
MLIRStandard
MLIRStandardOpsTransforms
+ MLIRTensor
MLIRTransformUtils
MLIRTransforms
)
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 8ef6ec6000c6a..ca6f18019a2ac 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
@@ -802,22 +803,75 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
return success();
}
+LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(operands.size());
+ for (Value operand : llvm::reverse(operands)) {
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(
+ 0, operand.getType().cast<RankedTensorType>().getRank()),
+ [&](int64_t dim) -> Value {
+ return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+ }));
+ shapes.push_back(builder.create<tensor::FromElementsOp>(
+ getLoc(), builder.getIndexType(), currShape));
+ }
+ return success();
+}
+
LogicalResult
OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
- SmallVector<Value> operand1Shape, operand2Shape;
Location loc = getLoc();
- for (auto i :
- llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
- operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
+ shapes.reserve(getNumOperands());
+ for (Value operand : llvm::reverse(getOperands())) {
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(
+ 0, operand.getType().cast<RankedTensorType>().getRank()),
+ [&](int64_t dim) -> Value {
+ return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+ }));
+ shapes.emplace_back(std::move(currShape));
}
- for (auto i :
- llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
- operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
+ return success();
+}
+
+LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(operands.size());
+ for (Value operand : llvm::reverse(operands)) {
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(
+ 0, operand.getType().cast<RankedTensorType>().getRank()),
+ [&](int64_t dim) -> Value {
+ return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+ }));
+ shapes.push_back(builder.create<tensor::FromElementsOp>(
+ getLoc(), builder.getIndexType(), currShape));
+ }
+ return success();
+}
+
+LogicalResult
+OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
+ OpBuilder &builder,
+ llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(getNumOperands());
+ for (Value operand : llvm::reverse(getOperands())) {
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(
+ 0, operand.getType().cast<RankedTensorType>().getRank()),
+ [&](int64_t dim) -> Value {
+ return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+ }));
+ shapes.emplace_back(std::move(currShape));
}
- shapes.emplace_back(std::move(operand2Shape));
- shapes.emplace_back(std::move(operand1Shape));
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ea39b9c8c65ab..0f1775f7a78f1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -571,9 +571,25 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
let results = (outs AnyTensor);
}
-def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface",
+def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["reifyReturnTypeShapesPerResultDim"]>]> {
+ ["reifyReturnTypeShapes"]>]> {
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def OpWithResultShapePerDimInterfaceOp :
+ TEST_Op<"op_with_result_shape_per_dim_interface",
+ [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["reifyReturnTypeShapesPerResultDim"]>]> {
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def OpWithResultShapeAndPerDimInterfaceOp :
+ TEST_Op<"op_with_result_shape_and_per_dim_interface",
+ [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["reifyReturnTypeShapes", "reifyReturnTypeShapesPerResultDim"]>]> {
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
More information about the Mlir-commits
mailing list