[Mlir-commits] [mlir] [mlir][sparse] handle padding on sparse levels. (PR #90527)
Peiming Liu
llvmlistbot at llvm.org
Mon Apr 29 15:19:24 PDT 2024
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/90527
None
>From 6bcb652c52ef0cd671d0e5249f7ae917dca47bbc Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 29 Apr 2024 17:49:18 +0000
Subject: [PATCH] [mlir][sparse] handle padding on sparse levels.
---
.../Transforms/Utils/LoopEmitter.cpp | 67 +++++--
.../Transforms/Utils/SparseTensorIterator.cpp | 111 ++++++++++--
.../Transforms/Utils/SparseTensorIterator.h | 7 +
.../CPU/padded_sparse_conv_2d.mlir | 169 ++++++++++++++++++
4 files changed, 328 insertions(+), 26 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 812c288a20c2df..98e315865ba5d9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -75,6 +75,40 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
}
+static bool isIntOrFPZero(Attribute attr) {
+ if (auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero())
+ return true;
+ if (auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero())
+ return true;
+ return false;
+}
+
+static Value unFoldOpIntResult(OpBuilder &builder, Location loc,
+ OpFoldResult ofr) {
+ if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value())
+ return constantIndex(builder, loc, *i);
+ return ofr.get<Value>();
+}
+
+static Value tryFoldTensors(Value t) {
+ // TODO: this should be done through a folding pass after switching to
+ // `sparse_tensor.iterate`-based sparsification.
+ auto stt = tryGetSparseTensorType(t);
+ auto padOp = t.getDefiningOp<tensor::PadOp>();
+ if (padOp && stt.has_value() && stt->hasEncoding() &&
+ padOp.getSourceType().getEncoding() == stt->getEncoding() &&
+ stt->getEncoding().isIdentity()) {
+ // Try fusing padOp with zeros.
+ Attribute padCst;
+ if (matchPattern(padOp.getBody()->getTerminator(),
+ m_Op<tensor::YieldOp>(m_Constant(&padCst))) &&
+ isIntOrFPZero(padCst)) {
+ return padOp.getSource();
+ }
+ }
+ return t;
+}
+
//===----------------------------------------------------------------------===//
// Sparse tensor loop emitter class implementations
//===----------------------------------------------------------------------===//
@@ -166,15 +200,30 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
std::unique_ptr<SparseIterator>
LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
Level l) {
+ Value tensor = tensors[t];
+ auto stt = getSparseTensorType(tensor);
auto it = makeSimpleIterator(*lvls[t][l], emitStrategy);
- auto stt = getSparseTensorType(tensors[t]);
+
+ Value folded = tryFoldTensors(tensor);
+ if (folded != tensor) {
+ auto padOp = tensor.getDefiningOp<tensor::PadOp>();
+ assert(padOp);
+ if (padOp.getPaddedDims().test(l)) {
+ Value low = unFoldOpIntResult(builder, loc, padOp.getMixedLowPad()[l]);
+ Value high = unFoldOpIntResult(builder, loc, padOp.getMixedHighPad()[l]);
+ auto padIt = makePaddedIterator(std::move(it), low, high, emitStrategy);
+ return padIt;
+ }
+ }
+
if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
- Value offset = genSliceOffset(builder, loc, tensors[t], l);
- Value stride = genSliceStride(builder, loc, tensors[t], l);
+ Value offset = genSliceOffset(builder, loc, tensor, l);
+ Value stride = genSliceStride(builder, loc, tensor, l);
auto slicedIt = makeSlicedLevelIterator(
std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);
return slicedIt;
}
+
return it;
}
@@ -200,7 +249,9 @@ void LoopEmitter::initializeLoopEmit(
// on positions.
for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
t++) {
- const Value tensor = tensors[t];
+ // TODO: this should be done through a folding pass after switching to
+ // `sparse_tensor.iterate`-based sparsification.
+ const Value tensor = tryFoldTensors(tensors[t]);
const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
if (!rtp)
// Skips only scalar, zero ranked tensor still need to be bufferized and
@@ -213,14 +264,6 @@ void LoopEmitter::initializeLoopEmit(
const Level lvlRank = stt.getLvlRank();
const auto shape = rtp.getShape();
- SmallVector<Value> lvlSzs;
- for (Level l = 0; l < stt.getLvlRank(); l++) {
- if (stt.hasEncoding())
- lvlSzs.push_back(builder.create<LvlOp>(loc, tensor, l));
- else
- lvlSzs.push_back(builder.create<tensor::DimOp>(loc, tensor, l));
- }
-
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
// Find upper bound in current dimension.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 745c081247dee8..252dfc85b528c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -462,11 +462,54 @@ class DedupIterator : public ConcreteIterator {
Value posHi;
};
+// A util base-iterator that delegates all methods to the wrapped iterator.
+class SimpleWrapIterator : public SparseIterator {
+public:
+ SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
+ : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ return wrap->getCursorValTypes(b);
+ }
+ bool isBatchIterator() const override { return wrap->isBatchIterator(); }
+ bool randomAccessible() const override { return wrap->randomAccessible(); };
+ bool iteratableByFor() const override { return wrap->iteratableByFor(); };
+ SmallVector<Value> serialize() const override { return wrap->serialize(); };
+ void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
+ ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
+ void genInitImpl(OpBuilder &b, Location l,
+ const SparseIterator *parent) override {
+ wrap->genInit(b, l, parent);
+ }
+ Value genNotEndImpl(OpBuilder &b, Location l) override {
+ return wrap->genNotEndImpl(b, l);
+ }
+ ValueRange forwardImpl(OpBuilder &b, Location l) override {
+ return wrap->forward(b, l);
+ };
+ Value upperBound(OpBuilder &b, Location l) const override {
+ return wrap->upperBound(b, l);
+ };
+
+ Value derefImpl(OpBuilder &b, Location l) override {
+ return wrap->derefImpl(b, l);
+ }
+
+ void locateImpl(OpBuilder &b, Location l, Value crd) override {
+ return wrap->locate(b, l, crd);
+ }
+
+ SparseIterator &getWrappedIterator() const { return *wrap; }
+
+protected:
+ std::unique_ptr<SparseIterator> wrap;
+};
+
//
// A filter iterator wrapped from another iterator. The filter iterator update
// the wrapped iterator *in-place*.
//
-class FilterIterator : public SparseIterator {
+class FilterIterator : public SimpleWrapIterator {
// Coorindate translation between crd loaded from the wrap iterator and the
// filter iterator.
Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
@@ -487,8 +530,8 @@ class FilterIterator : public SparseIterator {
// when crd always < size.
FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
Value stride, Value size)
- : SparseIterator(IterKind::kFilter, *wrap), offset(offset),
- stride(stride), size(size), wrap(std::move(wrap)) {}
+ : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
+ stride(stride), size(size) {}
// For LLVM-style RTTI.
static bool classof(const SparseIterator *from) {
@@ -498,19 +541,10 @@ class FilterIterator : public SparseIterator {
std::string getDebugInterfacePrefix() const override {
return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
}
- SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
- return wrap->getCursorValTypes(b);
- }
- bool isBatchIterator() const override { return wrap->isBatchIterator(); }
- bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return randomAccessible(); };
Value upperBound(OpBuilder &b, Location l) const override { return size; };
- SmallVector<Value> serialize() const override { return wrap->serialize(); };
- void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
- ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
-
void genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) override {
wrap->genInit(b, l, parent);
@@ -541,7 +575,47 @@ class FilterIterator : public SparseIterator {
ValueRange forwardImpl(OpBuilder &b, Location l) override;
Value offset, stride, size;
- std::unique_ptr<SparseIterator> wrap;
+};
+
+//
+// A pad iterator wrapped from another iterator. The pad iterator update
+// the wrapped iterator *in-place*.
+//
+class PadIterator : public SimpleWrapIterator {
+
+public:
+ PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
+ Value padHigh)
+ : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
+ padHigh(padHigh) {
+ assert(!randomAccessible() && "Not implemented.");
+ }
+
+ // For LLVM-style RTTI.
+ static bool classof(const SparseIterator *from) {
+ return from->kind == IterKind::kPad;
+ }
+
+ std::string getDebugInterfacePrefix() const override {
+ return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
+ }
+
+ // The upper bound after padding becomes `size + padLow + padHigh`.
+ Value upperBound(OpBuilder &b, Location l) const override {
+ return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
+ };
+
+ // The pad_coord = coord + pad_lo
+ Value derefImpl(OpBuilder &b, Location l) override {
+ updateCrd(ADDI(wrap->deref(b, l), padLow));
+ return getCrd();
+ }
+
+ void locateImpl(OpBuilder &b, Location l, Value crd) override {
+ assert(randomAccessible());
+ }
+
+ Value padLow, padHigh;
};
class NonEmptySubSectIterator : public SparseIterator {
@@ -1408,10 +1482,19 @@ sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
return ret;
}
+std::unique_ptr<SparseIterator>
+sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
+ Value padLow, Value padHigh,
+ SparseEmitStrategy strategy) {
+ auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
+ ret->setSparseEmitStrategy(strategy);
+ return ret;
+}
+
static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
if (filter)
- return filter->wrap.get();
+ return &filter->getWrappedIterator();
return it;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index b692848ec67bd8..fe43b15a33698b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -76,6 +76,7 @@ enum class IterKind : uint8_t {
kSubSect,
kNonEmptySubSect,
kFilter,
+ kPad,
};
/// Helper class that generates loop conditions, etc, to traverse a
@@ -303,6 +304,12 @@ std::unique_ptr<SparseIterator>
makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
Value stride, Value size, SparseEmitStrategy strategy);
+/// Helper function to create a SparseIterator object that iterate over a
+/// padded sparse level (the padded value must be zero).
+std::unique_ptr<SparseIterator>
+makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow,
+ Value padHigh, SparseEmitStrategy strategy);
+
/// Helper function to create a SparseIterator object that iterate over the
/// non-empty subsections set.
std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
new file mode 100644
index 00000000000000..1deb6f74c0d28b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -0,0 +1,169 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+// do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
+// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and VLA vectorization.
+// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
+
+#CCCC = #sparse_tensor.encoding<{
+ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
+}>
+
+#CDCD = #sparse_tensor.encoding<{
+ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
+}>
+
+#DCCD = #sparse_tensor.encoding<{
+ map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+}>
+
+// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
+func.func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f : f32) -> tensor<?x?x?x?xf32> {
+ %buf = tensor.empty(%s1, %s2, %s3, %s4) : tensor<?x?x?x?xf32>
+ %ret = linalg.fill ins(%f : f32) outs(%buf : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %ret : tensor<?x?x?x?xf32>
+}
+
+func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf32>, %arg2: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32> {
+ %cst_0 = arith.constant 0.00000e+00 : f32
+
+ %padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
+ ^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
+ tensor.yield %cst_0 : f32
+ } : tensor<3x8x8x3xf32> to tensor<3x12x12x3xf32>
+
+ %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%padded, %arg1: tensor<3x12x12x3xf32>, tensor<5x5x3x1xf32>)
+ outs (%arg2: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
+ return %ret : tensor<3x8x8x1xf32>
+}
+
+func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c8 = arith.constant 8 : index
+ %cst_0 = arith.constant 0.00000e+00 : f32
+ %buf = tensor.empty() : tensor<3x8x8x1xf32>
+ %s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
+
+ %padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
+ ^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
+ tensor.yield %cst_0 : f32
+ } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+
+ %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+ outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
+ return %ret : tensor<3x8x8x1xf32>
+}
+
+func.func @main() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c5 = arith.constant 5 : index
+ %c6 = arith.constant 6 : index
+ %c8 = arith.constant 8 : index
+ %f10 = arith.constant 10.00000e+00 : f32
+ %val = arith.constant 2.00000e+00 : f32
+ %zero = arith.constant 0.00000e+00 : f32
+
+ %filter2D_nhwc = call @alloc_4d_filled_f32(%c5, %c5, %c3, %c1, %val) :(index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
+ %in2D_tmp = call @alloc_4d_filled_f32(%c3, %c8, %c8, %c3, %val) : (index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
+ %in2D_nhwc = tensor.insert %f10 into %in2D_tmp[%c0, %c0, %c3, %c0] : tensor<?x?x?x?xf32>
+ %out2D_nhwc = call @alloc_4d_filled_f32(%c3, %c8, %c8, %c1, %zero) : (index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
+
+ %static_filter = tensor.cast %filter2D_nhwc : tensor<?x?x?x?xf32> to tensor<5x5x3x1xf32>
+ %static_input = tensor.cast %in2D_nhwc : tensor<?x?x?x?xf32> to tensor<3x8x8x3xf32>
+ %static_output = tensor.cast %out2D_nhwc : tensor<?x?x?x?xf32> to tensor<3x8x8x1xf32>
+
+ %dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
+
+ %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
+ %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+
+
+ // CHECK: ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 208 ), ( 256 ), ( 256 ), ( 256 ), ( 256 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 256 ), ( 316 ), ( 316 ), ( 316 ), ( 316 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ),
+ // CHECK-SAME:( ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ),
+ // CHECK-SAME:( ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
+ %dense_v = vector.transfer_read %dense_ret[%c0, %c0, %c0, %c0], %zero
+ : tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
+ vector.print %dense_v : vector<3x8x8x1xf32>
+
+ // CHECK-NEXT: ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 208 ), ( 256 ), ( 256 ), ( 256 ), ( 256 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 256 ), ( 316 ), ( 316 ), ( 316 ), ( 316 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ),
+ // CHECK-SAME: ( ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ),
+ // CHECK-SAME: ( ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+ // CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+ // CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
+ %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+ : tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
+ vector.print %CCCC_v : vector<3x8x8x1xf32>
+ return
+}
More information about the Mlir-commits
mailing list