[Mlir-commits] [mlir] f77e9f8 - [mlir] Extend `tensor.from_elements` to support N-D case.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Dec 16 05:59:16 PST 2021
Author: Alexander Belyaev
Date: 2021-12-16T14:52:41+01:00
New Revision: f77e9f876839e70d8adf0f02e4b2018cea6aedd5
URL: https://github.com/llvm/llvm-project/commit/f77e9f876839e70d8adf0f02e4b2018cea6aedd5
DIFF: https://github.com/llvm/llvm-project/commit/f77e9f876839e70d8adf0f02e4b2018cea6aedd5.diff
LOG: [mlir] Extend `tensor.from_elements` to support N-D case.
RFC: https://llvm.discourse.group/t/rfc-extend-tensor-fromelementsop-to-n-d/4715
Differential Revision: https://reviews.llvm.org/D115821
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 21331fc649cd5..1a95d921fee22 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -312,22 +312,29 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
NoSideEffect,
TypesMatchWith<"operand types match result element type",
"result", "elements", "SmallVector<Type, 2>("
- "$_self.cast<ShapedType>().getDimSize(0), "
+ "$_self.cast<ShapedType>().getNumElements(), "
"$_self.cast<ShapedType>().getElementType())">
]> {
string summary = "tensor from elements operation.";
string description = [{
- Create a 1D tensor from a range of same-type arguments.
+ Create a N-D tensor from a range of same-type arguments. The number of
+ provided `elements` should equal to the number of the elements in the
+ result type. The `elements` correspond to a flattened tensor.
Example:
```mlir
- tensor.from_elements i_1, ..., i_N : tensor<Nxindex>
+ tensor.from_elements %a, %b, %c, %d, %e, %f : tensor<2x3xindex>
```
+
+ will result in a tensor
+
+ [[%a, %b, %c]
+ [%d, %e, %f]]
}];
let arguments = (ins Variadic<AnyType>:$elements);
- let results = (outs 1DTensorOf<[AnyType]>:$result);
+ let results = (outs AnyStaticShapeTensor:$result);
let assemblyFormat = "$elements attr-dict `:` type($result)";
@@ -336,7 +343,7 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
let skipDefaultBuilders = 1;
let builders = [
- OpBuilder<(ins "Type":$elementType, "ValueRange":$elements)>,
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>,
// Special case builder for when `elements` has size >=1.
OpBuilder<(ins "ValueRange":$elements)>
];
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 5a1af7b33132e..deb4cd502ab98 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -193,10 +193,10 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
extentOperands.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
}
- Type indexTy = rewriter.getIndexType();
+ Type resultTy =
+ RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
Value tensor =
- rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
- Type resultTy = RankedTensorType::get({op.getShape().size()}, indexTy);
+ rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
return success();
}
@@ -569,7 +569,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
// Materialize extent tensor.
Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
- loc, rewriter.getIndexType(), extentValues);
+ loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
+ extentValues);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
staticExtentTensor);
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 4ddd88c342b99..9be95a1815334 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -28,8 +28,8 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
// A detensored value is converted back by creating a new tensor from its
// element(s).
- auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
- loc, inputs[0].getType(), inputs[0]);
+ auto createNewTensorOp =
+ builder.create<tensor::FromElementsOp>(loc, inputs[0]);
// FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
// a tensor<dtype> instead.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ecdd966a3c35e..37906adb29186 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -364,17 +364,17 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
void FromElementsOp::build(OpBuilder &builder, OperationState &result,
- Type elementType, ValueRange elements) {
- Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
- elementType);
+ Type resultType, ValueRange elements) {
result.addOperands(elements);
- result.addTypes(resultTy);
+ result.addTypes(resultType);
}
void FromElementsOp::build(OpBuilder &builder, OperationState &result,
ValueRange elements) {
assert(!elements.empty() && "expected at least one element");
- build(builder, result, elements.front().getType(), elements);
+ Type resultType = RankedTensorType::get(
+ {static_cast<int64_t>(elements.size())}, elements.front().getType());
+ build(builder, result, resultType, elements);
}
OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
@@ -397,23 +397,27 @@ struct ExtractElementFromTensorFromElements
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
- if (extract.indices().size() != 1)
- return failure();
-
auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
- if (tensorFromElements == nullptr)
- return failure();
-
- APInt index;
- if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
+ if (!tensorFromElements)
return failure();
+ auto tensorType = tensorFromElements.getType().cast<RankedTensorType>();
+ auto rank = tensorType.getRank();
+ SmallVector<APInt, 3> indices(rank);
+ int64_t flatIndex = 0;
+ int64_t stride = 1;
+ for (int i = rank - 1; i >= 0; --i) {
+ APInt index;
+ if (!matchPattern(extract.indices()[i], m_ConstantInt(&index)))
+ return failure();
+ if (i < rank - 1)
+ stride *= tensorType.getDimSize(i);
+ flatIndex += index.getSExtValue() * stride;
+ }
// Prevent out of bounds accesses. This can happen in invalid code that will
// never execute.
- if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
- index.getSExtValue() < 0)
+ if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0)
return failure();
- rewriter.replaceOp(extract,
- tensorFromElements.getOperand(index.getZExtValue()));
+ rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex));
return success();
}
};
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 0fd5b2d75d677..90c82515177e8 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -65,19 +66,65 @@ struct BufferizeFromElementsOp
LogicalResult
matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- int numberOfElements = op.elements().size();
- auto resultType = MemRefType::get(
- {numberOfElements}, op.getType().cast<TensorType>().getElementType());
- Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType);
- for (auto element : llvm::enumerate(op.elements())) {
- Value index =
- rewriter.create<arith::ConstantIndexOp>(op.getLoc(), element.index());
- rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result,
- index);
+ Location loc = op.getLoc();
+ auto tensorType = op.getType().cast<RankedTensorType>();
+ auto shape = tensorType.getShape();
+
+ // Allocate a buffer for the result.
+ auto resultType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ Value buffer = rewriter.create<memref::AllocOp>(loc, resultType);
+
+ // Case: tensor<0xelem_type>.
+ if (op.elements().empty()) {
+ rewriter.replaceOp(op, {buffer});
+ return success();
}
- rewriter.replaceOp(op, {result});
+
+ // Case: tensor<elem_type>.
+ if (shape.empty()) {
+ rewriter.create<memref::StoreOp>(loc, op.elements().front(), buffer);
+ rewriter.replaceOp(op, {buffer});
+ return success();
+ }
+
+ // Create constants for the range of possible indices [0, max{shape_i}).
+ auto maxDim = *std::max_element(shape.begin(), shape.end());
+ SmallVector<Value, 2> constants;
+ constants.reserve(maxDim);
+ for (int i = 0; i < maxDim; ++i)
+ constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+
+ // Traverse all `elements` and create `memref.store` ops.
+ ImplicitLocOpBuilder b(loc, rewriter);
+ auto element_it = adaptor.elements().begin();
+ SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
+ CreateStores(/*dim=*/0, buffer, shape, constants, element_it, indices, b);
+
+ rewriter.replaceOp(op, {buffer});
return success();
}
+
+private:
+ // Implements backtracking to traverse indices of the output buffer while
+ // iterating over op.elements().
+ void CreateStores(int dim, Value buffer, ArrayRef<int64_t> shape,
+ ArrayRef<Value> constants, ValueRange::iterator &element_it,
+ SmallVectorImpl<Value> &indices,
+ ImplicitLocOpBuilder b) const {
+ if (dim == shape.size() - 1) {
+ for (int i = 0; i < shape.back(); ++i) {
+ indices.back() = constants[i];
+ b.create<memref::StoreOp>(*element_it, buffer, indices);
+ ++element_it;
+ }
+ return;
+ }
+ for (int i = 0; i < shape[dim]; ++i) {
+ indices[dim] = constants[i];
+ CreateStores(dim + 1, buffer, shape, constants, element_it, indices, b);
+ }
+ }
};
struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 5b3bb149d6180..c6dd6b9310d92 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -65,21 +65,116 @@ func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
return %0 : f32
}
-// CHECK-LABEL: func @tensor.from_elements(
+// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> {
+// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<0xindex>
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK: return %[[RET]] : tensor<0xindex>
+func @tensor.from_elements_no_elements() -> tensor<0xindex> {
+ %0 = tensor.from_elements : tensor<0xindex>
+ return %0 : tensor<0xindex>
+}
+
+// CHECK-LABEL: func @tensor.from_elements_0d(
+// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> {
+// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<index>
+// CHECK: store %[[ELEM0]], %[[MEMREF]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK: return %[[RET]] : tensor<index>
+func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
+ %0 = tensor.from_elements %arg0 : tensor<index>
+ return %0 : tensor<index>
+}
+
+// CHECK-LABEL: func @tensor.from_elements_1d(
// CHECK-SAME: %[[ELEM0:.*]]: index,
// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
-// CHECK: %[[MEMREF:.*]] = memref.alloc()
+// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<2xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
// CHECK: return %[[RET]] : tensor<2xindex>
-func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
+func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
%0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
return %0 : tensor<2xindex>
}
+// CHECK-LABEL: func @tensor.from_elements_2d(
+// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
+// CHECK-SAME: -> tensor<3x2xindex> {
+// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
+// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
+// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
+// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
+// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
+// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK: return %[[RET]] : tensor<3x2xindex>
+func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
+ %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
+ : tensor<3x2xindex>
+ return %0 : tensor<3x2xindex>
+}
+
+// CHECK-LABEL: func @tensor.from_elements_3d()
+
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
+// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
+// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
+// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
+// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
+// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
+// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
+// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
+// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
+// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
+// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
+// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
+
+// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32>
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+
+// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
+// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]]
+// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]]
+// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]]
+// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]]
+// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]]
+// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]]
+// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]]
+// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]]
+// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]]
+
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK: return %[[RET]] : tensor<3x2x2xf32>
+func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
+ %f0 = arith.constant 0.0 : f32
+ %f1 = arith.constant 1.0 : f32
+ %f2 = arith.constant 2.0 : f32
+ %f3 = arith.constant 3.0 : f32
+ %f4 = arith.constant 4.0 : f32
+ %f5 = arith.constant 5.0 : f32
+ %f6 = arith.constant 6.0 : f32
+ %f7 = arith.constant 7.0 : f32
+ %f8 = arith.constant 8.0 : f32
+ %f9 = arith.constant 9.0 : f32
+ %f10 = arith.constant 10.0 : f32
+ %f11 = arith.constant 11.0 : f32
+ %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+ : tensor<3x2x2xf32>
+ return %0 : tensor<3x2x2xf32>
+}
+
// CHECK-LABEL: func @tensor.generate(
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ec9601e269939..5331e50790a6a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -135,6 +135,61 @@ func @extract_from_tensor.from_elements(%element : index) -> index {
// -----
+// CHECK-LABEL: func @extract_from_tensor.from_elements_3d
+func @extract_from_tensor.from_elements_3d()
+ -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+ %f0 = arith.constant 0.0 : f32
+ %f1 = arith.constant 1.0 : f32
+ %f2 = arith.constant 2.0 : f32
+ %f3 = arith.constant 3.0 : f32
+ %f4 = arith.constant 4.0 : f32
+ %f5 = arith.constant 5.0 : f32
+ %f6 = arith.constant 6.0 : f32
+ %f7 = arith.constant 7.0 : f32
+ %f8 = arith.constant 8.0 : f32
+ %f9 = arith.constant 9.0 : f32
+ %f10 = arith.constant 10.0 : f32
+ %f11 = arith.constant 11.0 : f32
+
+ %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+ : tensor<3x2x2xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+
+ %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
+ %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
+ %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
+ %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
+ %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
+ %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
+ %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
+ %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
+ %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
+ %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
+ %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
+ %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
+ return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
+ : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
+}
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
+// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
+// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
+// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
+// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
+// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
+// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
+// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
+// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
+// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
+// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
+// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
+
+// CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]],
+// CHECK-SAME: %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]]
+
+// -----
+
// Ensure the optimization doesn't segfault from bad constants
// CHECK-LABEL: func @extract_negative_from_tensor.from_elements
func @extract_negative_from_tensor.from_elements(%element : index) -> index {
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 564526f16370f..ece2f54d84012 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -33,7 +33,7 @@ func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
// -----
func @tensor.from_elements_wrong_result_type() {
- // expected-error at +2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
+ // expected-error at +2 {{'result' must be statically shaped tensor of any type values, but got 'tensor<*xi32>'}}
%c0 = arith.constant 0 : i32
%0 = tensor.from_elements %c0 : tensor<*xi32>
return
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 8d50d15184218..d461dffeb6d5b 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -38,21 +38,26 @@ func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>, %arg3: tensor<*
// CHECK-LABEL: func @tensor.from_elements() {
func @tensor.from_elements() {
%c0 = "arith.constant"() {value = 0: index} : () -> index
- // CHECK: %0 = tensor.from_elements %c0 : tensor<1xindex>
+ // CHECK: tensor.from_elements %c0 : tensor<1xindex>
%0 = tensor.from_elements %c0 : tensor<1xindex>
%c1 = "arith.constant"() {value = 1: index} : () -> index
- // CHECK: %1 = tensor.from_elements %c0, %c1 : tensor<2xindex>
+ // CHECK: tensor.from_elements %c0, %c1 : tensor<2xindex>
%1 = tensor.from_elements %c0, %c1 : tensor<2xindex>
%c0_f32 = "arith.constant"() {value = 0.0: f32} : () -> f32
// CHECK: [[C0_F32:%.*]] = arith.constant
- // CHECK: %2 = tensor.from_elements [[C0_F32]] : tensor<1xf32>
+ // CHECK: tensor.from_elements [[C0_F32]] : tensor<1xf32>
%2 = tensor.from_elements %c0_f32 : tensor<1xf32>
// CHECK: tensor.from_elements : tensor<0xindex>
%3 = tensor.from_elements : tensor<0xindex>
+ // CHECK: tensor.from_elements %c0, %c1, %c0, %c1, %c0, %c1 : tensor<2x3xindex>
+ %4 = tensor.from_elements %c0, %c1, %c0, %c1, %c0, %c1 : tensor<2x3xindex>
+
+ // CHECK: tensor.from_elements %c0 : tensor<index>
+ %5 = tensor.from_elements %c0 : tensor<index>
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a6b317db467cf..aae98baf1c26c 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -945,14 +945,14 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
Location loc = getLoc();
shapes.reserve(operands.size());
for (Value operand : llvm::reverse(operands)) {
- auto currShape = llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(
- 0, operand.getType().cast<RankedTensorType>().getRank()),
- [&](int64_t dim) -> Value {
+ auto rank = operand.getType().cast<RankedTensorType>().getRank();
+ auto currShape = llvm::to_vector<4>(
+ llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
}));
shapes.push_back(builder.create<tensor::FromElementsOp>(
- getLoc(), builder.getIndexType(), currShape));
+ getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
+ currShape));
}
return success();
}
More information about the Mlir-commits
mailing list