[Mlir-commits] [mlir] 6c7be41 - Support buffers in LinalgFoldUnitExtentDims

Tres Popp llvmlistbot at llvm.org
Mon Jun 14 23:22:32 PDT 2021


Author: Tres Popp
Date: 2021-06-15T08:22:22+02:00
New Revision: 6c7be4176703fff69d20acc466a879e080346f30

URL: https://github.com/llvm/llvm-project/commit/6c7be4176703fff69d20acc466a879e080346f30
DIFF: https://github.com/llvm/llvm-project/commit/6c7be4176703fff69d20acc466a879e080346f30.diff

LOG: Support buffers in LinalgFoldUnitExtentDims

This doesn't add any canonicalizations, but executes the same
simplification on bufferSemantic linalg.generic ops by using
linalg::ReshapeOp instead of linalg::TensorReshapeOp.

Differential Revision: https://reviews.llvm.org/D103513

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 1deea9476674..68102cb0f480 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/CommandLine.h"
@@ -256,7 +257,7 @@ struct UnitExtentReplacementInfo {
 } // namespace
 
 /// Utility function for replacing operands/results to a linalg generic
-/// operation on tensors with unit-extent dimensions. These can be replaced with
+/// operation with unit-extent dimensions. These can be replaced with
 /// an operand/result with the unit-extent dimension removed. This is only done
 /// if the indexing map used to access that didimensionmension has a
 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
@@ -301,10 +302,19 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
     ++dim;
   }
   // Compute the tensor or scalar replacement type.
+  Type actualType = opOperand->get().getType();
   Type elementType = getElementTypeOrSelf(opOperand->get());
-  Type replacementType = elementType == opOperand->get().getType()
-                             ? elementType
-                             : RankedTensorType::get(newShape, elementType);
+  Type replacementType;
+  if (elementType == opOperand->get().getType()) {
+    replacementType = elementType;
+  } else if (actualType.isa<RankedTensorType>()) {
+    replacementType = RankedTensorType::get(newShape, elementType);
+  } else if (actualType.isa<MemRefType>()) {
+    assert(actualType.cast<MemRefType>().getAffineMaps().empty() &&
+           "unsupported strided memrefs");
+    replacementType = MemRefType::get(newShape, elementType);
+  }
+  assert(replacementType && "unsupported shaped type");
   UnitExtentReplacementInfo info = {replacementType,
                                     AffineMap::get(indexingMap.getNumDims(),
                                                    indexingMap.getNumSymbols(),
@@ -324,14 +334,53 @@ convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
   return reassociationExprs;
 }
 
-/// Pattern to replace tensors operands/results that are unit extents.
-struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
+/// Pattern to replace tensor/buffer operands/results that are unit extents.
+struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
   using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  // Return the original value if the type is unchanged, or reshape it. Return a
+  // nullptr if this is an unsupported type.
+  Value maybeExpand(Value result, Type origResultType,
+                    ArrayAttr reassociationMap, Location loc,
+                    PatternRewriter &rewriter) const {
+    if (origResultType == result.getType())
+      return result;
+    if (origResultType.isa<RankedTensorType>()) {
+      return rewriter.create<linalg::TensorExpandShapeOp>(
+          loc, origResultType, result,
+          convertAffineMapArrayToExprs(reassociationMap));
+    }
+    if (origResultType.isa<MemRefType>()) {
+      return rewriter.create<linalg::ExpandShapeOp>(
+          loc, origResultType, result,
+          convertAffineMapArrayToExprs(reassociationMap));
+    }
+    return nullptr;
+  };
+
+  // Return the original value if the type is unchanged, or reshape it. Return a
+  // nullptr if this is an unsupported type.
+  Value maybeCollapse(Value operand, Type newInputOutputType,
+                      ArrayAttr reassociationMap, Location loc,
+                      PatternRewriter &rewriter) const {
+    auto operandType = operand.getType();
+    if (operandType == newInputOutputType)
+      return operand;
+    if (operandType.isa<MemRefType>()) {
+      return rewriter.create<linalg::CollapseShapeOp>(
+          loc, newInputOutputType, operand,
+          convertAffineMapArrayToExprs(reassociationMap));
+    }
+    if (operandType.isa<RankedTensorType>()) {
+      return rewriter.create<linalg::TensorCollapseShapeOp>(
+          loc, newInputOutputType, operand,
+          convertAffineMapArrayToExprs(reassociationMap));
+    }
+    return nullptr;
+  };
+
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
-      return failure();
-
     MLIRContext *context = rewriter.getContext();
     Location loc = genericOp.getLoc();
 
@@ -339,7 +388,6 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
     SmallVector<ArrayAttr> reassociationMaps;
     SmallVector<Type> newInputOutputTypes;
     bool doCanonicalization = false;
-
     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
       UnitExtentReplacementInfo replacementInfo =
           replaceUnitExtents(genericOp, opOperand, context);
@@ -362,14 +410,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
     auto insertReshapes = [&](ValueRange values) {
       SmallVector<Value, 4> res;
       res.reserve(values.size());
-      for (auto operand : llvm::enumerate(values)) {
-        if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
-          res.push_back(operand.value());
-        else {
-          res.push_back(rewriter.create<TensorCollapseShapeOp>(
-              loc, newInputOutputTypes[flattenedIdx], operand.value(),
-              convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])));
-        }
+      for (auto operand : values) {
+        auto reshapedValue =
+            maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
+                          reassociationMaps[flattenedIdx], loc, rewriter);
+        assert(reshapedValue &&
+               "expected ranked MemRef or Tensor operand type");
+        res.push_back(reshapedValue);
         ++flattenedIdx;
       }
       return res;
