[Mlir-commits] [mlir] [mlir][sparse] handle padding on sparse levels. (PR #90527)

Peiming Liu llvmlistbot at llvm.org
Mon Apr 29 15:19:24 PDT 2024


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/90527

None

>From 6bcb652c52ef0cd671d0e5249f7ae917dca47bbc Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 29 Apr 2024 17:49:18 +0000
Subject: [PATCH] [mlir][sparse] handle padding on sparse levels.

---
 .../Transforms/Utils/LoopEmitter.cpp          |  67 +++++--
 .../Transforms/Utils/SparseTensorIterator.cpp | 111 ++++++++++--
 .../Transforms/Utils/SparseTensorIterator.h   |   7 +
 .../CPU/padded_sparse_conv_2d.mlir            | 169 ++++++++++++++++++
 4 files changed, 328 insertions(+), 26 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 812c288a20c2df..98e315865ba5d9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -75,6 +75,40 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
   return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
 }
 
+static bool isIntOrFPZero(Attribute attr) {
+  if (auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero())
+    return true;
+  if (auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero())
+    return true;
+  return false;
+}
+
+static Value unFoldOpIntResult(OpBuilder &builder, Location loc,
+                               OpFoldResult ofr) {
+  if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value())
+    return constantIndex(builder, loc, *i);
+  return ofr.get<Value>();
+}
+
+static Value tryFoldTensors(Value t) {
+  // TODO: this should be done through a folding pass after switching to
+  // `sparse_tensor.iterate`-based sparsification.
+  auto stt = tryGetSparseTensorType(t);
+  auto padOp = t.getDefiningOp<tensor::PadOp>();
+  if (padOp && stt.has_value() && stt->hasEncoding() &&
+      padOp.getSourceType().getEncoding() == stt->getEncoding() &&
+      stt->getEncoding().isIdentity()) {
+    // Try fusing padOp with zeros.
+    Attribute padCst;
+    if (matchPattern(padOp.getBody()->getTerminator(),
+                     m_Op<tensor::YieldOp>(m_Constant(&padCst))) &&
+        isIntOrFPZero(padCst)) {
+      return padOp.getSource();
+    }
+  }
+  return t;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse tensor loop emitter class implementations
 //===----------------------------------------------------------------------===//
@@ -166,15 +200,30 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
 std::unique_ptr<SparseIterator>
 LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
                                Level l) {
+  Value tensor = tensors[t];
+  auto stt = getSparseTensorType(tensor);
   auto it = makeSimpleIterator(*lvls[t][l], emitStrategy);
-  auto stt = getSparseTensorType(tensors[t]);
+
+  Value folded = tryFoldTensors(tensor);
+  if (folded != tensor) {
+    auto padOp = tensor.getDefiningOp<tensor::PadOp>();
+    assert(padOp);
+    if (padOp.getPaddedDims().test(l)) {
+      Value low = unFoldOpIntResult(builder, loc, padOp.getMixedLowPad()[l]);
+      Value high = unFoldOpIntResult(builder, loc, padOp.getMixedHighPad()[l]);
+      auto padIt = makePaddedIterator(std::move(it), low, high, emitStrategy);
+      return padIt;
+    }
+  }
+
   if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
-    Value offset = genSliceOffset(builder, loc, tensors[t], l);
-    Value stride = genSliceStride(builder, loc, tensors[t], l);
+    Value offset = genSliceOffset(builder, loc, tensor, l);
+    Value stride = genSliceStride(builder, loc, tensor, l);
     auto slicedIt = makeSlicedLevelIterator(
         std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);
     return slicedIt;
   }
+
   return it;
 }
 
