[Mlir-commits] [mlir] fd64de3 - [mlir][linalg] Add BroadcastOp to Linalg structured ops.
Oleg Shyshkov
llvmlistbot at llvm.org
Fri Nov 4 04:07:35 PDT 2022
Author: Oleg Shyshkov
Date: 2022-11-04T12:07:18+01:00
New Revision: fd64de32129977f3bb52d874f499ed0a98214db3
URL: https://github.com/llvm/llvm-project/commit/fd64de32129977f3bb52d874f499ed0a98214db3
DIFF: https://github.com/llvm/llvm-project/commit/fd64de32129977f3bb52d874f499ed0a98214db3.diff
LOG: [mlir][linalg] Add BroadcastOp to Linalg structured ops.
[[RFC] Primitive Ops: add BroadcastOp to Linalg](https://discourse.llvm.org/t/rfc-primitive-ops-add-broadcastop-to-linalg/66313?u=olegshyshkov)
Differential Revision: https://reviews.llvm.org/D137331
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index b067a1ddd1e61..9866620fd4892 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -440,7 +440,9 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
- getRegionBuilder();
+ getRegionBuilder() {
+ return nullptr;
+ }
static void createRegion(::mlir::OpBuilder &opBuilder,
::mlir::OperationState & odsState);
@@ -450,6 +452,79 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
let hasVerifier = 1;
}
+
+//===----------------------------------------------------------------------===//
+// Broadcast op.
+//===----------------------------------------------------------------------===//
+
+def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ SameVariadicOperandSize,
+ SingleBlockImplicitTerminator<"YieldOp">]> {
+ let summary = "Static broadcast operator";
+ let description = [{
+ Broadcast the input into the given shape by adding dimensions.
+
+ Each index in `dimensions` attribute maps input dimension into the
+ corresponding target dimension. The length of the `dimensions` list should
+ match the `input` rank and dimensions should be in sorted order. There is no
+ ambiguity at compile-time about shape information.
+
+ Example:
+ ```
+ %bcast = linalg.broadcast
+ ins(%input:tensor<16xf32>)
+ inits(%init:tensor<16x64xf32>)
+ dimensions = [0]
+ ```
+ }];
+
+ let arguments = (ins
+ // Input arg
+ TensorOrMemref:$input,
+ // Output arg
+ TensorOrMemref:$init,
+
+ DenseI64ArrayAttr:$dimensions
+ );
+ let results = (outs Variadic<AnyTensor>:$result);
+ let regions = (region SizedRegion<1>:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$input, "Value":$init,
+ "DenseI64ArrayAttr":$dimensions, CArg<"ArrayRef<NamedAttribute>",
+ "{}">:$attributes)>,
+ OpBuilder<(ins "Value":$input, "Value":$init,
+ "ArrayRef<int64_t>":$dimensions, CArg<"ArrayRef<NamedAttribute>",
+ "{}">:$attributes)>,
+ ];
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Declare functions necessary for LinalgStructuredInterface.
+ SmallVector<StringRef> getIteratorTypesArray();
+ ArrayAttr getIndexingMaps();
+ std::string getLibraryCallName() {
+ return "op_has_no_registered_library_name";
+ }
+
+ // Implement functions necessary for DestinationStyleOpInterface.
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ int64_t getNumOperands = this->getNumOperands();
+ return {getNumOperands - 1, getNumOperands};
+ }
+
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return nullptr;
+ }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 568b9317ca364..6377a68bc3c5d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -662,7 +662,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
static void buildGenericRegion(
- OpBuilder &builder, OperationState &result, ValueRange inputs,
+ OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs,
ValueRange outputs,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
SmallVector<Type, 4> blockArgTypes;
@@ -675,10 +675,9 @@ static void buildGenericRegion(
}
OpBuilder::InsertionGuard guard(builder);
- auto ®ion = *result.regions.front();
Block *bodyBlock =
builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
- bodyBuild(builder, result.location, bodyBlock->getArguments());
+ bodyBuild(builder, loc, bodyBlock->getArguments());
}
void GenericOp::getAsmBlockArgumentNames(Region ®ion,
@@ -699,7 +698,8 @@ void GenericOp::build(
iteratorTypes, doc, libraryCall);
result.addAttributes(attributes);
if (bodyBuild)
- buildGenericRegion(builder, result, inputs, outputs, bodyBuild);
+ buildGenericRegion(builder, result.location, *result.regions.front(),
+ inputs, outputs, bodyBuild);
}
void GenericOp::build(
@@ -1346,7 +1346,8 @@ void MapOp::build(
result.addTypes(initType);
if (bodyBuild)
- buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild);
+ buildGenericRegion(builder, result.location, *result.regions.front(),
+ inputs, /*outputs=*/{}, bodyBuild);
}
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1471,7 +1472,8 @@ void ReduceOp::build(
}
if (bodyBuild)
- buildGenericRegion(builder, result, inputs, inits, bodyBuild);
+ buildGenericRegion(builder, result.location, *result.regions.front(),
+ inputs, inits, bodyBuild);
}
SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
@@ -1648,13 +1650,13 @@ LogicalResult ReduceOp::verify() {
// TransposeOp
//===----------------------------------------------------------------------===//
-std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
- mlir::ArrayRef<mlir::NamedAttribute>)>
-TransposeOp::getRegionBuilder() {
- return [](mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
- mlir::ArrayRef<mlir::NamedAttribute>) {
- b.create<linalg::YieldOp>(block.getArguments().front());
- };
+static void buildIdentityRegion(OpBuilder &builder, Location loc,
+ Region ®ion, ValueRange inputs,
+ ValueRange outputs) {
+ buildGenericRegion(builder, loc, region, inputs, outputs,
+ [](OpBuilder &b, Location loc, ValueRange args) {
+ b.create<linalg::YieldOp>(loc, args[0]);
+ });
}
void TransposeOp::build(::mlir::OpBuilder &builder,
@@ -1671,11 +1673,8 @@ void TransposeOp::build(::mlir::OpBuilder &builder,
if (initType.isa<RankedTensorType>())
result.addTypes(initType);
- (void)result.addRegion();
- buildGenericRegion(builder, result, input, init,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- b.create<linalg::YieldOp>(loc, args[0]);
- });
+ buildIdentityRegion(builder, result.location, *result.addRegion(), input,
+ init);
}
void TransposeOp::build(::mlir::OpBuilder &builder,
@@ -1693,13 +1692,10 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
})))
return failure();
- (void)result.addRegion();
OpBuilder builder(parser.getContext());
- buildGenericRegion(builder, result, /*inputs=*/result.operands,
- /*outputs=*/{},
- [&](OpBuilder &b, Location loc, ValueRange args) {
- b.create<linalg::YieldOp>(loc, args[0]);
- });
+ buildIdentityRegion(builder, result.location, *result.addRegion(),
+ /*inputs=*/result.operands,
+ /*outputs=*/{});
return success();
}
@@ -1778,6 +1774,144 @@ void TransposeOp::getEffects(
getDpsInputOperands(), getDpsInitOperands());
}
+//===----------------------------------------------------------------------===//
+// BroadcastOp
+//===----------------------------------------------------------------------===//
+
+void BroadcastOp::build(::mlir::OpBuilder &builder,
+ ::mlir::OperationState &result, Value input, Value init,
+ DenseI64ArrayAttr dimensions,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addOperands(input);
+ result.addOperands(init);
+ result.addAttribute(getDimensionsAttrName(result.name), dimensions);
+ result.addAttributes(attributes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ Type initType = init.getType();
+ if (initType.isa<RankedTensorType>())
+ result.addTypes(initType);
+
+ buildIdentityRegion(builder, result.location, *result.addRegion(), input,
+ init);
+}
+
+void BroadcastOp::build(::mlir::OpBuilder &builder,
+ ::mlir::OperationState &result, Value input, Value init,
+ ArrayRef<int64_t> dimensions,
+ ArrayRef<NamedAttribute> attributes) {
+ build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
+ attributes);
+}
+
+ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
+ if (failed(parseDstStyleOp(
+ parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
+ return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
+ })))
+ return failure();
+
+ OpBuilder builder(parser.getContext());
+ buildIdentityRegion(builder, result.location, *result.addRegion(),
+ /*inputs=*/result.operands,
+ /*outputs=*/{});
+ return success();
+}
+
+void BroadcastOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ if (!getResults().empty())
+ setNameFn(getResults().front(), "broadcasted");
+}
+
+void BroadcastOp::print(OpAsmPrinter &p) {
+ p.increaseIndent();
+ printCommonStructuredOpPartsWithNewLine(
+ p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
+ p.printNewline();
+
+ printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
+ p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
+ p.decreaseIndent();
+}
+
+LogicalResult BroadcastOp::verify() {
+ ArrayRef<int64_t> dimensionsRef = getDimensions();
+
+ if (!llvm::is_sorted(dimensionsRef))
+ return emitOpError() << "dimensions should be in sorted order, implicit "
+ "transpose is not supported";
+
+ auto inputType = getInput().getType();
+ auto initType = getInit().getType();
+
+ int64_t inputRank = inputType.getRank();
+ int64_t initRank = initType.getRank();
+
+ auto inputShape = inputType.getShape();
+ auto initShape = initType.getShape();
+
+ if (inputRank != dimensionsRef.size())
+ return emitOpError()
+ << "input rank does match the number of dimensions. expected: "
+ << inputRank << ", got: " << dimensionsRef.size();
+
+ // Mapping from init dims to input dims.
+ const int64_t kUnmappedDim = -1;
+ SmallVector<int64_t> reverseDimMap(initRank, kUnmappedDim);
+
+ for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
+ if (dim < 0 || dim >= initRank)
+ return emitOpError() << "dimension " << idx
+ << " is out of range. expected range: [0, "
+ << initRank - 1 << "], got: " << dim;
+
+ reverseDimMap[dim] = idx;
+ }
+
+ for (const auto &[idx, inputDimIdx] : llvm::enumerate(reverseDimMap)) {
+ if (inputDimIdx == kUnmappedDim) {
+ // This dimensions is being added. Should be statically known.
+ if (ShapedType::isDynamic(initShape[idx]))
+ return emitOpError()
+ << "init dim " << idx
+ << " can't be dynamic, because it's not matched to input";
+ } else {
+ // This dimensions is mapped from the input. Init and input dims should
+ // match.
+ if (inputShape[inputDimIdx] != initShape[idx])
+ return emitOpError()
+ << "input dim " << inputDimIdx << " should match init dim "
+ << idx << ". input: " << inputShape[inputDimIdx]
+ << ", init: " << initShape[idx];
+ }
+ }
+
+ return success();
+}
+
+SmallVector<StringRef> BroadcastOp::getIteratorTypesArray() {
+ int64_t rank = getInit().getType().getRank();
+ return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+}
+
+ArrayAttr BroadcastOp::getIndexingMaps() {
+ Builder builder(getContext());
+ int64_t rank = getInit().getType().getRank();
+ return builder.getAffineMapArrayAttr(
+ {builder.getMultiDimIdentityMap(rank).getSubMap(
+ llvm::to_vector_of<unsigned>(getDimensions())),
+ builder.getMultiDimIdentityMap(rank)});
+}
+
+void BroadcastOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getGenericEffectsImpl(effects, getOperation()->getResults(),
+ getDpsInputOperands(), getDpsInitOperands());
+}
+
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 9200c6117a493..5a1c2afdebbdd 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -673,3 +673,81 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
permutation = [1, 0, 2]
func.return %transpose : tensor<32x64x16xf32>
}
+
+// -----
+
+func.func @broadcast_unsorted_dims(
+ %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
+ -> tensor<4x8x16xf32> {
+ // expected-error @+1 {{'linalg.broadcast' op dimensions should be in sorted order}}
+ %bcast = linalg.broadcast
+ ins(%input:tensor<4x16xf32>)
+ outs(%init:tensor<4x8x16xf32>)
+ dimensions = [1, 0]
+ func.return %bcast : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @broadcast_input_dims_rank_mismatch(
+ %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
+ -> tensor<4x8x16xf32> {
+ // expected-error @+1 {{'linalg.broadcast' op input rank does match the number of dimensions. expected: 2, got: 1}}
+ %bcast = linalg.broadcast
+ ins(%input:tensor<4x16xf32>)
+ outs(%init:tensor<4x8x16xf32>)
+ dimensions = [0]
+ func.return %bcast : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @broadcast_unsorted_dims(
+ %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
+ -> tensor<4x8x16xf32> {
+ // expected-error @+1 {{'linalg.broadcast' op dimension 1 is out of range. expected range: [0, 2], got: 5}}
+ %bcast = linalg.broadcast
+ ins(%input:tensor<4x16xf32>)
+ outs(%init:tensor<4x8x16xf32>)
+ dimensions = [0, 5]
+ func.return %bcast : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @broadcast_mapped_dim_mismatch(
+ %input: tensor<4x16xf32>, %init: tensor<5x8x16xf32>)
+ -> tensor<5x8x16xf32> {
+ // expected-error @+1 {{'linalg.broadcast' op input dim 0 should match init dim 0. input: 4, init: 5}}
+ %bcast = linalg.broadcast
+ ins(%input:tensor<4x16xf32>)
+ outs(%init:tensor<5x8x16xf32>)
+ dimensions = [0, 2]
+ func.return %bcast : tensor<5x8x16xf32>
+}
+
+// -----
+
+func.func @broadcast_added_dynamic_mismatch(
+ %input: tensor<4x16xf32>, %init: tensor<4x?x16xf32>)
+ -> tensor<4x?x16xf32> {
+ // expected-error @+1 {{'linalg.broadcast' op init dim 1 can't be dynamic, because it's not matched to input}}
+ %bcast = linalg.broadcast
+ ins(%input:tensor<4x16xf32>)
+ outs(%init:tensor<4x?x16xf32>)
+ dimensions = [0, 2]
+ func.return %bcast : tensor<4x?x16xf32>
+}
+
+// -----
+
+func.func @broadcast_size_1_extension_not_supported(
+ %input: tensor<1x16xf32>, %init: tensor<4x?x16xf32>)
+ -> tensor<4x?x16xf32> {
+ // expected-error @+1 {{'linalg.broadcast' op input dim 0 should match init dim 0. input: 1, init: 4}}
+ %bcast = linalg.broadcast
+ ins(%input:tensor<1x16xf32>)
+ outs(%init:tensor<4x?x16xf32>)
+ dimensions = [0, 2]
+ func.return %bcast : tensor<4x?x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 58dec2be2373a..9d100d5117fdd 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -388,6 +388,19 @@ func.func @transpose(%input: tensor<16x32x64xf32>,
// -----
+// CHECK-LABEL: func @broadcast
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x32xf32
+func.func @broadcast(%input: tensor<8x32xf32>,
+ %init: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %bcast = linalg.broadcast
+ ins(%input:tensor<8x32xf32>)
+ outs(%init:tensor<8x16x32xf32>)
+ dimensions = [0, 2]
+ func.return %bcast : tensor<8x16x32xf32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// AllocTensorOp elimination would produce SSA violations for the example below.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index fc0e3e057d9a8..64c2bea1f7ee1 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -517,3 +517,53 @@ func.func @transpose_memref(%input: memref<16x32x64xf32>,
func.return
}
// CHECK-LABEL: func @transpose_memref
+
+// -----
+
+func.func @broadcast_static_sizes(%input: tensor<8x32xf32>,
+ %init: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %bcast = linalg.broadcast
+ ins(%input:tensor<8x32xf32>)
+ outs(%init:tensor<8x16x32xf32>)
+ dimensions = [0, 2]
+ func.return %bcast : tensor<8x16x32xf32>
+}
+// CHECK-LABEL: func @broadcast_static_sizes
+// CHECK: linalg.broadcast
+// CHECK-NEXT: ins
+// CHECK-NEXT: outs
+// CHECK-NEXT: dimensions
+
+// -----
+
+func.func @broadcast_with_dynamic_sizes(
+ %input: tensor<8x?xf32>, %init: tensor<8x16x?xf32>)
+ -> tensor<8x16x?xf32> {
+ %bcast = linalg.broadcast
+ ins(%input:tensor<8x?xf32>)
+ outs(%init:tensor<8x16x?xf32>)
+ dimensions = [0, 2]
+ func.return %bcast : tensor<8x16x?xf32>
+}
+// CHECK-LABEL: func @broadcast_with_dynamic_sizes
+// CHECK: linalg.broadcast
+// CHECK-NEXT: ins
+// CHECK-NEXT: outs
+// CHECK-NEXT: dimensions
+
+// -----
+
+func.func @broadcast_memref(%input: memref<8x32xf32>,
+ %init: memref<8x16x32xf32>) {
+ linalg.broadcast
+ ins(%input:memref<8x32xf32>)
+ outs(%init:memref<8x16x32xf32>)
+ dimensions = [0, 2]
+ func.return
+}
+
+// CHECK-LABEL: func @broadcast_memref
+// CHECK: linalg.broadcast
+// CHECK-NEXT: ins
+// CHECK-NEXT: outs
+// CHECK-NEXT: dimensions
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 9addbcc83517c..b2e3fd5eec3b1 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -240,3 +240,29 @@ func.func @reduce(%arg0: memref<16x32x64xf32>,
// CHECK: %[[OUT_ELEM:.*]] = memref.load %[[OUT]][%[[I]], %[[K]]]
// CHECK: %[[ADD:.*]] = arith.addf %[[IN_ELEM]], %[[OUT_ELEM]]
// CHECK: memref.store %[[ADD]], %[[OUT]][%[[I]], %[[K]]]
+
+// -----
+
+func.func @broadcast(%input: memref<8x32xf32>,
+ %init: memref<8x16x32xf32>) {
+ linalg.broadcast
+ ins(%input:memref<8x32xf32>)
+ outs(%init:memref<8x16x32xf32>)
+ dimensions = [0, 2]
+ func.return
+}
+// CHECK-LABEL: func.func @broadcast(
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<8x32xf32>,
+// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<8x16x32xf32>
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK: %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]]
+// CHECK: memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]
More information about the Mlir-commits
mailing list