@@ -396,15 +443,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
     SmallVector<Value, 4> resultReplacements;
     for (auto result : llvm::enumerate(replacementOp.getResults())) {
       unsigned index = result.index() + replacementOp.getNumInputs();
-      RankedTensorType origResultType = genericOp.getResult(result.index())
-                                            .getType()
-                                            .template cast<RankedTensorType>();
-      if (origResultType != result.value().getType()) {
-        resultReplacements.push_back(rewriter.create<TensorExpandShapeOp>(
-            loc, origResultType, result.value(),
-            convertAffineMapArrayToExprs(reassociationMaps[index])));
-      } else
-        resultReplacements.push_back(result.value());
+      auto origResultType = genericOp.getResult(result.index()).getType();
+
+      auto newResult = maybeExpand(result.value(), origResultType,
+                                   reassociationMaps[index], loc, rewriter);
+      assert(newResult &&
+             "unexpected output type other than ranked MemRef or Tensor");
+      resultReplacements.push_back(newResult);
     }
     rewriter.replaceOp(genericOp, resultReplacements);
     return success();
@@ -501,9 +546,8 @@ struct UseRankReducedSubTensorInsertOp
 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
     RewritePatternSet &patterns) {
   auto *context = patterns.getContext();
-  patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors,
-               UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
-      context);
+  patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, UseRankReducedSubTensorOp,
+               UseRankReducedSubTensorInsertOp>(context);
   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
 }

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5a53c228bea5..f5359e54a7d5 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -451,3 +451,303 @@ func @subtensor_insert_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>)
 //       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]]
 //  CHECK-SAME:     tensor<f32> into tensor<1x3xf32>
 //       CHECK:   return %[[RESULT]]
