[Mlir-commits] [mlir] e377520 - [mlir] Move tosa.concat lowering from TosaToLinalg to TosaToTensor

Maya Amrami llvmlistbot at llvm.org
Tue Mar 14 02:24:08 PDT 2023


Author: Maya Amrami
Date: 2023-03-14T11:24:01+02:00
New Revision: e377520a47e6ced171fe1a4a39b91297326a1817

URL: https://github.com/llvm/llvm-project/commit/e377520a47e6ced171fe1a4a39b91297326a1817
DIFF: https://github.com/llvm/llvm-project/commit/e377520a47e6ced171fe1a4a39b91297326a1817.diff

LOG: [mlir] Move tosa.concat lowering from TosaToLinalg to TosaToTensor

tosa.concat is lowered to tensor.insert_slice thus it should be in
TosaToTensor rather than in TosaToLinalg.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D145952

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 82cadf2a07daa..f6ca01949632a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1556,68 +1556,6 @@ class ReduceConverter : public OpRewritePattern<SrcOp> {
   }
 };
 
-struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
-  using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
-    auto resultType = op.getType().dyn_cast<RankedTensorType>();
-
-    Location loc = op.getLoc();
-    int axis = op.getAxis();
-    Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
-        loc, rewriter.getIndexAttr(axis));
-    int rank = resultType.getRank();
-    SmallVector<Value, 3> offsets, sizes, strides;
-    sizes.reserve(rank);
-    strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
-    offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
-    SmallVector<Value> dynDims;
-    for (int i = 0; i < rank; ++i) {
-      sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
-          loc, adaptor.getOperands()[0], i));
-      if (inputType.isDynamicDim(i)) {
-        dynDims.push_back(
-            rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
-      }
-    }
-
-    Value resultDimSize = sizes[axis];
-    for (auto arg : adaptor.getOperands().drop_front()) {
-      auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
-      resultDimSize =
-          rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
-    }
-    sizes[axis] = resultDimSize;
-
-    Value emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, resultType.getShape(), resultType.getElementType(), dynDims);
-
-    auto toOpFoldResult = [](Value v) -> OpFoldResult {
-      auto op = v.getDefiningOp<arith::ConstantIndexOp>();
-      if (!op)
-        return v;
-      return op.getValue();
-    };
-    Value result = emptyTensor;
-    for (auto arg : adaptor.getOperands()) {
-      sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
-      result = rewriter.createOrFold<tensor::InsertSliceOp>(
-          loc, arg, result,
-          llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
-          llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
-          llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
-      offsets[axis] =
-          rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
 public:
   using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
@@ -2110,7 +2048,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       ReduceConverter<tosa::ReduceSumOp>,
       ReduceConverter<tosa::ReduceProdOp>,
       ArgMaxConverter,
-      ConcatConverter,
       GatherConverter,
       RescaleConverter,
       ReverseConverter,

diff  --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index b276c773d55ba..1ef31df2defd7 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -349,11 +349,74 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
   }
 };
 
+struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
+  using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
+    auto resultType = op.getType().dyn_cast<RankedTensorType>();
+
+    Location loc = op.getLoc();
+    int axis = op.getAxis();
+    Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(axis));
+    int rank = resultType.getRank();
+    SmallVector<Value, 3> offsets, sizes, strides;
+    sizes.reserve(rank);
+    strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
+    offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+    SmallVector<Value> dynDims;
+    for (int i = 0; i < rank; ++i) {
+      sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
+          loc, adaptor.getOperands()[0], i));
+      if (inputType.isDynamicDim(i)) {
+        dynDims.push_back(
+            rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
+      }
+    }
+
+    Value resultDimSize = sizes[axis];
+    for (auto arg : adaptor.getOperands().drop_front()) {
+      auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
+      resultDimSize =
+          rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
+    }
+    sizes[axis] = resultDimSize;
+
+    Value emptyTensor = rewriter.create<tensor::EmptyOp>(
+        loc, resultType.getShape(), resultType.getElementType(), dynDims);
+
+    auto toOpFoldResult = [](Value v) -> OpFoldResult {
+      auto op = v.getDefiningOp<arith::ConstantIndexOp>();
+      if (!op)
+        return v;
+      return op.getValue();
+    };
+    Value result = emptyTensor;
+    for (auto arg : adaptor.getOperands()) {
+      sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
+      result = rewriter.createOrFold<tensor::InsertSliceOp>(
+          loc, arg, result,
+          llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
+          llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
+          llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
+      offsets[axis] =
+          rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
     RewritePatternSet *patterns) {
-  patterns->add<SliceConverter, PadConverter>(patterns->getContext());
+  patterns->add<SliceConverter, PadConverter, ConcatConverter>(
+      patterns->getContext());
   patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
                                           /*benefit=*/100);
   patterns->add<ReshapeConverterExpand>(patterns->getContext(),

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 138a30fa837d4..427fe6b2f16b2 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -823,79 +823,6 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
   return
 }
 
-// -----
-
-// CHECK-LABEL: @concat
-// CHECK-SAME: %[[ARG0:.+]]: tensor<5x1xf32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<6x1xf32>
-func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
-  // CHECK: [[AXIS:%.+]] = arith.constant 0
-  // CHECK: [[STRIDE:%.+]]   = arith.constant 1
-  // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
-  // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
-  // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
-  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
-  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
-  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
-  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>)  -> (tensor<11x1xf32>)
-
-  // CHECK: [[AXIS:%.+]] = arith.constant 1
-  // CHECK: [[STRIDE:%.+]]   = arith.constant 1
-  // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
-  // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
-  // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
-  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
-  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
-  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1]
-  %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @concat_non_axis_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
-func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () {
-  // CHECK: %[[AXIS:.+]] = arith.constant 0
-  // CHECK: %[[STRIDE:.+]]   = arith.constant 1
-  // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
-  // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
-  // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
-  // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX1]]
-  // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
-  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]]
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32>
-  // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1]
-  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
-  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>)  -> (tensor<11x?xf32>)
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @concat_axis_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
-func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
-  // CHECK: %[[AXIS:.+]] = arith.constant 0
-  // CHECK: %[[STRIDE:.+]]   = arith.constant 1
-  // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
-  // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
-  // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX0]]
-  // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index
-  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]]
-  // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<?x3xf32>
-  // CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]]
-  // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1]
-  // CHECK: %[[SUM:.+]]  = arith.addi %[[OFFSET]], %[[DYN1]]
-  // CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]]
-  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
-  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>)  -> (tensor<?x3xf32>)
-  return
-}
-
 // -----
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 

diff  --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 34084ebf5d3ce..d96c7848ed5b1 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -195,3 +195,76 @@ func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
   %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
   return %1 : tensor<?x9xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @concat
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x1xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<6x1xf32>
+func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
+  // CHECK: [[AXIS:%.+]] = arith.constant 0
+  // CHECK: [[STRIDE:%.+]]   = arith.constant 1
+  // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
+  // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
+  // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
+  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
+  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
+  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>)  -> (tensor<11x1xf32>)
+
+  // CHECK: [[AXIS:%.+]] = arith.constant 1
+  // CHECK: [[STRIDE:%.+]]   = arith.constant 1
+  // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
+  // CHECK: [[IDX0:%.+]] = arith.constant 0 : index
+  // CHECK: [[IDX1:%.+]] = arith.constant 1 : index
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
+  // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
+  // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1]
+  %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @concat_non_axis_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
+func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () {
+  // CHECK: %[[AXIS:.+]] = arith.constant 0
+  // CHECK: %[[STRIDE:.+]]   = arith.constant 1
+  // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
+  // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
+  // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
+  // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX1]]
+  // CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
+  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]]
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32>
+  // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1]
+  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
+  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>)  -> (tensor<11x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @concat_axis_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
+  // CHECK: %[[AXIS:.+]] = arith.constant 0
+  // CHECK: %[[STRIDE:.+]]   = arith.constant 1
+  // CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
+  // CHECK: %[[IDX0:.+]] = arith.constant 0 : index
+  // CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX0]]
+  // CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index
+  // CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]]
+  // CHECK: %[[IDX1:.+]] = arith.constant 1 : index
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<?x3xf32>
+  // CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]]
+  // CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1]
+  // CHECK: %[[SUM:.+]]  = arith.addi %[[OFFSET]], %[[DYN1]]
+  // CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]]
+  // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
+  %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>)  -> (tensor<?x3xf32>)
+  return
+}


        


More information about the Mlir-commits mailing list