[Mlir-commits] [mlir] 04af7d2 - [mlir][spirv][vector] Fix vector shuffle conversion for scalar inputs

Jakub Kuderski llvmlistbot at llvm.org
Mon Jul 31 17:13:41 PDT 2023


Author: Jakub Kuderski
Date: 2023-07-31T20:13:30-04:00
New Revision: 04af7d224bd8ee2b346525b02ed04f57d8d9d1f8

URL: https://github.com/llvm/llvm-project/commit/04af7d224bd8ee2b346525b02ed04f57d8d9d1f8
DIFF: https://github.com/llvm/llvm-project/commit/04af7d224bd8ee2b346525b02ed04f57d8d9d1f8.diff

LOG: [mlir][spirv][vector] Fix vector shuffle conversion for scalar inputs

Handle both v0 and v1 operands that may be converted to SPIR-V scalars.

Fixes: https://github.com/llvm/llvm-project/issues/64271

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D156717

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c15b99b5a62d3e..57191ce8dd4c94 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -27,7 +27,10 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
+#include <cassert>
+#include <cstdint>
 #include <numeric>
 
 using namespace mlir;
@@ -444,24 +447,48 @@ struct VectorShuffleOpConvert final
       return rewriter.notifyMatchFailure(shuffleOp,
                                          "unsupported result vector type");
 
-    auto oldSourceType = shuffleOp.getV1VectorType();
-    if (oldSourceType.getNumElements() > 1) {
-      SmallVector<int32_t, 4> components = llvm::to_vector<4>(
-          llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
-            return cast<IntegerAttr>(attr).getValue().getZExtValue();
-          }));
+    SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
+        shuffleOp.getMask(), [](Attribute attr) -> int32_t {
+          return cast<IntegerAttr>(attr).getValue().getZExtValue();
+        });
+
+    auto oldV1Type = shuffleOp.getV1VectorType();
+    auto oldV2Type = shuffleOp.getV2VectorType();
+
+    // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
+    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
           shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
-          rewriter.getI32ArrayAttr(components));
+          rewriter.getI32ArrayAttr(mask));
       return success();
     }
 
-    SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
-    SmallVector<Value, 4> newOperands;
-    newOperands.reserve(oldResultType.getNumElements());
-    for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
-      newOperands.push_back(oldOperands[i.getZExtValue()]);
+    // When at least one of the operands becomes a scalar after type conversion
+    // for SPIR-V, extract all the required elements and construct the result
+    // vector.
+    auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
+                               Value scalarOrVec, int32_t idx) -> Value {
+      if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
+        return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
+                                                          idx);
+
+      assert(idx == 0 && "Invalid scalar element index");
+      return scalarOrVec;
+    };
+
+    int32_t numV1Elems = oldV1Type.getNumElements();
+    SmallVector<Value> newOperands(mask.size());
+    for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
+      Value vec = adaptor.getV1();
+      int32_t elementIdx = shuffleIdx;
+      if (elementIdx >= numV1Elems) {
+        vec = adaptor.getV2();
+        elementIdx -= numV1Elems;
+      }
+
+      newOperand = getElementAtIdx(vec, elementIdx);
     }
+
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
         shuffleOp, newResultType, newOperands);
 

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 922351265e38b9..310df8030db300 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -442,6 +442,47 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
 
 // -----
 
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<3xi32>
+//       CHECK:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xi32> to i32
+//       CHECK:    %[[S1:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<3xi32>
+//       CHECK:    %[[S2:.+]] = spirv.CompositeExtract %[[ARG1]][2 : i32] : vector<3xi32>
+//       CHECK:    %[[RES:.+]] = spirv.CompositeConstruct %[[V0]], %[[S1]], %[[S2]] : (i32, i32, i32) -> vector<3xi32>
+//       CHECK:    return %[[RES]]
+func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<3xi32>) -> vector<3xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [0, 2, 3] : vector<1xi32>, vector<3xi32>
+  return %shuffle : vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<1xi32>
+//       CHECK:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xi32> to i32
+//       CHECK:    %[[S0:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<3xi32>
+//       CHECK:    %[[S1:.+]] = spirv.CompositeExtract %[[ARG0]][2 : i32] : vector<3xi32>
+//       CHECK:    %[[RES:.+]] = spirv.CompositeConstruct %[[S0]], %[[S1]], %[[V1]] : (i32, i32, i32) -> vector<3xi32>
+//       CHECK:    return %[[RES]]
+func.func @shuffle(%v0 : vector<3xi32>, %v1: vector<1xi32>) -> vector<3xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [0, 2, 3] : vector<3xi32>, vector<1xi32>
+  return %shuffle : vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>
+//       CHECK:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xi32> to i32
+//       CHECK:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xi32> to i32
+//       CHECK:    %[[RES:.+]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (i32, i32) -> vector<2xi32>
+//       CHECK:    return %[[RES]]
+func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [0, 1] : vector<1xi32>, vector<1xi32>
+  return %shuffle : vector<2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_add
 //  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>


        


More information about the Mlir-commits mailing list