[Mlir-commits] [mlir] [mlir][sparse] Implement rewriters to reinterpret maps on foreach (PR #70868)
Peiming Liu
llvmlistbot at llvm.org
Wed Nov 1 11:25:06 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/70868
>From 236ce30cc54febeb14b98ba25f0085c6b18a755b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 31 Oct 2023 00:04:22 +0000
Subject: [PATCH 1/3] [mlir][sparse] add helper class to implement common
rewriter to re/demap sparse tensors.
---
.../Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a822effbb2ab78c..b301943c8732dab 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -142,7 +142,7 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
};
//===----------------------------------------------------------------------===//
-// Rewriting rules for operations other than linalg generic ops.
+// Reinterpret Map Rewriters for operations other than linalg.generics
//===----------------------------------------------------------------------===//
// CRTP to help implementing a rewriter that demaps all its inputs and remaps
>From f00f63901eaa508857e46ba1b4aa7dc44845927b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 31 Oct 2023 22:34:58 +0000
Subject: [PATCH 2/3] [mlir][sparse] Implement rewriters to reinterpret maps on
foreach operands.
---
.../Transforms/SparseReinterpretMap.cpp | 190 ++++++++++++------
.../Transforms/SparseTensorRewriting.cpp | 27 ++-
.../Dialect/Affine/decompose-affine-ops.mlir | 10 +-
.../SparseTensor/sparse_reinterpret_map.mlir | 48 +++--
4 files changed, 200 insertions(+), 75 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index b301943c8732dab..2d50205f1e96b8b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -25,8 +25,8 @@ namespace {
//===----------------------------------------------------------------------===//
// Translates a "simple" map according to an identity lvl-map.
-static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
- AffineMap map) {
+AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
+ AffineMap map) {
unsigned lvlRank = stt.getLvlRank();
AffineMap lvl2dim = stt.getLvlToDim();
assert(lvl2dim.getNumInputs() == lvlRank);
@@ -39,18 +39,37 @@ static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
}
// Generates a "de"mapping reinterpretation of the map.
-static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
- Value val) {
+Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) {
return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
val);
}
// Generates a "re"mapping reinterpretation of the map.
-static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
- Value val) {
+Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) {
return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
}
+SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
+ ValueRange outs) {
+ SmallVector<Value> ret(outs);
+ assert(outs.size() == types.size());
+ for (auto [r, t] : llvm::zip(ret, types))
+ if (r.getType() != t)
+ r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
+ return ret;
+}
+
+/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
+bool hasNonIdentityOperandsOrResults(Operation *op) {
+ auto hasNonIdentityMap = [](Value v) {
+ auto stt = tryGetSparseTensorType(v);
+ return stt && !stt->isIdentity();
+ };
+
+ return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
+ llvm::any_of(op->getResults(), hasNonIdentityMap);
+}
+
// Generates a clone of the given linalg generic operation, but with
// remapped arguments, index maps, and iteration types.
//
@@ -141,10 +160,6 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
}
};
-//===----------------------------------------------------------------------===//
-// Reinterpret Map Rewriters for operations other than linalg.generics
-//===----------------------------------------------------------------------===//
-
// CRTP to help implementing a rewriter that demaps all its inputs and remaps
// all its outputs.
template <typename SubClass, typename SourceOp>
@@ -154,9 +169,6 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
- if (!static_cast<const SubClass *>(this)->matchOp(op))
- return failure();
-
Location loc = op.getLoc();
// Demaps non-trivial inputs.
SmallVector<Value> deMappedIns(op->getOperands());
@@ -166,61 +178,125 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
// CRTP call.
OpAdaptor adaptor(deMappedIns);
- ValueRange outs =
- static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
- assert(outs.size() == op->getResults().size());
-
- // Remap outputs.
- SmallVector<Value> reMappedOuts(outs);
- for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
- if (r.getType() != a.getType())
- r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
-
- rewriter.replaceOp(op, reMappedOuts);
- return success();
+ return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+ rewriter);
}
};
-struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(CrdTranslateOp op,
- PatternRewriter &rewriter) const override {
- AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
- ? op.getEncoder().getDimToLvl()
- : op.getEncoder().getLvlToDim();
-
- SmallVector<Value> outCrds;
- for (AffineExpr result : map.getResults()) {
- // TODO: we should probably expand the affine map to IR using our own
- // rules, since affine.apply assume signed value, while the cooridinates
- // we provided must always be signless.
- Value trans = rewriter.create<affine::AffineApplyOp>(
- op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
- op.getInCrds());
- outCrds.push_back(trans);
- }
- rewriter.replaceOp(op, outCrds);
- return success();
- }
-};
+//===----------------------------------------------------------------------===//
+// Reinterpret Map Rewriters for operations other than linalg.generics
+//===----------------------------------------------------------------------===//
-struct TensorInsertRewriter
- : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
+struct TensorInsertDemapper
+ : public DemapInsRemapOutsRewriter<TensorInsertDemapper, tensor::InsertOp> {
using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+ LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
+ if (!hasAnySparseResult(op))
+ return failure();
- bool matchOp(tensor::InsertOp op) const {
- return op.getResult().getType().getEncoding() != nullptr;
- }
-
- ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
- PatternRewriter &rewriter) const {
Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
CrdTransDirectionKind::dim2lvl);
Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
- return insertOp->getResults();
+
+ SmallVector<Value> outs(insertOp->getResults());
+ remapValueRange(rewriter, op->getResultTypes(), outs);
+ rewriter.replaceOp(op, outs);
+ return success();
+ }
+};
+
+struct ForeachOpDemapper
+ : public DemapInsRemapOutsRewriter<ForeachOpDemapper, ForeachOp> {
+ using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+ LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
+ // Only handles operations with sparse input/output.
+ if (!hasNonIdentityOperandsOrResults(op))
+ return failure();
+
+ // TODO: demap constant as well.
+ if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
+ if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
+ return failure();
+
+ Location loc = op.getLoc();
+ // Cache the type information since we update the foreach op in-place.
+ auto srcStt = getSparseTensorType(op.getTensor());
+ SmallVector<Type> prevRetTps(op.getResultTypes());
+
+ rewriter.startRootUpdate(op);
+ op.getTensorMutable().assign(adaptor.getTensor());
+ op.getInitArgsMutable().assign(adaptor.getInitArgs());
+ // Update results' types.
+ for (auto r : op.getResults())
+ if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
+ r.setType(stt->getDemappedType());
+
+ Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
+ // Update the foreach body.
+ SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
+ blockArgTps.push_back(srcStt.getElementType());
+ blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
+ adaptor.getInitArgs().getTypes().end());
+ Block *body = op.getBody();
+ // Block Args: [dimCrd, val, initArgs]
+ unsigned preArgNum = body->getNumArguments();
+ for (Type t : blockArgTps)
+ body->addArgument(t, loc);
+
+ // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
+ rewriter.setInsertionPointToStart(body);
+ ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
+
+ ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
+ CrdTransDirectionKind::lvl2dim);
+ rewriter.replaceAllUsesWith(
+ body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
+ body->eraseArguments(0, srcStt.getDimRank());
+ // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
+ unsigned numInitArgs = op.getInitArgs().size();
+ rewriter.replaceAllUsesWith(body->getArgument(0),
+ body->getArgument(lvlRank + numInitArgs + 1));
+ body->eraseArgument(0);
+ // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
+ ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
+ SmallVector<Value> dstArgs(body->getArguments().take_back(numInitArgs));
+ // Remap back before replacement;
+ for (auto [s, d] : llvm::zip(srcArgs, dstArgs))
+ if (s.getType() != d.getType())
+ d = rewriter.create<ReinterpretMapOp>(loc, s.getType(), d);
+ rewriter.replaceAllUsesWith(srcArgs, dstArgs);
+ body->eraseArguments(0, numInitArgs);
+ // Block Args: [lvlCrds, DemappedArgs] and we are done.
+
+ // Update yield operations.
+ if (numInitArgs != 0) {
+ rewriter.setInsertionPointToEnd(body);
+ auto yield = llvm::cast<YieldOp>(body->getTerminator());
+ if (auto stt = tryGetSparseTensorType(yield.getResult());
+ stt && !stt->isIdentity()) {
+ Value y = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(),
+ yield.getResult());
+ rewriter.create<YieldOp>(loc, y);
+ rewriter.eraseOp(yield);
+ }
+ }
+ rewriter.finalizeRootUpdate(op);
+
+ rewriter.setInsertionPointAfter(op);
+ SmallVector<Value> outs(op.getResults());
+ remapValueRange(rewriter, prevRetTps, outs);
+
+ // Replace all the uses of the foreach results, expect the use in
+ // reinterpret_map used to remap the output.
+ for (auto [from, to] : llvm::zip(op.getResults(), outs))
+ rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
+
+ return success();
}
};
@@ -234,7 +310,7 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
}
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kExceptGeneric) {
- patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(
+ patterns.add<TensorInsertDemapper, ForeachOpDemapper>(
patterns.getContext());
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 02796bc9a7e7df6..c00f19916e49fbe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1063,6 +1063,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
}
};
+struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CrdTranslateOp op,
+ PatternRewriter &rewriter) const override {
+ AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
+ ? op.getEncoder().getDimToLvl()
+ : op.getEncoder().getLvlToDim();
+
+ SmallVector<Value> outCrds;
+ for (AffineExpr result : map.getResults()) {
+ // TODO: we should probably expand the affine map to IR using our own
+ // rules, since affine.apply assume signed value, while the cooridinates
+ // we provided must always be signless.
+ Value trans = rewriter.create<affine::AffineApplyOp>(
+ op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
+ op.getInCrds());
+ outCrds.push_back(trans);
+ }
+ rewriter.replaceOp(op, outCrds);
+ return success();
+ }
+};
+
/// Sparse rewriting rule for the foreach operator.
struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
public:
@@ -1284,5 +1307,7 @@ void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
}
void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
- patterns.add<ForeachRewriter>(patterns.getContext());
+ // Run CrdTranslateRewriter later in the pipeline so that operation can be
+ // folded before lowering to affine.apply
+ patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
index 6acdc436fe6774a..654b73a0bb6df09 100644
--- a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
+++ b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
@@ -108,12 +108,12 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
// The hoisted part is %b.
%a = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 16 + s2 * 32 + s3 * 32 + s0 floordiv 4)>()[%0, %1, %2, %i]
- // Gets completely hoisted
+ // Gets completely hoisted
%b = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- // Gets completely hoisted
+ // Gets completely hoisted
%c = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
-
+
// 32 * %j + %c remains here, the rest is hoisted.
// CHECK-DAG: %[[R10:.*]] = affine.apply #[[$times32]]()[%[[j]]]
// CHECK-DAG: %[[d:.*]] = affine.apply #[[$add]]()[%[[c]], %[[R10]]]
@@ -134,7 +134,7 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
// CHECK-NEXT: %[[g:.*]] = affine.apply #[[$add]]()[%[[b]], %[[idk]]]
%g = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 16 + s3 * 32 + s1 floordiv 4)>()[%k, %0, %1, %2]
-
+
// CHECK-NEXT: "some_side_effecting_consumer"(%[[a]]) : (index) -> ()
"some_side_effecting_consumer"(%a) : (index) -> ()
// CHECK-NEXT: "some_side_effecting_consumer"(%[[b]]) : (index) -> ()
@@ -151,6 +151,6 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
"some_side_effecting_consumer"(%g) : (index) -> ()
}
}
- }
+ }
return
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 149c0bc46e25118..be3ab37e9cbd182 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -1,15 +1,4 @@
-// RUN: mlir-opt %s -split-input-file --sparse-reinterpret-map | FileCheck %s
-
-#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
-
-// CHECK-LABEL: func @sparse_nop(
-// CHECK-SAME: %[[A0:.*]]: tensor<?xf64, #sparse_tensor.encoding<{{{.*}}}>>)
-// CHECK: return %[[A0]]
-func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
- return %arg0 : tensor<?xf64, #SparseVector>
-}
-
-// -----
+// RUN: mlir-opt %s -split-input-file --sparse-reinterpret-map | FileCheck %s
#trait_mul = {
indexing_maps = [
@@ -55,3 +44,38 @@ func.func @mul(%arg0: tensor<32x32xf32>,
return %0 : tensor<32x32xf32, #BSR>
}
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 2 : compressed,
+ i mod 2 : dense,
+ j mod 2 : dense
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_foreach_reinterpret_map(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64
+// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<2x4xf64
+// CHECK: %[[VAL_2:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.reinterpret_map %[[VAL_1]] : tensor<2x4xf64
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_3]])
+// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: f64, %[[VAL_10:.*]]: tensor<1x2x2x2xf64
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.insert %[[VAL_9]] into %[[VAL_10]]{{\[}}%[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]]] : tensor<1x2x2x2xf64
+// CHECK: sparse_tensor.yield %[[VAL_11]] : tensor<1x2x2x2xf64
+// CHECK: }
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.reinterpret_map %[[VAL_4]] : tensor<1x2x2x2xf64
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.load %[[VAL_12]] hasInserts : tensor<2x4xf64
+// CHECK: return %[[VAL_13]] : tensor<2x4xf64
+// CHECK: }
+func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<2x4xf64, #BSR> {
+ %7 = bufferization.alloc_tensor() : tensor<2x4xf64, #BSR>
+ %8 = sparse_tensor.foreach in %6 init(%7) : tensor<2x4xf64, #BSR>, tensor<2x4xf64, #BSR> -> tensor<2x4xf64, #BSR> do {
+ ^bb0(%arg0: index, %arg1: index, %arg2: f64, %arg3: tensor<2x4xf64, #BSR>):
+ %inserted = tensor.insert %arg2 into %arg3[%arg0, %arg1] : tensor<2x4xf64, #BSR>
+ sparse_tensor.yield %inserted : tensor<2x4xf64, #BSR>
+ }
+ %9 = sparse_tensor.load %8 hasInserts : tensor<2x4xf64, #BSR>
+ return %9 : tensor<2x4xf64, #BSR>
+}
>From 862684df2ee601785c3c0e3aa462832d4b1380f3 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 31 Oct 2023 22:45:09 +0000
Subject: [PATCH 3/3] revert unintended change
---
mlir/test/Dialect/Affine/decompose-affine-ops.mlir | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
index 654b73a0bb6df09..6acdc436fe6774a 100644
--- a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
+++ b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
@@ -108,12 +108,12 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
// The hoisted part is %b.
%a = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 16 + s2 * 32 + s3 * 32 + s0 floordiv 4)>()[%0, %1, %2, %i]
- // Gets completely hoisted
+ // Gets completely hoisted
%b = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
- // Gets completely hoisted
+ // Gets completely hoisted
%c = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
-
+
// 32 * %j + %c remains here, the rest is hoisted.
// CHECK-DAG: %[[R10:.*]] = affine.apply #[[$times32]]()[%[[j]]]
// CHECK-DAG: %[[d:.*]] = affine.apply #[[$add]]()[%[[c]], %[[R10]]]
@@ -134,7 +134,7 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
// CHECK-NEXT: %[[g:.*]] = affine.apply #[[$add]]()[%[[b]], %[[idk]]]
%g = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 16 + s3 * 32 + s1 floordiv 4)>()[%k, %0, %1, %2]
-
+
// CHECK-NEXT: "some_side_effecting_consumer"(%[[a]]) : (index) -> ()
"some_side_effecting_consumer"(%a) : (index) -> ()
// CHECK-NEXT: "some_side_effecting_consumer"(%[[b]]) : (index) -> ()
@@ -151,6 +151,6 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
"some_side_effecting_consumer"(%g) : (index) -> ()
}
}
- }
+ }
return
}
More information about the Mlir-commits
mailing list