[Mlir-commits] [mlir] fd64de3 - [mlir][linalg] Add BroadcastOp to Linalg structured ops.

Oleg Shyshkov llvmlistbot at llvm.org
Fri Nov 4 04:07:35 PDT 2022


Author: Oleg Shyshkov
Date: 2022-11-04T12:07:18+01:00
New Revision: fd64de32129977f3bb52d874f499ed0a98214db3

URL: https://github.com/llvm/llvm-project/commit/fd64de32129977f3bb52d874f499ed0a98214db3
DIFF: https://github.com/llvm/llvm-project/commit/fd64de32129977f3bb52d874f499ed0a98214db3.diff

LOG: [mlir][linalg] Add BroadcastOp to Linalg structured ops.

[[RFC] Primitive Ops: add BroadcastOp to Linalg](https://discourse.llvm.org/t/rfc-primitive-ops-add-broadcastop-to-linalg/66313?u=olegshyshkov)

Differential Revision: https://reviews.llvm.org/D137331

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index b067a1ddd1e61..9866620fd4892 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -440,7 +440,9 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
         mlir::ArrayRef<mlir::NamedAttribute>)>
-      getRegionBuilder();
+      getRegionBuilder() {
+      return nullptr;
+    }
 
     static void createRegion(::mlir::OpBuilder &opBuilder,
                              ::mlir::OperationState & odsState);
@@ -450,6 +452,79 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
   let hasVerifier = 1;
 }
 
+
+//===----------------------------------------------------------------------===//
+// Broadcast op.
+//===----------------------------------------------------------------------===//
+
+def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    SameVariadicOperandSize,
+    SingleBlockImplicitTerminator<"YieldOp">]> {
+  let summary = "Static broadcast operator";
+  let description = [{
+    Broadcast the input into the given shape by adding dimensions.
+
+    Each index in `dimensions` attribute maps input dimension into the
+    corresponding target dimension. The length of the `dimensions` list should
+    match the `input` rank and dimensions should be in sorted order. There is no
+    ambiguity at compile-time about shape information.
+
+    Example:
+    ```
+      %bcast = linalg.broadcast
+          ins(%input:tensor<16xf32>)
+          inits(%init:tensor<16x64xf32>)
+          dimensions = [0]
+    ```
+  }];
+
+  let arguments = (ins
+    // Input arg
+    TensorOrMemref:$input,
+    // Output arg
+    TensorOrMemref:$init,
+
+    DenseI64ArrayAttr:$dimensions
+  );
+  let results = (outs Variadic<AnyTensor>:$result);
+  let regions = (region SizedRegion<1>:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$input, "Value":$init,
+        "DenseI64ArrayAttr":$dimensions, CArg<"ArrayRef<NamedAttribute>",
+        "{}">:$attributes)>,
+    OpBuilder<(ins "Value":$input, "Value":$init,
+        "ArrayRef<int64_t>":$dimensions, CArg<"ArrayRef<NamedAttribute>",
+        "{}">:$attributes)>,
+  ];
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Declare functions necessary for LinalgStructuredInterface.
+    SmallVector<StringRef> getIteratorTypesArray();
+    ArrayAttr getIndexingMaps();
+    std::string getLibraryCallName() {
+      return "op_has_no_registered_library_name";
+    }
+
+    // Implement functions necessary for DestinationStyleOpInterface.
+    std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+      int64_t getNumOperands = this->getNumOperands();
+      return {getNumOperands - 1, getNumOperands};
+    }
+
+    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+        mlir::ArrayRef<mlir::NamedAttribute>)>
+      getRegionBuilder() {
+      return nullptr;
+    }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 568b9317ca364..6377a68bc3c5d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -662,7 +662,7 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 static void buildGenericRegion(
-    OpBuilder &builder, OperationState &result, ValueRange inputs,
+    OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
     ValueRange outputs,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
   SmallVector<Type, 4> blockArgTypes;
@@ -675,10 +675,9 @@ static void buildGenericRegion(
   }
 
   OpBuilder::InsertionGuard guard(builder);
-  auto &region = *result.regions.front();
   Block *bodyBlock =
       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
-  bodyBuild(builder, result.location, bodyBlock->getArguments());
+  bodyBuild(builder, loc, bodyBlock->getArguments());
 }
 
 void GenericOp::getAsmBlockArgumentNames(Region &region,
@@ -699,7 +698,8 @@ void GenericOp::build(
         iteratorTypes, doc, libraryCall);
   result.addAttributes(attributes);
   if (bodyBuild)
-    buildGenericRegion(builder, result, inputs, outputs, bodyBuild);
+    buildGenericRegion(builder, result.location, *result.regions.front(),
+                       inputs, outputs, bodyBuild);
 }
 
 void GenericOp::build(
@@ -1346,7 +1346,8 @@ void MapOp::build(
     result.addTypes(initType);
 
   if (bodyBuild)
-    buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild);
+    buildGenericRegion(builder, result.location, *result.regions.front(),
+                       inputs, /*outputs=*/{}, bodyBuild);
 }
 
 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1471,7 +1472,8 @@ void ReduceOp::build(
   }
 
   if (bodyBuild)