@@ -200,7 +249,9 @@ void LoopEmitter::initializeLoopEmit(
   //     on positions.
   for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
        t++) {
-    const Value tensor = tensors[t];
+    // TODO: this should be done through a folding pass after switching to
+    // `sparse_tensor.iterate`-based sparsification.
+    const Value tensor = tryFoldTensors(tensors[t]);
     const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
     if (!rtp)
       // Skips only scalar, zero ranked tensor still need to be bufferized and
@@ -213,14 +264,6 @@ void LoopEmitter::initializeLoopEmit(
     const Level lvlRank = stt.getLvlRank();
     const auto shape = rtp.getShape();
 
-    SmallVector<Value> lvlSzs;
-    for (Level l = 0; l < stt.getLvlRank(); l++) {
-      if (stt.hasEncoding())
-        lvlSzs.push_back(builder.create<LvlOp>(loc, tensor, l));
-      else
-        lvlSzs.push_back(builder.create<tensor::DimOp>(loc, tensor, l));
-    }
-
     // Scan all levels of current tensor.
     for (Level l = 0; l < lvlRank; l++) {
       // Find upper bound in current dimension.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 745c081247dee8..252dfc85b528c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -462,11 +462,54 @@ class DedupIterator : public ConcreteIterator {
   Value posHi;
 };
 
+// A util base-iterator that delegates all methods to the wrapped iterator.
+class SimpleWrapIterator : public SparseIterator {
+public:
+  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
+      : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    return wrap->getCursorValTypes(b);
+  }
+  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
+  bool randomAccessible() const override { return wrap->randomAccessible(); };
+  bool iteratableByFor() const override { return wrap->iteratableByFor(); };
+  SmallVector<Value> serialize() const override { return wrap->serialize(); };
+  void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
+  ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
+  void genInitImpl(OpBuilder &b, Location l,
+                   const SparseIterator *parent) override {
+    wrap->genInit(b, l, parent);
+  }
+  Value genNotEndImpl(OpBuilder &b, Location l) override {
+    return wrap->genNotEndImpl(b, l);
+  }
+  ValueRange forwardImpl(OpBuilder &b, Location l) override {
+    return wrap->forward(b, l);
+  };
+  Value upperBound(OpBuilder &b, Location l) const override {
+    return wrap->upperBound(b, l);
+  };
+
+  Value derefImpl(OpBuilder &b, Location l) override {
+    return wrap->derefImpl(b, l);
+  }
+
+  void locateImpl(OpBuilder &b, Location l, Value crd) override {
+    return wrap->locate(b, l, crd);
+  }
+
+  SparseIterator &getWrappedIterator() const { return *wrap; }
+
+protected:
+  std::unique_ptr<SparseIterator> wrap;
+};
+
 //
 // A filter iterator wrapped from another iterator. The filter iterator update
 // the wrapped iterator *in-place*.
 //
-class FilterIterator : public SparseIterator {
+class FilterIterator : public SimpleWrapIterator {
   // Coorindate translation between crd loaded from the wrap iterator and the
   // filter iterator.
   Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
@@ -487,8 +530,8 @@ class FilterIterator : public SparseIterator {
   // when crd always < size.
   FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
                  Value stride, Value size)
-      : SparseIterator(IterKind::kFilter, *wrap), offset(offset),
-        stride(stride), size(size), wrap(std::move(wrap)) {}
+      : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
+        stride(stride), size(size) {}
 
   // For LLVM-style RTTI.
   static bool classof(const SparseIterator *from) {
@@ -498,19 +541,10 @@ class FilterIterator : public SparseIterator {
   std::string getDebugInterfacePrefix() const override {
     return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
   }
-  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
-    return wrap->getCursorValTypes(b);
-  }
 
-  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
-  bool randomAccessible() const override { return wrap->randomAccessible(); };
   bool iteratableByFor() const override { return randomAccessible(); };
   Value upperBound(OpBuilder &b, Location l) const override { return size; };
 
-  SmallVector<Value> serialize() const override { return wrap->serialize(); };
-  void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
-  ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
-
   void genInitImpl(OpBuilder &b, Location l,
                    const SparseIterator *parent) override {
     wrap->genInit(b, l, parent);
@@ -541,7 +575,47 @@ class FilterIterator : public SparseIterator {
   ValueRange forwardImpl(OpBuilder &b, Location l) override;
 
   Value offset, stride, size;
-  std::unique_ptr<SparseIterator> wrap;
+};
+
+//
+// A pad iterator wrapped from another iterator. The pad iterator update
+// the wrapped iterator *in-place*.
+//
+class PadIterator : public SimpleWrapIterator {
+
+public:
+  PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
+              Value padHigh)
+      : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
+        padHigh(padHigh) {
+    assert(!randomAccessible() && "Not implemented.");
+  }
+
+  // For LLVM-style RTTI.
+  static bool classof(const SparseIterator *from) {
+    return from->kind == IterKind::kPad;
+  }
+
+  std::string getDebugInterfacePrefix() const override {
+    return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
+  }
+
+  // The upper bound after padding becomes `size + padLow + padHigh`.
+  Value upperBound(OpBuilder &b, Location l) const override {
+    return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
+  };
+
+  // The pad_coord = coord + pad_lo
+  Value derefImpl(OpBuilder &b, Location l) override {
+    updateCrd(ADDI(wrap->deref(b, l), padLow));
+    return getCrd();
+  }
+
+  void locateImpl(OpBuilder &b, Location l, Value crd) override {
+    assert(randomAccessible());
+  }
+
+  Value padLow, padHigh;
 };
 
 class NonEmptySubSectIterator : public SparseIterator {
@@ -1408,10 +1482,19 @@ sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
   return ret;
 }
 
+std::unique_ptr<SparseIterator>
+sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
+                                  Value padLow, Value padHigh,
+                                  SparseEmitStrategy strategy) {
+  auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
+  ret->setSparseEmitStrategy(strategy);
+  return ret;
+}
+
 static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
   auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
   if (filter)
-    return filter->wrap.get();
+    return &filter->getWrappedIterator();
   return it;
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index b692848ec67bd8..fe43b15a33698b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -76,6 +76,7 @@ enum class IterKind : uint8_t {
   kSubSect,
   kNonEmptySubSect,
   kFilter,
+  kPad,
 };
 
 /// Helper class that generates loop conditions, etc, to traverse a
@@ -303,6 +304,12 @@ std::unique_ptr<SparseIterator>
 makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
                         Value stride, Value size, SparseEmitStrategy strategy);
 
+/// Helper function to create a SparseIterator object that iterate over a
+/// padded sparse level (the padded value must be zero).
+std::unique_ptr<SparseIterator>
+makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow,
+                   Value padHigh, SparseEmitStrategy strategy);
+
 /// Helper function to create a SparseIterator object that iterate over the
 /// non-empty subsections set.
 std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
new file mode 100644
index 00000000000000..1deb6f74c0d28b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -0,0 +1,169 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+//  do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
+// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and VLA vectorization.
+// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
+
+#CCCC = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
+}>
+
+#CDCD = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
+}>
+
+#DCCD = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+}>
+
+// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
+func.func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f : f32) -> tensor<?x?x?x?xf32> {
+  %buf = tensor.empty(%s1, %s2, %s3, %s4) : tensor<?x?x?x?xf32>
+  %ret = linalg.fill ins(%f : f32) outs(%buf : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %ret : tensor<?x?x?x?xf32>
+}
+
+func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf32>, %arg2: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32> {
+  %cst_0 = arith.constant 0.00000e+00 : f32
+
+  %padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
+   ^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
+     tensor.yield %cst_0 : f32
+   } : tensor<3x8x8x3xf32> to tensor<3x12x12x3xf32>
+
+  %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                     strides = dense<1> : tensor<2xi64>}
+     ins (%padded, %arg1: tensor<3x12x12x3xf32>, tensor<5x5x3x1xf32>)
+    outs (%arg2: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
+  return %ret : tensor<3x8x8x1xf32>
+}
+
+func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c8 = arith.constant 8 : index
+  %cst_0 = arith.constant 0.00000e+00 : f32
+  %buf = tensor.empty() : tensor<3x8x8x1xf32>
+  %s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
+
+  %padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
+   ^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
+     tensor.yield %cst_0 : f32
+   } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+
+  %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                     strides = dense<1> : tensor<2xi64>}
+     ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+    outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
+  return %ret : tensor<3x8x8x1xf32>
+}
+
+func.func @main() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %c5 = arith.constant 5 : index
+  %c6 = arith.constant 6 : index
+  %c8 = arith.constant 8 : index
+  %f10 = arith.constant 10.00000e+00 : f32
+  %val = arith.constant 2.00000e+00 : f32
+  %zero = arith.constant 0.00000e+00 : f32
+
+  %filter2D_nhwc = call @alloc_4d_filled_f32(%c5, %c5, %c3, %c1, %val) :(index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
+  %in2D_tmp = call @alloc_4d_filled_f32(%c3, %c8, %c8, %c3, %val) : (index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
+  %in2D_nhwc = tensor.insert %f10 into %in2D_tmp[%c0, %c0, %c3, %c0] : tensor<?x?x?x?xf32>
+  %out2D_nhwc = call @alloc_4d_filled_f32(%c3, %c8, %c8, %c1, %zero) : (index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
+
+  %static_filter = tensor.cast %filter2D_nhwc : tensor<?x?x?x?xf32> to tensor<5x5x3x1xf32>
+  %static_input = tensor.cast %in2D_nhwc : tensor<?x?x?x?xf32> to tensor<3x8x8x3xf32>
+  %static_output = tensor.cast %out2D_nhwc : tensor<?x?x?x?xf32> to tensor<3x8x8x1xf32>
+
+  %dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
+
+  %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
+  %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+
+
+  // CHECK:   ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
+  // CHECK-SAME:  ( ( 144 ), ( 208 ), ( 256 ), ( 256 ), ( 256 ), ( 256 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 256 ), ( 316 ), ( 316 ), ( 316 ), ( 316 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:  ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ),
+  // CHECK-SAME:( ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ),
+  // CHECK-SAME:  ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:  ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ),
+  // CHECK-SAME:( ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ),
+  // CHECK-SAME:  ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:  ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:  ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
+  %dense_v = vector.transfer_read %dense_ret[%c0, %c0, %c0, %c0], %zero
+      : tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
+  vector.print %dense_v : vector<3x8x8x1xf32>
+
+  // CHECK-NEXT: ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
+  // CHECK-SAME:     ( ( 144 ), ( 208 ), ( 256 ), ( 256 ), ( 256 ), ( 256 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 256 ), ( 316 ), ( 316 ), ( 316 ), ( 316 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:     ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ),
+  // CHECK-SAME:   ( ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ),
+  // CHECK-SAME:     ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:     ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ),
+  // CHECK-SAME:   ( ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ),
+  // CHECK-SAME:     ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
+  // CHECK-SAME:     ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
+  // CHECK-SAME:     ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
+  %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+      : tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
+  vector.print %CCCC_v : vector<3x8x8x1xf32>
+  return
+}



More information about the Mlir-commits mailing list