+
+// -----
+
+#accesses = [
+  affine_map<(i, j, k, l, m) -> (i, k, m)>,
+  affine_map<(i, j, k, l, m) -> ()>,
+  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+  indexing_maps = #accesses,
+  library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops(%arg0 : memref<?x1x?xf32>, %arg1 : f32, %shape: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32> {
+  linalg.generic #trait
+     ins(%arg0, %arg1 : memref<?x1x?xf32>, f32)
+    outs(%shape : memref<?x1x?x1x?xf32>) {
+       ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
+         linalg.yield %arg3 : f32
+       }
+  return %shape : memref<?x1x?x1x?xf32>
+}
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
+//   CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @drop_one_trip_loops
+//       CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
+
+// -----
+
+#accesses = [
+  affine_map<(i, j, k, l, m) -> (i, k, m)>,
+  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+  indexing_maps = #accesses,
+  library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops_indexed
+  (%arg0 : memref<?x1x?xi32>, %shape: memref<?x1x?x1x?xi32>) -> memref<?x1x?x1x?xi32>
+{
+  linalg.generic #trait
+     ins(%arg0 : memref<?x1x?xi32>)
+    outs(%shape: memref<?x1x?x1x?xi32>) {
+       ^bb0(%arg6 : i32, %arg7 : i32) :
+         %idx0 = linalg.index 0 : index
+         %idx1 = linalg.index 1 : index
+         %idx2 = linalg.index 2 : index
+         %idx3 = linalg.index 3 : index
+         %idx4 = linalg.index 4 : index
+         %1 = addi %idx0, %idx1 : index
+         %2 = subi %1, %idx2 : index
+         %3 = subi %2, %idx3 : index
+         %4 = addi %3, %idx4 : index
+         %5 = index_cast %4 : index to i32
+         %6 = addi %5, %arg6 : i32
+         linalg.yield %6 : i32
+       }
+  return %shape : memref<?x1x?x1x?xi32>
+}
+// The subtractions disappear the access map of the output memref maps its unit
+// dimensions 1 and 3 to the index dimensions 2 and 3.
+// CHECK-LABEL: func @drop_one_trip_loops_indexed
+//       CHECK:   linalg.generic
+//       CHECK:   ^{{.+}}(
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32)
+//       CHECK:     %[[IDX0:.+]] = linalg.index 0 : index
+//       CHECK:     %[[IDX1:.+]] = linalg.index 1 : index
+//       CHECK:     %[[IDX2:.+]] = linalg.index 2 : index
+//       CHECK:     %[[T3:.+]] = addi %[[IDX0]], %[[IDX1]]
+//       CHECK:     %[[T4:.+]] = addi %[[T3]], %[[IDX2]]
+//       CHECK:     %[[T5:.+]] = index_cast %[[T4]] : index to i32
+//       CHECK:     %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32
+//       CHECK:     linalg.yield %[[T6]] : i32
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+  iterator_types = ["parallel", "parallel"],
+  indexing_maps = #access,
+  library_call = "some_external_func"
+}
+
+func @drop_all_loops(%arg0 : memref<1x1xf32>) -> memref<1x1xf32>
+{
+  linalg.generic #trait
+     ins(%arg0 : memref<1x1xf32>)
+    outs(%arg0 : memref<1x1xf32>) {
+       ^bb0(%arg1: f32, %arg2: f32) :
+         linalg.yield %arg1 : f32
+       }
+  return %arg0 : memref<1x1xf32>
+}
+//       CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
+// CHECK-LABEL: func @drop_all_loops
+//       CHECK:   linalg.collapse_shape %{{.*}} []
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
+//  CHECK-SAME:     iterator_types = []
+
+// -----
+
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+  iterator_types = ["parallel", "parallel"],
+  indexing_maps = #access,
+  library_call = "some_external_func"
+}
+
+func @drop_all_loops_indexed
+  (%arg0 : memref<1x1xi32>) -> memref<1x1xi32>{
+  linalg.generic #trait
+     ins(%arg0 : memref<1x1xi32>)
+    outs(%arg0 : memref<1x1xi32>) {
+       ^bb0(%arg3: i32, %arg4: i32) :
+         %idx0 = linalg.index 0 : index
+         %idx1 = linalg.index 1 : index
+         %1 = addi %idx0, %idx1 : index
+         %2 = index_cast %1 : index to i32
+         %3 = addi %2, %arg3 : i32
+         linalg.yield %3 : i32
+       }
+  return %arg0 : memref<1x1xi32>
+}
+
+// CHECK-LABEL: func @drop_all_loops_indexed
+//       CHECK:   linalg.generic
+//       CHECK:   ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+//       CHECK:     linalg.yield %[[ARG1]] : i32
+
+// -----
+
+#accesses = [
+  affine_map<(d0) -> (0, d0)>,
+  affine_map<(d0) -> (d0)>
+]
+
+#trait = {
+  indexing_maps = #accesses,
+  iterator_types = ["parallel"],
+  library_call = "some_external_fn"
+}
+
+func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref<5xf32>) -> memref<5xf32> {
+  linalg.generic #trait
+     ins(%arg0 : memref<1x5xf32>)
+    outs(%shape : memref<5xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):     // no predecessors
+    linalg.yield %arg2 : f32
+  }
+  return %shape : memref<5xf32>
+}
+//   CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: func @leading_dim_1_canonicalization
+//       CHECK:   linalg.collapse_shape %{{.*}} {{\[}}[0, 1]]
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[$MAP1]], #[[$MAP1]]]
+//  CHECK-SAME:     iterator_types = ["parallel"]
+
+// -----
+
+#accesses = [
+  affine_map<(d0, d1) -> (0, d1)>,
+  affine_map<(d0, d1) -> (d0, 0)>,
+  affine_map<(d0, d1) -> (d0, d1)>
+]
+
+#trait = {
+  indexing_maps = #accesses,
+  iterator_types = ["parallel", "parallel"],
+  library_call = "some_external_fn"
+}
+
+func @broadcast_test(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %shape : memref<5x5xf32>) -> memref<5x5xf32>
+{
+  %0 = linalg.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32>
+  %1 = linalg.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32>
+  linalg.generic #trait
+     ins(%0, %1 : memref<1x5xf32>, memref<5x1xf32>)
+    outs(%shape : memref<5x5xf32>) {
+       ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+         %3 = addf %arg3, %arg4 : f32
+         linalg.yield %3 : f32
+       }
+  return %shape : memref<5x5xf32>
+}
+//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
+//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @broadcast_test
+//   CHECK-NOT:   linalg.memref_{{.*}}shape
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+//  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
+//   CHECK-NOT:   linalg.memref_{{.*}}shape
+
+// -----
+
+#accesses = [
+  affine_map<(d0, d1) -> (0, 0)>,
+  affine_map<(d0, d1) -> (d0, d1)>
+]
+
+#trait = {
+  indexing_maps = #accesses,
+  iterator_types = ["parallel", "parallel"],
+  library_call = "some_external_fn"
+}
+
+func @broadcast_scalar(%arg0 : memref<1x1xf32>, %shape : memref<?x?xf32>) -> memref<?x?xf32>
+{
+   linalg.generic #trait
+     ins(%arg0 : memref<1x1xf32>)
+    outs(%shape : memref<?x?xf32>) {
+      ^bb0(%arg2 : f32, %arg3 : f32):
+        linalg.yield %arg2 : f32
+   }
+   return %shape : memref<?x?xf32>
+}
+//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @broadcast_scalar
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<1x1xf32>
+//       CHECK:   %[[A:.*]] = linalg.collapse_shape %[[ARG0]] []
+//  CHECK-SAME:     memref<1x1xf32> into memref<f32>
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
+//  CHECK-SAME:     %[[A]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2)>
+func @fold_unit_dim_memref_reshape_op(%arg0 : memref<5xf32>) -> memref<2x5xf32>
+{
+  %1 = memref.alloc() : memref<1x2x5xf32>
+  linalg.generic {i64, indexing_maps = [#map1, #map0],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%arg0 : memref<5xf32>) outs(%1 : memref<1x2x5xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
+      linalg.yield %arg1 : f32
+    }
+  %3 = linalg.collapse_shape %1 [[0, 1], [2]]
+    : memref<1x2x5xf32> into memref<2x5xf32>
+  return %3 : memref<2x5xf32>
+}
+// CHECK-LABEL: func @fold_unit_dim_memref_reshape_op
+//       CHECK:   %[[ALLOC:.*]] = memref.alloc() : memref<1x2x5xf32>
+//       CHECK:   %[[OUT:.*]] = linalg.collapse_shape %[[ALLOC]]
+//       CHECK:   linalg.generic
+//       CHECK-SAME:   outs(%[[OUT:.*]] :
+//       CHECK:   %[[RESULT:.*]] = linalg.collapse_shape %[[ALLOC]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32> {
+  %cst = constant 0.0 : f32
+  %init = memref.alloc() : memref<1xf32>
+  linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+    ins(%input : memref<1x1000xf32>)outs(%init : memref<1xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    %1823 = addf %arg1, %arg2 : f32
+    linalg.yield %1823 : f32
+  }
+  return %init : memref<1xf32>
+}
+
+
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ()>
+
+//       CHECK: func @fold_unit_dim_for_init_memref
+//       CHECK: %[[INIT:.+]] = memref.alloc() : memref<1xf32>
+//       CHECK: %[[INPUT_RESHAPE:.+]] = linalg.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32>
+//       CHECK: %[[INIT_RESHAPE:.+]] = linalg.collapse_shape %[[INIT]] [] : memref<1xf32> into memref<f32>
+//       CHECK: linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP2]]]
+//  CHECK-SAME:     iterator_types = ["reduction"]
+//  CHECK-SAME:   ins(%[[INPUT_RESHAPE]] : memref<1000xf32>)
+//  CHECK-SAME:   outs(%[[INIT_RESHAPE]] : memref<f32>)
+//       CHECK: return %[[INIT:.+]] : memref<1xf32>
+
+
+


        


More information about the Mlir-commits mailing list