[Mlir-commits] [mlir] [mlir][sparse] fix error when convolution stride is applied on a dens… (PR #79521)
Peiming Liu
llvmlistbot at llvm.org
Thu Jan 25 15:29:24 PST 2024
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/79521
…e level.
>From fef22ec23fc86f6b5cc3390f04284bdec89ce312 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 25 Jan 2024 23:27:53 +0000
Subject: [PATCH] [mlir][sparse] fix error when convolution stride is applied
on a dense level.
---
.../Transforms/Utils/LoopEmitter.cpp | 4 ++--
.../Transforms/Utils/SparseTensorLevel.cpp | 9 +++++---
.../Transforms/Utils/SparseTensorLevel.h | 2 +-
.../CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir | 21 +++++++++++++++++++
4 files changed, 30 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 6d832fdc0c2201..3fa4004ae460ef 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -313,8 +313,8 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
Value loopHi = loopHighs[loop];
size = ADDI(size, MULI(loopHi, C_IDX(stride)));
}
- it = makeNonEmptySubSectIterator(builder, loc, parent, std::move(lvlIt),
- size, curDep.second);
+ it = makeNonEmptySubSectIterator(builder, loc, parent, loopHighs[loop],
+ std::move(lvlIt), size, curDep.second);
} else {
Value size = loopHighs[loop];
const SparseIterator &subSectIter = *iters[t][lvl].back();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 226cccbc422b9b..e43896942d7fe6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1271,7 +1271,7 @@ static const IterType *unwrapFilter(const SparseIterator *it) {
}
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
- OpBuilder &b, Location l, const SparseIterator *parent,
+ OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
// Try unwrap the NonEmptySubSectIterator from a filter parent.
@@ -1279,9 +1279,12 @@ std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
auto it = std::make_unique<NonEmptySubSectIterator>(
b, l, parent, std::move(delegate), size);
- if (stride != 1)
+ if (stride != 1) {
+ // TODO: We can safely skip bound checking on sparse levels, but for dense
+ // iteration space, we need the bound to infer the dense loop range.
return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
- C_IDX(stride), /*size=*/C_IDX(-1));
+ C_IDX(stride), /*size=*/loopBound);
+ }
return it;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 08f7c6a747eb57..d2b3396b72836c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -246,7 +246,7 @@ makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
/// Helper function to create a SparseIterator object that iterate over the
/// non-empty subsections set.
std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
- OpBuilder &b, Location l, const SparseIterator *parent,
+ OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);
/// Helper function to create a SparseIterator object that iterate over a
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir
index 98adc26f06e56b..8ee80045afc760 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir
@@ -69,6 +69,14 @@ func.func @conv_2d_nhwc_hwcf_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tens
return %ret : tensor<?x?x?x?xf32>
}
+func.func @conv_2d_nhwc_hwcf_dual_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tensor<?x?x?x?xf32, #CDCC>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>)
+ outs (%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %ret : tensor<?x?x?x?xf32>
+}
+
func.func @entry() {
%c0 = arith.constant 0 : index
@@ -87,6 +95,8 @@ func.func @entry() {
%in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CCCC>
+ %filter2D_nhwc_CDCC = sparse_tensor.convert %filter2D_nhwc
+ : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
%in2D_nhwc_CDCC = sparse_tensor.convert %in2D_nhwc
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
@@ -94,9 +104,19 @@ func.func @entry() {
%CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
%CDCC_ret = call @conv_2d_nhwc_hwcf_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
+ %dual_CDCC_ret = call @conv_2d_nhwc_hwcf_dual_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc_CDCC, %out2D_nhwc)
+ : (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
+
// CHECK: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
+ %v_dual = vector.transfer_read %dual_CDCC_ret[%c0, %c0, %c0, %c0], %zero
+ : tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
+ vector.print %v_dual : vector<3x3x3x1xf32>
+
+ // CHECK-NEXT: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
+ // CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
+ // CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
%dense_v = vector.transfer_read %dense_ret[%c0, %c0, %c0, %c0], %zero
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
vector.print %dense_v : vector<3x3x3x1xf32>
@@ -120,6 +140,7 @@ func.func @entry() {
bufferization.dealloc_tensor %filter2D_nhwc : tensor<?x?x?x?xf32>
bufferization.dealloc_tensor %out2D_nhwc : tensor<?x?x?x?xf32>
+ bufferization.dealloc_tensor %filter2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<?x?x?x?xf32, #CCCC>
bufferization.dealloc_tensor %in2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
return
More information about the Mlir-commits
mailing list