[llvm-branch-commits] [mlir] be7352c - [mlir][splitting std] move 2 more ops to `tensor`
Sean Silva via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 19 13:58:56 PST 2021
Author: Sean Silva
Date: 2021-01-19T13:49:25-08:00
New Revision: be7352c00d51f4358db3a23ed6a077f7cb48eafd
URL: https://github.com/llvm/llvm-project/commit/be7352c00d51f4358db3a23ed6a077f7cb48eafd
DIFF: https://github.com/llvm/llvm-project/commit/be7352c00d51f4358db3a23ed6a077f7cb48eafd.diff
LOG: [mlir][splitting std] move 2 more ops to `tensor`
- DynamicTensorFromElementsOp
- TensorFromElements
Differential Revision: https://reviews.llvm.org/D94994
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tensor/Transforms/PassDetail.h
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Standard/bufferize.mlir
mlir/test/Dialect/Standard/canonicalize.mlir
mlir/test/Dialect/Standard/invalid.mlir
mlir/test/Dialect/Standard/ops.mlir
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 6eabe1179234..8e3f1f1a7a85 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1591,47 +1591,6 @@ def DivFOp : FloatArithmeticOp<"divf"> {
let summary = "floating point division operation";
}
-//===----------------------------------------------------------------------===//
-// DynamicTensorFromElementsOp
-//===----------------------------------------------------------------------===//
-
-def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements",
- [RecursiveSideEffects, SingleBlockImplicitTerminator<"YieldOp">]> {
- string summary = "Creates a dynamically sized tensor from elements";
- string description = [{
- This operation creates a dynamically sized tensor with elements of any type.
- It expects one index operand per dynamic extent of the result tensor.
-
- The body region defines the tensor's elements. It takes index operands as
- its region arguments that span the index space. The element at the given
- position is yielded with the `yield` operation (see `YieldOp`). There is
- no defined ordering to the invocations of the body. It is conceptually
- a "parallel map" operation.
-
- Example:
-
- ```mlir
- %tnsr = dynamic_tensor_from_elements %m, %n {
- ^bb0(%i : index, %j : index, %k : index):
- ...
- yield %elem : f32
- } : tensor<?x3x?f32>
- ```
- }];
-
- let arguments = (ins Variadic<Index>:$dynamicExtents);
- let results = (outs AnyRankedTensor:$result);
- let regions = (region SizedRegion<1>:$body);
-
- let builders = [
- // Build op and populate its body per callback function.
- OpBuilderDAG<(ins "Type":$resultTy, "ValueRange":$dynamicExtents,
- "function_ref<void(OpBuilder &, Location, ValueRange)>")>,
- ];
-
- let hasCanonicalizer = 1;
-}
-
//===----------------------------------------------------------------------===//
// ExpOp
//===----------------------------------------------------------------------===//
@@ -1672,46 +1631,6 @@ def Exp2Op : FloatUnaryOp<"exp2"> {
let summary = "base-2 exponential of the specified value";
}
-//===----------------------------------------------------------------------===//
-// TensorFromElementsOp
-//===----------------------------------------------------------------------===//
-
-def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
- NoSideEffect,
- TypesMatchWith<"operand types match result element type",
- "result", "elements", "SmallVector<Type, 2>("
- "$_self.cast<ShapedType>().getDimSize(0), "
- "$_self.cast<ShapedType>().getElementType())">
- ]> {
- string summary = "tensor from elements operation.";
- string description = [{
- Create a 1D tensor from a range of same-type arguments.
-
- Example:
-
- ```mlir
- tensor_from_elements(i_1, ..., i_N) : tensor<Nxindex>
- ```
- }];
-
- let arguments = (ins Variadic<AnyType>:$elements);
- let results = (outs 1DTensorOf<[AnyType]>:$result);
-
- let assemblyFormat = "$elements attr-dict `:` type($result)";
-
- // This op is fully verified by its traits.
- let verifier = ?;
-
- let skipDefaultBuilders = 1;
- let builders = [
- OpBuilderDAG<(ins "Type":$elementType, "ValueRange":$elements)>,
- // Special case builder for when `elements` has size >=1.
- OpBuilderDAG<(ins "ValueRange":$elements)>
- ];
-
- let hasCanonicalizer = 1;
-}
-
//===----------------------------------------------------------------------===//
// FPExtOp
//===----------------------------------------------------------------------===//
@@ -3837,24 +3756,6 @@ def ViewOp : Std_Op<"view", [
let hasCanonicalizer = 1;
}
-//===----------------------------------------------------------------------===//
-// YieldOp
-//===----------------------------------------------------------------------===//
-
-def YieldOp : Std_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
- HasParent<"DynamicTensorFromElementsOp">]> {
- let summary = "Yield a value from a region";
- let description = [{
- This operation is used to yield a single value from a within a region. It
- is used to create dynamically sized tensors
- (see `DynamicTensorFromElementsOp`).
- }];
-
- let arguments = (ins AnyType:$value);
- let assemblyFormat = "$value attr-dict `:` type($value)";
- let verifier = ?;
-}
-
//===----------------------------------------------------------------------===//
// XOrOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 53980db64dc0..3a1a20835959 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -13,6 +13,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e0500b8fcfa6..e7776c4e8a9b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -10,6 +10,7 @@
#define TENSOR_OPS
include "mlir/Dialect/Tensor/IR/TensorBase.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class Tensor_Op<string mnemonic, list<OpTrait> traits = []>
@@ -105,4 +106,109 @@ def Tensor_ExtractOp : Tensor_Op<"extract",
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// FromElementsOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
+ NoSideEffect,
+ TypesMatchWith<"operand types match result element type",
+ "result", "elements", "SmallVector<Type, 2>("
+ "$_self.cast<ShapedType>().getDimSize(0), "
+ "$_self.cast<ShapedType>().getElementType())">
+ ]> {
+ string summary = "tensor from elements operation.";
+ string description = [{
+ Create a 1D tensor from a range of same-type arguments.
+
+ Example:
+
+ ```mlir
+ tensor.from_elements(i_1, ..., i_N) : tensor<Nxindex>
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$elements);
+ let results = (outs 1DTensorOf<[AnyType]>:$result);
+
+ let assemblyFormat = "$elements attr-dict `:` type($result)";
+
+ // This op is fully verified by its traits.
+ let verifier = ?;
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilderDAG<(ins "Type":$elementType, "ValueRange":$elements)>,
+ // Special case builder for when `elements` has size >=1.
+ OpBuilderDAG<(ins "ValueRange":$elements)>
+ ];
+
+ let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// GenerateOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_GenerateOp : Tensor_Op<"generate",
+ [RecursiveSideEffects,
+ SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
+ string summary = "Creates a dynamically sized tensor from elements";
+ string description = [{
+ This operation creates a dynamically sized tensor with elements of any type.
+ It expects one index operand per dynamic extent of the result tensor.
+
+ The body region defines the tensor's elements. It takes index operands as
+ its region arguments that span the index space. The element at the given
+ position is yielded with the `yield` operation (see `YieldOp`). There is
+ no defined ordering to the invocations of the body. It is conceptually
+ a "parallel map" operation.
+
+ Example:
+
+ ```mlir
+ %tnsr = tensor.generate %m, %n {
+ ^bb0(%i : index, %j : index, %k : index):
+ ...
+ yield %elem : f32
+ } : tensor<?x3x?f32>
+ ```
+ }];
+
+ let arguments = (ins Variadic<Index>:$dynamicExtents);
+ let results = (outs AnyRankedTensor:$result);
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat = "$dynamicExtents $body attr-dict `:` type($result)";
+
+ let builders = [
+ // Build op and populate its body per callback function.
+ OpBuilderDAG<(ins "Type":$resultTy, "ValueRange":$dynamicExtents,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>")>,
+ ];
+
+ let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// YieldOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_YieldOp : Tensor_Op<"yield",
+ [NoSideEffect, ReturnLike, Terminator,
+ HasParent<"::mlir::tensor::GenerateOp">]> {
+ let summary = "Yield a value from a region";
+ let description = [{
+ This operation is used to yield a single value from a within a region. It
+ is used to create dynamically sized tensors
+ (see `tensor.generate` op).
+ }];
+
+ let arguments = (ins AnyType:$value);
+ let assemblyFormat = "$value attr-dict `:` type($value)";
+ // Dummy builder to appease code in templated ensureTerminator that
+ // GenerateOp's auto-generated parser calls.
+ let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>];
+ let verifier = ?;
+}
+
#endif // TENSOR_OPS
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
index 327c7499e0c8..7abb3daed2fe 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
@@ -14,6 +14,7 @@ include "mlir/Pass/PassBase.td"
def TensorBufferize : FunctionPass<"tensor-bufferize"> {
let summary = "Bufferize the `tensor` dialect";
let constructor = "mlir::createTensorBufferizePass()";
+ let dependentDialects = ["scf::SCFDialect"];
}
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
index 25c835d97723..f65e9aec3142 100644
--- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_conversion_library(MLIRShapeToStandard
MLIREDSC
MLIRIR
MLIRShape
+ MLIRTensor
MLIRPass
MLIRSCF
MLIRTransforms
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 0d87d4f10975..0eeea250f19f 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -113,7 +113,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
- rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
+ rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value outputDimension = args[0];
@@ -151,7 +151,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
greaterRankOperandExtent);
b.create<scf::YieldOp>(loc, broadcastedExtent);
});
- b.create<mlir::YieldOp>(loc, ifOp.getResult(0));
+ b.create<tensor::YieldOp>(loc, ifOp.getResult(0));
});
return success();
}
@@ -184,7 +184,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
}
Type indexTy = rewriter.getIndexType();
Value tensor =
- rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
+ rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
return success();
@@ -503,7 +503,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
if (op.getType().isa<ShapeType>())
return failure();
- // For ranked tensor arguments, lower to `tensor_from_elements`.
+ // For ranked tensor arguments, lower to `tensor.from_elements`.
auto loc = op.getLoc();
ShapeOfOp::Adaptor transformed(operands);
Value tensor = transformed.arg();
@@ -526,22 +526,22 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
}
// Materialize extent tensor.
- Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
+ Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
loc, rewriter.getIndexType(), extentValues);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
staticExtentTensor);
return success();
}
- // Lower to `dynamic_tensor_from_elements` otherwise.
+ // Lower to `tensor.generate` otherwise.
auto *ctx = rewriter.getContext();
Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
- rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
+ rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
op, getExtentTensorType(ctx), ValueRange{rank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dim = args.front();
Value extent = b.create<DimOp>(loc, tensor, dim);
- b.create<mlir::YieldOp>(loc, extent);
+ b.create<tensor::YieldOp>(loc, extent);
});
return success();
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index c4a8a0155f50..e1be47f54798 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1392,9 +1392,8 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return getResult();
}
- // Fold dim to the operand of dynamic_tensor_from_elements.
- if (auto fromElements =
- dyn_cast_or_null<DynamicTensorFromElementsOp>(definingOp)) {
+ // Fold dim to the operand of tensor.generate.
+ if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
auto resultType =
fromElements.getResult().getType().cast<RankedTensorType>();
// The case where the type encodes the size of the dimension is handled
@@ -1734,258 +1733,6 @@ LogicalResult DmaWaitOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// DynamicTensorFromElementsOp
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser,
- OperationState &result) {
- // Parse operands.
- SmallVector<OpAsmParser::OperandType, 4> dynamicExtents;
- Type indexTy = parser.getBuilder().getIndexType();
- if (parser.parseOperandList(dynamicExtents) ||
- parser.resolveOperands(dynamicExtents, indexTy, result.operands))
- return failure();
-
- // Parse body.
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, {}, {}))
- return failure();
-
- // Parse result type.
- Type resultType;
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(resultType))
- return failure();
- result.addTypes(resultType);
-
- return success();
-}
-
-static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) {
- p << "dynamic_tensor_from_elements " << op.dynamicExtents();
- p.printRegion(op.body());
- p.printOptionalAttrDict(op.getAttrs());
- p << " : " << op.getType();
-}
-
-static LogicalResult verify(DynamicTensorFromElementsOp op) {
- // Ensure that the tensor type has as many dynamic dimensions as are specified
- // by the operands.
- RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
- if (op.getNumOperands() != resultTy.getNumDynamicDims())
- return op.emitError("must have as many index operands as dynamic extents "
- "in the result type");
-
- // Ensure that region arguments span the index space.
- if (!llvm::all_of(op.body().getArgumentTypes(),
- [](Type ty) { return ty.isIndex(); }))
- return op.emitError("all body arguments must be index");
- if (op.body().getNumArguments() != resultTy.getRank())
- return op.emitError("must have one body argument per input dimension");
-
- // Ensure that the region yields an element of the right type.
- auto yieldOp =
- llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
- if (yieldOp.value().getType() != resultTy.getElementType())
- return op.emitOpError(
- "body must be terminated with a `yield` operation of the tensor "
- "element type");
-
- return success();
-}
-
-void DynamicTensorFromElementsOp::build(
- OpBuilder &b, OperationState &result, Type resultTy,
- ValueRange dynamicExtents,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
- build(b, result, resultTy, dynamicExtents);
-
- // Build and populate body.
- OpBuilder::InsertionGuard guard(b);
- Region *bodyRegion = result.regions.front().get();
- auto rank = resultTy.cast<RankedTensorType>().getRank();
- SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
- Block *bodyBlock =
- b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
- bodyBuilder(b, result.location, bodyBlock->getArguments());
-}
-
-namespace {
-
-/// Canonicalizes dynamic_tensor_from_elements operations with a constant
-/// operand into the equivalent operation with the operand expressed in the
-/// result type, instead. We also insert a type cast to make sure that the
-/// resulting IR is still well-typed.
-struct StaticDynamicTensorFromElements
- : public OpRewritePattern<DynamicTensorFromElementsOp> {
- using OpRewritePattern<DynamicTensorFromElementsOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements,
- PatternRewriter &rewriter) const final {
- auto resultType =
- tensorFromElements.getResult().getType().cast<RankedTensorType>();
-
- if (resultType.hasStaticShape())
- return failure();
-
- SmallVector<Value, 4> newOperands;
- SmallVector<int64_t, 4> newShape;
- auto operandsIt = tensorFromElements.dynamicExtents().begin();
-
- for (int64_t dim : resultType.getShape()) {
- if (dim != RankedTensorType::kDynamicSize) {
- newShape.push_back(dim);
- continue;
- }
- APInt index;
- if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
- newShape.push_back(RankedTensorType::kDynamicSize);
- newOperands.push_back(*operandsIt++);
- continue;
- }
- newShape.push_back(index.getSExtValue());
- operandsIt++;
- }
-
- if (newOperands.size() == tensorFromElements.dynamicExtents().size())
- return failure();
-
- auto loc = tensorFromElements.getLoc();
- auto newOp = rewriter.create<DynamicTensorFromElementsOp>(
- loc, RankedTensorType::get(newShape, resultType.getElementType()),
- newOperands);
- rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
- newOp.body().begin());
- rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
- newOp);
- return success();
- }
-};
-
-/// Canonicalizes the pattern of the form
-///
-/// %tensor = dynamic_tensor_from_elements %x {
-/// ^bb0(%arg0: index): // no predecessors
-/// <computation>
-/// yield %1 : index
-/// } : tensor<?xindex>
-/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
-///
-/// to just <computation> with %arg0 replaced by %c0. We only do this if the
-/// dynamic_tensor_from_elements operation has no side-effects.
-struct ExtractFromDynamicTensorFromElements
- : public OpRewritePattern<tensor::ExtractOp> {
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::ExtractOp extract,
- PatternRewriter &rewriter) const final {
- auto tensorFromElements =
- extract.tensor().getDefiningOp<DynamicTensorFromElementsOp>();
- if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
- return failure();
-
- BlockAndValueMapping mapping;
- Block *body = tensorFromElements.getBody();
- mapping.map(body->getArguments(), extract.indices());
- for (auto &op : body->without_terminator())
- rewriter.clone(op, mapping);
-
- auto yield = cast<YieldOp>(body->getTerminator());
-
- rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
- return success();
- }
-};
-
-/// Canonicalizes the pattern of the form
-///
-/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
-/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
-///
-/// to
-///
-/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
-struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::ExtractOp extract,
- PatternRewriter &rewriter) const final {
- auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
- if (!tensorCast)
- return failure();
-
- rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
- extract.indices());
- return success();
- }
-};
-
-} // namespace
-
-void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- // TODO: Move extract patterns to tensor::ExtractOp.
- results.insert<ExtractFromDynamicTensorFromElements, ExtractFromTensorCast,
- StaticDynamicTensorFromElements>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// TensorFromElementsOp
-//===----------------------------------------------------------------------===//
-
-void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
- Type elementType, ValueRange elements) {
- Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
- elementType);
- result.addOperands(elements);
- result.addTypes(resultTy);
-}
-
-void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
- ValueRange elements) {
- assert(!elements.empty() && "expected at least one element");
- build(builder, result, elements.front().getType(), elements);
-}
-
-namespace {
-
-// Canonicalizes the pattern of the form
-//
-// %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32>
-// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
-//
-// to just %element.
-struct ExtractElementFromTensorFromElements
- : public OpRewritePattern<tensor::ExtractOp> {
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::ExtractOp extract,
- PatternRewriter &rewriter) const final {
- if (extract.indices().size() != 1)
- return failure();
-
- auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>(
- extract.tensor().getDefiningOp());
- if (tensorFromElements == nullptr)
- return failure();
-
- APInt index;
- if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
- return failure();
- rewriter.replaceOp(extract,
- tensorFromElements.getOperand(index.getZExtValue()));
- return success();
- }
-};
-
-} // namespace
-
-void TensorFromElementsOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ExtractElementFromTensorFromElements>(context);
-}
-
//===----------------------------------------------------------------------===//
// FPExtOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 98792838deff..2a3a464cd0a8 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -35,70 +35,6 @@ class BufferizeDimOp : public OpConversionPattern<DimOp> {
};
} // namespace
-namespace {
-class BufferizeDynamicTensorFromElementsOp
- : public OpConversionPattern<DynamicTensorFromElementsOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(DynamicTensorFromElementsOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- // Allocate memory.
- Location loc = op.getLoc();
- DynamicTensorFromElementsOp::Adaptor transformed(operands);
- RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
- MemRefType memrefType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- Value result =
- rewriter.create<AllocOp>(loc, memrefType, transformed.dynamicExtents());
-
- // Collect loop bounds.
- int64_t rank = tensorType.getRank();
- Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- SmallVector<Value, 4> lowerBounds(rank, zero);
- SmallVector<Value, 4> steps(rank, one);
- SmallVector<Value, 4> upperBounds;
- int nextDynamicIndex = 0;
- for (int i = 0; i < rank; i++) {
- Value upperBound =
- tensorType.isDynamicDim(i)
- ? transformed.dynamicExtents()[nextDynamicIndex++]
- : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i));
- upperBounds.push_back(upperBound);
- }
-
- // Generate tensor elements with a parallel loop that stores into
- // each element of the resulting memref.
- //
- // This is a bit tricky. We cannot simply clone the ops because when an op
- // is cloned, it must be legalized. However, we want to allow arbitrary ops
- // in the body that we don't necessarily have legalization patterns for as
- // part of this dialect conversion invocation.
- //
- // To accomplish this, we use mergeBlockBefore to "move" this op's body
- // into the scf.parallel's body.
- auto parallel =
- rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
- Block *parallelBody = parallel.getBody();
- rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
- parallelBody->getArguments());
- // Replace the inlined yield op with a store op. The scf.parallel's builder
- // already populated an scf.yield at the end, so we don't need to worry
- // about creating that.
- Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
- rewriter.setInsertionPointAfter(elementYield);
- rewriter.replaceOpWithNewOp<StoreOp>(elementYield,
- elementYield->getOperands()[0], result,
- parallelBody->getArguments());
-
- rewriter.replaceOp(op, {result});
- return success();
- }
-};
-} // namespace
-
namespace {
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
public:
@@ -117,40 +53,10 @@ class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
};
} // namespace
-namespace {
-class BufferizeTensorFromElementsOp
- : public OpConversionPattern<TensorFromElementsOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(TensorFromElementsOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- int numberOfElements = op.elements().size();
- auto resultType = MemRefType::get(
- {numberOfElements}, op.getType().cast<TensorType>().getElementType());
- Value result = rewriter.create<AllocOp>(op.getLoc(), resultType);
- for (auto element : llvm::enumerate(op.elements())) {
- Value index =
- rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
- rewriter.create<StoreOp>(op.getLoc(), element.value(), result, index);
- }
- rewriter.replaceOp(op, {result});
- return success();
- }
-};
-} // namespace
-
void mlir::populateStdBufferizePatterns(MLIRContext *context,
BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<
- // clang-format off
- BufferizeDimOp,
- BufferizeDynamicTensorFromElementsOp,
- BufferizeSelectOp,
- BufferizeTensorFromElementsOp
- // clang-format on
- >(typeConverter, context);
+ patterns.insert<BufferizeDimOp, BufferizeSelectOp>(typeConverter, context);
}
namespace {
@@ -165,7 +71,6 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
target.addLegalDialect<scf::SCFDialect>();
populateStdBufferizePatterns(context, typeConverter, patterns);
- target.addIllegalOp<DynamicTensorFromElementsOp, TensorFromElementsOp>();
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index 2d5e2fbd6a31..b8fb44a9f4cb 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -13,5 +13,6 @@ add_mlir_dialect_library(MLIRTensor
LINK_LIBS PUBLIC
MLIRIR
+ MLIRSideEffectInterfaces
MLIRSupport
)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index aaae7fbf807c..e231a3a3b56e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
@@ -205,6 +207,223 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+//===----------------------------------------------------------------------===//
+// FromElementsOp
+//===----------------------------------------------------------------------===//
+
+void FromElementsOp::build(OpBuilder &builder, OperationState &result,
+ Type elementType, ValueRange elements) {
+ Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
+ elementType);
+ result.addOperands(elements);
+ result.addTypes(resultTy);
+}
+
+void FromElementsOp::build(OpBuilder &builder, OperationState &result,
+ ValueRange elements) {
+ assert(!elements.empty() && "expected at least one element");
+ build(builder, result, elements.front().getType(), elements);
+}
+
+namespace {
+
+// Canonicalizes the pattern of the form
+//
+// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
+// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
+//
+// to just %element.
+struct ExtractElementFromTensorFromElements
+ : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+ PatternRewriter &rewriter) const final {
+ if (extract.indices().size() != 1)
+ return failure();
+
+ auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
+ if (tensorFromElements == nullptr)
+ return failure();
+
+ APInt index;
+ if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
+ return failure();
+ rewriter.replaceOp(extract,
+ tensorFromElements.getOperand(index.getZExtValue()));
+ return success();
+ }
+};
+
+} // namespace
+
+void FromElementsOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ExtractElementFromTensorFromElements>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// GenerateOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(GenerateOp op) {
+ // Ensure that the tensor type has as many dynamic dimensions as are specified
+ // by the operands.
+ RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
+ if (op.getNumOperands() != resultTy.getNumDynamicDims())
+ return op.emitError("must have as many index operands as dynamic extents "
+ "in the result type");
+
+ // Ensure that region arguments span the index space.
+ if (!llvm::all_of(op.body().getArgumentTypes(),
+ [](Type ty) { return ty.isIndex(); }))
+ return op.emitError("all body arguments must be index");
+ if (op.body().getNumArguments() != resultTy.getRank())
+ return op.emitError("must have one body argument per input dimension");
+
+ // Ensure that the region yields an element of the right type.
+ auto yieldOp =
+ llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
+ if (yieldOp.value().getType() != resultTy.getElementType())
+ return op.emitOpError(
+ "body must be terminated with a `yield` operation of the tensor "
+ "element type");
+
+ return success();
+}
+
+void GenerateOp::build(
+ OpBuilder &b, OperationState &result, Type resultTy,
+ ValueRange dynamicExtents,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
+ build(b, result, resultTy, dynamicExtents);
+
+ // Build and populate body.
+ OpBuilder::InsertionGuard guard(b);
+ Region *bodyRegion = result.regions.front().get();
+ auto rank = resultTy.cast<RankedTensorType>().getRank();
+ SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
+ Block *bodyBlock =
+ b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
+ bodyBuilder(b, result.location, bodyBlock->getArguments());
+}
+
+namespace {
+
+/// Canonicalizes tensor.generate operations with a constant
+/// operand into the equivalent operation with the operand expressed in the
+/// result type, instead. We also insert a type cast to make sure that the
+/// resulting IR is still well-typed.
+struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
+ using OpRewritePattern<GenerateOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
+ PatternRewriter &rewriter) const final {
+ auto resultType =
+ tensorFromElements.getResult().getType().cast<RankedTensorType>();
+
+ if (resultType.hasStaticShape())
+ return failure();
+
+ SmallVector<Value, 4> newOperands;
+ SmallVector<int64_t, 4> newShape;
+ auto operandsIt = tensorFromElements.dynamicExtents().begin();
+
+ for (int64_t dim : resultType.getShape()) {
+ if (dim != RankedTensorType::kDynamicSize) {
+ newShape.push_back(dim);
+ continue;
+ }
+ APInt index;
+ if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
+ newShape.push_back(RankedTensorType::kDynamicSize);
+ newOperands.push_back(*operandsIt++);
+ continue;
+ }
+ newShape.push_back(index.getSExtValue());
+ operandsIt++;
+ }
+
+ if (newOperands.size() == tensorFromElements.dynamicExtents().size())
+ return failure();
+
+ auto loc = tensorFromElements.getLoc();
+ auto newOp = rewriter.create<GenerateOp>(
+ loc, RankedTensorType::get(newShape, resultType.getElementType()),
+ newOperands);
+ rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
+ newOp.body().begin());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
+ newOp);
+ return success();
+ }
+};
+
+/// Canonicalizes the pattern of the form
+///
+/// %tensor = tensor.generate %x {
+/// ^bb0(%arg0: index): // no predecessors
+/// <computation>
+/// yield %1 : index
+/// } : tensor<?xindex>
+/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
+///
+/// to just <computation> with %arg0 replaced by %c0. We only do this if the
+/// tensor.generate operation has no side-effects.
+struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+ PatternRewriter &rewriter) const final {
+ auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
+ if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
+ return failure();
+
+ BlockAndValueMapping mapping;
+ Block *body = tensorFromElements.getBody();
+ mapping.map(body->getArguments(), extract.indices());
+ for (auto &op : body->without_terminator())
+ rewriter.clone(op, mapping);
+
+ auto yield = cast<YieldOp>(body->getTerminator());
+
+ rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
+ return success();
+ }
+};
+
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
+/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
+///
+/// to
+///
+/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
+struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+ PatternRewriter &rewriter) const final {
+ auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
+ if (!tensorCast)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
+ extract.indices());
+ return success();
+ }
+};
+
+} // namespace
+
+void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ // TODO: Move extract patterns to tensor::ExtractOp.
+ results.insert<ExtractFromTensorGenerate, ExtractFromTensorCast,
+ StaticTensorGenerate>(context);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 05ff96fb8d69..66de78758692 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -12,6 +12,7 @@
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
@@ -48,10 +49,97 @@ class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
};
} // namespace
+namespace {
+class BufferizeFromElementsOp
+ : public OpConversionPattern<tensor::FromElementsOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tensor::FromElementsOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ int numberOfElements = op.elements().size();
+ auto resultType = MemRefType::get(
+ {numberOfElements}, op.getType().cast<TensorType>().getElementType());
+ Value result = rewriter.create<AllocOp>(op.getLoc(), resultType);
+ for (auto element : llvm::enumerate(op.elements())) {
+ Value index =
+ rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
+ rewriter.create<StoreOp>(op.getLoc(), element.value(), result, index);
+ }
+ rewriter.replaceOp(op, {result});
+ return success();
+ }
+};
+} // namespace
+
+namespace {
+class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tensor::GenerateOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Allocate memory.
+ Location loc = op.getLoc();
+ tensor::GenerateOp::Adaptor transformed(operands);
+ RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
+ MemRefType memrefType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ Value result =
+ rewriter.create<AllocOp>(loc, memrefType, transformed.dynamicExtents());
+
+ // Collect loop bounds.
+ int64_t rank = tensorType.getRank();
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ SmallVector<Value, 4> lowerBounds(rank, zero);
+ SmallVector<Value, 4> steps(rank, one);
+ SmallVector<Value, 4> upperBounds;
+ int nextDynamicIndex = 0;
+ for (int i = 0; i < rank; i++) {
+ Value upperBound =
+ tensorType.isDynamicDim(i)
+ ? transformed.dynamicExtents()[nextDynamicIndex++]
+ : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i));
+ upperBounds.push_back(upperBound);
+ }
+
+ // Generate tensor elements with a parallel loop that stores into
+ // each element of the resulting memref.
+ //
+ // This is a bit tricky. We cannot simply clone the ops because when an op
+ // is cloned, it must be legalized. However, we want to allow arbitrary ops
+ // in the body that we don't necessarily have legalization patterns for as
+ // part of this dialect conversion invocation.
+ //
+ // To accomplish this, we use mergeBlockBefore to "move" this op's body
+ // into the scf.parallel's body.
+ auto parallel =
+ rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
+ Block *parallelBody = parallel.getBody();
+ rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
+ parallelBody->getArguments());
+ // Replace the inlined yield op with a store op. The scf.parallel's builder
+ // already populated an scf.yield at the end, so we don't need to worry
+ // about creating that.
+ Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
+ rewriter.setInsertionPointAfter(elementYield);
+ rewriter.replaceOpWithNewOp<StoreOp>(elementYield,
+ elementYield->getOperands()[0], result,
+ parallelBody->getArguments());
+
+ rewriter.replaceOp(op, {result});
+ return success();
+ }
+};
+} // namespace
+
void mlir::populateTensorBufferizePatterns(
MLIRContext *context, BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeCastOp, BufferizeExtractOp>(typeConverter, context);
+ patterns.insert<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
+ BufferizeGenerateOp>(typeConverter, context);
}
namespace {
@@ -62,9 +150,13 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
OwningRewritePatternList patterns;
ConversionTarget target(*context);
+ populateBufferizeMaterializationLegality(target);
+
populateTensorBufferizePatterns(context, typeConverter, patterns);
- target.addIllegalOp<tensor::CastOp, tensor::ExtractOp>();
+ target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
+ tensor::FromElementsOp, tensor::GenerateOp>();
target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalDialect<scf::SCFDialect>();
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 141f8caebb57..6d29bd56dca6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
+ MLIRSCF
MLIRTensor
MLIRTransforms
)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h
index fd1f1cf22bd6..bd4a61e6b7ee 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h
@@ -13,6 +13,10 @@
namespace mlir {
+namespace scf {
+class SCFDialect;
+} // end namespace scf
+
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 9f7a20ab9de6..2bd4a1d34901 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -87,14 +87,14 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
// -----
-// Lower `const_shape` to `tensor_from_elements`.
+// Lower `const_shape` to `tensor.from_elements`.
// CHECK-LABEL: @const_shape
// CHECK-SAME: () -> tensor<?xindex>
func @const_shape() -> tensor<?xindex> {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[C3:.*]] = constant 3 : index
- // CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]]
+ // CHECK: %[[TENSOR3:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]]
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
@@ -107,7 +107,7 @@ func @const_shape() -> tensor<?xindex> {
// CHECK-LABEL: func @const_shape_zero_elements
// CHECK-SAME: () -> tensor<?xindex>
func @const_shape_zero_elements() -> tensor<?xindex> {
- // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
+ // CHECK: %[[TENSOR:.*]] = tensor.from_elements : tensor<0xindex>
// CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%shape = shape.const_shape [] : tensor<?xindex>
@@ -204,7 +204,7 @@ func @shape_of(%arg : tensor<*xf32>) {
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @shape_of_unranked(%arg : tensor<*xf32>) {
// CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
- // CHECK: %[[SHAPE:.*]] = dynamic_tensor_from_elements %[[RANK]] {
+ // CHECK: %[[SHAPE:.*]] = tensor.generate %[[RANK]] {
// CHECK: ^bb0(%[[I:.*]]: index):
// CHECK: %[[EXTENT:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
// CHECK: yield %[[EXTENT]] : index
@@ -233,7 +233,7 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
- // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex>
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex>
%shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
return
}
@@ -244,7 +244,7 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
// CHECK-LABEL: @shape_of_zero_d
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
func @shape_of_zero_d(%arg : tensor<f32>) {
- // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex>
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements : tensor<0xindex>
%shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex>
return
}
@@ -259,7 +259,7 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
// CHECK-DAG: %[[C5:.*]] = constant 5 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
- // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex>
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex>
%shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
return
}
@@ -321,7 +321,7 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
- // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
+ // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
@@ -361,7 +361,7 @@ func @broadcast_known_
diff erent_extents(%a : tensor<2xindex>, %b : tensor<3xinde
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
- // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
+ // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 4e8f1282c36b..10310542f138 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -11,56 +11,6 @@ func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
return %0 : index
}
-// CHECK-LABEL: func @dynamic_tensor_from_elements(
-// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
-// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
-// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex>
-// CHECK: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[C1:.*]] = constant 1 : index
-// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
-// CHECK: %[[ARG_MEMREF:.*]] = tensor_to_memref %[[ARG]] : memref<*xf32>
-// CHECK: %[[ELEM:.*]] = dim %[[ARG_MEMREF]], %[[I]] : memref<*xf32>
-// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<?xindex>
-// CHECK: return %[[RET]] : tensor<?xindex>
-// CHECK: }
-func @dynamic_tensor_from_elements(%arg: tensor<*xf32>, %rank: index) -> tensor<?xindex> {
- %result = dynamic_tensor_from_elements %rank {
- ^bb0(%i : index):
- %elem = dim %arg, %i : tensor<*xf32>
- yield %elem : index
- } : tensor<?xindex>
- return %result : tensor<?xindex>
-}
-
-// Additional test that checks the logic for intermixed static and dynamic
-// extents.
-//
-// CHECK-LABEL: func @dynamic_tensor_from_elements_static_and_dynamic(
-// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
-// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex>
-// CHECK: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[C1:.*]] = constant 1 : index
-// CHECK: %[[C16:.*]] = constant 16 : index
-// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
-// CHECK: %[[VAL_7:.*]] = addi %[[I]], %[[J]] : index
-// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
-// CHECK: scf.yield
-// CHECK: }
-// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<16x?xindex>
-// CHECK: return %[[RET]] : tensor<16x?xindex>
-// CHECK: }
-func @dynamic_tensor_from_elements_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
- %result = dynamic_tensor_from_elements %arg0 {
- ^bb0(%i: index, %j: index):
- %sum = addi %i, %j : index
- yield %sum : index
- } : tensor<16x?xindex>
- return %result : tensor<16x?xindex>
-}
-
// CHECK-LABEL: func @select(
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
@@ -74,36 +24,3 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
%0 = select %arg0, %arg1, %arg2 : tensor<f32>
return %0 : tensor<f32>
}
-
-// CHECK-LABEL: func @tensor_from_elements(
-// CHECK-SAME: %[[ELEM0:.*]]: index,
-// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
-// CHECK: %[[MEMREF:.*]] = alloc()
-// CHECK: %[[C0:.*]] = constant 0 : index
-// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
-// CHECK: %[[C1:.*]] = constant 1 : index
-// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
-// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]]
-// CHECK: return %[[RET]] : tensor<2xindex>
-func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
- %0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex>
- return %0 : tensor<2xindex>
-}
-
-// The dynamic_tensor_from_elements op needs to put its body into the
-// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
-// the body because that would require the cloned ops to be legalized
-// immediately, which is usually not possible since they might be from various
-// other dialects.
-//
-// CHECK-LABEL: func @unknown_ops_in_body
-func @unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
- // CHECK-NOT: dynamic_tensor_from_elements
- %tensor = dynamic_tensor_from_elements %arg0 {
- ^bb0(%iv: index):
- // CHECK: test.source
- %0 = "test.source"() : () -> index
- yield %0 : index
- } : tensor<?xindex>
- return %tensor : tensor<?xindex>
-}
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index e7e4d4f49222..8187c2f3215d 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -59,16 +59,16 @@ func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf
return %1 : f32
}
-// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx
-// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements(
+// Test case: Folding of dim(tensor.generate %idx) -> %idx
+// CHECK-LABEL: func @dim_of_tensor.generate(
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
// CHECK-NOT: dim
// CHECK: return %[[IDX1]] : index
-func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index {
+func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
%c3 = constant 3 : index
- %0 = dynamic_tensor_from_elements %arg0, %arg1 {
+ %0 = tensor.generate %arg0, %arg1 {
^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
- yield %c3 : index
+ tensor.yield %c3 : index
} : tensor<2x?x4x?x5xindex>
%1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
return %1 : index
diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 48d2ae23466c..2d6e0342786c 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -16,72 +16,6 @@ func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
// -----
-func @dynamic_tensor_from_elements(%m : index)
- -> tensor<?x3x?xf32> {
- // expected-error @+1 {{must have as many index operands as dynamic extents in the result type}}
- %tnsr = dynamic_tensor_from_elements %m {
- ^bb0(%i : index, %j : index, %k : index):
- %elem = constant 8.0 : f32
- yield %elem : f32
- } : tensor<?x3x?xf32>
- return %tnsr : tensor<?x3x?xf32>
-}
-
-// -----
-
-func @dynamic_tensor_from_elements(%m : index, %n : index)
- -> tensor<?x3x?xf32> {
- // expected-error @+1 {{must have one body argument per input dimension}}
- %tnsr = dynamic_tensor_from_elements %m, %n {
- ^bb0(%i : index, %j : index):
- %elem = constant 8.0 : f32
- yield %elem : f32
- } : tensor<?x3x?xf32>
- return %tnsr : tensor<?x3x?xf32>
-}
-
-// -----
-
-func @dynamic_tensor_from_elements(%m : index, %n : index)
- -> tensor<?x3x?xf32> {
- // expected-error @+1 {{all body arguments must be index}}
- %tnsr = dynamic_tensor_from_elements %m, %n {
- ^bb0(%i : index, %j : index, %k : i64):
- %elem = constant 8.0 : f32
- yield %elem : f32
- } : tensor<?x3x?xf32>
- return %tnsr : tensor<?x3x?xf32>
-}
-
-// -----
-
-func @dynamic_tensor_from_elements(%m : index, %n : index)
- -> tensor<?x3x?xf32> {
- // expected-error @+2 {{op expects regions to end with 'std.yield', found 'std.return'}}
- // expected-note @+1 {{in custom textual format, the absence of terminator implies 'std.yield'}}
- %tnsr = dynamic_tensor_from_elements %m, %n {
- ^bb0(%i : index, %j : index, %k : index):
- %elem = constant 8.0 : f32
- return %elem : f32
- } : tensor<?x3x?xf32>
- return %tnsr : tensor<?x3x?xf32>
-}
-
-// -----
-
-func @dynamic_tensor_from_elements(%m : index, %n : index)
- -> tensor<?x3x?xf32> {
- // expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}}
- %tnsr = dynamic_tensor_from_elements %m, %n {
- ^bb0(%i : index, %j : index, %k : index):
- %elem = constant 8 : i32
- yield %elem : i32
- } : tensor<?x3x?xf32>
- return %tnsr : tensor<?x3x?xf32>
-}
-
-// -----
-
func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map}}
transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index cd173670ae54..e81d0fa03b7d 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -32,17 +32,6 @@ func @assert(%arg : i1) {
return
}
-// CHECK-LABEL: @dynamic_tensor_from_elements
-func @dynamic_tensor_from_elements(%m : index, %n : index)
- -> tensor<?x3x?xf32> {
- %tnsr = dynamic_tensor_from_elements %m, %n {
- ^bb0(%i : index, %j : index, %k : index):
- %elem = constant 8.0 : f32
- yield %elem : f32
- } : tensor<?x3x?xf32>
- return %tnsr : tensor<?x3x?xf32>
-}
-
// CHECK-LABEL: @atan
func @atan(%arg : f32) -> f32 {
%result = atan %arg : f32
@@ -107,4 +96,3 @@ func @read_global_memref() {
%1 = tensor_load %0 : memref<2xf32>
return
}
-
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 0e55040ec116..abc7d2af5676 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -33,14 +33,96 @@ func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
return %0 : tensor<*xf32>
}
-// CHECK-LABEL: func @extract(
+// CHECK-LABEL: func @tensor.extract(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> f32 {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32>
// CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref<?xf32>
// CHECK: return %[[RET]] : f32
// CHECK: }
-func @extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
+func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
%0 = tensor.extract %arg0[%arg1] : tensor<?xf32>
return %0 : f32
}
+
+// CHECK-LABEL: func @tensor.from_elements(
+// CHECK-SAME: %[[ELEM0:.*]]: index,
+// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
+// CHECK: %[[MEMREF:.*]] = alloc()
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
+// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]]
+// CHECK: return %[[RET]] : tensor<2xindex>
+func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
+ %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
+ return %0 : tensor<2xindex>
+}
+
+// CHECK-LABEL: func @tensor.generate(
+// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
+// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
+// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex>
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
+// CHECK: %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<?xindex>
+// CHECK: return %[[RET]] : tensor<?xindex>
+// CHECK: }
+func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> {
+ %result = tensor.generate %dynamic_extent {
+ ^bb0(%i : index):
+ %elem = dim %arg, %i : tensor<*xf32>
+ tensor.yield %elem : index
+ } : tensor<?xindex>
+ return %result : tensor<?xindex>
+}
+
+// Additional test that checks the logic for intermixed static and dynamic
+// extents.
+//
+// CHECK-LABEL: func @tensor.generate_static_and_dynamic(
+// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
+// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex>
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[C16:.*]] = constant 16 : index
+// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
+// CHECK: %[[VAL_7:.*]] = addi %[[I]], %[[J]] : index
+// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<16x?xindex>
+// CHECK: return %[[RET]] : tensor<16x?xindex>
+// CHECK: }
+func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
+ %result = tensor.generate %arg0 {
+ ^bb0(%i: index, %j: index):
+ %sum = addi %i, %j : index
+ tensor.yield %sum : index
+ } : tensor<16x?xindex>
+ return %result : tensor<16x?xindex>
+}
+
+// The tensor.generate op needs to put its body into the
+// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
+// the body because that would require the cloned ops to be legalized
+// immediately, which is usually not possible since they might be from various
+// other dialects.
+//
+// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
+func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
+ // CHECK-NOT: tensor.generate
+ %tensor = tensor.generate %arg0 {
+ ^bb0(%iv: index):
+ // CHECK: test.source
+ %0 = "test.source"() : () -> index
+ tensor.yield %0 : index
+ } : tensor<?xindex>
+ return %tensor : tensor<?xindex>
+}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9dcd4da13cc5..ae145934ef4d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -107,3 +107,90 @@ func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 {
%result = tensor.extract %casted[%c0] : tensor<?xf32>
return %result : f32
}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_tensor.from_elements
+func @extract_from_tensor.from_elements(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %c0 = constant 0 : index
+ %tensor = tensor.from_elements %element : tensor<1xindex>
+ %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex>
+ // CHECK: [[ARG]] : index
+ return %extracted_element : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_tensor.generate
+// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {
+ %size = rank %tensor : tensor<*xf32>
+ // CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]]
+ %0 = tensor.generate %size {
+ ^bb0(%arg0: index):
+ %1 = dim %tensor, %arg0 : tensor<*xf32>
+ tensor.yield %1 : index
+ } : tensor<?xindex>
+ %1 = tensor.extract %0[%idx] : tensor<?xindex>
+ // CHECK-NEXT: return %[[RES]]
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_tensor.generate_2d
+// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
+ %size = rank %tensor : tensor<*xf32>
+ // CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]]
+ // CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]]
+ // CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]]
+ %0 = tensor.generate %size, %size {
+ ^bb0(%arg0: index, %arg1: index):
+ %1 = dim %tensor, %arg0 : tensor<*xf32>
+ %2 = dim %tensor, %arg1 : tensor<*xf32>
+ %3 = addi %1, %2 : index
+ tensor.yield %3 : index
+ } : tensor<?x?xindex>
+ %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex>
+ // CHECK-NEXT: return %[[RES]]
+ return %4 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_tensor.generate_sideeffects
+// CHECK-SAME: %[[IDX:.*]]: index
+func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
+ %size = rank %tensor : tensor<*xf32>
+ %mem = alloc(%size) : memref<?xindex>
+ // CHECK: %[[DTENSOR:.*]] = tensor.generate
+ %0 = tensor.generate %size {
+ ^bb0(%arg0: index):
+ %1 = dim %tensor, %arg0 : tensor<*xf32>
+ store %1, %mem[%arg0] : memref<?xindex>
+ tensor.yield %1 : index
+ } : tensor<?xindex>
+ // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]]
+ %1 = tensor.extract %0[%idx] : tensor<?xindex>
+ // CHECK-NEXT: return %[[RES]]
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @static_tensor.generate
+// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index)
+func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
+ %c5 = constant 5 : index
+ // CHECK: tensor.generate %[[SIZE1]], %[[SIZE4]]
+ %0 = tensor.generate %size1, %c5, %size4 {
+ ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
+ %1 = constant 32 : index
+ tensor.yield %1 : index
+ // CHECK: : tensor<3x?x5x7x?xindex>
+ } : tensor<3x?x?x7x?xindex>
+ // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
+ return %0 : tensor<3x?x?x7x?xindex>
+}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index cb38ac884bc3..11866990c885 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -13,3 +13,87 @@ func @extract_too_many_indices(%arg0: tensor<?xf32>) {
%0 = tensor.extract %arg0[] : tensor<?xf32>
return
}
+
+// -----
+
+func @tensor.from_elements_wrong_result_type() {
+ // expected-error at +2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
+ %c0 = constant 0 : i32
+ %0 = tensor.from_elements %c0 : tensor<*xi32>
+ return
+}
+
+// -----
+
+func @tensor.from_elements_wrong_elements_count() {
+ // expected-error at +2 {{1 operands present, but expected 2}}
+ %c0 = constant 0 : index
+ %0 = tensor.from_elements %c0 : tensor<2xindex>
+ return
+}
+
+// -----
+
+func @tensor.generate(%m : index)
+ -> tensor<?x3x?xf32> {
+ // expected-error @+1 {{must have as many index operands as dynamic extents in the result type}}
+ %tnsr = tensor.generate %m {
+ ^bb0(%i : index, %j : index, %k : index):
+ %elem = constant 8.0 : f32
+ tensor.yield %elem : f32
+ } : tensor<?x3x?xf32>
+ return %tnsr : tensor<?x3x?xf32>
+}
+
+// -----
+
+func @tensor.generate(%m : index, %n : index)
+ -> tensor<?x3x?xf32> {
+ // expected-error @+1 {{must have one body argument per input dimension}}
+ %tnsr = tensor.generate %m, %n {
+ ^bb0(%i : index, %j : index):
+ %elem = constant 8.0 : f32
+ tensor.yield %elem : f32
+ } : tensor<?x3x?xf32>
+ return %tnsr : tensor<?x3x?xf32>
+}
+
+// -----
+
+func @tensor.generate(%m : index, %n : index)
+ -> tensor<?x3x?xf32> {
+ // expected-error @+1 {{all body arguments must be index}}
+ %tnsr = tensor.generate %m, %n {
+ ^bb0(%i : index, %j : index, %k : i64):
+ %elem = constant 8.0 : f32
+ tensor.yield %elem : f32
+ } : tensor<?x3x?xf32>
+ return %tnsr : tensor<?x3x?xf32>
+}
+
+// -----
+
+func @tensor.generate(%m : index, %n : index)
+ -> tensor<?x3x?xf32> {
+ // expected-error @+2 {{op expects regions to end with 'tensor.yield', found 'std.return'}}
+ // expected-note @+1 {{in custom textual format, the absence of terminator implies 'tensor.yield'}}
+ %tnsr = tensor.generate %m, %n {
+ ^bb0(%i : index, %j : index, %k : index):
+ %elem = constant 8.0 : f32
+ return %elem : f32
+ } : tensor<?x3x?xf32>
+ return %tnsr : tensor<?x3x?xf32>
+}
+
+// -----
+
+func @tensor.generate(%m : index, %n : index)
+ -> tensor<?x3x?xf32> {
+ // expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}}
+ %tnsr = tensor.generate %m, %n {
+ ^bb0(%i : index, %j : index, %k : index):
+ %elem = constant 8 : i32
+ tensor.yield %elem : i32
+ } : tensor<?x3x?xf32>
+ return %tnsr : tensor<?x3x?xf32>
+}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 06db2bb237cd..9b15712058a2 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -21,3 +21,35 @@ func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
%0 = tensor.extract %arg0[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
return
}
+
+// CHECK-LABEL: func @tensor.from_elements() {
+func @tensor.from_elements() {
+ %c0 = "std.constant"() {value = 0: index} : () -> index
+ // CHECK: %0 = tensor.from_elements %c0 : tensor<1xindex>
+ %0 = tensor.from_elements %c0 : tensor<1xindex>
+
+ %c1 = "std.constant"() {value = 1: index} : () -> index
+ // CHECK: %1 = tensor.from_elements %c0, %c1 : tensor<2xindex>
+ %1 = tensor.from_elements %c0, %c1 : tensor<2xindex>
+
+ %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32
+ // CHECK: [[C0_F32:%.*]] = constant
+ // CHECK: %2 = tensor.from_elements [[C0_F32]] : tensor<1xf32>
+ %2 = tensor.from_elements %c0_f32 : tensor<1xf32>
+
+ // CHECK: tensor.from_elements : tensor<0xindex>
+ %3 = tensor.from_elements : tensor<0xindex>
+
+ return
+}
+
+// CHECK-LABEL: @tensor.generate
+func @tensor.generate(%m : index, %n : index)
+ -> tensor<?x3x?xf32> {
+ %tnsr = tensor.generate %m, %n {
+ ^bb0(%i : index, %j : index, %k : index):
+ %elem = constant 8.0 : f32
+ tensor.yield %elem : f32
+ } : tensor<?x3x?xf32>
+ return %tnsr : tensor<?x3x?xf32>
+}
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 1deeb3ec49d0..0e86050870ff 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -675,27 +675,6 @@ func @calls(%arg0: i32) {
return
}
-// CHECK-LABEL: func @tensor_from_elements() {
-func @tensor_from_elements() {
- %c0 = "std.constant"() {value = 0: index} : () -> index
- // CHECK: %0 = tensor_from_elements %c0 : tensor<1xindex>
- %0 = tensor_from_elements %c0 : tensor<1xindex>
-
- %c1 = "std.constant"() {value = 1: index} : () -> index
- // CHECK: %1 = tensor_from_elements %c0, %c1 : tensor<2xindex>
- %1 = tensor_from_elements %c0, %c1 : tensor<2xindex>
-
- %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32
- // CHECK: [[C0_F32:%.*]] = constant
- // CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32>
- %2 = tensor_from_elements %c0_f32 : tensor<1xf32>
-
- // CHECK: tensor_from_elements : tensor<0xindex>
- %3 = tensor_from_elements : tensor<0xindex>
-
- return
-}
-
// CHECK-LABEL: func @memref_cast(%arg0
func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) {
// CHECK: %0 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 45ebfff34d57..364c9155e2da 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -541,24 +541,6 @@ func @cmpf_canonical_type_mismatch(%a : f32, %b : f64) { // expected-note {{prio
// -----
-func @tensor_from_elements_wrong_result_type() {
- // expected-error at +2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
- %c0 = constant 0 : i32
- %0 = tensor_from_elements %c0 : tensor<*xi32>
- return
-}
-
-// -----
-
-func @tensor_from_elements_wrong_elements_count() {
- // expected-error at +2 {{1 operands present, but expected 2}}
- %c0 = constant 0 : index
- %0 = tensor_from_elements %c0 : tensor<2xindex>
- return
-}
-
-// -----
-
func @index_cast_index_to_index(%arg0: index) {
// expected-error at +1 {{are cast incompatible}}
%0 = index_cast %arg0: index to index
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 5b6f8cde9fec..62c07dd8a063 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1032,93 +1032,6 @@ func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: i
// -----
-// CHECK-LABEL: func @extract_from_tensor_from_elements
-func @extract_from_tensor_from_elements(%element : index) -> index {
- // CHECK-SAME: ([[ARG:%.*]]: index)
- %c0 = constant 0 : index
- %tensor = tensor_from_elements %element : tensor<1xindex>
- %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex>
- // CHECK: [[ARG]] : index
- return %extracted_element : index
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements
-// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
-func @extract_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index {
- %size = rank %tensor : tensor<*xf32>
- // CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]]
- %0 = dynamic_tensor_from_elements %size {
- ^bb0(%arg0: index):
- %1 = dim %tensor, %arg0 : tensor<*xf32>
- yield %1 : index
- } : tensor<?xindex>
- %1 = tensor.extract %0[%idx] : tensor<?xindex>
- // CHECK-NEXT: return %[[RES]]
- return %1 : index
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_2d
-// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
-func @extract_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
- %size = rank %tensor : tensor<*xf32>
- // CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]]
- // CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]]
- // CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]]
- %0 = dynamic_tensor_from_elements %size, %size {
- ^bb0(%arg0: index, %arg1: index):
- %1 = dim %tensor, %arg0 : tensor<*xf32>
- %2 = dim %tensor, %arg1 : tensor<*xf32>
- %3 = addi %1, %2 : index
- yield %3 : index
- } : tensor<?x?xindex>
- %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex>
- // CHECK-NEXT: return %[[RES]]
- return %4 : index
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_sideeffects
-// CHECK-SAME: %[[IDX:.*]]: index
-func @extract_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
- %size = rank %tensor : tensor<*xf32>
- %mem = alloc(%size) : memref<?xindex>
- // CHECK: %[[DTENSOR:.*]] = dynamic_tensor_from_elements
- %0 = dynamic_tensor_from_elements %size {
- ^bb0(%arg0: index):
- %1 = dim %tensor, %arg0 : tensor<*xf32>
- store %1, %mem[%arg0] : memref<?xindex>
- yield %1 : index
- } : tensor<?xindex>
- // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]]
- %1 = tensor.extract %0[%idx] : tensor<?xindex>
- // CHECK-NEXT: return %[[RES]]
- return %1 : index
-}
-
-// -----
-
-// CHECK-LABEL: @static_dynamic_tensor_from_elements
-// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index)
-func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
- %c5 = constant 5 : index
- // CHECK: dynamic_tensor_from_elements %[[SIZE1]], %[[SIZE4]]
- %0 = dynamic_tensor_from_elements %size1, %c5, %size4 {
- ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
- %1 = constant 32 : index
- yield %1 : index
- // CHECK: : tensor<3x?x5x7x?xindex>
- } : tensor<3x?x?x7x?xindex>
- // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
- return %0 : tensor<3x?x?x7x?xindex>
-}
-
-// -----
-
// CHECK-LABEL: func @subtensor
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index
func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
More information about the llvm-branch-commits
mailing list