-    buildGenericRegion(builder, result, inputs, inits, bodyBuild);
+    buildGenericRegion(builder, result.location, *result.regions.front(),
+                       inputs, inits, bodyBuild);
 }
 
 SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
@@ -1648,13 +1650,13 @@ LogicalResult ReduceOp::verify() {
 // TransposeOp
 //===----------------------------------------------------------------------===//
 
-std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                   mlir::ArrayRef<mlir::NamedAttribute>)>
-TransposeOp::getRegionBuilder() {
-  return [](mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-            mlir::ArrayRef<mlir::NamedAttribute>) {
-    b.create<linalg::YieldOp>(block.getArguments().front());
-  };
+static void buildIdentityRegion(OpBuilder &builder, Location loc,
+                                Region &region, ValueRange inputs,
+                                ValueRange outputs) {
+  buildGenericRegion(builder, loc, region, inputs, outputs,
+                     [](OpBuilder &b, Location loc, ValueRange args) {
+                       b.create<linalg::YieldOp>(loc, args[0]);
+                     });
 }
 
 void TransposeOp::build(::mlir::OpBuilder &builder,
@@ -1671,11 +1673,8 @@ void TransposeOp::build(::mlir::OpBuilder &builder,
   if (initType.isa<RankedTensorType>())
     result.addTypes(initType);
 
-  (void)result.addRegion();
-  buildGenericRegion(builder, result, input, init,
-                     [&](OpBuilder &b, Location loc, ValueRange args) {
-                       b.create<linalg::YieldOp>(loc, args[0]);
-                     });
+  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
+                      init);
 }
 
 void TransposeOp::build(::mlir::OpBuilder &builder,
@@ -1693,13 +1692,10 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
           })))
     return failure();
 
-  (void)result.addRegion();
   OpBuilder builder(parser.getContext());
-  buildGenericRegion(builder, result, /*inputs=*/result.operands,
-                     /*outputs=*/{},
-                     [&](OpBuilder &b, Location loc, ValueRange args) {
-                       b.create<linalg::YieldOp>(loc, args[0]);
-                     });
+  buildIdentityRegion(builder, result.location, *result.addRegion(),
+                      /*inputs=*/result.operands,
+                      /*outputs=*/{});
   return success();
 }
 
@@ -1778,6 +1774,144 @@ void TransposeOp::getEffects(
                         getDpsInputOperands(), getDpsInitOperands());
 }
 
