[Mlir-commits] [mlir] f77e9f8 - [mlir] Extend `tensor.from_elements` to support N-D case.

Alexander Belyaev llvmlistbot at llvm.org
Thu Dec 16 05:59:16 PST 2021


Author: Alexander Belyaev
Date: 2021-12-16T14:52:41+01:00
New Revision: f77e9f876839e70d8adf0f02e4b2018cea6aedd5

URL: https://github.com/llvm/llvm-project/commit/f77e9f876839e70d8adf0f02e4b2018cea6aedd5
DIFF: https://github.com/llvm/llvm-project/commit/f77e9f876839e70d8adf0f02e4b2018cea6aedd5.diff

LOG: [mlir] Extend `tensor.from_elements` to support N-D case.

RFC: https://llvm.discourse.group/t/rfc-extend-tensor-fromelementsop-to-n-d/4715

Differential Revision: https://reviews.llvm.org/D115821

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
    mlir/test/Dialect/Tensor/bufferize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/invalid.mlir
    mlir/test/Dialect/Tensor/ops.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 21331fc649cd5..1a95d921fee22 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -312,22 +312,29 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
     NoSideEffect,
     TypesMatchWith<"operand types match result element type",
                    "result", "elements", "SmallVector<Type, 2>("
-                   "$_self.cast<ShapedType>().getDimSize(0), "
+                   "$_self.cast<ShapedType>().getNumElements(), "
                    "$_self.cast<ShapedType>().getElementType())">
   ]> {
   string summary = "tensor from elements operation.";
   string description = [{
-    Create a 1D tensor from a range of same-type arguments.
+    Create a N-D tensor from a range of same-type arguments. The number of
+    provided `elements` should equal to the number of the elements in the
+    result type. The `elements` correspond to a flattened tensor.
 
     Example:
 
     ```mlir
-    tensor.from_elements i_1, ..., i_N :  tensor<Nxindex>
+    tensor.from_elements %a, %b, %c, %d, %e, %f :  tensor<2x3xindex>
     ```
+
+    will result in a tensor
+
+    [[%a, %b, %c]
+     [%d, %e, %f]]
   }];
 
   let arguments = (ins Variadic<AnyType>:$elements);
-  let results = (outs 1DTensorOf<[AnyType]>:$result);
+  let results = (outs AnyStaticShapeTensor:$result);
 
   let assemblyFormat = "$elements attr-dict `:` type($result)";
 
@@ -336,7 +343,7 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
 
   let skipDefaultBuilders = 1;
   let builders = [
-    OpBuilder<(ins "Type":$elementType, "ValueRange":$elements)>,
+    OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>,
     // Special case builder for when `elements` has size >=1.
     OpBuilder<(ins "ValueRange":$elements)>
   ];

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 5a1af7b33132e..deb4cd502ab98 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -193,10 +193,10 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
     extentOperands.push_back(
         rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
   }
-  Type indexTy = rewriter.getIndexType();
+  Type resultTy =
+      RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
   Value tensor =
-      rewriter.create<tensor::FromElementsOp>(loc, indexTy, extentOperands);
-  Type resultTy = RankedTensorType::get({op.getShape().size()}, indexTy);
+      rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
   return success();
 }
@@ -569,7 +569,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
 
     // Materialize extent tensor.
     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
-        loc, rewriter.getIndexType(), extentValues);
+        loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
+        extentValues);
     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
                                                 staticExtentTensor);
     return success();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 4ddd88c342b99..9be95a1815334 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -28,8 +28,8 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
 
   // A detensored value is converted back by creating a new tensor from its
   // element(s).
-  auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
-      loc, inputs[0].getType(), inputs[0]);
+  auto createNewTensorOp =
+      builder.create<tensor::FromElementsOp>(loc, inputs[0]);
 
   // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
   // a tensor<dtype> instead.

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ecdd966a3c35e..37906adb29186 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -364,17 +364,17 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
-                           Type elementType, ValueRange elements) {
-  Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
-                                        elementType);
+                           Type resultType, ValueRange elements) {
   result.addOperands(elements);
-  result.addTypes(resultTy);
+  result.addTypes(resultType);
 }
 
 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
                            ValueRange elements) {
   assert(!elements.empty() && "expected at least one element");
-  build(builder, result, elements.front().getType(), elements);
+  Type resultType = RankedTensorType::get(
+      {static_cast<int64_t>(elements.size())}, elements.front().getType());
+  build(builder, result, resultType, elements);
 }
 
 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
