[Mlir-commits] [mlir] 3483fc5 - [mlir][SPIRVToLLVM] Add shufflevector conversion

Weiwei Li llvmlistbot at llvm.org
Mon Nov 1 08:05:58 PDT 2021


Author: Weiwei Li
Date: 2021-11-01T23:05:37+08:00
New Revision: 3483fc5a31179123ffd11494488074683706a492

URL: https://github.com/llvm/llvm-project/commit/3483fc5a31179123ffd11494488074683706a492
DIFF: https://github.com/llvm/llvm-project/commit/3483fc5a31179123ffd11494488074683706a492.diff

LOG: [mlir][SPIRVToLLVM] Add shufflevector conversion

Add the shufflevector conversion. It only handles the static, i.e., IntegerAttr, index.

Co-authored: Xinyi Liu <xyliuhelen at gmail.com>

Reviewed by: antiagainst

Differential revision: https://reviews.llvm.org/D112161

Added: 
    

Modified: 
    mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
    mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 51df6c763c3fe..b416c303ad51a 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1364,6 +1364,65 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// VectorShuffleOp conversion
+//===----------------------------------------------------------------------===//
+
+class VectorShufflePattern
+    : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
+  LogicalResult
+  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto components = adaptor.components();
+    auto vector1 = adaptor.vector1();
+    auto vector2 = adaptor.vector2();
+    int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
+    int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
+    if (vector1Size == vector2Size) {
+      rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2,
+                                                         components);
+      return success();
+    }
+
+    auto dstType = typeConverter.convertType(op.getType());
+    auto scalarType = dstType.cast<VectorType>().getElementType();
+    auto componentsArray = components.getValue();
+    auto context = rewriter.getContext();
+    auto llvmI32Type = IntegerType::get(context, 32);
+    Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
+    for (unsigned i = 0; i < componentsArray.size(); i++) {
+      if (componentsArray[i].isa<IntegerAttr>())
+        op.emitError("unable to support non-constant component");
+
+      int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
+      if (indexVal == -1)
+        continue;
+
+      int offsetVal = 0;
+      Value baseVector = vector1;
+      if (indexVal >= vector1Size) {
+        offsetVal = vector1Size;
+        baseVector = vector2;
+      }
+
+      Value dstIndex = rewriter.create<LLVM::ConstantOp>(
+          loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
+      Value index = rewriter.create<LLVM::ConstantOp>(
+          loc, llvmI32Type,
+          rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
+
+      auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
+          loc, scalarType, baseVector, index);
+      targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
+                                                        extractOp, dstIndex);
+    }
+    rewriter.replaceOp(op, targetOp);
+    return success();
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1489,6 +1548,7 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       CompositeExtractPattern, CompositeInsertPattern,
       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
+      VectorShufflePattern,
 
       // Shift ops
       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
index c8528d062f7cb..38d55b9a659b8 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
@@ -58,6 +58,32 @@ spv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) "None" {
   spv.Return
 }
 
+//===----------------------------------------------------------------------===//
+// spv.VectorShuffle
+//===----------------------------------------------------------------------===//
+
+spv.func @vector_shuffle_same_size(%vector1: vector<2xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" {
+  //      CHECK: %[[res:.*]] = llvm.shufflevector {{.*}} [0 : i32, 2 : i32, -1 : i32] : vector<2xf32>, vector<2xf32>
+  // CHECK-NEXT: return %[[res]] : vector<3xf32>
+  %0 = spv.VectorShuffle [0: i32, 2: i32, 0xffffffff: i32] %vector1: vector<2xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+  spv.ReturnValue %0: vector<3xf32>
+}
+
+spv.func @vector_shuffle_
diff erent_size(%vector1: vector<3xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" {
+  //      CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : vector<3xf32>
+  // CHECK-NEXT: %[[C0_0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK-NEXT: %[[EXT0:.*]] = llvm.extractelement %arg0[%[[C0_1]] : i32] : vector<3xf32>
+  // CHECK-NEXT: %[[INSERT0:.*]] = llvm.insertelement %[[EXT0]], %[[UNDEF]][%[[C0_0]] : i32] : vector<3xf32>
+  // CHECK-NEXT: %[[C1_0:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NEXT: %[[C1_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NEXT: %[[EXT1:.*]] = llvm.extractelement {{.*}}[%[[C1_1]] : i32] : vector<2xf32>
+  // CHECK-NEXT: %[[RES:.*]] = llvm.insertelement %[[EXT1]], %[[INSERT0]][%[[C1_0]] : i32] : vector<3xf32>
+  // CHECK-NEXT: llvm.return %[[RES]] : vector<3xf32>
+  %0 = spv.VectorShuffle [0: i32, 4: i32, 0xffffffff: i32] %vector1: vector<3xf32>, %vector2: vector<2xf32> -> vector<3xf32>
+  spv.ReturnValue %0: vector<3xf32>
+}
+
 //===----------------------------------------------------------------------===//
 // spv.EntryPoint and spv.ExecutionMode
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list