[Mlir-commits] [mlir] 431f6a5 - [sparse][mlir][vectorization] add support for shift-by-invariant
Aart Bik
llvmlistbot at llvm.org
Tue Dec 27 11:07:24 PST 2022
Author: Aart Bik
Date: 2022-12-27T11:07:13-08:00
New Revision: 431f6a543e858825ff9d258310e8d4eb9592e326
URL: https://github.com/llvm/llvm-project/commit/431f6a543e858825ff9d258310e8d4eb9592e326
DIFF: https://github.com/llvm/llvm-project/commit/431f6a543e858825ff9d258310e8d4eb9592e326.diff
LOG: [sparse][mlir][vectorization] add support for shift-by-invariant
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D140596
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 65af4e0e1e86..e652ebdff5cc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -50,6 +50,16 @@ static bool isIntValue(Value val, int64_t idx) {
return false;
}
+/// Helper test for invariant value (defined outside given block).
+static bool isInvariantValue(Value val, Block *block) {
+ return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
+}
+
+/// Helper test for invariant argument (defined outside given block).
+static bool isInvariantArg(BlockArgument arg, Block *block) {
+ return arg.getOwner() != block;
+}
+
/// Constructs vector type for element type.
static VectorType vectorType(VL vl, Type etp) {
unsigned numScalableDims = vl.enableVLAVectorization;
@@ -236,13 +246,15 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
Value vmask, SmallVectorImpl<Value> &idxs) {
unsigned d = 0;
unsigned dim = subs.size();
+ Block *block = &forOp.getRegion().front();
for (auto sub : subs) {
bool innermost = ++d == dim;
// Invariant subscripts in outer dimensions simply pass through.
// Note that we rely on LICM to hoist loads where all subscripts
// are invariant in the innermost loop.
- if (sub.getDefiningOp() &&
- sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) {
+ // Example:
+ // a[inv][i] for inv
+ if (isInvariantValue(sub, block)) {
if (innermost)
return false;
if (codegen)
@@ -252,9 +264,10 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
// Invariant block arguments (including outer loop indices) in outer
// dimensions simply pass through. Direct loop indices in the
// innermost loop simply pass through as well.
- if (auto barg = sub.dyn_cast<BlockArgument>()) {
- bool invariant = barg.getOwner() != &forOp.getRegion().front();
- if (invariant == innermost)
+ // Example:
+ // a[i][j] for both i and j
+ if (auto arg = sub.dyn_cast<BlockArgument>()) {
+ if (isInvariantArg(arg, block) == innermost)
return false;
if (codegen)
idxs.push_back(sub);
@@ -281,6 +294,8 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
// values, there is no good way to state that the indices are unsigned,
// which creates the potential of incorrect address calculations in the
// unlikely case we need such extremely large offsets.
+ // Example:
+ // a[ ind[i] ]
if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
if (!innermost)
return false;
@@ -303,18 +318,20 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
continue; // success so far
}
// Address calculation 'i = add inv, idx' (after LICM).
+ // Example:
+ // a[base + i]
if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
Value inv = load.getOperand(0);
Value idx = load.getOperand(1);
- if (inv.getDefiningOp() &&
- inv.getDefiningOp()->getBlock() != &forOp.getRegion().front() &&
- idx.dyn_cast<BlockArgument>()) {
- if (!innermost)
- return false;
- if (codegen)
- idxs.push_back(
- rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
- continue; // success so far
+ if (isInvariantValue(inv, block)) {
+ if (auto arg = idx.dyn_cast<BlockArgument>()) {
+ if (isInvariantArg(arg, block) || !innermost)
+ return false;
+ if (codegen)
+ idxs.push_back(
+ rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
+ continue; // success so far
+ }
}
}
return false;
@@ -389,7 +406,8 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
}
// Something defined outside the loop-body is invariant.
Operation *def = exp.getDefiningOp();
- if (def->getBlock() != &forOp.getRegion().front()) {
+ Block *block = &forOp.getRegion().front();
+ if (def->getBlock() != block) {
if (codegen)
vexp = genVectorInvariantValue(rewriter, vl, exp);
return true;
@@ -450,6 +468,17 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
vx) &&
vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
vy)) {
+ // We only accept shift-by-invariant (where the same shift factor applies
+ // to all packed elements). In the vector dialect, this is still
+ // represented with an expanded vector at the right-hand-side, however,
+ // so that we do not have to special case the code generation.
+ if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
+ isa<arith::ShRSIOp>(def)) {
+ Value shiftFactor = def->getOperand(1);
+ if (!isInvariantValue(shiftFactor, block))
+ return false;
+ }
+ // Generate code.
BINOP(arith::MulFOp)
BINOP(arith::MulIOp)
BINOP(arith::DivFOp)
@@ -462,8 +491,10 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
BINOP(arith::AndIOp)
BINOP(arith::OrIOp)
BINOP(arith::XOrIOp)
+ BINOP(arith::ShLIOp)
+ BINOP(arith::ShRUIOp)
+ BINOP(arith::ShRSIOp)
// TODO: complex?
- // TODO: shift by invariant?
}
}
return false;
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
index 32900d93c0bf..bf885f1920ad 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_ops.mlir
@@ -17,6 +17,8 @@
// CHECK-DAG: %[[C1:.*]] = arith.constant dense<2.000000e+00> : vector<8xf32>
// CHECK-DAG: %[[C2:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32>
// CHECK-DAG: %[[C3:.*]] = arith.constant dense<255> : vector<8xi64>
+// CHECK-DAG: %[[C4:.*]] = arith.constant dense<4> : vector<8xi32>
+// CHECK-DAG: %[[C5:.*]] = arith.constant dense<1> : vector<8xi32>
// CHECK: scf.for
// CHECK: %[[VAL_14:.*]] = vector.load
// CHECK: %[[VAL_15:.*]] = math.absf %[[VAL_14]] : vector<8xf32>
@@ -38,8 +40,11 @@
// CHECK: %[[VAL_31:.*]] = arith.andi %[[VAL_30]], %[[C3]] : vector<8xi64>
// CHECK: %[[VAL_32:.*]] = arith.trunci %[[VAL_31]] : vector<8xi64> to vector<8xi16>
// CHECK: %[[VAL_33:.*]] = arith.extsi %[[VAL_32]] : vector<8xi16> to vector<8xi32>
-// CHECK: %[[VAL_34:.*]] = arith.uitofp %[[VAL_33]] : vector<8xi32> to vector<8xf32>
-// CHECK: vector.store %[[VAL_34]]
+// CHECK: %[[VAL_34:.*]] = arith.shrsi %[[VAL_33]], %[[C4]] : vector<8xi32>
+// CHECK: %[[VAL_35:.*]] = arith.shrui %[[VAL_34]], %[[C4]] : vector<8xi32>
+// CHECK: %[[VAL_36:.*]] = arith.shli %[[VAL_35]], %[[C5]] : vector<8xi32>
+// CHECK: %[[VAL_37:.*]] = arith.uitofp %[[VAL_36]] : vector<8xi32> to vector<8xf32>
+// CHECK: vector.store %[[VAL_37]]
// CHECK: }
func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
%argb: tensor<1024xf32, #DenseVector>) -> tensor<1024xf32> {
@@ -47,6 +52,8 @@ func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
%o = arith.constant 1.0 : f32
%c = arith.constant 2.0 : f32
%i = arith.constant 255 : i64
+ %s = arith.constant 4 : i32
+ %t = arith.constant 1 : i32
%0 = linalg.generic #trait
ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32, #DenseVector>)
outs(%init: tensor<1024xf32>) {
@@ -69,8 +76,11 @@ func.func @vops(%arga: tensor<1024xf32, #DenseVector>,
%15 = arith.andi %14, %i : i64
%16 = arith.trunci %15 : i64 to i16
%17 = arith.extsi %16 : i16 to i32
- %18 = arith.uitofp %17 : i32 to f32
- linalg.yield %18 : f32
+ %18 = arith.shrsi %17, %s : i32
+ %19 = arith.shrui %18, %s : i32
+ %20 = arith.shli %19, %t : i32
+ %21 = arith.uitofp %20 : i32 to f32
+ linalg.yield %21 : f32
} -> tensor<1024xf32>
return %0 : tensor<1024xf32>
}
More information about the Mlir-commits
mailing list