[Mlir-commits] [mlir] 86c4972 - [mlir][tosa] Improve lowering support for tosa.concat

Rob Suderman llvmlistbot at llvm.org
Thu Jun 15 11:40:55 PDT 2023


Author: Spenser Bauman
Date: 2023-06-15T11:39:37-07:00
New Revision: 86c4972f5f6759b0ba85c568d934f9b8db8875a2

URL: https://github.com/llvm/llvm-project/commit/86c4972f5f6759b0ba85c568d934f9b8db8875a2
DIFF: https://github.com/llvm/llvm-project/commit/86c4972f5f6759b0ba85c568d934f9b8db8875a2.diff

LOG: [mlir][tosa] Improve lowering support for tosa.concat

The existing lowering for tosa.concat fails in some instances when the
output shape contains more information the input shapes. The result is
an illegal tensor.empty operation.

This change bases the output shape on the original tosa.concat
operation, while querying the input tensor shapes to build the slicing
operations.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D151707

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
    mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5e6e971f1f1ce..0b2b006bfc365 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -12,7 +12,9 @@
 
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -355,56 +357,56 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
   LogicalResult
   matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto inputType = cast<ShapedType>(op.getOperand(0).getType());
     auto resultType = dyn_cast<RankedTensorType>(op.getType());
 
     Location loc = op.getLoc();
     int axis = op.getAxis();
     Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
         loc, rewriter.getIndexAttr(axis));
-    int rank = resultType.getRank();
-    SmallVector<Value, 3> offsets, sizes, strides;
-    sizes.reserve(rank);
-    strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
-    offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+    int64_t rank = resultType.getRank();
 
-    SmallVector<Value> dynDims;
-    for (int i = 0; i < rank; ++i) {
-      sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
-          loc, adaptor.getOperands()[0], i));
-      if (inputType.isDynamicDim(i)) {
-        dynDims.push_back(
-            rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
-      }
-    }
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> sizes = tensor::createDimValues(
+        rewriter, op.getLoc(), adaptor.getOperands()[0]);
+
+    // Pre-compute the offsets along the axis dimension.
+    // The axisOffsets will be of size rank + 1, where the last value
+    // will hold the total size of the tensor along the 'axis' dimension.
+    SmallVector<OpFoldResult> axisOffsets;
+    axisOffsets.push_back(rewriter.getIndexAttr(0));
+    axisOffsets.push_back(sizes[axis]);
 
-    Value resultDimSize = sizes[axis];
     for (auto arg : adaptor.getOperands().drop_front()) {
       auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
-      resultDimSize =
-          rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
+      auto currentOffset =
+          getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
+      auto total =
+          rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
+      axisOffsets.push_back(getAsOpFoldResult(total));
+    }
+    sizes[axis] = axisOffsets.back();
+
+    // Compute the dynamic sizes of the tensor.empty operation.
+    // This is based off of the specified result type of the tosa.concat
+    // operation, since we don't want to change the result type of the operation
+    // during the conversion.
+    SmallVector<Value> dynDims;
+    for (int64_t i = 0; i < rank; ++i) {
+      if (resultType.isDynamicDim(i)) {
+        dynDims.push_back(
+            getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
+      }
     }
-    sizes[axis] = resultDimSize;
 
-    Value emptyTensor = rewriter.create<tensor::EmptyOp>(
+    Value result = rewriter.create<tensor::EmptyOp>(
         loc, resultType.getShape(), resultType.getElementType(), dynDims);
 
-    auto toOpFoldResult = [](Value v) -> OpFoldResult {
-      auto op = v.getDefiningOp<arith::ConstantIndexOp>();
-      if (!op)
-        return v;
-      return op.getValue();
-    };
-    Value result = emptyTensor;
-    for (auto arg : adaptor.getOperands()) {
-      sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
+    for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
+      auto sizes = tensor::createDimValues(rewriter, op.getLoc(), arg);
+      offsets[axis] = offset;
       result = rewriter.createOrFold<tensor::InsertSliceOp>(
-          loc, arg, result,
-          llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
-          llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
-          llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
-      offsets[axis] =
-          rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
+          loc, arg, result, offsets, sizes, strides);
     }
     rewriter.replaceOp(op, result);
     return success();

diff  --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index d96c7848ed5b1..363648afb5180 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -202,23 +202,13 @@ func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
 // CHECK-SAME: %[[ARG0:.+]]: tensor<5x1xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<6x1xf32>
 func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
-  // CHECK: [[AXIS:%.+]] = arith.constant 0
-  // CHECK: [[STRIDE:%.+]]   = arith.constant 1
-  // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
-  // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
-  // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
-  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
-  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
-  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
+  // CHECK-DAG: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
+  // CHECK-DAG: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
+  // CHECK-DAG: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
   %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>)  -> (tensor<11x1xf32>)
 
-  // CHECK: [[AXIS:%.+]] = arith.constant 1
-  // CHECK: [[STRIDE:%.+]]   = arith.constant 1
-  // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
-  // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
-  // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
-  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
-  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
+  // CHECK-DAG: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
+  // CHECK-DAG: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
   // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1]
   %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
   return
@@ -230,17 +220,16 @@ func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 // CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
 func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () {
-  // CHECK: %[[AXIS:.+]] = arith.constant 0
-  // CHECK: %[[STRIDE:.+]]   = arith.constant 1
-  // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
-  // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
-  // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
-  // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX1]]
-  // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
-  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]]
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32>
-  // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1]
-  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
+  // CHECK-DAG: %[[AXIS:.+]] = arith.constant 0
+  // CHECK-DAG: %[[IDX1:.+]] = arith.constant 1
+  // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[IDX1]]
+  // CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<11x?xf32>
+  // CHECK-DAG: %[[IDX1_1:.+]] = arith.constant 1 : index
+  // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[IDX1_1]]
+  // CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[DIM1]]] [1, 1]
+  // CHECK-DAG: %[[IDX1_2:.+]] = arith.constant 1 : index
+  // CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG1]], %[[IDX1_2]] : tensor<6x?xf32>
+  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[DIM2]]] [1, 1]
   %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>)  -> (tensor<11x?xf32>)
   return
 }
