[Mlir-commits] [mlir] d0f9271 - [mlir][Standard] Support 0-D vectors in `SplatOp`

Nicolas Vasilache llvmlistbot at llvm.org
Fri Nov 26 09:05:20 PST 2021


Author: Michal Terepeta
Date: 2021-11-26T17:05:15Z
New Revision: d0f927121ecef997d5df4a8879997cc8c9f8f0c9

URL: https://github.com/llvm/llvm-project/commit/d0f927121ecef997d5df4a8879997cc8c9f8f0c9
DIFF: https://github.com/llvm/llvm-project/commit/d0f927121ecef997d5df4a8879997cc8c9f8f0c9.diff

LOG: [mlir][Standard] Support 0-D vectors in `SplatOp`

This changes the op to produce `AnyVectorOfAnyRank` and implements this by just
inserting the element (skipping the shuffle that we do for the 1-D case).

Depends On D114549

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114598

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 58c372b938b0c..d9df01bdea9a9 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -802,7 +802,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
 
   let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
                                  "integer/index/float type">:$input);
-  let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate);
+  let results = (outs AnyTypeOf<[AnyVectorOfAnyRank,
+                                 AnyStaticShapeTensor]>:$aggregate);
 
   let builders = [
     OpBuilder<(ins "Value":$element, "Type":$aggregateType),

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 83f45721d853a..fcaba6e874834 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -702,7 +702,7 @@ struct SwitchOpLowering
 };
 
 // The Splat operation is lowered to an insertelement + a shufflevector
-// operation. Splat to only 1-d vector result types are lowered.
+// operation. Splat to only 0-d and 1-d vector result types are lowered.
 struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
 
@@ -710,7 +710,7 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
-    if (!resultType || resultType.getRank() != 1)
+    if (!resultType || resultType.getRank() > 1)
       return failure();
 
     // First insert it into an undef vector so we can shuffle it.
@@ -721,6 +721,14 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
         typeConverter->convertType(rewriter.getIntegerType(32)),
         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
 
+    // For 0-d vector, we simply do `insertelement`.
+    if (resultType.getRank() == 0) {
+      rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+          splatOp, vectorType, undef, adaptor.getInput(), zero);
+      return success();
+    }
+
+    // For 1-d vector, we additionally do a `vectorshuffle`.
     auto v = rewriter.create<LLVM::InsertElementOp>(
         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
 
@@ -745,7 +753,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
-    if (!resultType || resultType.getRank() == 1)
+    if (!resultType || resultType.getRank() <= 1)
       return failure();
 
     // First insert it into an undef vector so we can shuffle it.

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 7d0942ca8691b..a5c88455124a7 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -454,6 +454,21 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
   br ^bb1
 }
 
+// -----
+
+// CHECK-LABEL: @splat_0d
+// CHECK-SAME: %[[ARG:.*]]: f32
+func @splat_0d(%a: f32) -> vector<f32> {
+  %v = splat %a : vector<f32>
+  return %v : vector<f32>
+}
+// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<1xf32>
+// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ARG]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32>
+// CHECK-NEXT: llvm.return %[[V]] : vector<1xf32>
+
+// -----
+
 // CHECK-LABEL: @splat
 // CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32>
 // CHECK-SAME: %[[ELT:arg[0-9]+]]: f32

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index e7cbece4b1ed3..4d74366f80da0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -21,6 +21,13 @@ func @print_vector_0d(%a: vector<f32>) {
   return
 }
 
+func @splat_0d(%a: f32) {
+  %1 = splat %a : vector<f32>
+  // CHECK: ( 42 )
+  vector.print %1: vector<f32>
+  return
+}
+
 func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
@@ -30,5 +37,8 @@ func @entry() {
   %3 = arith.constant dense<42.0> : vector<f32>
   call  @print_vector_0d(%3) : (vector<f32>) -> ()
 
+  %4 = arith.constant 42.0 : f32
+  call  @splat_0d(%4) : (f32) -> ()
+
   return
 }


        


More information about the Mlir-commits mailing list