[clang-tools-extra] [clang] [mlir] [llvm] Generalize depthwise conv (PR #75017)
via cfe-commits
cfe-commits at lists.llvm.org
Sun Dec 10 18:49:44 PST 2023
https://github.com/srcarroll created https://github.com/llvm/llvm-project/pull/75017
None
>From f6ca8308f15b1200380b04a6a79cd893ef89a709 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 7 Dec 2023 21:57:20 -0600
Subject: [PATCH 01/21] Rough prototype of depthwise conv with switchable
channel dim
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 64 ++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 113 ++++++++++++++++++
2 files changed, 177 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..716d67fae15886 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -384,6 +384,70 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// DepthwiseConv1DOp op.
+//===----------------------------------------------------------------------===//
+
+def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
+ [AttrSizedOperandSegments, LinalgConvolutionOpInterface]> {
+
+ let summary = [{
+ Performs 1-D depthwise convolution with switchable channel position; either first or last.
+ }];
+ let description = [{
+
+ Channel position is determined by the `BoolAttr` `channel_first`. If true,
+ layout is
+
+ Input: `NCW`
+ Kernel: `CW`
+
+ otherwise
+
+ Input: `NWC`
+ Kernel: `WC`
+
+
+ }];
+
+ let arguments = (ins
+ Variadic<TensorOrMemref>:$inputs,
+ Variadic<TensorOrMemref>:$inits,
+ DefaultValuedOptionalAttr<BoolAttr, "true">:$channel_first
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let builders = [
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
+ "bool":$channel_first,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>",
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Declare functions necessary for LinalgStructuredInterface.
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ ArrayAttr getIndexingMaps();
+ std::string getLibraryCallName() {
+ return "op_has_no_registered_library_name";
+ }
+
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ // Implement functions necessary for DestinationStyleOpInterface.
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+ static unsigned getNumRegionArgs() { return 3; }
+ MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+ }];
+}
//===----------------------------------------------------------------------===//
// Transpose op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 58af9995548e93..99ef37c6dc1221 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1672,6 +1672,119 @@ LogicalResult ReduceOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// DepthwiseConv1DOp
+//===----------------------------------------------------------------------===//
+
+void DepthwiseConv1DOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+void DepthwiseConv1DOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange inits, bool channel_first,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+ ArrayRef<NamedAttribute> attributes) {
+ build(builder, result, TypeRange{}, inputs, inits, channel_first);
+ result.addAttribute(getChannelFirstAttrName(result.name),
+ builder.getBoolAttr(channel_first));
+ result.addAttributes(attributes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ for (Value init : inits) {
+ Type initType = init.getType();
+ if (llvm::isa<RankedTensorType>(initType))
+ result.addTypes(initType);
+ }
+
+ if (bodyBuild)
+ buildGenericRegion(builder, result.location, *result.regions.front(),
+ inputs, inits, bodyBuild);
+}
+
+SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction};
+ ;
+}
+
+ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
+ ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(LinalgDialect::kMemoizedIndexingMapsAttrName);
+ if (cached)
+ return cached;
+
+ MLIRContext *context = getContext();
+ SmallVector<AffineExpr> symbolBindings(
+ {getAffineSymbolExpr(0, context), getAffineSymbolExpr(1, context),
+ getAffineSymbolExpr(2, context), getAffineConstantExpr(1, context),
+ getAffineSymbolExpr(4, context), getAffineConstantExpr(1, context)});
+ // Don't actually do something stupid like this
+ SmallVector<StringRef> rawStrings =
+ (getChannelFirst())
+ ? SmallVector<StringRef>({
+ "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+ "-> (d0, d2, d1 * s3 + d3 * s5)>",
+ "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+ "-> (d2, d3)>",
+ "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+ "-> (d0, d2, d1)>",
+ })
+ : SmallVector<StringRef>({
+ "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+ "-> (d0, d1 * s3 + d3 * s5, d2)>",
+ "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+ "-> (d3, d2)>",
+ "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
+ "-> (d0, d1, d2)>",
+ });
+ SmallVector<AffineMap> maps(llvm::map_range(rawStrings, [&](StringRef &m) {
+ return simplifyAffineMap(
+ llvm::cast<AffineMapAttr>(parseAttribute(m, context))
+ .getValue()
+ .replaceDimsAndSymbols({}, symbolBindings, 4, 0));
+ }));
+
+ cached = Builder(context).getAffineMapArrayAttr(maps);
+ getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+ return cached;
+}
+
+void DepthwiseConv1DOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getDpsInits());
+}
+
+ParseResult DepthwiseConv1DOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+ getRegionBuilder());
+}
+
+void DepthwiseConv1DOp::print(OpAsmPrinter &p) {
+ printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
+}
+
+LogicalResult DepthwiseConv1DOp::verify() { return success(); }
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
>From d9097977f7db4c4b484d4e9cd922f7e26deec9e0 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 7 Dec 2023 23:10:51 -0600
Subject: [PATCH 02/21] Support strides and dilations
---
.../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 6 +++++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 9 ++++++---
2 files changed, 11 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 716d67fae15886..d8e5b0fcf627a9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -413,7 +413,11 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$inits,
- DefaultValuedOptionalAttr<BoolAttr, "true">:$channel_first
+ DefaultValuedOptionalAttr<BoolAttr, "true">:$channel_first,
+ DefaultValuedOptionalAttr<RankedI64ElementsAttr<[1]>,
+ "{ static_cast<int64_t>(1) }">:$strides,
+ DefaultValuedOptionalAttr<RankedI64ElementsAttr<[1]>,
+ "{ static_cast<int64_t>(1) }">:$dilations
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 99ef37c6dc1221..c56562dc6af778 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1726,15 +1726,18 @@ SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
}
ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
- ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(LinalgDialect::kMemoizedIndexingMapsAttrName);
+ ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
+ LinalgDialect::kMemoizedIndexingMapsAttrName);
if (cached)
return cached;
MLIRContext *context = getContext();
SmallVector<AffineExpr> symbolBindings(
{getAffineSymbolExpr(0, context), getAffineSymbolExpr(1, context),
- getAffineSymbolExpr(2, context), getAffineConstantExpr(1, context),
- getAffineSymbolExpr(4, context), getAffineConstantExpr(1, context)});
+ getAffineSymbolExpr(2, context),
+ getAffineConstantExpr(getStrides().getValues<int64_t>()[0], context),
+ getAffineSymbolExpr(4, context),
+ getAffineConstantExpr(getDilations().getValues<int64_t>()[0], context)});
// Don't actually do something stupid like this
SmallVector<StringRef> rawStrings =
(getChannelFirst())
>From 6d3a22a24c33f50bf774eea8d1c39cbeb0168c4d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 01:06:17 -0600
Subject: [PATCH 03/21] Improve map creation
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 19 +++---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 58 ++++++++-----------
2 files changed, 35 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d8e5b0fcf627a9..304c621621c4b2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -395,18 +395,19 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
Performs 1-D depthwise convolution with switchable channel position; either first or last.
}];
let description = [{
+ Domain: N, OW, C, KW
- Channel position is determined by the `BoolAttr` `channel_first`. If true,
- layout is
+ Layout of operands is determined by the `channel_first` `BoolAttr`:
- Input: `NCW`
- Kernel: `CW`
-
- otherwise
-
- Input: `NWC`
- Kernel: `WC`
+ `channel_first == true`:
+ Input: `NCW`
+ Kernel: `CW`
+ Output: `NCW`
+ `channel_first == false`:
+ Input: `NWC`
+ Kernel: `WC`
+ Output: `NWC`
}];
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c56562dc6af778..bf1d8d565e64d5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1732,39 +1732,31 @@ ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
return cached;
MLIRContext *context = getContext();
- SmallVector<AffineExpr> symbolBindings(
- {getAffineSymbolExpr(0, context), getAffineSymbolExpr(1, context),
- getAffineSymbolExpr(2, context),
- getAffineConstantExpr(getStrides().getValues<int64_t>()[0], context),
- getAffineSymbolExpr(4, context),
- getAffineConstantExpr(getDilations().getValues<int64_t>()[0], context)});
- // Don't actually do something stupid like this
- SmallVector<StringRef> rawStrings =
- (getChannelFirst())
- ? SmallVector<StringRef>({
- "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
- "-> (d0, d2, d1 * s3 + d3 * s5)>",
- "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
- "-> (d2, d3)>",
- "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
- "-> (d0, d2, d1)>",
- })
- : SmallVector<StringRef>({
- "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
- "-> (d0, d1 * s3 + d3 * s5, d2)>",
- "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
- "-> (d3, d2)>",
- "affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] "
- "-> (d0, d1, d2)>",
- });
- SmallVector<AffineMap> maps(llvm::map_range(rawStrings, [&](StringRef &m) {
- return simplifyAffineMap(
- llvm::cast<AffineMapAttr>(parseAttribute(m, context))
- .getValue()
- .replaceDimsAndSymbols({}, symbolBindings, 4, 0));
- }));
-
- cached = Builder(context).getAffineMapArrayAttr(maps);
+
+ // Domain: (n, w, c, kw)
+ AffineExpr n = getAffineDimExpr(0, context);
+ AffineExpr w = getAffineDimExpr(1, context);
+ AffineExpr c = getAffineDimExpr(2, context);
+ AffineExpr kw = getAffineDimExpr(3, context);
+
+ // Temp subsitute for channel position attr
+ int64_t channelPos = (getChannelFirst()) ? 1 : 2;
+ // Initialze operand accesses in nw order and insert c according to channel
+ // position
+ SmallVector<AffineExpr> inExprs(
+ {n, w * getStrides().getValues<int64_t>()[0] +
+ kw * getDilations().getValues<int64_t>()[0]});
+ SmallVector<AffineExpr> kExprs({kw});
+ SmallVector<AffineExpr> outExprs({n, w});
+ inExprs.insert(inExprs.begin() + channelPos, c);
+ kExprs.insert(
+ channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
+ outExprs.insert(outExprs.begin() + channelPos, c);
+
+ cached = Builder(context).getAffineMapArrayAttr(
+ {AffineMap::get(4, 0, inExprs, context),
+ AffineMap::get(4, 0, kExprs, context),
+ AffineMap::get(4, 0, outExprs, context)});
getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
return cached;
}
>From 13abaefd0dda4c616fdc97225801ed1b5389f6a0 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 01:22:58 -0600
Subject: [PATCH 04/21] Add a couple verification regression tests
---
mlir/test/Dialect/Linalg/named-ops.mlir | 36 +++++++++++++++++++++++++
1 file changed, 36 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 5ca35155854d33..f281c138727640 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,5 +1,41 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+// CHECK-LABEL: func @gen_depthwise_channel_first_memref
+func.func @gen_depthwise_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x3xf32>, %arg2: memref<64x16x8xf32>) {
+ // CHECK: depthwise_conv_1d {{.*}}channel_first = true
+ linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x3xf32>) outs(%arg2: memref<64x16x8xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @gen_depthwise_channel_last_memref
+func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
+ // CHECK: depthwise_conv_1d {{.*}}channel_first = false
+ linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16xf32>) outs(%arg2: memref<64x8x16xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @gen_depthwise_channel_first_tensor
+func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1: tensor<16x3xf32>, %arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> {
+ // CHECK: depthwise_conv_1d {{.*}}channel_first = true
+ %0 = linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x3xf32>) outs(%arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32>
+ return %0 : tensor<64x16x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @gen_depthwise_channel_last_tensor
+func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1: tensor<3x16xf32>, %arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32> {
+ // CHECK: depthwise_conv_1d {{.*}}channel_first = false
+ %0 = linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16xf32>) outs(%arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32>
+ return %0 : tensor<64x8x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm
func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
%zero = arith.constant 0.000000e+00 : f32
>From 840fadb6d12aabf7071488e3d3a3d8cd22e88739 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 02:03:38 -0600
Subject: [PATCH 05/21] Add convert-linalg-to-loops regression tests
---
mlir/test/Dialect/Linalg/loops.mlir | 69 ++++++++++++++++++++++++++---
1 file changed, 64 insertions(+), 5 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 8c13422fd63833..4e9bd0d8118a6c 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -1,12 +1,11 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s
-// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s
+// RUN: mlir-opt %s -convert-linalg-to-loops | FileCheck %s --check-prefixes=COMMON,CHECK
+// RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefixes=COMMON,CHECKPARALLEL %s
// Test that we can lower all the way to LLVM without crashing, don't check results here.
// RUN: mlir-opt %s -convert-linalg-to-loops -test-lower-to-llvm -o=/dev/null 2>&1
-// CHECK: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-
-// CHECKPARALLEL: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// COMMON: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// COMMON: #[[$stride3Dilation2:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 2)>
func.func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
%c0 = arith.constant 0 : index
@@ -843,6 +842,66 @@ func.func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32
// CHECKPARALLEL: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
+
+func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// COMMON-LABEL: func @gen_depthwise_channel_first_memref
+// COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// COMMON: %[[c0:.*]] = arith.constant 0 : index
+// COMMON: %[[c1:.*]] = arith.constant 1 : index
+// COMMON: %[[c2:.*]] = arith.constant 2 : index
+// COMMON: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
+// COMMON: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?xf32>
+// COMMON: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?xf32>
+// COMMON: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+// CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+// CHECK: scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim3]], %[[dim1]])
+// COMMON: scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+// COMMON: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[ow]], %[[kw]])
+// COMMON: %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[c]], %[[aff]]] : memref<?x?x?xf32>
+// COMMON: %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[kw]]] : memref<?x?xf32>
+// COMMON: %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[ow]]] : memref<?x?x?xf32>
+// COMMON: %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
+// COMMON: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[ow]]] : memref<?x?x?xf32>
+
+
+func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// COMMON-LABEL: func @gen_depthwise_channel_last_memref
+// COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
+// COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
+// COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
+// COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
+// COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
+// COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+// CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+// CHECK: scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim3]], %[[dim1]])
+// COMMON: scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+// COMMON: %[[aff:.*]] = affine.apply #[[$stride3Dilation2]](%[[ow]], %[[kw]])
+// COMMON: %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[aff]], %[[c]]] : memref<?x?x?xf32>
+// COMMON: %[[va:.*]] = memref.load %[[arg1]][%[[kw]], %[[c]]] : memref<?x?xf32>
+// COMMON: %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
+// COMMON: %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
+// COMMON: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
+
// -----
func.func @lower_to_loops_with_rank_reducing_subviews(
>From 488c1702615fed9a932e3d6f04eee8f286e07cd5 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 15:14:21 -0600
Subject: [PATCH 06/21] define depthwise conv interface (works but messy)
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 6 +
.../Dialect/Linalg/IR/LinalgInterfaces.td | 48 ++++++-
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 11 +-
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 60 +++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 119 ++++++++++++------
5 files changed, 203 insertions(+), 41 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6c8240267e7d05..c3eb3ff665c20d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;
+class DepthwiseConvolutionOpInterface;
namespace detail {
/// Implementation of the method that check if given operands
@@ -115,6 +116,11 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
namespace detail {
+// DictionaryAttr getStridesDict(DepthwiseConvolutionOpInterface op);
+DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
+DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
+BoolAttr getChannelFirstAttr(DepthwiseConvolutionOpInterface op);
+ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
/// Returns true if the block contains a contraction of the following form:
///
/// %0 = <elemwise>(permutation-of(cu(block-argument-0),
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9b..e151d8c634dd87 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -178,6 +178,8 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
];
}
+
+
def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
let description = [{
A fill operation is defined in general terms:
@@ -871,7 +873,51 @@ def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
return {};
}]
>
- ];
+ ];
}
+def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpInterface", [LinalgConvolutionOpInterface]> {
+ let description = [{
+ A depthwise convolution is defined in general terms:
+ 1. it is a convolution as defined by `ConvolutionOpInterface`
+ 1. `in_channels = K * out_channels` for some integer `K`
+ 4. The indexing maps of the input have expressions that satisfy
+ ```
+ AffineExpr ::== AffineDimExpr | ConvolvedExpr
+ ConvolvedExpr ::== MulExpr (`+` MulExpr)+
+ MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
+ ```
+ }];
+ let cppNamespace = "::mlir::linalg";
+ let verify = [{ return detail::verifyConvolutionInterface($_op); }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns strides attribute.
+ }],
+ "::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ printf("whatever\n");
+ return detail::getStridesAttr($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns dilations attribute.
+ }],
+ "::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ printf("whatever\n");
+ return detail::getDilationsAttr($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns channel dim attribute.
+ }],
+ "::mlir::BoolAttr", "getChannelFirstAttr", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::getChannelFirstAttr($_op);
+ }]>
+ ];
+ }
+
#endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 304c621621c4b2..37d668bc43f1b2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -389,7 +389,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
//===----------------------------------------------------------------------===//
def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
- [AttrSizedOperandSegments, LinalgConvolutionOpInterface]> {
+ [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
let summary = [{
Performs 1-D depthwise convolution with switchable channel position; either first or last.
@@ -414,10 +414,10 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$inits,
- DefaultValuedOptionalAttr<BoolAttr, "true">:$channel_first,
- DefaultValuedOptionalAttr<RankedI64ElementsAttr<[1]>,
+ DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+ DefaultValuedAttr<RankedI64ElementsAttr<[1]>,
"{ static_cast<int64_t>(1) }">:$strides,
- DefaultValuedOptionalAttr<RankedI64ElementsAttr<[1]>,
+ DefaultValuedAttr<RankedI64ElementsAttr<[1]>,
"{ static_cast<int64_t>(1) }">:$dilations
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -435,7 +435,8 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
- ArrayAttr getIndexingMaps();
+ // Figure out how to get implementation in `LinalgDepthwiseConvolutionOpInterface` working
+ ArrayAttr getIndexingMaps() { return detail::getIndexingMaps(*this); };
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d32f22a3e..3c8ae952f49bdb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -638,6 +638,66 @@ enum class MatchConvolutionResult {
};
} // namespace mlir::linalg::detail
+DenseIntElementsAttr mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
+ auto maybeStridesAttr = op.getStridesAttr();
+ maybeStridesAttr.dump();
+ if (!maybeStridesAttr) {
+ OpBuilder builder(op.getContext());
+ int64_t numSpatialDims = op.image().getType().cast<ShapedType>().getRank() - 2;
+ auto type = RankedTensorType::get({static_cast<int64_t>(numSpatialDims)},
+ builder.getI64Type());
+ SmallVector<int64_t> strides(numSpatialDims, 1);
+ return DenseIntElementsAttr::get(type, strides);
+ }
+ return op.getStridesAttr();
+}
+
+DenseIntElementsAttr mlir::linalg::detail::getDilationsAttr(DepthwiseConvolutionOpInterface op) {
+ return op.getDilationsAttr();
+}
+
+BoolAttr mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
+ return op.getChannelFirstAttr();
+}
+
+ArrayAttr mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
+ ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
+ LinalgDialect::kMemoizedIndexingMapsAttrName);
+ if (cached)
+ return cached;
+
+ MLIRContext *ctx = op.getContext();
+ auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
+ // Domain: (n, w, c, kw)
+ AffineExpr n = getAffineDimExpr(0, ctx);
+ SmallVector<AffineExpr> s(llvm::map_range(llvm::seq<int64_t>(1, numSpatial + 1), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+ AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
+ SmallVector<AffineExpr> ks(llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+ // Temp subsitute for channel position attr
+ int64_t channelPos = (op.getChannelFirstAttr().getValue()) ? 1 : numSpatial + 1;
+
+ // Initialze operand accesses in nw order and insert c according to channel
+ // position
+ SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
+ for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(), op.getDilationsAttr().getValues<int64_t>())) {
+ inExprs.push_back(sp * st + ksp * di);
+ outExprs.push_back(sp);
+ }
+ SmallVector<AffineExpr> kExprs(ks);
+ inExprs.insert(inExprs.begin() + channelPos, c);
+ kExprs.insert(
+ channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
+ outExprs.insert(outExprs.begin() + channelPos, c);
+
+
+ cached = Builder(ctx).getAffineMapArrayAttr(
+ {AffineMap::get(4, 0, inExprs, ctx),
+ AffineMap::get(4, 0, kExprs, ctx),
+ AffineMap::get(4, 0, outExprs, ctx)});
+ op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+ return cached;
+}
+
mlir::linalg::detail::MatchConvolutionResult
mlir::linalg::detail::isConvolutionInterfaceImpl(
Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bf1d8d565e64d5..6e904aa03c1eaf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1725,41 +1725,90 @@ SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
;
}
-ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
- ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
- LinalgDialect::kMemoizedIndexingMapsAttrName);
- if (cached)
- return cached;
-
- MLIRContext *context = getContext();
-
- // Domain: (n, w, c, kw)
- AffineExpr n = getAffineDimExpr(0, context);
- AffineExpr w = getAffineDimExpr(1, context);
- AffineExpr c = getAffineDimExpr(2, context);
- AffineExpr kw = getAffineDimExpr(3, context);
-
- // Temp subsitute for channel position attr
- int64_t channelPos = (getChannelFirst()) ? 1 : 2;
- // Initialze operand accesses in nw order and insert c according to channel
- // position
- SmallVector<AffineExpr> inExprs(
- {n, w * getStrides().getValues<int64_t>()[0] +
- kw * getDilations().getValues<int64_t>()[0]});
- SmallVector<AffineExpr> kExprs({kw});
- SmallVector<AffineExpr> outExprs({n, w});
- inExprs.insert(inExprs.begin() + channelPos, c);
- kExprs.insert(
- channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
- outExprs.insert(outExprs.begin() + channelPos, c);
-
- cached = Builder(context).getAffineMapArrayAttr(
- {AffineMap::get(4, 0, inExprs, context),
- AffineMap::get(4, 0, kExprs, context),
- AffineMap::get(4, 0, outExprs, context)});
- getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
- return cached;
-}
+// ArrayAttr getNDIndexngMaps(DepthwiseConvolutionOpInterface op) {
+// ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
+// LinalgDialect::kMemoizedIndexingMapsAttrName);
+// if (cached)
+// return cached;
+
+// MLIRContext *ctx = op.getContext();
+// auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
+// // Domain: (n, w, c, kw)
+// AffineExpr n = getAffineDimExpr(0, ctx);
+// SmallVector<AffineExpr> s(llvm::map_range(llvm::seq<int64_t>(1, numSpatial), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+// AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
+// SmallVector<AffineExpr> ks(llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+
+// // Temp subsitute for channel position attr
+// int64_t channelPos = (1) ? 1 : numSpatial + 1;
+
+// // Initialze operand accesses in nw order and insert c according to channel
+// // position
+// SmallVector<AffineExpr> inExprs, outExprs = {n};
+// for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(), op.getDilationsAttr().getValues<int64_t>())) {
+// inExprs.push_back(sp * st + ksp * di);
+// outExprs.push_back(sp);
+// }
+// SmallVector<AffineExpr> kExprs(ks);
+// inExprs.insert(inExprs.begin() + channelPos, c);
+// kExprs.insert(
+// channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
+// outExprs.insert(outExprs.begin() + channelPos, c);
+
+// n.dump();
+// for (auto sp : s)
+// sp.dump();
+// c.dump();
+// for (auto ksp : ks)
+// ksp.dump();
+
+// for (auto b : inExprs)
+// b.dump();
+
+// cached = Builder(ctx).getAffineMapArrayAttr(
+// {AffineMap::get(4, 0, inExprs, ctx),
+// AffineMap::get(4, 0, kExprs, ctx),
+// AffineMap::get(4, 0, outExprs, ctx)});
+// op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+// return cached;
+// }
+
+// ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
+// return getNDIndexngMaps(this);
+// // ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
+// // LinalgDialect::kMemoizedIndexingMapsAttrName);
+// // if (cached)
+// // return cached;
+
+// // MLIRContext *context = getContext();
+
+// // // Domain: (n, w, c, kw)
+// // AffineExpr n = getAffineDimExpr(0, context);
+// // AffineExpr w = getAffineDimExpr(1, context);
+// // AffineExpr c = getAffineDimExpr(2, context);
+// // AffineExpr kw = getAffineDimExpr(3, context);
+
+// // // Temp subsitute for channel position attr
+// // int64_t channelPos = (getChannelFirst()) ? 1 : 2;
+// // // Initialze operand accesses in nw order and insert c according to channel
+// // // position
+// // SmallVector<AffineExpr> inExprs(
+// // {n, w * getStrides().getValues<int64_t>()[0] +
+// // kw * getDilations().getValues<int64_t>()[0]});
+// // SmallVector<AffineExpr> kExprs({kw});
+// // SmallVector<AffineExpr> outExprs({n, w});
+// // inExprs.insert(inExprs.begin() + channelPos, c);
+// // kExprs.insert(
+// // channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
+// // outExprs.insert(outExprs.begin() + channelPos, c);
+
+// // cached = Builder(context).getAffineMapArrayAttr(
+// // {AffineMap::get(4, 0, inExprs, context),
+// // AffineMap::get(4, 0, kExprs, context),
+// // AffineMap::get(4, 0, outExprs, context)});
+// // getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
+// // return cached;
+// }
void DepthwiseConv1DOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
>From e22dc06986c9bd7d72440b61ba398f49697abc0e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 17:14:31 -0600
Subject: [PATCH 07/21] Add 2d depthwise with tests
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 69 +++++++-
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 41 +++--
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 159 ++++++++----------
mlir/test/Dialect/Linalg/loops.mlir | 35 ++++
mlir/test/Dialect/Linalg/named-ops.mlir | 9 +
5 files changed, 212 insertions(+), 101 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 37d668bc43f1b2..866f1f1fde0d51 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -385,7 +385,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
}
//===----------------------------------------------------------------------===//
-// DepthwiseConv1DOp op.
+// DepthwiseConvNDOp ops.
//===----------------------------------------------------------------------===//
def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
@@ -455,6 +455,73 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
}];
}
+def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
+ [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
+
+ let summary = [{
+ Performs 2-D depthwise convolution with switchable channel position; either first or last.
+ }];
+ let description = [{
+ Domain: N, OH, OW, C, KH, KW
+
+ Layout of operands is determined by the `channel_first` `BoolAttr`:
+
+ `channel_first == true`:
+ Input: `NCHW`
+ Kernel: `CHW`
+ Output: `NCHW`
+
+ `channel_first == false`:
+ Input: `NHWC`
+ Kernel: `HWC`
+ Output: `NHWC`
+
+ }];
+
+ let arguments = (ins
+ Variadic<TensorOrMemref>:$inputs,
+ Variadic<TensorOrMemref>:$inits,
+ DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+ DefaultValuedAttr<RankedI64ElementsAttr<[2]>,
+ "{ static_cast<int64_t>(1) }">:$strides,
+ DefaultValuedAttr<RankedI64ElementsAttr<[2]>,
+ "{ static_cast<int64_t>(1) }">:$dilations
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let builders = [
+ OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
+ "bool":$channel_first,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>",
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Declare functions necessary for LinalgStructuredInterface.
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ // Figure out how to get implementation in `LinalgDepthwiseConvolutionOpInterface` working
+ ArrayAttr getIndexingMaps() { return detail::getIndexingMaps(*this); };
+ std::string getLibraryCallName() {
+ return "op_has_no_registered_library_name";
+ }
+
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ // Implement functions necessary for DestinationStyleOpInterface.
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+ static unsigned getNumRegionArgs() { return 3; }
+ MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Transpose op.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3c8ae952f49bdb..cdd830c9d90a68 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -638,29 +638,34 @@ enum class MatchConvolutionResult {
};
} // namespace mlir::linalg::detail
-DenseIntElementsAttr mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
+DenseIntElementsAttr
+mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
auto maybeStridesAttr = op.getStridesAttr();
maybeStridesAttr.dump();
if (!maybeStridesAttr) {
OpBuilder builder(op.getContext());
- int64_t numSpatialDims = op.image().getType().cast<ShapedType>().getRank() - 2;
+ int64_t numSpatialDims =
+ op.image().getType().cast<ShapedType>().getRank() - 2;
auto type = RankedTensorType::get({static_cast<int64_t>(numSpatialDims)},
- builder.getI64Type());
+ builder.getI64Type());
SmallVector<int64_t> strides(numSpatialDims, 1);
return DenseIntElementsAttr::get(type, strides);
}
return op.getStridesAttr();
}
-DenseIntElementsAttr mlir::linalg::detail::getDilationsAttr(DepthwiseConvolutionOpInterface op) {
+DenseIntElementsAttr
+mlir::linalg::detail::getDilationsAttr(DepthwiseConvolutionOpInterface op) {
return op.getDilationsAttr();
}
-BoolAttr mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
+BoolAttr
+mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
return op.getChannelFirstAttr();
}
-ArrayAttr mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
+ArrayAttr
+mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
LinalgDialect::kMemoizedIndexingMapsAttrName);
if (cached)
@@ -670,16 +675,23 @@ ArrayAttr mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface
auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
// Domain: (n, w, c, kw)
AffineExpr n = getAffineDimExpr(0, ctx);
- SmallVector<AffineExpr> s(llvm::map_range(llvm::seq<int64_t>(1, numSpatial + 1), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+ SmallVector<AffineExpr> s(
+ llvm::map_range(llvm::seq<int64_t>(1, numSpatial + 1),
+ [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
- SmallVector<AffineExpr> ks(llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+ SmallVector<AffineExpr> ks(
+ llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)),
+ [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
// Temp subsitute for channel position attr
- int64_t channelPos = (op.getChannelFirstAttr().getValue()) ? 1 : numSpatial + 1;
-
+ int64_t channelPos =
+ (op.getChannelFirstAttr().getValue()) ? 1 : numSpatial + 1;
+
// Initialze operand accesses in nw order and insert c according to channel
// position
SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
- for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(), op.getDilationsAttr().getValues<int64_t>())) {
+ for (const auto &[sp, ksp, st, di] :
+ llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(),
+ op.getDilationsAttr().getValues<int64_t>())) {
inExprs.push_back(sp * st + ksp * di);
outExprs.push_back(sp);
}
@@ -689,11 +701,10 @@ ArrayAttr mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface
channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
outExprs.insert(outExprs.begin() + channelPos, c);
-
cached = Builder(ctx).getAffineMapArrayAttr(
- {AffineMap::get(4, 0, inExprs, ctx),
- AffineMap::get(4, 0, kExprs, ctx),
- AffineMap::get(4, 0, outExprs, ctx)});
+ {AffineMap::get(2 + 2 * numSpatial, 0, inExprs, ctx),
+ AffineMap::get(2 + 2 * numSpatial, 0, kExprs, ctx),
+ AffineMap::get(2 + 2 * numSpatial, 0, outExprs, ctx)});
op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
return cached;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6e904aa03c1eaf..2528a6d1971014 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1725,91 +1725,6 @@ SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
;
}
-// ArrayAttr getNDIndexngMaps(DepthwiseConvolutionOpInterface op) {
-// ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
-// LinalgDialect::kMemoizedIndexingMapsAttrName);
-// if (cached)
-// return cached;
-
-// MLIRContext *ctx = op.getContext();
-// auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
-// // Domain: (n, w, c, kw)
-// AffineExpr n = getAffineDimExpr(0, ctx);
-// SmallVector<AffineExpr> s(llvm::map_range(llvm::seq<int64_t>(1, numSpatial), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
-// AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
-// SmallVector<AffineExpr> ks(llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)), [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
-
-// // Temp subsitute for channel position attr
-// int64_t channelPos = (1) ? 1 : numSpatial + 1;
-
-// // Initialze operand accesses in nw order and insert c according to channel
-// // position
-// SmallVector<AffineExpr> inExprs, outExprs = {n};
-// for (const auto &[sp, ksp, st, di] : llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(), op.getDilationsAttr().getValues<int64_t>())) {
-// inExprs.push_back(sp * st + ksp * di);
-// outExprs.push_back(sp);
-// }
-// SmallVector<AffineExpr> kExprs(ks);
-// inExprs.insert(inExprs.begin() + channelPos, c);
-// kExprs.insert(
-// channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
-// outExprs.insert(outExprs.begin() + channelPos, c);
-
-// n.dump();
-// for (auto sp : s)
-// sp.dump();
-// c.dump();
-// for (auto ksp : ks)
-// ksp.dump();
-
-// for (auto b : inExprs)
-// b.dump();
-
-// cached = Builder(ctx).getAffineMapArrayAttr(
-// {AffineMap::get(4, 0, inExprs, ctx),
-// AffineMap::get(4, 0, kExprs, ctx),
-// AffineMap::get(4, 0, outExprs, ctx)});
-// op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
-// return cached;
-// }
-
-// ArrayAttr DepthwiseConv1DOp::getIndexingMaps() {
-// return getNDIndexngMaps(this);
-// // ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(
-// // LinalgDialect::kMemoizedIndexingMapsAttrName);
-// // if (cached)
-// // return cached;
-
-// // MLIRContext *context = getContext();
-
-// // // Domain: (n, w, c, kw)
-// // AffineExpr n = getAffineDimExpr(0, context);
-// // AffineExpr w = getAffineDimExpr(1, context);
-// // AffineExpr c = getAffineDimExpr(2, context);
-// // AffineExpr kw = getAffineDimExpr(3, context);
-
-// // // Temp subsitute for channel position attr
-// // int64_t channelPos = (getChannelFirst()) ? 1 : 2;
-// // // Initialze operand accesses in nw order and insert c according to channel
-// // // position
-// // SmallVector<AffineExpr> inExprs(
-// // {n, w * getStrides().getValues<int64_t>()[0] +
-// // kw * getDilations().getValues<int64_t>()[0]});
-// // SmallVector<AffineExpr> kExprs({kw});
-// // SmallVector<AffineExpr> outExprs({n, w});
-// // inExprs.insert(inExprs.begin() + channelPos, c);
-// // kExprs.insert(
-// // channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
-// // outExprs.insert(outExprs.begin() + channelPos, c);
-
-// // cached = Builder(context).getAffineMapArrayAttr(
-// // {AffineMap::get(4, 0, inExprs, context),
-// // AffineMap::get(4, 0, kExprs, context),
-// // AffineMap::get(4, 0, outExprs, context)});
-// // getOperation()->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
-// // return cached;
-// }
-
void DepthwiseConv1DOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -1829,6 +1744,80 @@ void DepthwiseConv1DOp::print(OpAsmPrinter &p) {
LogicalResult DepthwiseConv1DOp::verify() { return success(); }
+//===----------------------------------------------------------------------===//
+// DepthwiseConv2DOp
+//===----------------------------------------------------------------------===//
+
+// TODO: refactor into base implementation for all spatial dims
+void DepthwiseConv2DOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "DepthwiseConv2DOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+void DepthwiseConv2DOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange inits, bool channel_first,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+ ArrayRef<NamedAttribute> attributes) {
+ build(builder, result, TypeRange{}, inputs, inits, channel_first);
+ result.addAttribute(getChannelFirstAttrName(result.name),
+ builder.getBoolAttr(channel_first));
+ result.addAttributes(attributes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ for (Value init : inits) {
+ Type initType = init.getType();
+ if (llvm::isa<RankedTensorType>(initType))
+ result.addTypes(initType);
+ }
+
+ if (bodyBuild)
+ buildGenericRegion(builder, result.location, *result.regions.front(),
+ inputs, inits, bodyBuild);
+}
+
+SmallVector<utils::IteratorType> DepthwiseConv2DOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::reduction, utils::IteratorType::reduction};
+ ;
+}
+
+void DepthwiseConv2DOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getDpsInits());
+}
+
+ParseResult DepthwiseConv2DOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+ getRegionBuilder());
+}
+
+void DepthwiseConv2DOp::print(OpAsmPrinter &p) {
+ printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
+}
+
+LogicalResult DepthwiseConv2DOp::verify() { return success(); }
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 4e9bd0d8118a6c..d23036e07f2223 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -902,6 +902,41 @@ func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: me
// COMMON: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
+func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
+ linalg.depthwise_conv_2d {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
+ return
+}
+
+// COMMON-LABEL: func @gen_depthwise_2D_channel_first_memref
+// COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
+// COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
+// COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
+// COMMON-DAG: %[[c3:.*]] = arith.constant 3 : index
+// COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+// CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
+// CHECK: scf.for %[[oh:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
+// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] {
+// CHECK: scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
+// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[oh:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim4]], %[[dim5]], %[[dim1]])
+// COMMON: scf.for %[[kh:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+// COMMON: scf.for %[[kw:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+// COMMON: %[[affh:.*]] = affine.apply #[[$stride1Dilation1]](%[[oh]], %[[kh]])
+// COMMON: %[[affw:.*]] = affine.apply #[[$stride1Dilation1]](%[[ow]], %[[kw]])
+// COMMON: %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[c]], %[[affh]], %[[affw]]] : memref<?x?x?x?xf32>
+// COMMON: %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[kh]], %[[kw]]] : memref<?x?x?xf32>
+// COMMON: %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[oh]], %[[ow]]] : memref<?x?x?x?xf32>
+// COMMON: %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
+// COMMON: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
+// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[oh]], %[[ow]]] : memref<?x?x?x?xf32>
+
// -----
func.func @lower_to_loops_with_rank_reducing_subviews(
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index f281c138727640..54075dbe36ab47 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -9,6 +9,15 @@ func.func @gen_depthwise_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1
// -----
+// CHECK-LABEL: func @gen_depthwise_2D_channel_first_memref
+func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
+ // CHECK: depthwise_conv_2d {{.*}}channel_first = true
+ linalg.depthwise_conv_2d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @gen_depthwise_channel_last_memref
func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
// CHECK: depthwise_conv_1d {{.*}}channel_first = false
>From 142c2e35f1e7980851289c80f48a815a27838177 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 8 Dec 2023 21:47:21 -0600
Subject: [PATCH 08/21] Refactor common methods to detail functions
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 5 +-
.../Dialect/Linalg/IR/LinalgInterfaces.td | 101 +++++----
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 34 +--
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 197 ++++++++++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 61 +-----
mlir/test/Dialect/Linalg/loops.mlir | 36 ++--
6 files changed, 283 insertions(+), 151 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index c3eb3ff665c20d..a278d6509c19b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -116,11 +116,14 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
namespace detail {
-// DictionaryAttr getStridesDict(DepthwiseConvolutionOpInterface op);
+// Common implementations for DepthwiseConvolutionOpInterface
DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
BoolAttr getChannelFirstAttr(DepthwiseConvolutionOpInterface op);
ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
+ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
+void regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs);
+
/// Returns true if the block contains a contraction of the following form:
///
/// %0 = <elemwise>(permutation-of(cu(block-argument-0),
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index e151d8c634dd87..6b05bc1e3f311f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -876,48 +876,63 @@ def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
];
}
-def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpInterface", [LinalgConvolutionOpInterface]> {
- let description = [{
- A depthwise convolution is defined in general terms:
- 1. it is a convolution as defined by `ConvolutionOpInterface`
- 1. `in_channels = K * out_channels` for some integer `K`
- 4. The indexing maps of the input have expressions that satisfy
- ```
- AffineExpr ::== AffineDimExpr | ConvolvedExpr
- ConvolvedExpr ::== MulExpr (`+` MulExpr)+
- MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
- ```
- }];
- let cppNamespace = "::mlir::linalg";
- let verify = [{ return detail::verifyConvolutionInterface($_op); }];
- let methods = [
- InterfaceMethod<[{
- Returns strides attribute.
- }],
- "::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- printf("whatever\n");
- return detail::getStridesAttr($_op);
- }]>,
- InterfaceMethod<[{
- Returns dilations attribute.
- }],
- "::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- printf("whatever\n");
- return detail::getDilationsAttr($_op);
- }]>,
- InterfaceMethod<[{
- Returns channel dim attribute.
- }],
- "::mlir::BoolAttr", "getChannelFirstAttr", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- return detail::getChannelFirstAttr($_op);
- }]>
- ];
- }
+def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpInterface", [
+ LinalgConvolutionOpInterface]> {
+ let description = [{
+ A depthwise convolution is defined in general terms:
+ 1. it is a convolution as defined by `ConvolutionOpInterface`
+ 1. `in_channels = K * out_channels` for some integer `K`
+ 4. The indexing maps of the input have expressions that satisfy
+ ```
+ AffineExpr ::== AffineDimExpr | ConvolvedExpr
+ ConvolvedExpr ::== MulExpr (`+` MulExpr)+
+ MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
+ ```
+ }];
+ let cppNamespace = "::mlir::linalg";
+ let verify = [{ return detail::verifyConvolutionInterface($_op); }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns strides attribute.
+ }],
+ "::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::getStridesAttr($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns dilations attribute.
+ }],
+ "::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::getDilationsAttr($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns channel dim attribute.
+ }],
+ "::mlir::BoolAttr", "getChannelFirstAttr", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::getChannelFirstAttr($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns indexing maps for any spatial dimension.
+ }],
+ "::mlir::ArrayAttr", "getIndexingMaps", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::getIndexingMaps($_op);
+ }]>,
+ InterfaceMethod<[{
+ Returns indexing maps for any spatial dimension.
+ }],
+ "::mlir::ArrayAttr", "getIteratorTypes", (ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return detail::getIteratorTypes($_op);
+ }]>
+ ];
+}
#endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 866f1f1fde0d51..679718c5d7ee55 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -430,27 +430,14 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
let hasCustomAssemblyFormat = 1;
- let hasVerifier = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{
- // Declare functions necessary for LinalgStructuredInterface.
- SmallVector<utils::IteratorType> getIteratorTypesArray();
- // Figure out how to get implementation in `LinalgDepthwiseConvolutionOpInterface` working
- ArrayAttr getIndexingMaps() { return detail::getIndexingMaps(*this); };
- std::string getLibraryCallName() {
- return "op_has_no_registered_library_name";
- }
-
- static void regionBuilder(ImplicitLocOpBuilder &b,
- Block &block, ArrayRef<NamedAttribute> attrs);
-
- // Implement functions necessary for DestinationStyleOpInterface.
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return regionBuilder;
+ return detail::regionBuilder;
}
- static unsigned getNumRegionArgs() { return 3; }
+ // Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
}];
}
@@ -497,27 +484,14 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
let hasCustomAssemblyFormat = 1;
- let hasVerifier = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{
- // Declare functions necessary for LinalgStructuredInterface.
- SmallVector<utils::IteratorType> getIteratorTypesArray();
- // Figure out how to get implementation in `LinalgDepthwiseConvolutionOpInterface` working
- ArrayAttr getIndexingMaps() { return detail::getIndexingMaps(*this); };
- std::string getLibraryCallName() {
- return "op_has_no_registered_library_name";
- }
-
- static void regionBuilder(ImplicitLocOpBuilder &b,
- Block &block, ArrayRef<NamedAttribute> attrs);
-
- // Implement functions necessary for DestinationStyleOpInterface.
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return regionBuilder;
+ return detail::regionBuilder;
}
- static unsigned getNumRegionArgs() { return 3; }
+ // Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
}];
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index cdd830c9d90a68..5686692f9da026 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -8,11 +8,13 @@
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExprVisitor.h"
@@ -204,6 +206,171 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
namespace {
auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
+// TODO: Figure out a way to not copy this from LinalgOps.cpp
+class RegionBuilderHelper {
+public:
+ RegionBuilderHelper(MLIRContext *context, Block &block)
+ : context(context), block(block) {}
+
+ // Build the unary functions defined by OpDSL.
+ Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
+ if (!isFloatingPoint(arg))
+ llvm_unreachable("unsupported non numeric type");
+ OpBuilder builder = getBuilder();
+ switch (unaryFn) {
+ case UnaryFn::exp:
+ return builder.create<math::ExpOp>(arg.getLoc(), arg);
+ case UnaryFn::log:
+ return builder.create<math::LogOp>(arg.getLoc(), arg);
+ case UnaryFn::abs:
+ return builder.create<math::AbsFOp>(arg.getLoc(), arg);
+ case UnaryFn::ceil:
+ return builder.create<math::CeilOp>(arg.getLoc(), arg);
+ case UnaryFn::floor:
+ return builder.create<math::FloorOp>(arg.getLoc(), arg);
+ case UnaryFn::negf:
+ return builder.create<arith::NegFOp>(arg.getLoc(), arg);
+ }
+ llvm_unreachable("unsupported unary function");
+ }
+
+ // Build the binary functions defined by OpDSL.
+ Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+ bool allComplex = isComplex(arg0) && isComplex(arg1);
+ bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
+ bool allInteger = isInteger(arg0) && isInteger(arg1);
+ bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
+ arg1.getType().getIntOrFloatBitWidth() == 1;
+ if (!allComplex && !allFloatingPoint && !allInteger)
+ llvm_unreachable("unsupported non numeric type");
+ OpBuilder builder = getBuilder();
+ switch (binaryFn) {
+ case BinaryFn::add:
+ if (allComplex)
+ return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
+ if (allFloatingPoint)
+ return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
+ if (allBool)
+ return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::sub:
+ if (allComplex)
+ return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
+ if (allFloatingPoint)
+ return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
+ if (allBool)
+ llvm_unreachable("unsupported operation: sub with bools");
+ return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::mul:
+ if (allComplex)
+ return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
+ if (allFloatingPoint)
+ return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
+ if (allBool)
+ return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::div:
+ if (allComplex)
+ return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
+ if (allFloatingPoint)
+ return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
+ if (allBool)
+ llvm_unreachable("unsupported operation: div with bools");
+ return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::div_unsigned:
+ if (!allInteger || allBool)
+ llvm_unreachable("unsupported operation: unsigned div not on uint");
+ return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::max_signed:
+ assert(!allComplex);
+ if (allFloatingPoint)
+ return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::min_signed:
+ assert(!allComplex);
+ if (allFloatingPoint)
+ return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::max_unsigned:
+ assert(!allComplex);
+ if (allFloatingPoint)
+ return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
+ case BinaryFn::min_unsigned:
+ assert(!allComplex);
+ if (allFloatingPoint)
+ return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
+ return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
+ }
+ llvm_unreachable("unsupported binary function");
+ }
+
+ // Build the type functions defined by OpDSL.
+ Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
+ switch (typeFn) {
+ case TypeFn::cast_signed:
+ return cast(toType, operand, false);
+ case TypeFn::cast_unsigned:
+ return cast(toType, operand, true);
+ }
+ llvm_unreachable("unsupported type conversion function");
+ }
+
+ void yieldOutputs(ValueRange values) {
+ OpBuilder builder = getBuilder();
+ Location loc = builder.getUnknownLoc();
+ builder.create<YieldOp>(loc, values);
+ }
+
+ Value constant(const std::string &value) {
+ OpBuilder builder = getBuilder();
+ Location loc = builder.getUnknownLoc();
+ Attribute valueAttr = parseAttribute(value, builder.getContext());
+ return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
+ }
+
+ Value index(int64_t dim) {
+ OpBuilder builder = getBuilder();
+ return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
+ }
+
+ Type getIntegerType(unsigned width) {
+ return IntegerType::get(context, width);
+ }
+
+ Type getFloat32Type() { return Float32Type::get(context); }
+ Type getFloat64Type() { return Float64Type::get(context); }
+
+private:
+ // Generates operations to cast the given operand to a specified type.
+ // If the cast cannot be performed, a warning will be issued and the
+ // operand returned as-is (which will presumably yield a verification
+ // issue downstream).
+ Value cast(Type toType, Value operand, bool isUnsignedCast) {
+ OpBuilder builder = getBuilder();
+ auto loc = operand.getLoc();
+ return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
+ }
+
+ bool isComplex(Value value) {
+ return llvm::isa<ComplexType>(value.getType());
+ }
+ bool isFloatingPoint(Value value) {
+ return llvm::isa<FloatType>(value.getType());
+ }
+ bool isInteger(Value value) {
+ return llvm::isa<IntegerType>(value.getType());
+ }
+
+ OpBuilder getBuilder() {
+ OpBuilder builder(context);
+ builder.setInsertionPointToEnd(&block);
+ return builder;
+ }
+
+ MLIRContext *context;
+ Block █
+};
} // namespace
/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
@@ -664,6 +831,16 @@ mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
return op.getChannelFirstAttr();
}
+ArrayAttr mlir::linalg::detail::getIteratorTypes(DepthwiseConvolutionOpInterface op) {
+ int64_t numSpatialDims =
+ op.image().getType().cast<ShapedType>().getRank() - 2;
+ SmallVector<Attribute> iteratorTypes(2 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+ SmallVector<Attribute> reductions(numSpatialDims, IteratorTypeAttr::get(op.getContext(), red));
+ iteratorTypes.insert(iteratorTypes.end(), reductions.begin(), reductions.end());
+
+ return Builder(op.getContext()).getArrayAttr(iteratorTypes);
+}
+
ArrayAttr
mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
@@ -709,6 +886,26 @@ mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
return cached;
}
+void mlir::linalg::detail::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+ SmallVector<Value> yields;
+
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
mlir::linalg::detail::MatchConvolutionResult
mlir::linalg::detail::isConvolutionInterfaceImpl(
Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2528a6d1971014..b8663c7643f592 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1676,25 +1676,6 @@ LogicalResult ReduceOp::verify() {
// DepthwiseConv1DOp
//===----------------------------------------------------------------------===//
-void DepthwiseConv1DOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
- ArrayRef<NamedAttribute> attrs) {
- assert(block.getNumArguments() == 3 &&
- "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
- RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
- SmallVector<Value> yields;
-
- Value value1 =
- helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
- block.getArgument(0));
- Value value2 =
- helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
- block.getArgument(1));
- Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
- Value value4 =
- helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
- yields.push_back(value4);
- helper.yieldOutputs(yields);
-}
void DepthwiseConv1DOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
@@ -1718,13 +1699,6 @@ void DepthwiseConv1DOp::build(
inputs, inits, bodyBuild);
}
-SmallVector<utils::IteratorType> DepthwiseConv1DOp::getIteratorTypesArray() {
- return SmallVector<utils::IteratorType>{
- utils::IteratorType::parallel, utils::IteratorType::parallel,
- utils::IteratorType::parallel, utils::IteratorType::reduction};
- ;
-}
-
void DepthwiseConv1DOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -1734,7 +1708,7 @@ void DepthwiseConv1DOp::getEffects(
ParseResult DepthwiseConv1DOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+ return parseNamedStructuredOp(parser, result, 3,
getRegionBuilder());
}
@@ -1742,32 +1716,11 @@ void DepthwiseConv1DOp::print(OpAsmPrinter &p) {
printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
}
-LogicalResult DepthwiseConv1DOp::verify() { return success(); }
-
//===----------------------------------------------------------------------===//
// DepthwiseConv2DOp
//===----------------------------------------------------------------------===//
// TODO: refactor into base implementation for all spatial dims
-void DepthwiseConv2DOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
- ArrayRef<NamedAttribute> attrs) {
- assert(block.getNumArguments() == 3 &&
- "DepthwiseConv2DOp regionBuilder expects 3 (>=0) args");
- RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
- SmallVector<Value> yields;
-
- Value value1 =
- helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
- block.getArgument(0));
- Value value2 =
- helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
- block.getArgument(1));
- Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
- Value value4 =
- helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
- yields.push_back(value4);
- helper.yieldOutputs(yields);
-}
void DepthwiseConv2DOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
@@ -1791,14 +1744,6 @@ void DepthwiseConv2DOp::build(
inputs, inits, bodyBuild);
}
-SmallVector<utils::IteratorType> DepthwiseConv2DOp::getIteratorTypesArray() {
- return SmallVector<utils::IteratorType>{
- utils::IteratorType::parallel, utils::IteratorType::parallel,
- utils::IteratorType::parallel, utils::IteratorType::parallel,
- utils::IteratorType::reduction, utils::IteratorType::reduction};
- ;
-}
-
void DepthwiseConv2DOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -1808,7 +1753,7 @@ void DepthwiseConv2DOp::getEffects(
ParseResult DepthwiseConv2DOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+ return parseNamedStructuredOp(parser, result, 3,
getRegionBuilder());
}
@@ -1816,8 +1761,6 @@ void DepthwiseConv2DOp::print(OpAsmPrinter &p) {
printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
}
-LogicalResult DepthwiseConv2DOp::verify() { return success(); }
-
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index d23036e07f2223..40d5078fc1d5fc 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -882,17 +882,17 @@ func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: me
// COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
-// COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
-// COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
-// COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
-// COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
-// COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
-// COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?xf32>
-// COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+// COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
+// COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
+// COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
+// COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
+// COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
+// COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?xf32>
+// COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
// CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
// CHECK: scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim3]], %[[dim1]])
+// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim3]], %[[dim1]])
// COMMON: scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
// COMMON: %[[aff:.*]] = affine.apply #[[$stride3Dilation2]](%[[ow]], %[[kw]])
// COMMON: %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[aff]], %[[c]]] : memref<?x?x?xf32>
@@ -911,16 +911,16 @@ func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %ar
// COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
-// COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
-// COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
-// COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
-// COMMON-DAG: %[[c3:.*]] = arith.constant 3 : index
-// COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?x?xf32>
-// COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?x?xf32>
-// COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
-// COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
-// COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
-// COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
+// COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
+// COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
+// COMMON-DAG: %[[c3:.*]] = arith.constant 3 : index
+// COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
// CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
// CHECK: scf.for %[[oh:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] {
>From 44a5e434a16c9cd7c4258cf57675d07c400d91bf Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 02:41:55 -0600
Subject: [PATCH 09/21] More refactoring
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 13 +-
.../Dialect/Linalg/IR/LinalgInterfaces.td | 30 ++-
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 36 +--
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 213 ++----------------
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 106 ++++-----
5 files changed, 109 insertions(+), 289 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index a278d6509c19b6..1c0e66edeabbb1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -116,13 +116,22 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
namespace detail {
-// Common implementations for DepthwiseConvolutionOpInterface
+// Common implementations for DepthwiseConvolutionOpInterface
+namespace depthwise_convolution_impl {
DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
BoolAttr getChannelFirstAttr(DepthwiseConvolutionOpInterface op);
ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
-void regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs);
+void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs);
+void getEffects(
+ DepthwiseConvolutionOpInterface op,
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects);
+ParseResult parse(OpAsmParser &parser, OperationState &result);
+void print(DepthwiseConvolutionOpInterface op, OpAsmPrinter &p);
+} // namespace depthwise_convolution_impl
/// Returns true if the block contains a contraction of the following form:
///
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 6b05bc1e3f311f..9373e8811f443f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -175,6 +175,26 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
return $_op.getOperation()->getOperand(1);
}]
>,
+ InterfaceMethod<
+ /*desc=*/"Return the init operand.",
+ /*retTy=*/"Value",
+ /*methodName=*/"init",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getOperation()->getOperand(2);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the result tensor operand.",
+ /*retTy=*/"Value",
+ /*methodName=*/"result",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getOperation()->getResult(2);
+ }]
+ >,
];
}
@@ -898,7 +918,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
"::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
- return detail::getStridesAttr($_op);
+ return detail::depthwise_convolution_impl::getStridesAttr($_op);
}]>,
InterfaceMethod<[{
Returns dilations attribute.
@@ -906,7 +926,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
"::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
- return detail::getDilationsAttr($_op);
+ return detail::depthwise_convolution_impl::getDilationsAttr($_op);
}]>,
InterfaceMethod<[{
Returns channel dim attribute.
@@ -914,7 +934,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
"::mlir::BoolAttr", "getChannelFirstAttr", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
- return detail::getChannelFirstAttr($_op);
+ return detail::depthwise_convolution_impl::getChannelFirstAttr($_op);
}]>,
InterfaceMethod<[{
Returns indexing maps for any spatial dimension.
@@ -922,7 +942,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
"::mlir::ArrayAttr", "getIndexingMaps", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
- return detail::getIndexingMaps($_op);
+ return detail::depthwise_convolution_impl::getIndexingMaps($_op);
}]>,
InterfaceMethod<[{
Returns indexing maps for any spatial dimension.
@@ -930,7 +950,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
"::mlir::ArrayAttr", "getIteratorTypes", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
- return detail::getIteratorTypes($_op);
+ return detail::depthwise_convolution_impl::getIteratorTypes($_op);
}]>
];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 679718c5d7ee55..c4322dd2c7d536 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -423,19 +423,19 @@ def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
- let builders = [
- OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
- "bool":$channel_first,
- "function_ref<void(OpBuilder &, Location, ValueRange)>",
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
- ];
- let hasCustomAssemblyFormat = 1;
-
+ // TODO: Figure out how to move this to the interface
let extraClassDeclaration = structuredOpsBaseDecls # [{
+ void print(::mlir::OpAsmPrinter &printer) {
+ return detail::depthwise_convolution_impl::print(*this, printer);
+ }
+ static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ return detail::depthwise_convolution_impl::parse(parser, result);
+ }
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return detail::regionBuilder;
+ return detail::depthwise_convolution_impl::regionBuilder;
}
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
@@ -477,19 +477,19 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
- let builders = [
- OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
- "bool":$channel_first,
- "function_ref<void(OpBuilder &, Location, ValueRange)>",
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
- ];
- let hasCustomAssemblyFormat = 1;
-
+ // TODO: Figure out how to move this to the interface
let extraClassDeclaration = structuredOpsBaseDecls # [{
+ void print(::mlir::OpAsmPrinter &printer) {
+ return detail::depthwise_convolution_impl::print(*this, printer);
+ }
+ static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ return detail::depthwise_convolution_impl::parse(parser, result);
+ }
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return detail::regionBuilder;
+ return detail::depthwise_convolution_impl::regionBuilder;
}
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 5686692f9da026..b6c073a8f11f26 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -206,171 +206,6 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
namespace {
auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
-// TODO: Figure out a way to not copy this from LinalgOps.cpp
-class RegionBuilderHelper {
-public:
- RegionBuilderHelper(MLIRContext *context, Block &block)
- : context(context), block(block) {}
-
- // Build the unary functions defined by OpDSL.
- Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
- if (!isFloatingPoint(arg))
- llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
- switch (unaryFn) {
- case UnaryFn::exp:
- return builder.create<math::ExpOp>(arg.getLoc(), arg);
- case UnaryFn::log:
- return builder.create<math::LogOp>(arg.getLoc(), arg);
- case UnaryFn::abs:
- return builder.create<math::AbsFOp>(arg.getLoc(), arg);
- case UnaryFn::ceil:
- return builder.create<math::CeilOp>(arg.getLoc(), arg);
- case UnaryFn::floor:
- return builder.create<math::FloorOp>(arg.getLoc(), arg);
- case UnaryFn::negf:
- return builder.create<arith::NegFOp>(arg.getLoc(), arg);
- }
- llvm_unreachable("unsupported unary function");
- }
-
- // Build the binary functions defined by OpDSL.
- Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
- bool allComplex = isComplex(arg0) && isComplex(arg1);
- bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
- bool allInteger = isInteger(arg0) && isInteger(arg1);
- bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
- arg1.getType().getIntOrFloatBitWidth() == 1;
- if (!allComplex && !allFloatingPoint && !allInteger)
- llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
- switch (binaryFn) {
- case BinaryFn::add:
- if (allComplex)
- return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
- if (allFloatingPoint)
- return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
- if (allBool)
- return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
- case BinaryFn::sub:
- if (allComplex)
- return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
- if (allFloatingPoint)
- return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
- if (allBool)
- llvm_unreachable("unsupported operation: sub with bools");
- return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
- case BinaryFn::mul:
- if (allComplex)
- return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
- if (allFloatingPoint)
- return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
- if (allBool)
- return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
- case BinaryFn::div:
- if (allComplex)
- return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
- if (allFloatingPoint)
- return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
- if (allBool)
- llvm_unreachable("unsupported operation: div with bools");
- return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
- case BinaryFn::div_unsigned:
- if (!allInteger || allBool)
- llvm_unreachable("unsupported operation: unsigned div not on uint");
- return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
- case BinaryFn::max_signed:
- assert(!allComplex);
- if (allFloatingPoint)
- return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
- case BinaryFn::min_signed:
- assert(!allComplex);
- if (allFloatingPoint)
- return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
- case BinaryFn::max_unsigned:
- assert(!allComplex);
- if (allFloatingPoint)
- return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
- case BinaryFn::min_unsigned:
- assert(!allComplex);
- if (allFloatingPoint)
- return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
- return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
- }
- llvm_unreachable("unsupported binary function");
- }
-
- // Build the type functions defined by OpDSL.
- Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
- switch (typeFn) {
- case TypeFn::cast_signed:
- return cast(toType, operand, false);
- case TypeFn::cast_unsigned:
- return cast(toType, operand, true);
- }
- llvm_unreachable("unsupported type conversion function");
- }
-
- void yieldOutputs(ValueRange values) {
- OpBuilder builder = getBuilder();
- Location loc = builder.getUnknownLoc();
- builder.create<YieldOp>(loc, values);
- }
-
- Value constant(const std::string &value) {
- OpBuilder builder = getBuilder();
- Location loc = builder.getUnknownLoc();
- Attribute valueAttr = parseAttribute(value, builder.getContext());
- return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
- }
-
- Value index(int64_t dim) {
- OpBuilder builder = getBuilder();
- return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
- }
-
- Type getIntegerType(unsigned width) {
- return IntegerType::get(context, width);
- }
-
- Type getFloat32Type() { return Float32Type::get(context); }
- Type getFloat64Type() { return Float64Type::get(context); }
-
-private:
- // Generates operations to cast the given operand to a specified type.
- // If the cast cannot be performed, a warning will be issued and the
- // operand returned as-is (which will presumably yield a verification
- // issue downstream).
- Value cast(Type toType, Value operand, bool isUnsignedCast) {
- OpBuilder builder = getBuilder();
- auto loc = operand.getLoc();
- return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
- }
-
- bool isComplex(Value value) {
- return llvm::isa<ComplexType>(value.getType());
- }
- bool isFloatingPoint(Value value) {
- return llvm::isa<FloatType>(value.getType());
- }
- bool isInteger(Value value) {
- return llvm::isa<IntegerType>(value.getType());
- }
-
- OpBuilder getBuilder() {
- OpBuilder builder(context);
- builder.setInsertionPointToEnd(&block);
- return builder;
- }
-
- MLIRContext *context;
- Block █
-};
} // namespace
/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
@@ -806,7 +641,8 @@ enum class MatchConvolutionResult {
} // namespace mlir::linalg::detail
DenseIntElementsAttr
-mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
+mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
+ DepthwiseConvolutionOpInterface op) {
auto maybeStridesAttr = op.getStridesAttr();
maybeStridesAttr.dump();
if (!maybeStridesAttr) {
@@ -822,27 +658,32 @@ mlir::linalg::detail::getStridesAttr(DepthwiseConvolutionOpInterface op) {
}
DenseIntElementsAttr
-mlir::linalg::detail::getDilationsAttr(DepthwiseConvolutionOpInterface op) {
+mlir::linalg::detail::depthwise_convolution_impl::getDilationsAttr(
+ DepthwiseConvolutionOpInterface op) {
return op.getDilationsAttr();
}
-BoolAttr
-mlir::linalg::detail::getChannelFirstAttr(DepthwiseConvolutionOpInterface op) {
+BoolAttr mlir::linalg::detail::depthwise_convolution_impl::getChannelFirstAttr(
+ DepthwiseConvolutionOpInterface op) {
return op.getChannelFirstAttr();
}
-ArrayAttr mlir::linalg::detail::getIteratorTypes(DepthwiseConvolutionOpInterface op) {
+ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIteratorTypes(
+ DepthwiseConvolutionOpInterface op) {
int64_t numSpatialDims =
- op.image().getType().cast<ShapedType>().getRank() - 2;
- SmallVector<Attribute> iteratorTypes(2 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
- SmallVector<Attribute> reductions(numSpatialDims, IteratorTypeAttr::get(op.getContext(), red));
- iteratorTypes.insert(iteratorTypes.end(), reductions.begin(), reductions.end());
+ op.image().getType().cast<ShapedType>().getRank() - 2;
+ SmallVector<Attribute> iteratorTypes(
+ 2 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+ SmallVector<Attribute> reductions(
+ numSpatialDims, IteratorTypeAttr::get(op.getContext(), red));
+ iteratorTypes.insert(iteratorTypes.end(), reductions.begin(),
+ reductions.end());
return Builder(op.getContext()).getArrayAttr(iteratorTypes);
}
-ArrayAttr
-mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
+ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
+ DepthwiseConvolutionOpInterface op) {
ArrayAttr cached = op->getAttrOfType<ArrayAttr>(
LinalgDialect::kMemoizedIndexingMapsAttrName);
if (cached)
@@ -886,26 +727,6 @@ mlir::linalg::detail::getIndexingMaps(DepthwiseConvolutionOpInterface op) {
return cached;
}
-void mlir::linalg::detail::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
- ArrayRef<NamedAttribute> attrs) {
- assert(block.getNumArguments() == 3 &&
- "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
- RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
- SmallVector<Value> yields;
-
- Value value1 =
- helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
- block.getArgument(0));
- Value value2 =
- helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
- block.getArgument(1));
- Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
- Value value4 =
- helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
- yields.push_back(value4);
- helper.yieldOutputs(yields);
-}
-
mlir::linalg::detail::MatchConvolutionResult
mlir::linalg::detail::isConvolutionInterfaceImpl(
Operation *op, ConvolutionDimensions *dimensions) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b8663c7643f592..76879536fe21c7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1673,92 +1673,62 @@ LogicalResult ReduceOp::verify() {
}
//===----------------------------------------------------------------------===//
-// DepthwiseConv1DOp
+// DepthwiseConvNDOp
//===----------------------------------------------------------------------===//
-
-void DepthwiseConv1DOp::build(
- OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange inits, bool channel_first,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
- ArrayRef<NamedAttribute> attributes) {
- build(builder, result, TypeRange{}, inputs, inits, channel_first);
- result.addAttribute(getChannelFirstAttrName(result.name),
- builder.getBoolAttr(channel_first));
- result.addAttributes(attributes);
-
- // Add output types for `RankedTensorType` output arguments.
- for (Value init : inits) {
- Type initType = init.getType();
- if (llvm::isa<RankedTensorType>(initType))
- result.addTypes(initType);
- }
-
- if (bodyBuild)
- buildGenericRegion(builder, result.location, *result.regions.front(),
- inputs, inits, bodyBuild);
-}
-
-void DepthwiseConv1DOp::getEffects(
+// There must be a way to avoid defining the following 3 functions
+void mlir::linalg::detail::depthwise_convolution_impl::getEffects(
+ DepthwiseConvolutionOpInterface op,
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
- getDpsInits());
+ getGenericEffectsImpl(effects, op.result(), op.image(), op.filter());
}
-ParseResult DepthwiseConv1DOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseNamedStructuredOp(parser, result, 3,
- getRegionBuilder());
+ParseResult mlir::linalg::detail::depthwise_convolution_impl::parse(
+ OpAsmParser &parser, OperationState &result) {
+ return parseNamedStructuredOp(
+ parser, result, 3,
+ mlir::linalg::detail::depthwise_convolution_impl::regionBuilder);
}
-void DepthwiseConv1DOp::print(OpAsmPrinter &p) {
- printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
+void mlir::linalg::detail::depthwise_convolution_impl::print(
+ DepthwiseConvolutionOpInterface op, OpAsmPrinter &p) {
+ printNamedStructuredOp(p, op.getOperation(), op.image(), op.filter());
}
-//===----------------------------------------------------------------------===//
-// DepthwiseConv2DOp
-//===----------------------------------------------------------------------===//
-
-// TODO: refactor into base implementation for all spatial dims
+// Build {mul, add} region for convolution
+void mlir::linalg::detail::depthwise_convolution_impl::regionBuilder(
+ ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 3 &&
+ "DepthwiseConv1DOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+ SmallVector<Value> yields;
-void DepthwiseConv2DOp::build(
- OpBuilder &builder, OperationState &result, ValueRange inputs,
- ValueRange inits, bool channel_first,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
- ArrayRef<NamedAttribute> attributes) {
- build(builder, result, TypeRange{}, inputs, inits, channel_first);
- result.addAttribute(getChannelFirstAttrName(result.name),
- builder.getBoolAttr(channel_first));
- result.addAttributes(attributes);
-
- // Add output types for `RankedTensorType` output arguments.
- for (Value init : inits) {
- Type initType = init.getType();
- if (llvm::isa<RankedTensorType>(initType))
- result.addTypes(initType);
- }
-
- if (bodyBuild)
- buildGenericRegion(builder, result.location, *result.regions.front(),
- inputs, inits, bodyBuild);
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
}
-void DepthwiseConv2DOp::getEffects(
+// TODO: Figure out how to move this to interface
+void DepthwiseConv1DOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
-
-ParseResult DepthwiseConv2DOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseNamedStructuredOp(parser, result, 3,
- getRegionBuilder());
-}
-
-void DepthwiseConv2DOp::print(OpAsmPrinter &p) {
- printNamedStructuredOp(p, getOperation(), getDpsInputs(), getDpsInits());
+void DepthwiseConv2DOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getDpsInits());
}
//===----------------------------------------------------------------------===//
>From 93552bc8de16fc4748acee9b92a1b89294f9acf6 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 03:25:37 -0600
Subject: [PATCH 10/21] Add 3D depthwise (relatively easily after refactor)
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 5 +-
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 54 +++++++++++++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 6 +++
mlir/test/Dialect/Linalg/named-ops.mlir | 9 ++++
4 files changed, 72 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 9373e8811f443f..8753833b81d20e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -900,8 +900,9 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
LinalgConvolutionOpInterface]> {
let description = [{
A depthwise convolution is defined in general terms:
- 1. it is a convolution as defined by `ConvolutionOpInterface`
- 1. `in_channels = K * out_channels` for some integer `K`
+ 1. it is a convolution as defined by `ConvolutionOpInterface`.
+ 2. `in_channels = K * out_channels` for some integer `K`.
+ 3. `input_rank == output_rank == kernel_rank + 1` (including batch dim in input and output)
4. The indexing maps of the input have expressions that satisfy
```
AffineExpr ::== AffineDimExpr | ConvolvedExpr
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c4322dd2c7d536..ffa9cc0562aeab 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -496,6 +496,60 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
}];
}
+def DepthwiseConv3DOp : LinalgStructuredBase_Op<"depthwise_conv_3d",
+ [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
+
+ let summary = [{
+ Performs 3-D depthwise convolution with switchable channel position; either first or last.
+ }];
+ let description = [{
+ Domain: N, OD, OH, OW, C, KD, KH, KW
+
+ Layout of operands is determined by the `channel_first` `BoolAttr`:
+
+ `channel_first == true`:
+ Input: `NCDHW`
+ Kernel: `CDHW`
+ Output: `NCDHW`
+
+ `channel_first == false`:
+ Input: `NDHWC`
+ Kernel: `DHWC`
+ Output: `NDHWC`
+
+ }];
+
+ let arguments = (ins
+ Variadic<TensorOrMemref>:$inputs,
+ Variadic<TensorOrMemref>:$inits,
+ DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+ DefaultValuedAttr<RankedI64ElementsAttr<[3]>,
+ "{ static_cast<int64_t>(1) }">:$strides,
+ DefaultValuedAttr<RankedI64ElementsAttr<[3]>,
+ "{ static_cast<int64_t>(1) }">:$dilations
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ // TODO: Figure out how to move this to the interface
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ void print(::mlir::OpAsmPrinter &printer) {
+ return detail::depthwise_convolution_impl::print(*this, printer);
+ }
+ static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ return detail::depthwise_convolution_impl::parse(parser, result);
+ }
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return detail::depthwise_convolution_impl::regionBuilder;
+ }
+ // Implement functions necessary for DestinationStyleOpInterface.
+ MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Transpose op.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76879536fe21c7..23c7c4730a16aa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1730,6 +1730,12 @@ void DepthwiseConv2DOp::getEffects(
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
+void DepthwiseConv3DOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getDpsInits());
+}
//===----------------------------------------------------------------------===//
// TransposeOp
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 54075dbe36ab47..26df1ea316e532 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -18,6 +18,15 @@ func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>,
// -----
+// CHECK-LABEL: func @gen_depthwise_3D_channel_first_memref
+func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
+ // CHECK: depthwise_conv_3d {{.*}}channel_first = true
+ linalg.depthwise_conv_3d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @gen_depthwise_channel_last_memref
func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
// CHECK: depthwise_conv_1d {{.*}}channel_first = false
>From 123ffc7a757542fe82bda02b8e376c6237ac027d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 03:48:33 -0600
Subject: [PATCH 11/21] Add verifier for depthwise interface
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 4 ++++
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 2 +-
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 14 ++++++++++++++
3 files changed, 19 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 1c0e66edeabbb1..ed6cce6376932b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -189,6 +189,10 @@ LogicalResult verifyContractionInterface(Operation *op);
/// Verify that `op` conforms to the ConvolutionOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op);
+/// Verify that `op` conforms to the DepthwiseConvolutionOpInterface.
+LogicalResult verifyDepthwiseConvolutionInterface(Operation *op);
+
+LogicalResult verifyConvolutionInterface(Operation *op);
/// Verify that `op` conforms to the FillOpInterface.
LogicalResult verifyFillInterface(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 8753833b81d20e..9749f301dfb153 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -911,7 +911,7 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
```
}];
let cppNamespace = "::mlir::linalg";
- let verify = [{ return detail::verifyConvolutionInterface($_op); }];
+ let verify = [{ return detail::verifyDepthwiseConvolutionInterface($_op); }];
let methods = [
InterfaceMethod<[{
Returns strides attribute.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index b6c073a8f11f26..6fa9fd92304d30 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -727,6 +727,20 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
return cached;
}
+LogicalResult mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
+ if (failed(verifyConvolutionInterface(op)))
+ return failure();
+ if (DepthwiseConvolutionOpInterface conv = dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
+ const auto imageRank = conv.image().getType().cast<ShapedType>().getRank();
+ const auto kernelRank = conv.filter().getType().cast<ShapedType>().getRank();
+ const auto initRank = conv.init().getType().cast<ShapedType>().getRank();
+ if (imageRank != initRank || imageRank != kernelRank + 1)
+ return op->emitError("Rank relationship must be `in_rank == out_rank == kernel_rank + 1`");
+ return success();
+ }
+ return failure();
+}
+
mlir::linalg::detail::MatchConvolutionResult
mlir::linalg::detail::isConvolutionInterfaceImpl(
Operation *op, ConvolutionDimensions *dimensions) {
>From b9be7566cf8da5bbd9812d9e3845f3d9c7a80a64 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 05:11:36 -0600
Subject: [PATCH 12/21] Refactor separate spatial dims into one
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 1 -
.../Dialect/Linalg/IR/LinalgInterfaces.td | 41 +++--
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 141 +++---------------
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 40 ++---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 14 +-
mlir/test/Dialect/Linalg/loops.mlir | 6 +-
mlir/test/Dialect/Linalg/named-ops.mlir | 30 ++--
7 files changed, 77 insertions(+), 196 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index ed6cce6376932b..23240052607973 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -120,7 +120,6 @@ namespace detail {
namespace depthwise_convolution_impl {
DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
-BoolAttr getChannelFirstAttr(DepthwiseConvolutionOpInterface op);
ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 9749f301dfb153..d524fefbd43d24 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -913,30 +913,6 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
let cppNamespace = "::mlir::linalg";
let verify = [{ return detail::verifyDepthwiseConvolutionInterface($_op); }];
let methods = [
- InterfaceMethod<[{
- Returns strides attribute.
- }],
- "::mlir::DenseIntElementsAttr", "getStridesAttr", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- return detail::depthwise_convolution_impl::getStridesAttr($_op);
- }]>,
- InterfaceMethod<[{
- Returns dilations attribute.
- }],
- "::mlir::DenseIntElementsAttr", "getDilationsAttr", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- return detail::depthwise_convolution_impl::getDilationsAttr($_op);
- }]>,
- InterfaceMethod<[{
- Returns channel dim attribute.
- }],
- "::mlir::BoolAttr", "getChannelFirstAttr", (ins),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{
- return detail::depthwise_convolution_impl::getChannelFirstAttr($_op);
- }]>,
InterfaceMethod<[{
Returns indexing maps for any spatial dimension.
}],
@@ -954,6 +930,23 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
return detail::depthwise_convolution_impl::getIteratorTypes($_op);
}]>
];
+
+ let extraClassDeclaration = [{
+ // Returns channel first attribute.
+ bool getChannelFirst() {
+ return (*this)->getAttrOfType<BoolAttr>("channel_first").getValue();
+ }
+ }];
+ let extraSharedClassDeclaration = [{
+ // Returns strides attribute.
+ ::mlir::DenseIntElementsAttr getStridesAttr() {
+ return detail::depthwise_convolution_impl::getStridesAttr($_op);
+ }
+ // Returns dilations attribute.
+ ::mlir::DenseIntElementsAttr getDilationsAttr() {
+ return detail::depthwise_convolution_impl::getDilationsAttr($_op);
+ }
+ }];
}
#endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ffa9cc0562aeab..c8e7746b17fac9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -388,80 +388,30 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
// DepthwiseConvNDOp ops.
//===----------------------------------------------------------------------===//
-def DepthwiseConv1DOp : LinalgStructuredBase_Op<"depthwise_conv_1d",
+def DepthwiseConvNDOp : LinalgStructuredBase_Op<"depthwise_conv_nd",
[AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
let summary = [{
- Performs 1-D depthwise convolution with switchable channel position; either first or last.
+ Performs N-D depthwise convolution with switchable channel position; either first or last.
}];
let description = [{
- Domain: N, OW, C, KW
-
- Layout of operands is determined by the `channel_first` `BoolAttr`:
-
- `channel_first == true`:
- Input: `NCW`
- Kernel: `CW`
- Output: `NCW`
-
- `channel_first == false`:
- Input: `NWC`
- Kernel: `WC`
- Output: `NWC`
-
- }];
-
- let arguments = (ins
- Variadic<TensorOrMemref>:$inputs,
- Variadic<TensorOrMemref>:$inits,
- DefaultValuedAttr<BoolAttr, "true">:$channel_first,
- DefaultValuedAttr<RankedI64ElementsAttr<[1]>,
- "{ static_cast<int64_t>(1) }">:$strides,
- DefaultValuedAttr<RankedI64ElementsAttr<[1]>,
- "{ static_cast<int64_t>(1) }">:$dilations
- );
- let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
- let regions = (region AnyRegion:$region);
-
- // TODO: Figure out how to move this to the interface
- let extraClassDeclaration = structuredOpsBaseDecls # [{
- void print(::mlir::OpAsmPrinter &printer) {
- return detail::depthwise_convolution_impl::print(*this, printer);
- }
- static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
- ::mlir::OperationState &result) {
- return detail::depthwise_convolution_impl::parse(parser, result);
- }
- static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
- mlir::ArrayRef<mlir::NamedAttribute>)>
- getRegionBuilder() {
- return detail::depthwise_convolution_impl::regionBuilder;
- }
- // Implement functions necessary for DestinationStyleOpInterface.
- MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
- }];
-}
-
-def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
- [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
+ Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
+ will represent all spatial dimensions. Operand layouts are determined by the `channel_first`
+ `bool` attritbute. When placing the channel dim first or last, the batch dim is excluded. In
+ any case, the channel and spatial dims are in the same relative order for all operands.
- let summary = [{
- Performs 2-D depthwise convolution with switchable channel position; either first or last.
- }];
- let description = [{
- Domain: N, OH, OW, C, KH, KW
-
- Layout of operands is determined by the `channel_first` `BoolAttr`:
+ Domain: N, S, C, KS
- `channel_first == true`:
- Input: `NCHW`
- Kernel: `CHW`
- Output: `NCHW`
+ Layouts:
+ `channel_first == true`:
+ Input: `NCS`
+ Kernel: `CS`
+ Output: `NCS`
- `channel_first == false`:
- Input: `NHWC`
- Kernel: `HWC`
- Output: `NHWC`
+ `channel_first == false`:
+ Input: `NSC`
+ Kernel: `SC`
+ Output: `NSC`
}];
@@ -469,10 +419,8 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$inits,
DefaultValuedAttr<BoolAttr, "true">:$channel_first,
- DefaultValuedAttr<RankedI64ElementsAttr<[2]>,
- "{ static_cast<int64_t>(1) }">:$strides,
- DefaultValuedAttr<RankedI64ElementsAttr<[2]>,
- "{ static_cast<int64_t>(1) }">:$dilations
+ OptionalAttr<I64ElementsAttr>:$strides,
+ OptionalAttr<I64ElementsAttr>:$dilations
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
@@ -496,59 +444,6 @@ def DepthwiseConv2DOp : LinalgStructuredBase_Op<"depthwise_conv_2d",
}];
}
-def DepthwiseConv3DOp : LinalgStructuredBase_Op<"depthwise_conv_3d",
- [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
-
- let summary = [{
- Performs 3-D depthwise convolution with switchable channel position; either first or last.
- }];
- let description = [{
- Domain: N, OD, OH, OW, C, KD, KH, KW
-
- Layout of operands is determined by the `channel_first` `BoolAttr`:
-
- `channel_first == true`:
- Input: `NCDHW`
- Kernel: `CDHW`
- Output: `NCDHW`
-
- `channel_first == false`:
- Input: `NDHWC`
- Kernel: `DHWC`
- Output: `NDHWC`
-
- }];
-
- let arguments = (ins
- Variadic<TensorOrMemref>:$inputs,
- Variadic<TensorOrMemref>:$inits,
- DefaultValuedAttr<BoolAttr, "true">:$channel_first,
- DefaultValuedAttr<RankedI64ElementsAttr<[3]>,
- "{ static_cast<int64_t>(1) }">:$strides,
- DefaultValuedAttr<RankedI64ElementsAttr<[3]>,
- "{ static_cast<int64_t>(1) }">:$dilations
- );
- let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
- let regions = (region AnyRegion:$region);
-
- // TODO: Figure out how to move this to the interface
- let extraClassDeclaration = structuredOpsBaseDecls # [{
- void print(::mlir::OpAsmPrinter &printer) {
- return detail::depthwise_convolution_impl::print(*this, printer);
- }
- static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
- ::mlir::OperationState &result) {
- return detail::depthwise_convolution_impl::parse(parser, result);
- }
- static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
- mlir::ArrayRef<mlir::NamedAttribute>)>
- getRegionBuilder() {
- return detail::depthwise_convolution_impl::regionBuilder;
- }
- // Implement functions necessary for DestinationStyleOpInterface.
- MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
- }];
-}
//===----------------------------------------------------------------------===//
// Transpose op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6fa9fd92304d30..145516d8ab0922 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -8,13 +8,11 @@
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
-#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExprVisitor.h"
@@ -643,8 +641,7 @@ enum class MatchConvolutionResult {
DenseIntElementsAttr
mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
DepthwiseConvolutionOpInterface op) {
- auto maybeStridesAttr = op.getStridesAttr();
- maybeStridesAttr.dump();
+ auto maybeStridesAttr = op->getAttrOfType<DenseIntElementsAttr>("strides");
if (!maybeStridesAttr) {
OpBuilder builder(op.getContext());
int64_t numSpatialDims =
@@ -654,18 +651,24 @@ mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
SmallVector<int64_t> strides(numSpatialDims, 1);
return DenseIntElementsAttr::get(type, strides);
}
- return op.getStridesAttr();
+ return maybeStridesAttr;
}
DenseIntElementsAttr
mlir::linalg::detail::depthwise_convolution_impl::getDilationsAttr(
DepthwiseConvolutionOpInterface op) {
- return op.getDilationsAttr();
-}
-
-BoolAttr mlir::linalg::detail::depthwise_convolution_impl::getChannelFirstAttr(
- DepthwiseConvolutionOpInterface op) {
- return op.getChannelFirstAttr();
+ auto maybeDilationsAttr =
+ op->getAttrOfType<DenseIntElementsAttr>("dilations");
+ if (!maybeDilationsAttr) {
+ OpBuilder builder(op.getContext());
+ int64_t numSpatialDims =
+ op.image().getType().cast<ShapedType>().getRank() - 2;
+ auto type = RankedTensorType::get({static_cast<int64_t>(numSpatialDims)},
+ builder.getI64Type());
+ SmallVector<int64_t> strides(numSpatialDims, 1);
+ return DenseIntElementsAttr::get(type, strides);
+ }
+ return maybeDilationsAttr;
}
ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIteratorTypes(
@@ -701,8 +704,7 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)),
[&](int64_t d) { return getAffineDimExpr(d, ctx); }));
// Temp subsitute for channel position attr
- int64_t channelPos =
- (op.getChannelFirstAttr().getValue()) ? 1 : numSpatial + 1;
+ int64_t channelPos = (op.getChannelFirst()) ? 1 : numSpatial + 1;
// Initialze operand accesses in nw order and insert c according to channel
// position
@@ -727,15 +729,19 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
return cached;
}
-LogicalResult mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
+LogicalResult
+mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
if (failed(verifyConvolutionInterface(op)))
return failure();
- if (DepthwiseConvolutionOpInterface conv = dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
+ if (DepthwiseConvolutionOpInterface conv =
+ dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
const auto imageRank = conv.image().getType().cast<ShapedType>().getRank();
- const auto kernelRank = conv.filter().getType().cast<ShapedType>().getRank();
+ const auto kernelRank =
+ conv.filter().getType().cast<ShapedType>().getRank();
const auto initRank = conv.init().getType().cast<ShapedType>().getRank();
if (imageRank != initRank || imageRank != kernelRank + 1)
- return op->emitError("Rank relationship must be `in_rank == out_rank == kernel_rank + 1`");
+ return op->emitError(
+ "Rank relationship must be `in_rank == out_rank == kernel_rank + 1`");
return success();
}
return failure();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 23c7c4730a16aa..1a2935520604b0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1718,19 +1718,7 @@ void mlir::linalg::detail::depthwise_convolution_impl::regionBuilder(
}
// TODO: Figure out how to move this to interface
-void DepthwiseConv1DOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
- getDpsInits());
-}
-void DepthwiseConv2DOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
- getDpsInits());
-}
-void DepthwiseConv3DOp::getEffects(
+void DepthwiseConvNDOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 40d5078fc1d5fc..04e6c27f19dd81 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -844,7 +844,7 @@ func.func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32
func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
- linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
return
}
@@ -874,7 +874,7 @@ func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: m
func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
- linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
return
}
@@ -903,7 +903,7 @@ func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: me
// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
- linalg.depthwise_conv_2d {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 26df1ea316e532..ed8ad2bf9e635f 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
-// CHECK-LABEL: func @gen_depthwise_channel_first_memref
-func.func @gen_depthwise_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x3xf32>, %arg2: memref<64x16x8xf32>) {
- // CHECK: depthwise_conv_1d {{.*}}channel_first = true
- linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x3xf32>) outs(%arg2: memref<64x16x8xf32>)
+// CHECK-LABEL: func @gen_depthwise_1D_channel_first_memref
+func.func @gen_depthwise_1D_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x3xf32>, %arg2: memref<64x16x8xf32>) {
+ // CHECK: depthwise_conv_nd {{.*}}channel_first = true
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x3xf32>) outs(%arg2: memref<64x16x8xf32>)
return
}
@@ -11,17 +11,17 @@ func.func @gen_depthwise_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1
// CHECK-LABEL: func @gen_depthwise_2D_channel_first_memref
func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
- // CHECK: depthwise_conv_2d {{.*}}channel_first = true
- linalg.depthwise_conv_2d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+ // CHECK: depthwise_conv_nd {{.*}}channel_first = true
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
return
}
// -----
// CHECK-LABEL: func @gen_depthwise_3D_channel_first_memref
-func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
- // CHECK: depthwise_conv_3d {{.*}}channel_first = true
- linalg.depthwise_conv_3d {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10x10xf32>, %arg1: memref<16x3x3x3xf32>, %arg2: memref<64x16x8x8x8xf32>) {
+ // CHECK: depthwise_conv_nd {{.*}}channel_first = true
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10x10xf32>, memref<16x3x3x3xf32>) outs(%arg2: memref<64x16x8x8x8xf32>)
return
}
@@ -29,8 +29,8 @@ func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10xf32>,
// CHECK-LABEL: func @gen_depthwise_channel_last_memref
func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
- // CHECK: depthwise_conv_1d {{.*}}channel_first = false
- linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16xf32>) outs(%arg2: memref<64x8x16xf32>)
+ // CHECK: depthwise_conv_nd {{.*}}channel_first = false
+ linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16xf32>) outs(%arg2: memref<64x8x16xf32>)
return
}
@@ -38,8 +38,8 @@ func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1:
// CHECK-LABEL: func @gen_depthwise_channel_first_tensor
func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1: tensor<16x3xf32>, %arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> {
- // CHECK: depthwise_conv_1d {{.*}}channel_first = true
- %0 = linalg.depthwise_conv_1d {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x3xf32>) outs(%arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32>
+ // CHECK: depthwise_conv_nd {{.*}}channel_first = true
+ %0 = linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x3xf32>) outs(%arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32>
return %0 : tensor<64x16x8xf32>
}
@@ -47,8 +47,8 @@ func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1
// CHECK-LABEL: func @gen_depthwise_channel_last_tensor
func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1: tensor<3x16xf32>, %arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32> {
- // CHECK: depthwise_conv_1d {{.*}}channel_first = false
- %0 = linalg.depthwise_conv_1d {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16xf32>) outs(%arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32>
+ // CHECK: depthwise_conv_nd {{.*}}channel_first = false
+ %0 = linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16xf32>) outs(%arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32>
return %0 : tensor<64x8x16xf32>
}
>From e02cb893f3b51ec606ed1a797337e5d623483404 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 16:18:41 -0600
Subject: [PATCH 13/21] Fix a couple mistakes
---
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 1 -
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 4 +++-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 23240052607973..f3ae787dccea69 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -191,7 +191,6 @@ LogicalResult verifyConvolutionInterface(Operation *op);
/// Verify that `op` conforms to the DepthwiseConvolutionOpInterface.
LogicalResult verifyDepthwiseConvolutionInterface(Operation *op);
-LogicalResult verifyConvolutionInterface(Operation *op);
/// Verify that `op` conforms to the FillOpInterface.
LogicalResult verifyFillInterface(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index d524fefbd43d24..c5b40ba6826e44 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -192,7 +192,9 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op.getOperation()->getResult(2);
+ if ($_op.getOperation()->getResults().empty())
+ return nullptr;
+ return $_op.getOperation()->getResult(0);
}]
>,
];
>From d715ace643ec9df565ad7d71730b5a23c4332c90 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 17:42:34 -0600
Subject: [PATCH 14/21] Fix print function
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1a2935520604b0..908a9817251694 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1686,14 +1686,16 @@ void mlir::linalg::detail::depthwise_convolution_impl::getEffects(
ParseResult mlir::linalg::detail::depthwise_convolution_impl::parse(
OpAsmParser &parser, OperationState &result) {
- return parseNamedStructuredOp(
+ return ::parseNamedStructuredOp(
parser, result, 3,
mlir::linalg::detail::depthwise_convolution_impl::regionBuilder);
}
void mlir::linalg::detail::depthwise_convolution_impl::print(
DepthwiseConvolutionOpInterface op, OpAsmPrinter &p) {
- printNamedStructuredOp(p, op.getOperation(), op.image(), op.filter());
+ printNamedStructuredOp(p, op.getOperation(),
+ ValueRange{op.image(), op.filter()},
+ ValueRange{op.init()});
}
// Build {mul, add} region for convolution
@@ -1721,6 +1723,8 @@ void mlir::linalg::detail::depthwise_convolution_impl::regionBuilder(
void DepthwiseConvNDOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
+ if (hasTensorSemantics())
+ return;
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
>From 1e9822eb3f3eef1bea1c7455893f94b726db46d6 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 18:25:05 -0600
Subject: [PATCH 15/21] Add bufferization test
---
mlir/test/Dialect/Linalg/bufferize.mlir | 18 ++++++++++++++++++
1 file changed, 18 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 29f27e6838e661..c4625841baa5c3 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -217,3 +217,21 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
%3 = func.call @csum(%2) : (tensor<6xi64>) -> tensor<6xi64>
return %3 : tensor<6xi64>
}
+
+// CHECK-LABEL: func @gen_depthwise_3D_channel_first_tensor(
+// CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<64x16x26x26x26xf32>,
+// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<16x3x3x3xf32>,
+// CHECK-SAME: %[[ARG2_TENSOR:.*]]: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32> {
+// CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x16x26x26x26xf32>
+// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<16x3x3x3xf32>
+// CHECK-DAG: %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x16x8x8x8xf32>
+// CHECK-DAG: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x16x8x8x8xf32>
+// CHECK: memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x16x8x8x8xf32> to memref<64x16x8x8x8xf32>
+// CHECK: linalg.depthwise_conv_nd
+// CHECK-SAME: {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>}
+// CHECK-SAME: ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x16x26x26x26xf32>, memref<16x3x3x3xf32>)
+// CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<64x16x8x8x8xf32>)
+func.func @gen_depthwise_3D_channel_first_tensor(%arg0: tensor<64x16x26x26x26xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32> {
+ %0 = linalg.depthwise_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x16x26x26x26xf32>, tensor<16x3x3x3xf32>) outs(%arg2: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32>
+ return %0 : tensor<64x16x8x8x8xf32>
+}
\ No newline at end of file
>From ce0f47ea6986cf5e222bb22b56fa8777fe48889e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 20:25:35 -0600
Subject: [PATCH 16/21] Add multiplier dimension
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 11 +--
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 26 +++---
mlir/test/Dialect/Linalg/bufferize.mlir | 22 ++---
mlir/test/Dialect/Linalg/loops.mlir | 90 ++++++++++---------
mlir/test/Dialect/Linalg/named-ops.mlir | 28 +++---
5 files changed, 92 insertions(+), 85 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index c5b40ba6826e44..39e0d05768e782 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -903,14 +903,9 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
let description = [{
A depthwise convolution is defined in general terms:
1. it is a convolution as defined by `ConvolutionOpInterface`.
- 2. `in_channels = K * out_channels` for some integer `K`.
- 3. `input_rank == output_rank == kernel_rank + 1` (including batch dim in input and output)
- 4. The indexing maps of the input have expressions that satisfy
- ```
- AffineExpr ::== AffineDimExpr | ConvolvedExpr
- ConvolvedExpr ::== MulExpr (`+` MulExpr)+
- MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
- ```
+ 2. `in_channels = K * out_channels` for some integer `m`.
+ 3. The dimension of the filter preceding the channel dim is equal to `K`, the depth multiplier
+ 4. `input_rank == kernel_rank == output_rank + 1` (including batch dim in input and output)
}];
let cppNamespace = "::mlir::linalg";
let verify = [{ return detail::verifyDepthwiseConvolutionInterface($_op); }];
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 145516d8ab0922..33f4cd49eb530a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -676,7 +676,7 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIteratorTypes(
int64_t numSpatialDims =
op.image().getType().cast<ShapedType>().getRank() - 2;
SmallVector<Attribute> iteratorTypes(
- 2 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
+ 3 + numSpatialDims, IteratorTypeAttr::get(op.getContext(), par));
SmallVector<Attribute> reductions(
numSpatialDims, IteratorTypeAttr::get(op.getContext(), red));
iteratorTypes.insert(iteratorTypes.end(), reductions.begin(),
@@ -694,14 +694,15 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
MLIRContext *ctx = op.getContext();
auto numSpatial = op.image().getType().cast<ShapedType>().getRank() - 2;
- // Domain: (n, w, c, kw)
+ // Domain: (n, w, c, m, kw)
AffineExpr n = getAffineDimExpr(0, ctx);
SmallVector<AffineExpr> s(
llvm::map_range(llvm::seq<int64_t>(1, numSpatial + 1),
[&](int64_t d) { return getAffineDimExpr(d, ctx); }));
AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
+ AffineExpr m = getAffineDimExpr(numSpatial + 2, ctx);
SmallVector<AffineExpr> ks(
- llvm::map_range(llvm::seq<int64_t>(numSpatial + 2, 2 * (numSpatial + 1)),
+ llvm::map_range(llvm::seq<int64_t>(numSpatial + 3, 2 * (numSpatial + 1) + 1),
[&](int64_t d) { return getAffineDimExpr(d, ctx); }));
// Temp subsitute for channel position attr
int64_t channelPos = (op.getChannelFirst()) ? 1 : numSpatial + 1;
@@ -709,6 +710,7 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
// Initialze operand accesses in nw order and insert c according to channel
// position
SmallVector<AffineExpr> inExprs = {n}, outExprs = {n};
+ SmallVector<AffineExpr> cm = {c, m};
for (const auto &[sp, ksp, st, di] :
llvm::zip(s, ks, op.getStridesAttr().getValues<int64_t>(),
op.getDilationsAttr().getValues<int64_t>())) {
@@ -718,13 +720,13 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
SmallVector<AffineExpr> kExprs(ks);
inExprs.insert(inExprs.begin() + channelPos, c);
kExprs.insert(
- channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, c);
- outExprs.insert(outExprs.begin() + channelPos, c);
+ channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, cm.begin(), cm.end());
+ outExprs.insert(outExprs.begin() + channelPos, cm.begin(), cm.end());
cached = Builder(ctx).getAffineMapArrayAttr(
- {AffineMap::get(2 + 2 * numSpatial, 0, inExprs, ctx),
- AffineMap::get(2 + 2 * numSpatial, 0, kExprs, ctx),
- AffineMap::get(2 + 2 * numSpatial, 0, outExprs, ctx)});
+ {AffineMap::get(3 + 2 * numSpatial, 0, inExprs, ctx),
+ AffineMap::get(3 + 2 * numSpatial, 0, kExprs, ctx),
+ AffineMap::get(3 + 2 * numSpatial, 0, outExprs, ctx)});
op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
return cached;
}
@@ -735,11 +737,13 @@ mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
return failure();
if (DepthwiseConvolutionOpInterface conv =
dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
- const auto imageRank = conv.image().getType().cast<ShapedType>().getRank();
+ const auto imageType = conv.image().getType().cast<ShapedType>();
+ const auto imageRank = imageType.getRank();
const auto kernelRank =
conv.filter().getType().cast<ShapedType>().getRank();
- const auto initRank = conv.init().getType().cast<ShapedType>().getRank();
- if (imageRank != initRank || imageRank != kernelRank + 1)
+ const auto initType = conv.init().getType().cast<ShapedType>();
+ const auto initRank = initType.getRank();
+ if (imageRank != kernelRank || imageRank != initRank - 1)
return op->emitError(
"Rank relationship must be `in_rank == out_rank == kernel_rank + 1`");
return success();
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index c4625841baa5c3..3e04646a183907 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -220,18 +220,18 @@ func.func public @main(%arg0: tensor<2x3xi1>) -> tensor<6xi64> {
// CHECK-LABEL: func @gen_depthwise_3D_channel_first_tensor(
// CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<64x16x26x26x26xf32>,
-// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<16x3x3x3xf32>,
-// CHECK-SAME: %[[ARG2_TENSOR:.*]]: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32> {
+// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<16x2x3x3x3xf32>,
+// CHECK-SAME: %[[ARG2_TENSOR:.*]]: tensor<64x16x2x8x8x8xf32>) -> tensor<64x16x2x8x8x8xf32> {
// CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<64x16x26x26x26xf32>
-// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<16x3x3x3xf32>
-// CHECK-DAG: %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x16x8x8x8xf32>
-// CHECK-DAG: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x16x8x8x8xf32>
-// CHECK: memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x16x8x8x8xf32> to memref<64x16x8x8x8xf32>
+// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<16x2x3x3x3xf32>
+// CHECK-DAG: %[[ARG2_MEMREF:.*]] = bufferization.to_memref %[[ARG2_TENSOR]] : memref<64x16x2x8x8x8xf32>
+// CHECK-DAG: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<64x16x2x8x8x8xf32>
+// CHECK: memref.copy %[[ARG2_MEMREF]], %[[INIT_BUFFER]] : memref<64x16x2x8x8x8xf32> to memref<64x16x2x8x8x8xf32>
// CHECK: linalg.depthwise_conv_nd
// CHECK-SAME: {channel_first = true, dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>}
-// CHECK-SAME: ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x16x26x26x26xf32>, memref<16x3x3x3xf32>)
-// CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<64x16x8x8x8xf32>)
-func.func @gen_depthwise_3D_channel_first_tensor(%arg0: tensor<64x16x26x26x26xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32> {
- %0 = linalg.depthwise_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x16x26x26x26xf32>, tensor<16x3x3x3xf32>) outs(%arg2: tensor<64x16x8x8x8xf32>) -> tensor<64x16x8x8x8xf32>
- return %0 : tensor<64x16x8x8x8xf32>
+// CHECK-SAME: ins(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]] : memref<64x16x26x26x26xf32>, memref<16x2x3x3x3xf32>)
+// CHECK-SAME: outs(%[[INIT_BUFFER]] : memref<64x16x2x8x8x8xf32>)
+func.func @gen_depthwise_3D_channel_first_tensor(%arg0: tensor<64x16x26x26x26xf32>, %arg1: tensor<16x2x3x3x3xf32>, %arg2: tensor<64x16x2x8x8x8xf32>) -> tensor<64x16x2x8x8x8xf32> {
+ %0 = linalg.depthwise_conv_nd {channel_first = true, strides = dense<3> : tensor<3xi64>, dilations = dense<2> : tensor<3xi64>} ins(%arg0, %arg1: tensor<64x16x26x26x26xf32>, tensor<16x2x3x3x3xf32>) outs(%arg2: tensor<64x16x2x8x8x8xf32>) -> tensor<64x16x2x8x8x8xf32>
+ return %0 : tensor<64x16x2x8x8x8xf32>
}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 04e6c27f19dd81..fb2e5c34331dc8 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -843,99 +843,107 @@ func.func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
-func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
- linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+func.func @gen_depthwise_channel_first_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
return
}
// COMMON-LABEL: func @gen_depthwise_channel_first_memref
// COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
-// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
-// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
// COMMON: %[[c0:.*]] = arith.constant 0 : index
// COMMON: %[[c1:.*]] = arith.constant 1 : index
// COMMON: %[[c2:.*]] = arith.constant 2 : index
+// COMMON: %[[c3:.*]] = arith.constant 3 : index
// COMMON: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
// COMMON: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?xf32>
-// COMMON: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?xf32>
-// COMMON: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
+// COMMON: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
+// COMMON: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// COMMON: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
// CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
// CHECK: scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim3]], %[[dim1]])
-// COMMON: scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+// CHECK: scf.for %[[m:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim4]], %[[dim1]], %[[dim2]])
+// COMMON: scf.for %[[kw:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
// COMMON: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[ow]], %[[kw]])
// COMMON: %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[c]], %[[aff]]] : memref<?x?x?xf32>
-// COMMON: %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[kw]]] : memref<?x?xf32>
-// COMMON: %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[ow]]] : memref<?x?x?xf32>
+// COMMON: %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[m]], %[[kw]]] : memref<?x?x?xf32>
+// COMMON: %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[m]], %[[ow]]] : memref<?x?x?x?xf32>
// COMMON: %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
// COMMON: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
-// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[ow]]] : memref<?x?x?xf32>
+// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[m]], %[[ow]]] : memref<?x?x?x?xf32>
-func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
- linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+func.func @gen_depthwise_channel_last_memref(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
+ linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
return
}
// COMMON-LABEL: func @gen_depthwise_channel_last_memref
// COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
-// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?xf32>
-// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
// COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
// COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
// COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
-// COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
-// COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
-// COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?xf32>
-// COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
+// COMMON: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?xf32>
+// COMMON: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<?x?x?xf32>
+// COMMON: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?x?x?xf32>
+// COMMON: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
+// COMMON: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c1]] : memref<?x?x?x?xf32>
// CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
// CHECK: scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim3]], %[[dim1]])
+// CHECK: scf.for %[[m:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[ow:.*]], %[[c:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim4]], %[[dim1]], %[[dim3]])
// COMMON: scf.for %[[kw:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
// COMMON: %[[aff:.*]] = affine.apply #[[$stride3Dilation2]](%[[ow]], %[[kw]])
// COMMON: %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[aff]], %[[c]]] : memref<?x?x?xf32>
-// COMMON: %[[va:.*]] = memref.load %[[arg1]][%[[kw]], %[[c]]] : memref<?x?xf32>
-// COMMON: %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
+// COMMON: %[[va:.*]] = memref.load %[[arg1]][%[[kw]], %[[c]], %[[m]]] : memref<?x?x?xf32>
+// COMMON: %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[ow]], %[[c]], %[[m]]] : memref<?x?x?x?xf32>
// COMMON: %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
// COMMON: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
-// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]]] : memref<?x?x?xf32>
+// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[ow]], %[[c]], %[[m]]] : memref<?x?x?x?xf32>
-func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
- linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?x?xf32>)
+func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>) outs(%arg2: memref<?x?x?x?x?xf32>)
return
}
// COMMON-LABEL: func @gen_depthwise_2D_channel_first_memref
// COMMON-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
-// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
-// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// COMMON-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// COMMON-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?x?xf32>
// COMMON-DAG: %[[c0:.*]] = arith.constant 0 : index
// COMMON-DAG: %[[c1:.*]] = arith.constant 1 : index
// COMMON-DAG: %[[c2:.*]] = arith.constant 2 : index
// COMMON-DAG: %[[c3:.*]] = arith.constant 3 : index
+// COMMON-DAG: %[[c4:.*]] = arith.constant 4 : index
// COMMON-DAG: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?x?x?x?xf32>
// COMMON-DAG: %[[dim1:.*]] = memref.dim %[[arg0]], %[[c1]] : memref<?x?x?x?xf32>
-// COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?xf32>
-// COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?xf32>
-// COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg2]], %[[c2]] : memref<?x?x?x?xf32>
-// COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim2:.*]] = memref.dim %[[arg1]], %[[c1]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim3:.*]] = memref.dim %[[arg1]], %[[c2]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim4:.*]] = memref.dim %[[arg1]], %[[c3]] : memref<?x?x?x?xf32>
+// COMMON-DAG: %[[dim5:.*]] = memref.dim %[[arg2]], %[[c3]] : memref<?x?x?x?x?xf32>
+// COMMON-DAG: %[[dim6:.*]] = memref.dim %[[arg2]], %[[c4]] : memref<?x?x?x?x?xf32>
// CHECK: scf.for %[[n:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
-// CHECK: scf.for %[[oh:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
-// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] {
+// CHECK: scf.for %[[oh:.*]] = %[[c0]] to %[[dim5]] step %[[c1]] {
+// CHECK: scf.for %[[ow:.*]] = %[[c0]] to %[[dim6]] step %[[c1]] {
// CHECK: scf.for %[[c:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
-// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[oh:.*]], %[[ow:.*]], %[[c:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim4]], %[[dim5]], %[[dim1]])
-// COMMON: scf.for %[[kh:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
-// COMMON: scf.for %[[kw:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+// CHECK: scf.for %[[m:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
+// CHECKPARALLEL: scf.parallel (%[[n:.*]], %[[oh:.*]], %[[ow:.*]], %[[c:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim5]], %[[dim6]], %[[dim1]], %[[dim2]])
+// COMMON: scf.for %[[kh:.*]] = %[[c0]] to %[[dim3]] step %[[c1]] {
+// COMMON: scf.for %[[kw:.*]] = %[[c0]] to %[[dim4]] step %[[c1]] {
// COMMON: %[[affh:.*]] = affine.apply #[[$stride1Dilation1]](%[[oh]], %[[kh]])
// COMMON: %[[affw:.*]] = affine.apply #[[$stride1Dilation1]](%[[ow]], %[[kw]])
// COMMON: %[[vb:.*]] = memref.load %[[arg0]][%[[n]], %[[c]], %[[affh]], %[[affw]]] : memref<?x?x?x?xf32>
-// COMMON: %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[kh]], %[[kw]]] : memref<?x?x?xf32>
-// COMMON: %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[oh]], %[[ow]]] : memref<?x?x?x?xf32>
+// COMMON: %[[va:.*]] = memref.load %[[arg1]][%[[c]], %[[m]], %[[kh]], %[[kw]]] : memref<?x?x?x?xf32>
+// COMMON: %[[vc:.*]] = memref.load %[[arg2]][%[[n]], %[[c]], %[[m]], %[[oh]], %[[ow]]] : memref<?x?x?x?x?xf32>
// COMMON: %[[inc:.*]] = arith.mulf %[[vb]], %[[va]] : f32
// COMMON: %[[res:.*]] = arith.addf %[[vc]], %[[inc]] : f32
-// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[oh]], %[[ow]]] : memref<?x?x?x?xf32>
+// COMMON: store %[[res]], %[[arg2]][%[[n]], %[[c]], %[[m]], %[[oh]], %[[ow]]] : memref<?x?x?x?x?xf32>
// -----
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index ed8ad2bf9e635f..21191ab1cb6f3b 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,55 +1,55 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: func @gen_depthwise_1D_channel_first_memref
-func.func @gen_depthwise_1D_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x3xf32>, %arg2: memref<64x16x8xf32>) {
+func.func @gen_depthwise_1D_channel_first_memref(%arg0: memref<64x16x10xf32>, %arg1: memref<16x1x3xf32>, %arg2: memref<64x16x1x8xf32>) {
// CHECK: depthwise_conv_nd {{.*}}channel_first = true
- linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x3xf32>) outs(%arg2: memref<64x16x8xf32>)
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10xf32>, memref<16x1x3xf32>) outs(%arg2: memref<64x16x1x8xf32>)
return
}
// -----
// CHECK-LABEL: func @gen_depthwise_2D_channel_first_memref
-func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x3x3xf32>, %arg2: memref<64x16x8x8xf32>) {
+func.func @gen_depthwise_2D_channel_first_memref(%arg0: memref<64x16x10x10xf32>, %arg1: memref<16x1x3x3xf32>, %arg2: memref<64x16x1x8x8xf32>) {
// CHECK: depthwise_conv_nd {{.*}}channel_first = true
- linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x3x3xf32>) outs(%arg2: memref<64x16x8x8xf32>)
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10xf32>, memref<16x1x3x3xf32>) outs(%arg2: memref<64x16x1x8x8xf32>)
return
}
// -----
// CHECK-LABEL: func @gen_depthwise_3D_channel_first_memref
-func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10x10xf32>, %arg1: memref<16x3x3x3xf32>, %arg2: memref<64x16x8x8x8xf32>) {
+func.func @gen_depthwise_3D_channel_first_memref(%arg0: memref<64x16x10x10x10xf32>, %arg1: memref<16x1x3x3x3xf32>, %arg2: memref<64x16x1x8x8x8xf32>) {
// CHECK: depthwise_conv_nd {{.*}}channel_first = true
- linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10x10xf32>, memref<16x3x3x3xf32>) outs(%arg2: memref<64x16x8x8x8xf32>)
+ linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: memref<64x16x10x10x10xf32>, memref<16x1x3x3x3xf32>) outs(%arg2: memref<64x16x1x8x8x8xf32>)
return
}
// -----
// CHECK-LABEL: func @gen_depthwise_channel_last_memref
-func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16xf32>, %arg2: memref<64x8x16xf32>) {
+func.func @gen_depthwise_channel_last_memref(%arg0: memref<64x26x16xf32>, %arg1: memref<3x16x1xf32>, %arg2: memref<64x8x16x1xf32>) {
// CHECK: depthwise_conv_nd {{.*}}channel_first = false
- linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16xf32>) outs(%arg2: memref<64x8x16xf32>)
+ linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: memref<64x26x16xf32>, memref<3x16x1xf32>) outs(%arg2: memref<64x8x16x1xf32>)
return
}
// -----
// CHECK-LABEL: func @gen_depthwise_channel_first_tensor
-func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1: tensor<16x3xf32>, %arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32> {
+func.func @gen_depthwise_channel_first_tensor(%arg0: tensor<64x16x10xf32>, %arg1: tensor<16x1x3xf32>, %arg2: tensor<64x16x1x8xf32>) -> tensor<64x16x1x8xf32> {
// CHECK: depthwise_conv_nd {{.*}}channel_first = true
- %0 = linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x3xf32>) outs(%arg2: tensor<64x16x8xf32>) -> tensor<64x16x8xf32>
- return %0 : tensor<64x16x8xf32>
+ %0 = linalg.depthwise_conv_nd {channel_first = true} ins(%arg0, %arg1: tensor<64x16x10xf32>, tensor<16x1x3xf32>) outs(%arg2: tensor<64x16x1x8xf32>) -> tensor<64x16x1x8xf32>
+ return %0 : tensor<64x16x1x8xf32>
}
// -----
// CHECK-LABEL: func @gen_depthwise_channel_last_tensor
-func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1: tensor<3x16xf32>, %arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32> {
+func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1: tensor<3x16x1xf32>, %arg2: tensor<64x8x16x1xf32>) -> tensor<64x8x16x1xf32> {
// CHECK: depthwise_conv_nd {{.*}}channel_first = false
- %0 = linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16xf32>) outs(%arg2: tensor<64x8x16xf32>) -> tensor<64x8x16xf32>
- return %0 : tensor<64x8x16xf32>
+ %0 = linalg.depthwise_conv_nd {channel_first = false, dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>} ins(%arg0, %arg1: tensor<64x26x16xf32>, tensor<3x16x1xf32>) outs(%arg2: tensor<64x8x16x1xf32>) -> tensor<64x8x16x1xf32>
+ return %0 : tensor<64x8x16x1xf32>
}
// -----
>From 11c2a9404fdaf45b9378586fd9a1237f636bc67d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 9 Dec 2023 20:51:44 -0600
Subject: [PATCH 17/21] Update docs (still needs more elaboration)
---
.../include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 7 +++----
.../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 10 +++++-----
2 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 39e0d05768e782..54d8cdf3e35d2a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -902,10 +902,9 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
LinalgConvolutionOpInterface]> {
let description = [{
A depthwise convolution is defined in general terms:
- 1. it is a convolution as defined by `ConvolutionOpInterface`.
- 2. `in_channels = K * out_channels` for some integer `m`.
- 3. The dimension of the filter preceding the channel dim is equal to `K`, the depth multiplier
- 4. `input_rank == kernel_rank == output_rank + 1` (including batch dim in input and output)
+ 1. It is a convolution as defined by `ConvolutionOpInterface`.
+ 2. The channel and multiplier dimensions are consecutive and always in `CM` order
+ 3. `input_rank == kernel_rank == output_rank + 1` (including batch dim in input and output)
}];
let cppNamespace = "::mlir::linalg";
let verify = [{ return detail::verifyDepthwiseConvolutionInterface($_op); }];
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c8e7746b17fac9..395e5219ada263 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -400,18 +400,18 @@ def DepthwiseConvNDOp : LinalgStructuredBase_Op<"depthwise_conv_nd",
`bool` attritbute. When placing the channel dim first or last, the batch dim is excluded. In
any case, the channel and spatial dims are in the same relative order for all operands.
- Domain: N, S, C, KS
+ Domain: N, S, C, M, KS
Layouts:
`channel_first == true`:
Input: `NCS`
- Kernel: `CS`
- Output: `NCS`
+ Kernel: `CMS`
+ Output: `NCMS`
`channel_first == false`:
Input: `NSC`
- Kernel: `SC`
- Output: `NSC`
+ Kernel: `SCM`
+ Output: `NSCM`
}];
>From 221e3b3e724e271dbdae45b1b756638645d856a7 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 10 Dec 2023 09:58:01 -0600
Subject: [PATCH 18/21] Add builders
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 41 +++++++++++++++++++
1 file changed, 41 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 395e5219ada263..09e2abaec8dc4e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -425,6 +425,47 @@ def DepthwiseConvNDOp : LinalgStructuredBase_Op<"depthwise_conv_nd",
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$inits, "bool":$channel_first,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, inits,
+ attributes, DepthwiseConvNDOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$inits, "bool":$channel_first,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
+ inputs, inits, attributes, DepthwiseConvNDOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion();
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$inits, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("channel_first", channel_first);
+ $_state.addAttribute("strides", strides);
+ $_state.addAttribute("dilations", dilations);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, inits,
+ attributes, DepthwiseConvNDOp::getRegionBuilder());
+ }]>
+ ];
+
// TODO: Figure out how to move this to the interface
let extraClassDeclaration = structuredOpsBaseDecls # [{
void print(::mlir::OpAsmPrinter &printer) {
>From cff07a1e747ddec13c0df6b3fef65fb218b938c5 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 10 Dec 2023 10:27:11 -0600
Subject: [PATCH 19/21] Move getStridesAttr and getDilationsAttr to
ConvolutionInterface
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 9 ++++++--
.../Dialect/Linalg/IR/LinalgInterfaces.td | 20 +++++++++---------
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 21 +++++++++----------
3 files changed, 27 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f3ae787dccea69..4cd086ce5de073 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;
+class ConvolutionOpInterface;
class DepthwiseConvolutionOpInterface;
namespace detail {
@@ -116,10 +117,14 @@ bool isaCopyOpInterface(LinalgOp linalgOp);
namespace detail {
+// Common implementation for ConvolutionOpInterface
+namespace convolution_impl {
+DenseIntElementsAttr getStridesAttr(ConvolutionOpInterface op);
+DenseIntElementsAttr getDilationsAttr(ConvolutionOpInterface op);
+} // namespace convolution_impl
+
// Common implementations for DepthwiseConvolutionOpInterface
namespace depthwise_convolution_impl {
-DenseIntElementsAttr getStridesAttr(DepthwiseConvolutionOpInterface op);
-DenseIntElementsAttr getDilationsAttr(DepthwiseConvolutionOpInterface op);
ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 54d8cdf3e35d2a..c2e16c7622e12f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -198,6 +198,16 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
}]
>,
];
+ let extraSharedClassDeclaration = [{
+ // Returns strides attribute.
+ ::mlir::DenseIntElementsAttr getStridesAttr() {
+ return detail::convolution_impl::getStridesAttr($_op);
+ }
+ // Returns dilations attribute.
+ ::mlir::DenseIntElementsAttr getDilationsAttr() {
+ return detail::convolution_impl::getDilationsAttr($_op);
+ }
+ }];
}
@@ -933,16 +943,6 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
return (*this)->getAttrOfType<BoolAttr>("channel_first").getValue();
}
}];
- let extraSharedClassDeclaration = [{
- // Returns strides attribute.
- ::mlir::DenseIntElementsAttr getStridesAttr() {
- return detail::depthwise_convolution_impl::getStridesAttr($_op);
- }
- // Returns dilations attribute.
- ::mlir::DenseIntElementsAttr getDilationsAttr() {
- return detail::depthwise_convolution_impl::getDilationsAttr($_op);
- }
- }];
}
#endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 33f4cd49eb530a..e93598ecd98d99 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -638,9 +638,8 @@ enum class MatchConvolutionResult {
};
} // namespace mlir::linalg::detail
-DenseIntElementsAttr
-mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
- DepthwiseConvolutionOpInterface op) {
+DenseIntElementsAttr mlir::linalg::detail::convolution_impl::getStridesAttr(
+ ConvolutionOpInterface op) {
auto maybeStridesAttr = op->getAttrOfType<DenseIntElementsAttr>("strides");
if (!maybeStridesAttr) {
OpBuilder builder(op.getContext());
@@ -654,9 +653,8 @@ mlir::linalg::detail::depthwise_convolution_impl::getStridesAttr(
return maybeStridesAttr;
}
-DenseIntElementsAttr
-mlir::linalg::detail::depthwise_convolution_impl::getDilationsAttr(
- DepthwiseConvolutionOpInterface op) {
+DenseIntElementsAttr mlir::linalg::detail::convolution_impl::getDilationsAttr(
+ ConvolutionOpInterface op) {
auto maybeDilationsAttr =
op->getAttrOfType<DenseIntElementsAttr>("dilations");
if (!maybeDilationsAttr) {
@@ -701,9 +699,9 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
[&](int64_t d) { return getAffineDimExpr(d, ctx); }));
AffineExpr c = getAffineDimExpr(numSpatial + 1, ctx);
AffineExpr m = getAffineDimExpr(numSpatial + 2, ctx);
- SmallVector<AffineExpr> ks(
- llvm::map_range(llvm::seq<int64_t>(numSpatial + 3, 2 * (numSpatial + 1) + 1),
- [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
+ SmallVector<AffineExpr> ks(llvm::map_range(
+ llvm::seq<int64_t>(numSpatial + 3, 2 * (numSpatial + 1) + 1),
+ [&](int64_t d) { return getAffineDimExpr(d, ctx); }));
// Temp subsitute for channel position attr
int64_t channelPos = (op.getChannelFirst()) ? 1 : numSpatial + 1;
@@ -719,8 +717,9 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
}
SmallVector<AffineExpr> kExprs(ks);
inExprs.insert(inExprs.begin() + channelPos, c);
- kExprs.insert(
- channelPos == 0 ? kExprs.begin() : kExprs.begin() + channelPos - 1, cm.begin(), cm.end());
+ kExprs.insert(channelPos == 0 ? kExprs.begin()
+ : kExprs.begin() + channelPos - 1,
+ cm.begin(), cm.end());
outExprs.insert(outExprs.begin() + channelPos, cm.begin(), cm.end());
cached = Builder(ctx).getAffineMapArrayAttr(
>From dbe552e86d37575517d2ab1970e9b6549531f8ab Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 10 Dec 2023 11:29:09 -0600
Subject: [PATCH 20/21] Add tiling regression test
---
mlir/test/Dialect/Linalg/tile-conv.mlir | 42 ++++++++++++++++++++++++-
1 file changed, 41 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index 4a940f12662e6c..c24a6d77242801 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -canonicalize -split-input-file | FileCheck %s
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
@@ -41,3 +41,43 @@ module attributes {transform.with_named_sequence} {
// CHECK: linalg.conv_2d
// CHECK-SAME: ins(%[[SVIN]], %[[SVKER]]
// CHECK-SAME: outs(%[[SVOUT]]
+
+// -----
+
+func.func @depthwise_conv_1D(%arg0 : memref<?x?x?xf32>, %arg1 : memref<?x?x?xf32>, %arg2 : memref<?x?x?x?xf32>) {
+ linalg.depthwise_conv_nd ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 : memref<?x?x?x?xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop:2 = transform.structured.tile_using_for %0 [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: func @depthwise_conv_1D
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[BATCH:.*]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[CHANNELS:.*]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[MULTIPLIER:.*]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[KW:.*]] = memref.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[W:.*]] = memref.dim %[[ARG2]], %[[C3]]
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[BATCH]] step %[[C2]]
+// CHECK: %[[T4:.*]] = affine.min #[[MAP0]](%[[I]])[%[[BATCH]]]
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[W]] step %[[C3]]
+// CHECK-DAG: %[[T5:.*]] = affine.min #[[MAP1]](%[[J]])[%[[W]]]
+// CHECK-DAG: %[[T6:.*]] = affine.apply #[[MAP2]](%[[T5]])[%[[KW]]]
+// CHECK-DAG: %[[SVIN:.*]] = memref.subview %[[ARG0]][%[[I]], 0, %[[J]]] [%[[T4]], %[[CHANNELS]], %[[T6]]]
+// CHECK-DAG: %[[SVKER:.*]] = memref.subview %[[ARG1]][0, 0, 0] [%[[CHANNELS]], %[[MULTIPLIER]], %[[KW]]]
+// CHECK-DAG: %[[SVOUT:.*]] = memref.subview %[[ARG2]][%[[I]], 0, 0, %[[J]]] [%[[T4]], %[[CHANNELS]], %[[MULTIPLIER]], %[[T5]]]
+// CHECK: linalg.depthwise_conv_nd {channel_first = true}
+// CHECK-SAME: ins(%[[SVIN]], %[[SVKER]]
+// CHECK-SAME: outs(%[[SVOUT]]
>From 1ccaf0d034d73145ddfa9e3265f856950628a7f7 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 10 Dec 2023 15:41:48 -0600
Subject: [PATCH 21/21] Add quantized depthwise op (still needs serious
refactoring)
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 4 +-
.../Dialect/Linalg/IR/LinalgInterfaces.td | 9 +-
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 100 ++++++++++++++++++
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 14 ++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 54 +++++++++-
mlir/test/Dialect/Linalg/named-ops.mlir | 11 ++
6 files changed, 181 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 4cd086ce5de073..2b73bd80d2340e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -129,11 +129,13 @@ ArrayAttr getIndexingMaps(DepthwiseConvolutionOpInterface op);
ArrayAttr getIteratorTypes(DepthwiseConvolutionOpInterface op);
void regionBuilder(ImplicitLocOpBuilder &b, Block &block,
ArrayRef<NamedAttribute> attrs);
+void quantizedRegionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs);
void getEffects(
DepthwiseConvolutionOpInterface op,
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects);
-ParseResult parse(OpAsmParser &parser, OperationState &result);
+ParseResult parse(OpAsmParser &parser, OperationState &result, bool isQuantized = false);
void print(DepthwiseConvolutionOpInterface op, OpAsmPrinter &p);
} // namespace depthwise_convolution_impl
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index c2e16c7622e12f..7e55c2cc01c3b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -182,7 +182,7 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return $_op.getOperation()->getOperand(2);
+ return $_op.getOperation()->getOperand($_op.getOperation()->getOperands().size() - 1);
}]
>,
InterfaceMethod<
@@ -934,7 +934,12 @@ def LinalgDepthwiseConvolutionOpInterface : OpInterface<"DepthwiseConvolutionOpI
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
return detail::depthwise_convolution_impl::getIteratorTypes($_op);
- }]>
+ }]>,
+ InterfaceMethod<[{
+ Returns whether op is quantized or not.
+ }],
+ "bool", "isQuantized", (ins)
+ >
];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 09e2abaec8dc4e..608df597b40107 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -482,9 +482,109 @@ def DepthwiseConvNDOp : LinalgStructuredBase_Op<"depthwise_conv_nd",
}
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+ // DepthwiseConvolutionInterface implementations
+ bool isQuantized() {return false; }
}];
}
+def DepthwiseConvNDQOp : LinalgStructuredBase_Op<"depthwise_conv_nd_q",
+ [AttrSizedOperandSegments, LinalgDepthwiseConvolutionOpInterface]> {
+
+ let summary = [{
+ Performs quantized N-D depthwise convolution with switchable channel position; either first or last.
+ }];
+ let description = [{
+ Allows any number of spatial dimensions but treats all of them as contiguous. Throughout, `S`,
+ will represent all spatial dimensions. Operand layouts are determined by the `channel_first`
+ `bool` attritbute. When placing the channel dim first or last, the batch dim is excluded. In
+ any case, the channel and spatial dims are in the same relative order for all operands.
+
+ Domain: N, S, C, M, KS
+
+ Layouts:
+ `channel_first == true`:
+ Input: `NCS`
+ Kernel: `CMS`
+ Output: `NCMS`
+
+ `channel_first == false`:
+ Input: `NSC`
+ Kernel: `SCM`
+ Output: `NSCM`
+
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<TensorOrMemref>:$inits,
+ DefaultValuedAttr<BoolAttr, "true">:$channel_first,
+ OptionalAttr<I64ElementsAttr>:$strides,
+ OptionalAttr<I64ElementsAttr>:$dilations
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$inits, "bool":$channel_first,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, inits,
+ attributes, DepthwiseConvNDQOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$inits, "bool":$channel_first,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute(getChannelFirstAttrName($_state.name), $_builder.getBoolAttr(channel_first));
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
+ inputs, inits, attributes, DepthwiseConvNDQOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion();
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$inits, "Attribute":$channel_first, "Attribute":$strides, "Attribute":$dilations,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("channel_first", channel_first);
+ $_state.addAttribute("strides", strides);
+ $_state.addAttribute("dilations", dilations);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, inits,
+ attributes, DepthwiseConvNDQOp::getRegionBuilder());
+ }]>
+ ];
+
+ // TODO: Figure out how to move this to the interface
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ void print(::mlir::OpAsmPrinter &printer) {
+ return detail::depthwise_convolution_impl::print(*this, printer);
+ }
+ static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ return detail::depthwise_convolution_impl::parse(parser, result, true);
+ }
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return detail::depthwise_convolution_impl::quantizedRegionBuilder;
+ }
+ // Implement functions necessary for DestinationStyleOpInterface.
+ MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+ // DepthwiseConvolutionInterface implementations
+ bool isQuantized() { return true; }
+ }];
+}
//===----------------------------------------------------------------------===//
// Transpose op.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index e93598ecd98d99..d50952f1169c83 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -722,10 +722,18 @@ ArrayAttr mlir::linalg::detail::depthwise_convolution_impl::getIndexingMaps(
cm.begin(), cm.end());
outExprs.insert(outExprs.begin() + channelPos, cm.begin(), cm.end());
- cached = Builder(ctx).getAffineMapArrayAttr(
+ SmallVector<AffineMap> maps(
{AffineMap::get(3 + 2 * numSpatial, 0, inExprs, ctx),
AffineMap::get(3 + 2 * numSpatial, 0, kExprs, ctx),
AffineMap::get(3 + 2 * numSpatial, 0, outExprs, ctx)});
+
+ if (op.isQuantized()) {
+ SmallVector<AffineMap> scalarMaps(
+ 2, AffineMap::get(3 + 2 * numSpatial, 0, {}, ctx));
+ maps.insert(maps.end() - 1, scalarMaps.begin(), scalarMaps.end());
+ }
+
+ cached = Builder(ctx).getAffineMapArrayAttr(maps);
op->setAttr(LinalgDialect::kMemoizedIndexingMapsAttrName, cached);
return cached;
}
@@ -736,11 +744,11 @@ mlir::linalg::detail::verifyDepthwiseConvolutionInterface(Operation *op) {
return failure();
if (DepthwiseConvolutionOpInterface conv =
dyn_cast<DepthwiseConvolutionOpInterface>(op)) {
- const auto imageType = conv.image().getType().cast<ShapedType>();
+ const auto imageType = conv.image().getType().dyn_cast<ShapedType>();
const auto imageRank = imageType.getRank();
const auto kernelRank =
conv.filter().getType().cast<ShapedType>().getRank();
- const auto initType = conv.init().getType().cast<ShapedType>();
+ const auto initType = conv.init().getType().dyn_cast<ShapedType>();
const auto initRank = initType.getRank();
if (imageRank != kernelRank || imageRank != initRank - 1)
return op->emitError(
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 10745a93bd9c60..f9f9aae8c9f388 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1671,17 +1671,28 @@ void mlir::linalg::detail::depthwise_convolution_impl::getEffects(
}
ParseResult mlir::linalg::detail::depthwise_convolution_impl::parse(
- OpAsmParser &parser, OperationState &result) {
- return ::parseNamedStructuredOp(
+ OpAsmParser &parser, OperationState &result, bool isQuantized) {
+ if (isQuantized)
+ return parseNamedStructuredOp(
+ parser, result, 5,
+ mlir::linalg::detail::depthwise_convolution_impl::
+ quantizedRegionBuilder);
+ return parseNamedStructuredOp(
parser, result, 3,
mlir::linalg::detail::depthwise_convolution_impl::regionBuilder);
}
void mlir::linalg::detail::depthwise_convolution_impl::print(
DepthwiseConvolutionOpInterface op, OpAsmPrinter &p) {
- printNamedStructuredOp(p, op.getOperation(),
- ValueRange{op.image(), op.filter()},
- ValueRange{op.init()});
+ if (op.isQuantized())
+ printNamedStructuredOp(p, op.getOperation(),
+ ValueRange{op.image(), op.filter(),
+ op->getOperand(2), op->getOperand(3)},
+ ValueRange{op.init()});
+ else
+ printNamedStructuredOp(p, op.getOperation(),
+ ValueRange{op.image(), op.filter()},
+ ValueRange{op.init()});
}
// Build {mul, add} region for convolution
@@ -1705,6 +1716,31 @@ void mlir::linalg::detail::depthwise_convolution_impl::regionBuilder(
helper.yieldOutputs(yields);
}
+void mlir::linalg::detail::depthwise_convolution_impl::quantizedRegionBuilder(
+ ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs) {
+ assert(block.getNumArguments() == 5 &&
+ "DepthwiseConvNDQOp regionBuilder expects 5 args");
+ RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+ Value value1 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(0));
+ Value value2 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(2));
+ Value value3 = helper.buildBinaryFn(BinaryFn::sub, value1, value2);
+ Value value4 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(1));
+ Value value5 =
+ helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(4).getType(),
+ block.getArgument(3));
+ Value value6 = helper.buildBinaryFn(BinaryFn::sub, value4, value5);
+ Value value7 = helper.buildBinaryFn(BinaryFn::mul, value3, value6);
+ Value value8 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(4), value7);
+ helper.yieldOutputs({value8});
+}
+
// TODO: Figure out how to move this to interface
void DepthwiseConvNDOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
@@ -1714,6 +1750,14 @@ void DepthwiseConvNDOp::getEffects(
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
+void DepthwiseConvNDQOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getDpsInits());
+}
//===----------------------------------------------------------------------===//
// TransposeOp
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index afbcf30c87207c..c85d143725cf45 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -54,6 +54,17 @@ func.func @gen_depthwise_channel_last_tensor(%arg0: tensor<64x26x16xf32>, %arg1:
// -----
+// CHECK-LABEL: func @gen_depthwise_1D_channel_first_quantized_memref
+func.func @gen_depthwise_1D_channel_first_quantized_memref(%arg0: memref<64x16x10xi8>, %arg1: memref<16x1x3xi8>, %arg2: memref<64x16x1x8xi64>) {
+ // CHECK: depthwise_conv_nd_q {{.*}}channel_first = true
+ %c0 = arith.constant 0 : i32
+ %c2 = arith.constant 2 : i32
+ linalg.depthwise_conv_nd_q {channel_first = true} ins(%arg0, %arg1, %c0, %c2: memref<64x16x10xi8>, memref<16x1x3xi8>, i32, i32) outs(%arg2: memref<64x16x1x8xi64>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @depthwise_conv_1d_nwc_wcm
func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<1x12x8xf32>, %filter: tensor<3x8x8xf32>) -> tensor<1x10x8x8xf32> {
%zero = arith.constant 0.000000e+00 : f32
More information about the cfe-commits
mailing list