@@ -397,23 +397,27 @@ struct ExtractElementFromTensorFromElements
 
   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
                                 PatternRewriter &rewriter) const final {
-    if (extract.indices().size() != 1)
-      return failure();
-
     auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
-    if (tensorFromElements == nullptr)
-      return failure();
-
-    APInt index;
-    if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
+    if (!tensorFromElements)
       return failure();
+    auto tensorType = tensorFromElements.getType().cast<RankedTensorType>();
+    auto rank = tensorType.getRank();
+    SmallVector<APInt, 3> indices(rank);
+    int64_t flatIndex = 0;
+    int64_t stride = 1;
+    for (int i = rank - 1; i >= 0; --i) {
+      APInt index;
+      if (!matchPattern(extract.indices()[i], m_ConstantInt(&index)))
+        return failure();
+      if (i < rank - 1)
+        stride *= tensorType.getDimSize(i);
+      flatIndex += index.getSExtValue() * stride;
+    }
     // Prevent out of bounds accesses. This can happen in invalid code that will
     // never execute.
-    if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
-        index.getSExtValue() < 0)
+    if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0)
       return failure();
-    rewriter.replaceOp(extract,
-                       tensorFromElements.getOperand(index.getZExtValue()));
+    rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex));
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 0fd5b2d75d677..90c82515177e8 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -65,19 +66,65 @@ struct BufferizeFromElementsOp
   LogicalResult
   matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    int numberOfElements = op.elements().size();
-    auto resultType = MemRefType::get(
-        {numberOfElements}, op.getType().cast<TensorType>().getElementType());
-    Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType);
-    for (auto element : llvm::enumerate(op.elements())) {
-      Value index =
-          rewriter.create<arith::ConstantIndexOp>(op.getLoc(), element.index());
-      rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result,
-                                       index);
+    Location loc = op.getLoc();
+    auto tensorType = op.getType().cast<RankedTensorType>();
+    auto shape = tensorType.getShape();
+
+    // Allocate a buffer for the result.
+    auto resultType =
+        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+    Value buffer = rewriter.create<memref::AllocOp>(loc, resultType);
+
+    // Case: tensor<0xelem_type>.
+    if (op.elements().empty()) {
+      rewriter.replaceOp(op, {buffer});
+      return success();
     }
-    rewriter.replaceOp(op, {result});
+
+    // Case: tensor<elem_type>.
+    if (shape.empty()) {
+      rewriter.create<memref::StoreOp>(loc, op.elements().front(), buffer);
+      rewriter.replaceOp(op, {buffer});
+      return success();
+    }
+
+    // Create constants for the range of possible indices [0, max{shape_i}).
+    auto maxDim = *std::max_element(shape.begin(), shape.end());
+    SmallVector<Value, 2> constants;
+    constants.reserve(maxDim);
+    for (int i = 0; i < maxDim; ++i)
+      constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+
+    // Traverse all `elements` and create `memref.store` ops.
+    ImplicitLocOpBuilder b(loc, rewriter);
+    auto element_it = adaptor.elements().begin();
+    SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
+    CreateStores(/*dim=*/0, buffer, shape, constants, element_it, indices, b);
+
+    rewriter.replaceOp(op, {buffer});
     return success();
   }
+
+private:
+  // Implements backtracking to traverse indices of the output buffer while
+  // iterating over op.elements().
+  void CreateStores(int dim, Value buffer, ArrayRef<int64_t> shape,
+                    ArrayRef<Value> constants, ValueRange::iterator &element_it,
+                    SmallVectorImpl<Value> &indices,
+                    ImplicitLocOpBuilder b) const {
+    if (dim == shape.size() - 1) {
+      for (int i = 0; i < shape.back(); ++i) {
+        indices.back() = constants[i];
+        b.create<memref::StoreOp>(*element_it, buffer, indices);
+        ++element_it;
+      }
+      return;
+    }
+    for (int i = 0; i < shape[dim]; ++i) {
+      indices[dim] = constants[i];
+      CreateStores(dim + 1, buffer, shape, constants, element_it, indices, b);
+    }
+  }
 };
 
 struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 5b3bb149d6180..c6dd6b9310d92 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -65,21 +65,116 @@ func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
   return %0 : f32
 }
 