@@ -251,20 +240,76 @@ func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 // CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
 func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
-  // CHECK: %[[AXIS:.+]] = arith.constant 0
-  // CHECK: %[[STRIDE:.+]]   = arith.constant 1
-  // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
-  // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
-  // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX0]]
-  // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index
-  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]]
-  // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<?x3xf32>
-  // CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]]
-  // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1]
-  // CHECK: %[[SUM:.+]]  = arith.addi %[[OFFSET]], %[[DYN1]]
-  // CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]]
-  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
+  // CHECK-DAG: %[[AXIS:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[IDX0:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[IDX0]] : tensor<?x3xf32>
+  // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[AXIS]] : tensor<?x3xf32>
+  // CHECK-DAG: %[[SUM:.+]] = arith.addi %[[DIM0]], %[[DIM1]] : index
+  // CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[SUM]]) : tensor<?x3xf32>
+  // CHECK-DAG: %[[IDX0_1:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[IDX0_1]] : tensor<?x3xf32>
+  // CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM2]], 3] [1, 1] : tensor<?x3xf32> into tensor<?x3xf32>
+  // CHECK-DAG: %[[IDX0_2:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[IDX0_2]] : tensor<?x3xf32>
+  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[DIM0]], 0] [%[[DIM3]], 3] [1, 1] : tensor<?x3xf32> into tensor<?x3xf32>
+
   %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>)  -> (tensor<?x3xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: @concat_axis_dyn_mixed
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
+// CHECK-SAME:  %[[ARG2:[0-9a-zA-Z_]*]]:
+func.func @concat_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> () {
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C0_0:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[OFFSET0:.+]] = tensor.dim %[[ARG0]], %[[C0_0]] : tensor<?x1xf32>
+  // CHECK-DAG: %[[DIM1_0:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x1xf32>
+  // CHECK-DAG: %[[OFFSET1:.+]] = arith.addi %[[OFFSET0]], %[[DIM1_0]] : index
+  // CHECK-DAG: %[[DIM2_2:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x1xf32>
+  // CHECK-DAG: %[[OFFSET2:.+]] = arith.addi %[[OFFSET1]], %[[DIM2_2]] : index
+  // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x1xf32>
+  // CHECK-DAG: %[[C0_3:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0_3]] : tensor<?x1xf32>
+  // CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM_4]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x1xf32>
+  // CHECK-DAG: %[[C0_4:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM_6:.+]] = tensor.dim %[[ARG1]], %[[C0_4]] : tensor<?x1xf32>
+  // CHECK-DAG: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[OFFSET0]], 0] [%[[DIM_6]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x1xf32>
+  // CHECK-DAG: %[[C0_8:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM_9:.+]] = tensor.dim %[[ARG2]], %[[C0_8]] : tensor<?x1xf32>
+  // CHECK-DAG: %[[INSERT3:.+]] = tensor.insert_slice %[[ARG2]] into %[[INSERT1]][%[[OFFSET1]], 0] [%[[DIM_9]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x1xf32>
+
+  // CHECK: return
+
+  %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 0 : i64}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x1xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @concat_non_axis_dyn_mixed
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
+// CHECK-SAME:  %[[ARG2:[0-9a-zA-Z_]*]]:
+func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> () {
+  // CHECK-DAG: %[[UNUSED0:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[UNUSED1:.+]] = tensor.dim %[[ARG0]], %[[UNUSED0]] : tensor<?x1xf32>
+
+  // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x3xf32>
+  // CHECK-DAG: %[[C0_0:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM0_0:.+]] = tensor.dim %[[ARG0]], %[[C0_0]] : tensor<?x1xf32>
+  // CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM0_0]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
+  // CHECK-DAG: %[[C0_1:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM1_0:.+]] = tensor.dim %[[ARG1]], %[[C0_1]] : tensor<?x1xf32>
+  // CHECK-DAG: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][0, 1] [%[[DIM1_0]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
+  // CHECK-DAG: %[[C0_2:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[DIM2_0:.+]] = tensor.dim %[[ARG2]], %[[C0_2]] : tensor<?x1xf32>
+  // CHECK-DAG: %[[INSERT2:.+]] = tensor.insert_slice %[[ARG2]] into %[[INSERT1]][0, 2] [%[[DIM2_0]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
+  // CHECK: return
+
+  %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 1 : i64}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x3xf32>
+  return
+}


        


More information about the Mlir-commits mailing list