+//===----------------------------------------------------------------------===//
+// BroadcastOp
+//===----------------------------------------------------------------------===//
+
+void BroadcastOp::build(::mlir::OpBuilder &builder,
+                        ::mlir::OperationState &result, Value input, Value init,
+                        DenseI64ArrayAttr dimensions,
+                        ArrayRef<NamedAttribute> attributes) {
+  result.addOperands(input);
+  result.addOperands(init);
+  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
+  result.addAttributes(attributes);
+
+  // Add output types for `RankedTensorType` output arguments.
+  Type initType = init.getType();
+  if (initType.isa<RankedTensorType>())
+    result.addTypes(initType);
+
+  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
+                      init);
+}
+
+void BroadcastOp::build(::mlir::OpBuilder &builder,
+                        ::mlir::OperationState &result, Value input, Value init,
+                        ArrayRef<int64_t> dimensions,
+                        ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
+        attributes);
+}
+
+ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
+  if (failed(parseDstStyleOp(
+          parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
+            return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
+          })))
+    return failure();
+
+  OpBuilder builder(parser.getContext());
+  buildIdentityRegion(builder, result.location, *result.addRegion(),
+                      /*inputs=*/result.operands,
+                      /*outputs=*/{});
+  return success();
+}
+
+void BroadcastOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  if (!getResults().empty())
+    setNameFn(getResults().front(), "broadcasted");
+}
+
+void BroadcastOp::print(OpAsmPrinter &p) {
+  p.increaseIndent();
+  printCommonStructuredOpPartsWithNewLine(
+      p, SmallVector<Value>(getDpsInputOperands()),
+      SmallVector<Value>(getDpsInitOperands()));
+  p.printNewline();
+
+  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
+  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
+  p.decreaseIndent();
+}
+
+LogicalResult BroadcastOp::verify() {
+  ArrayRef<int64_t> dimensionsRef = getDimensions();
+
+  if (!llvm::is_sorted(dimensionsRef))
+    return emitOpError() << "dimensions should be in sorted order, implicit "
+                            "transpose is not supported";
+
+  auto inputType = getInput().getType();
+  auto initType = getInit().getType();
+
+  int64_t inputRank = inputType.getRank();
+  int64_t initRank = initType.getRank();
+
+  auto inputShape = inputType.getShape();
+  auto initShape = initType.getShape();
+
+  if (inputRank != dimensionsRef.size())
+    return emitOpError()
+           << "input rank does match the number of dimensions. expected: "
+           << inputRank << ", got: " << dimensionsRef.size();
+
+  // Mapping from init dims to input dims.
+  const int64_t kUnmappedDim = -1;
+  SmallVector<int64_t> reverseDimMap(initRank, kUnmappedDim);
+
+  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
+    if (dim < 0 || dim >= initRank)
+      return emitOpError() << "dimension " << idx
+                           << " is out of range. expected range: [0, "
+                           << initRank - 1 << "], got: " << dim;
+
+    reverseDimMap[dim] = idx;
+  }
+
+  for (const auto &[idx, inputDimIdx] : llvm::enumerate(reverseDimMap)) {
+    if (inputDimIdx == kUnmappedDim) {
+      // This dimensions is being added. Should be statically known.
+      if (ShapedType::isDynamic(initShape[idx]))
+        return emitOpError()
+               << "init dim " << idx
+               << " can't be dynamic, because it's not matched to input";
+    } else {
+      // This dimensions is mapped from the input. Init and input dims should
+      // match.
+      if (inputShape[inputDimIdx] != initShape[idx])
+        return emitOpError()
+               << "input dim " << inputDimIdx << " should match init dim "
+               << idx << ". input: " << inputShape[inputDimIdx]
+               << ", init: " << initShape[idx];
+    }
+  }
+
+  return success();
+}
+
+SmallVector<StringRef> BroadcastOp::getIteratorTypesArray() {
+  int64_t rank = getInit().getType().getRank();
+  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+}
+
+ArrayAttr BroadcastOp::getIndexingMaps() {
+  Builder builder(getContext());
+  int64_t rank = getInit().getType().getRank();
+  return builder.getAffineMapArrayAttr(
+      {builder.getMultiDimIdentityMap(rank).getSubMap(
+           llvm::to_vector_of<unsigned>(getDimensions())),
+       builder.getMultiDimIdentityMap(rank)});
+}
+
+void BroadcastOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getGenericEffectsImpl(effects, getOperation()->getResults(),
+                        getDpsInputOperands(), getDpsInitOperands());
+}
+
 //===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 9200c6117a493..5a1c2afdebbdd 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -673,3 +673,81 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
       permutation = [1, 0, 2]
   func.return %transpose : tensor<32x64x16xf32>
 }
