[Mlir-commits] [mlir] 6301574 - [mlir][SparseTensor] Enable VLA ops in index value generation

Javier Setoain llvmlistbot at llvm.org
Thu Apr 28 01:45:48 PDT 2022


Author: Javier Setoain
Date: 2022-04-28T09:39:07+01:00
New Revision: 6301574206b39f72edb957f3b069f3892c117d4b

URL: https://github.com/llvm/llvm-project/commit/6301574206b39f72edb957f3b069f3892c117d4b
DIFF: https://github.com/llvm/llvm-project/commit/6301574206b39f72edb957f3b069f3892c117d4b.diff

LOG: [mlir][SparseTensor] Enable VLA ops in index value generation

Current index value generation uses fixed-length vector ops, this patch
adds an alterantive codegen path compatible with scalable vectors by
using `LLVM::StepVectorOp`.

Differential Revision: https://reviews.llvm.org/D124454

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 237cfa9724b5c..e7d03e8fd609d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -889,11 +890,18 @@ static Value genIndexValue(Merger &merger, CodeGen &codegen,
     VectorType vtp = vectorType(codegen, itype);
     ival = rewriter.create<vector::BroadcastOp>(loc, vtp, ival);
     if (idx == ldx) {
-      SmallVector<APInt, 4> integers;
-      for (unsigned i = 0; i < vl; i++)
-        integers.push_back(APInt(/*width=*/64, i));
-      auto values = DenseElementsAttr::get(vtp, integers);
-      Value incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
+      Value incr;
+      if (vtp.isScalable()) {
+        Type stepvty = vectorType(codegen, rewriter.getI64Type());
+        Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty);
+        incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv);
+      } else {
+        SmallVector<APInt, 4> integers;
+        for (unsigned i = 0; i < vl; i++)
+          integers.push_back(APInt(/*width=*/64, i));
+        auto values = DenseElementsAttr::get(vtp, integers);
+        incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
+      }
       ival = rewriter.create<arith::AddIOp>(loc, ival, incr);
     }
   }


        


More information about the Mlir-commits mailing list