-// CHECK-LABEL:   func @tensor.from_elements(
+// CHECK-LABEL:   func @tensor.from_elements_no_elements() -> tensor<0xindex> {
+// CHECK:           %[[MEMREF:.*]] = memref.alloc() : memref<0xindex>
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK:           return %[[RET]] : tensor<0xindex>
+func @tensor.from_elements_no_elements() -> tensor<0xindex> {
+  %0 = tensor.from_elements : tensor<0xindex>
+  return %0 : tensor<0xindex>
+}
+
+// CHECK-LABEL:   func @tensor.from_elements_0d(
+// CHECK-SAME:        %[[ELEM0:.*]]: index) -> tensor<index> {
+// CHECK:           %[[MEMREF:.*]] = memref.alloc() : memref<index>
+// CHECK:           store %[[ELEM0]], %[[MEMREF]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK:           return %[[RET]] : tensor<index>
+func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
+  %0 = tensor.from_elements %arg0 : tensor<index>
+  return %0 : tensor<index>
+}
+
+// CHECK-LABEL:   func @tensor.from_elements_1d(
 // CHECK-SAME:                               %[[ELEM0:.*]]: index,
 // CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {
-// CHECK:           %[[MEMREF:.*]] = memref.alloc()
+// CHECK:           %[[MEMREF:.*]] = memref.alloc() : memref<2xindex>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
 // CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
 // CHECK:           store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
 // CHECK:           return %[[RET]] : tensor<2xindex>
-func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
+func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
   %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex>
   return %0 : tensor<2xindex>
 }
 
+// CHECK-LABEL: func @tensor.from_elements_2d(
+// CHECK-SAME:      %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
+// CHECK-SAME:      -> tensor<3x2xindex> {
+// CHECK:         %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex>
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         %[[C1:.*]] = arith.constant 1 : index
+// CHECK:         %[[C2:.*]] = arith.constant 2 : index
+// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
+// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
+// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
+// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
+// CHECK:         store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
+// CHECK:         store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
+// CHECK:         %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK:         return %[[RET]] : tensor<3x2xindex>
+func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
+  %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
+         : tensor<3x2xindex>
+  return %0 : tensor<3x2xindex>
+}
+
+// CHECK-LABEL: func @tensor.from_elements_3d()
+
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
+// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
+// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
+// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
+// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
+// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
+// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
+// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
+// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
+// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
+// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
+// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
+
+// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32>
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+
+// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
+// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]]
+// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]]
+// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]]
+// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]]
+// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]]
+// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]]
+// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]]
+// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]]
+// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]]
+
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
+// CHECK: return %[[RET]] : tensor<3x2x2xf32>
+func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %f1 = arith.constant 1.0 : f32
+  %f2 = arith.constant 2.0 : f32
+  %f3 = arith.constant 3.0 : f32
+  %f4 = arith.constant 4.0 : f32
+  %f5 = arith.constant 5.0 : f32
+  %f6 = arith.constant 6.0 : f32
+  %f7 = arith.constant 7.0 : f32
+  %f8 = arith.constant 8.0 : f32
+  %f9 = arith.constant 9.0 : f32
+  %f10 = arith.constant 10.0 : f32
+  %f11 = arith.constant 11.0 : f32
+  %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+         : tensor<3x2x2xf32>
+  return %0 : tensor<3x2x2xf32>
+}
+
 // CHECK-LABEL:   func @tensor.generate(
 // CHECK-SAME:                                       %[[ARG:.*]]: tensor<*xf32>,
 // CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ec9601e269939..5331e50790a6a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -135,6 +135,61 @@ func @extract_from_tensor.from_elements(%element : index) -> index {
 
 // -----
 