+
+// -----
+
+func.func @broadcast_unsorted_dims(
+    %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
+    -> tensor<4x8x16xf32> {
+  // expected-error @+1 {{'linalg.broadcast' op dimensions should be in sorted order}}
+  %bcast = linalg.broadcast
+      ins(%input:tensor<4x16xf32>)
+      outs(%init:tensor<4x8x16xf32>)
+      dimensions = [1, 0]
+  func.return %bcast : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @broadcast_input_dims_rank_mismatch(
+    %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
+    -> tensor<4x8x16xf32> {
+  // expected-error @+1 {{'linalg.broadcast' op input rank does match the number of dimensions. expected: 2, got: 1}}
+  %bcast = linalg.broadcast
+      ins(%input:tensor<4x16xf32>)
+      outs(%init:tensor<4x8x16xf32>)
+      dimensions = [0]
+  func.return %bcast : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @broadcast_unsorted_dims(
+    %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
+    -> tensor<4x8x16xf32> {
+  // expected-error @+1 {{'linalg.broadcast' op dimension 1 is out of range. expected range: [0, 2], got: 5}}
+  %bcast = linalg.broadcast
+      ins(%input:tensor<4x16xf32>)
+      outs(%init:tensor<4x8x16xf32>)
+      dimensions = [0, 5]
+  func.return %bcast : tensor<4x8x16xf32>
+}
+
+// -----
+
+func.func @broadcast_mapped_dim_mismatch(
+    %input: tensor<4x16xf32>, %init: tensor<5x8x16xf32>)
+    -> tensor<5x8x16xf32> {
+  // expected-error @+1 {{'linalg.broadcast' op input dim 0 should match init dim 0. input: 4, init: 5}}
+  %bcast = linalg.broadcast
+      ins(%input:tensor<4x16xf32>)
+      outs(%init:tensor<5x8x16xf32>)
+      dimensions = [0, 2]
+  func.return %bcast : tensor<5x8x16xf32>
+}
+
+// -----
+
+func.func @broadcast_added_dynamic_mismatch(
+    %input: tensor<4x16xf32>, %init: tensor<4x?x16xf32>)
+    -> tensor<4x?x16xf32> {
+  // expected-error @+1 {{'linalg.broadcast' op init dim 1 can't be dynamic, because it's not matched to input}}
+  %bcast = linalg.broadcast
+      ins(%input:tensor<4x16xf32>)
+      outs(%init:tensor<4x?x16xf32>)
+      dimensions = [0, 2]
+  func.return %bcast : tensor<4x?x16xf32>
+}
+
+// -----
+
+func.func @broadcast_size_1_extension_not_supported(
+    %input: tensor<1x16xf32>, %init: tensor<4x?x16xf32>)
+    -> tensor<4x?x16xf32> {
+  // expected-error @+1 {{'linalg.broadcast' op input dim 0 should match init dim 0. input: 1, init: 4}}
+  %bcast = linalg.broadcast
+      ins(%input:tensor<1x16xf32>)
+      outs(%init:tensor<4x?x16xf32>)
+      dimensions = [0, 2]
+  func.return %bcast : tensor<4x?x16xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 58dec2be2373a..9d100d5117fdd 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -388,6 +388,19 @@ func.func @transpose(%input: tensor<16x32x64xf32>,
 
 // -----
 
+// CHECK-LABEL: func @broadcast
+// CHECK-SAME:  %[[ARG0:.*]]: memref<8x32xf32
+func.func @broadcast(%input: tensor<8x32xf32>,
+                     %init: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+  %bcast = linalg.broadcast
+      ins(%input:tensor<8x32xf32>)
+      outs(%init:tensor<8x16x32xf32>)
+      dimensions = [0, 2]
+  func.return %bcast : tensor<8x16x32xf32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // AllocTensorOp elimination would produce SSA violations for the example below.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index fc0e3e057d9a8..64c2bea1f7ee1 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -517,3 +517,53 @@ func.func @transpose_memref(%input: memref<16x32x64xf32>,
   func.return
 }
 // CHECK-LABEL: func @transpose_memref
+
+// -----
+
+func.func @broadcast_static_sizes(%input: tensor<8x32xf32>,
+                            %init: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+  %bcast = linalg.broadcast
+      ins(%input:tensor<8x32xf32>)
+      outs(%init:tensor<8x16x32xf32>)
+      dimensions = [0, 2]
+  func.return %bcast : tensor<8x16x32xf32>
+}
+// CHECK-LABEL: func @broadcast_static_sizes
+//      CHECK:    linalg.broadcast
+// CHECK-NEXT:    ins
+// CHECK-NEXT:    outs
+// CHECK-NEXT:    dimensions
+
+// -----
+
+func.func @broadcast_with_dynamic_sizes(
+              %input: tensor<8x?xf32>, %init: tensor<8x16x?xf32>)
+              -> tensor<8x16x?xf32> {
+  %bcast = linalg.broadcast
+      ins(%input:tensor<8x?xf32>)
+      outs(%init:tensor<8x16x?xf32>)
+      dimensions = [0, 2]
+  func.return %bcast : tensor<8x16x?xf32>
+}
+// CHECK-LABEL: func @broadcast_with_dynamic_sizes
+//      CHECK:    linalg.broadcast
+// CHECK-NEXT:    ins
+// CHECK-NEXT:    outs
+// CHECK-NEXT:    dimensions
+
+// -----
+
+func.func @broadcast_memref(%input: memref<8x32xf32>,
+                            %init: memref<8x16x32xf32>) {
+  linalg.broadcast
+      ins(%input:memref<8x32xf32>)
+      outs(%init:memref<8x16x32xf32>)
+      dimensions = [0, 2]
+  func.return
+}
+
+// CHECK-LABEL: func @broadcast_memref
+//      CHECK:    linalg.broadcast
+// CHECK-NEXT:    ins
+// CHECK-NEXT:    outs
+// CHECK-NEXT:    dimensions

diff  --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 9addbcc83517c..b2e3fd5eec3b1 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -240,3 +240,29 @@ func.func @reduce(%arg0: memref<16x32x64xf32>,
 // CHECK:           %[[OUT_ELEM:.*]] = memref.load %[[OUT]][%[[I]], %[[K]]]
 // CHECK:           %[[ADD:.*]] = arith.addf %[[IN_ELEM]], %[[OUT_ELEM]]
 // CHECK:           memref.store %[[ADD]], %[[OUT]][%[[I]], %[[K]]]
+
+// -----
+
+func.func @broadcast(%input: memref<8x32xf32>,
+                     %init: memref<8x16x32xf32>) {
+  linalg.broadcast
+      ins(%input:memref<8x32xf32>)
+      outs(%init:memref<8x16x32xf32>)
+      dimensions = [0, 2]
+  func.return
+}
+// CHECK-LABEL: func.func @broadcast(
+// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]: memref<8x32xf32>,
+// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: memref<8x16x32xf32>
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+
+// CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:       scf.for %[[J:.*]] = %[[C0]] to %[[C16]] step %[[C1]] {
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:           %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]]
+// CHECK:           memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]


        


More information about the Mlir-commits mailing list