[llvm-branch-commits] [mlir] 866cb26 - [mlir] Fix SubTensorInsertOp semantics

Nicolas Vasilache via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 20 12:22:48 PST 2021


Author: Nicolas Vasilache
Date: 2021-01-20T20:16:01Z
New Revision: 866cb26039043581d5ab8b30d5a999a7c273f361

URL: https://github.com/llvm/llvm-project/commit/866cb26039043581d5ab8b30d5a999a7c273f361
DIFF: https://github.com/llvm/llvm-project/commit/866cb26039043581d5ab8b30d5a999a7c273f361.diff

LOG: [mlir] Fix SubTensorInsertOp semantics

Like SubView, SubTensor/SubTensorInsertOp are allowed to have rank-reducing/expanding semantics. In the case of SubTensorInsertOp , the rank of offsets/sizes/strides should be the rank of the destination tensor.

Also, add a builder flavor for SubTensorOp to return a rank-reduced tensor.

Differential Revision: https://reviews.llvm.org/D95076

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/core-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 770e68f6da835..6dbb24a4358f8 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3244,6 +3244,17 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
     // Build a SubTensorOp with all dynamic entries.
     OpBuilderDAG<(ins "Value":$source, "ValueRange":$offsets,
       "ValueRange":$sizes, "ValueRange":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a SubTensorOp with mixed static and dynamic entries
+    // and custom result type.
+    OpBuilderDAG<(ins "RankedTensorType":$resultType, "Value":$source,
+      "ArrayRef<int64_t>":$staticOffsets, "ArrayRef<int64_t>":$staticSizes,
+      "ArrayRef<int64_t>":$staticStrides, "ValueRange":$offsets,
+      "ValueRange":$sizes, "ValueRange":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a SubTensorOp with all dynamic entries and custom result type.
+    OpBuilderDAG<(ins "RankedTensorType":$resultType, "Value":$source,
+      "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
   ];
 
@@ -3349,7 +3360,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
       return source().getType().cast<RankedTensorType>();
     }
 
-    /// The result of a subtensor is always a tensor.
+    /// The result of a subtensor_insert is always a tensor.
     RankedTensorType getType() {
       return getResult().getType().cast<RankedTensorType>();
     }
@@ -3357,7 +3368,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
     std::array<unsigned, 3> getArrayAttrRanks() {
-      unsigned rank = getSourceType().getRank();
+      unsigned rank = getType().getRank();
       return {rank, rank, rank};
     }
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1718ab14d5d12..428006e20d9f9 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2892,12 +2892,11 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
                             ArrayRef<NamedAttribute> attrs) {
   auto sourceMemRefType = source.getType().cast<MemRefType>();
   unsigned rank = sourceMemRefType.getRank();
-  SmallVector<int64_t, 4> staticOffsetsVector;
-  staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
-  SmallVector<int64_t, 4> staticSizesVector;
-  staticSizesVector.assign(rank, ShapedType::kDynamicSize);
-  SmallVector<int64_t, 4> staticStridesVector;
-  staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
+  SmallVector<int64_t, 4> staticOffsetsVector(
+      rank, ShapedType::kDynamicStrideOrOffset);
+  SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+  SmallVector<int64_t, 4> staticStridesVector(
+      rank, ShapedType::kDynamicStrideOrOffset);
   build(b, result, resultType, source, staticOffsetsVector, staticSizesVector,
         staticStridesVector, offsets, sizes, strides, attrs);
 }
@@ -3444,6 +3443,38 @@ void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
         staticStridesVector, offsets, sizes, strides, attrs);
 }
 
+/// Build a SubTensorOp as above but with custom result type.
+void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
+                              RankedTensorType resultType, Value source,
+                              ArrayRef<int64_t> staticOffsets,
+                              ArrayRef<int64_t> staticSizes,
+                              ArrayRef<int64_t> staticStrides,
+                              ValueRange offsets, ValueRange sizes,
+                              ValueRange strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  build(b, result, resultType, source, offsets, sizes, strides,
+        b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
+        b.getI64ArrayAttr(staticStrides));
+  result.addAttributes(attrs);
+}
+
+/// Build a SubTensorOp as above but with custom result type.
+void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,
+                              RankedTensorType resultType, Value source,
+                              ValueRange offsets, ValueRange sizes,
+                              ValueRange strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
+  unsigned rank = sourceRankedTensorType.getRank();
+  SmallVector<int64_t, 4> staticOffsetsVector(
+      rank, ShapedType::kDynamicStrideOrOffset);
+  SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+  SmallVector<int64_t, 4> staticStridesVector(
+      rank, ShapedType::kDynamicStrideOrOffset);
+  build(b, result, resultType, source, staticOffsetsVector, staticSizesVector,
+        staticStridesVector, offsets, sizes, strides, attrs);
+}
+
 /// Verifier for SubTensorOp.
 static LogicalResult verify(SubTensorOp op) {
   // Verify result type against inferred type.
@@ -3528,8 +3559,8 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
                                     ValueRange offsets, ValueRange sizes,
                                     ValueRange strides,
                                     ArrayRef<NamedAttribute> attrs) {
-  auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
-  unsigned rank = sourceRankedTensorType.getRank();
+  auto destRankedTensorType = dest.getType().cast<RankedTensorType>();
+  unsigned rank = destRankedTensorType.getRank();
   SmallVector<int64_t, 4> staticOffsetsVector(
       rank, ShapedType::kDynamicStrideOrOffset);
   SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index b5266fb2e580b..b8ccab7e9a77a 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -908,7 +908,11 @@ func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
 }
 
 // CHECK-LABEL: func @subtensor_insert({{.*}}) {
-func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %idx : index) {
+func @subtensor_insert(
+    %t: tensor<8x16x4xf32>, 
+    %t2: tensor<16x32x8xf32>, 
+    %t3: tensor<4x4xf32>, 
+    %idx : index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
 
@@ -922,5 +926,10 @@ func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %idx :
   %2 = subtensor_insert %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1]
     : tensor<8x16x4xf32> into tensor<16x32x8xf32>
 
+  // CHECK: subtensor_insert
+  // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32>
+  %3 = subtensor_insert %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1]
+    : tensor<4x4xf32> into tensor<8x16x4xf32>
+
   return
 }


        


More information about the llvm-branch-commits mailing list