+// CHECK-LABEL: func @extract_from_tensor.from_elements_3d
+func @extract_from_tensor.from_elements_3d()
+    -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+  %f0 = arith.constant 0.0 : f32
+  %f1 = arith.constant 1.0 : f32
+  %f2 = arith.constant 2.0 : f32
+  %f3 = arith.constant 3.0 : f32
+  %f4 = arith.constant 4.0 : f32
+  %f5 = arith.constant 5.0 : f32
+  %f6 = arith.constant 6.0 : f32
+  %f7 = arith.constant 7.0 : f32
+  %f8 = arith.constant 8.0 : f32
+  %f9 = arith.constant 9.0 : f32
+  %f10 = arith.constant 10.0 : f32
+  %f11 = arith.constant 11.0 : f32
+
+  %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+         : tensor<3x2x2xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+
+  %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
+  %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
+  %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
+  %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
+  %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
+  %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
+  %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
+  %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
+  %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
+  %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
+  %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
+  %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
+  return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
+         : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
+}
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
+// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
+// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
+// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
+// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
+// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
+// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
+// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
+// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
+// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
+// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
+// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
+
+// CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]],
+// CHECK-SAME:   %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]]
+
+// -----
+
 // Ensure the optimization doesn't segfault from bad constants
 // CHECK-LABEL: func @extract_negative_from_tensor.from_elements
 func @extract_negative_from_tensor.from_elements(%element : index) -> index {

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 564526f16370f..ece2f54d84012 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -33,7 +33,7 @@ func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
 // -----
 
 func @tensor.from_elements_wrong_result_type() {
-  // expected-error at +2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}}
+  // expected-error at +2 {{'result' must be statically shaped tensor of any type values, but got 'tensor<*xi32>'}}
   %c0 = arith.constant 0 : i32
   %0 = tensor.from_elements %c0 : tensor<*xi32>
   return

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 8d50d15184218..d461dffeb6d5b 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -38,21 +38,26 @@ func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>, %arg3: tensor<*
 // CHECK-LABEL: func @tensor.from_elements() {
 func @tensor.from_elements() {
   %c0 = "arith.constant"() {value = 0: index} : () -> index
-  // CHECK: %0 = tensor.from_elements %c0 : tensor<1xindex>
+  // CHECK: tensor.from_elements %c0 : tensor<1xindex>
   %0 = tensor.from_elements %c0 : tensor<1xindex>
 
   %c1 = "arith.constant"() {value = 1: index} : () -> index
-  // CHECK: %1 = tensor.from_elements %c0, %c1 : tensor<2xindex>
+  // CHECK: tensor.from_elements %c0, %c1 : tensor<2xindex>
   %1 = tensor.from_elements %c0, %c1 : tensor<2xindex>
 
   %c0_f32 = "arith.constant"() {value = 0.0: f32} : () -> f32
   // CHECK: [[C0_F32:%.*]] = arith.constant
-  // CHECK: %2 = tensor.from_elements [[C0_F32]] : tensor<1xf32>
+  // CHECK: tensor.from_elements [[C0_F32]] : tensor<1xf32>
   %2 = tensor.from_elements %c0_f32 : tensor<1xf32>
 
   // CHECK: tensor.from_elements : tensor<0xindex>
   %3 = tensor.from_elements : tensor<0xindex>
 
+  // CHECK: tensor.from_elements %c0, %c1, %c0, %c1, %c0, %c1 : tensor<2x3xindex>
+  %4 = tensor.from_elements %c0, %c1, %c0, %c1, %c0, %c1 : tensor<2x3xindex>
+
+  // CHECK: tensor.from_elements %c0 : tensor<index>
+  %5 = tensor.from_elements %c0 : tensor<index>
   return
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a6b317db467cf..aae98baf1c26c 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -945,14 +945,14 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
   Location loc = getLoc();
   shapes.reserve(operands.size());
   for (Value operand : llvm::reverse(operands)) {
-    auto currShape = llvm::to_vector<4>(llvm::map_range(
-        llvm::seq<int64_t>(
-            0, operand.getType().cast<RankedTensorType>().getRank()),
-        [&](int64_t dim) -> Value {
+    auto rank = operand.getType().cast<RankedTensorType>().getRank();
+    auto currShape = llvm::to_vector<4>(
+        llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
         }));
     shapes.push_back(builder.create<tensor::FromElementsOp>(
-        getLoc(), builder.getIndexType(), currShape));
+        getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
+        currShape));
   }
   return success();
 }


        


More information about the Mlir-commits mailing list