[Mlir-commits] [mlir] [mlir][vector][spirv] Lower vector.maskedload and vector.maskedstore to SPIR-V (PR #74834)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 8 05:20:14 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir
Author: Hsiangkai Wang (Hsiangkai)
<details>
<summary>Changes</summary>
Use spirv.mlir.loop and spirv.mlir.selection to lower vector.maskedload and vector.maskedstore.
---
Patch is 21.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74834.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+324-1)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+111)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index e48f29a4f1702..b32c004e28a1e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -647,6 +647,328 @@ struct VectorStoreOpConverter final
}
};
+mlir::spirv::LoopOp createSpirvLoop(ConversionPatternRewriter &rewriter,
+ Location loc) {
+ auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+ loopOp.addEntryAndMergeBlock();
+
+ auto &loopBody = loopOp.getBody();
+ // Create header block.
+ loopBody.getBlocks().insert(std::next(loopBody.begin(), 1), new Block());
+ // Create continue block.
+ loopBody.getBlocks().insert(std::prev(loopBody.end(), 2), new Block());
+
+ return loopOp;
+}
+
+mlir::spirv::SelectionOp
+createSpirvSelection(ConversionPatternRewriter &rewriter, Location loc) {
+ auto selectionOp =
+ rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ auto &loopBody = selectionOp.getBody();
+ // Create header block.
+ rewriter.createBlock(&loopBody, loopBody.end());
+ // Create if-true block.
+ rewriter.createBlock(&loopBody, loopBody.end());
+ // Create merge block.
+ rewriter.createBlock(&loopBody, loopBody.end());
+ rewriter.create<spirv::MergeOp>(loc);
+
+ return selectionOp;
+}
+
+Value addOffsetToIndices(ConversionPatternRewriter &rewriter, Location loc,
+ SmallVectorImpl<Value> &indices, const Value offset,
+ const SPIRVTypeConverter &typeConverter,
+ const MemRefType memrefType, const Value base) {
+ indices.back() = rewriter.create<spirv::IAddOp>(loc, indices.back(), offset);
+ return spirv::getElementPtr(typeConverter, memrefType, base, indices, loc,
+ rewriter);
+}
+
+Value extractMaskBit(ConversionPatternRewriter &rewriter, Location loc,
+ Value mask, Value offset) {
+ return rewriter.create<spirv::VectorExtractDynamicOp>(
+ loc, rewriter.getI1Type(), mask, offset);
+}
+
+Value extractVectorElement(ConversionPatternRewriter &rewriter, Location loc,
+ Type type, Value vector, Value offset) {
+ return rewriter.create<spirv::VectorExtractDynamicOp>(loc, type, vector,
+ offset);
+}
+
+Value createConstantInteger(ConversionPatternRewriter &rewriter, Location loc,
+ int32_t value) {
+ auto i32Type = rewriter.getI32Type();
+ return rewriter.create<spirv::ConstantOp>(loc, i32Type,
+ IntegerAttr::get(i32Type, value));
+}
+
+/// Convert vector.maskedload to spirv dialect.
+///
+/// Before:
+///
+/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+/// %buffer = spirv.Variable
+/// spirv.mlir.loop {
+/// spirv.Branch ^bb1(0, %buffer)
+/// ^bb1(%i: i32, %partial: vector):
+/// %m = spirv.VectorExtractDynamic %mask[%i]
+/// %p = spirv.VectorExtractDynamic %pass_thru[%i]
+/// %value = spirv.Load
+/// %s = spirv.Select %m, %value, %p
+/// %v = spirv.VectorInsertDynamic %s, %partial[%i]
+/// spirv.Store %buffer, %v
+/// spirv.Branch ^bb2(%i, %v)
+/// ^bb2(%i: i32, %partial: vector):
+/// %update_i = spirv.IAdd %i, 1
+/// %cond = spirv.SLessThan %update_i, %veclen
+/// spirv.BranchConditional %cond, ^bb1(%update_i, %partial), ^bb3
+/// ^bb3:
+/// spirv.mlir.merge
+/// }
+/// %ret = spirv.Load %buffer
+/// return %ret
+///
+struct VectorMaskedLoadOpConverter final
+ : public OpConversionPattern<vector::MaskedLoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = maskedLoadOp.getMemRefType();
+ if (!isa<spirv::StorageClassAttr>(memrefType.getMemorySpace()))
+ return failure();
+
+ VectorType maskVType = maskedLoadOp.getMaskVectorType();
+ if (maskVType.getRank() != 1)
+ return failure();
+ if (maskVType.getShape().size() != 1)
+ return failure();
+
+ // Create a local variable to store the loaded value.
+ auto loc = maskedLoadOp.getLoc();
+ auto vectorType = maskedLoadOp.getVectorType();
+ auto pointerType =
+ spirv::PointerType::get(vectorType, spirv::StorageClass::Function);
+ auto alloc = rewriter.create<spirv::VariableOp>(
+ loc, pointerType, spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
+
+ // Create constants for the loop.
+ Value zero = createConstantInteger(rewriter, loc, 0);
+ Value one = createConstantInteger(rewriter, loc, 1);
+ Value maskLength =
+ createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+ auto emptyVector = rewriter.create<spirv::ConstantOp>(
+ loc, vectorType, rewriter.getZeroAttr(vectorType));
+
+ // Construct a loop to go through the mask value
+ auto loopOp = createSpirvLoop(rewriter, loc);
+
+ auto *headerBlock = loopOp.getHeaderBlock();
+ auto *continueBlock = loopOp.getContinueBlock();
+
+ auto i32Type = rewriter.getI32Type();
+ BlockArgument indVar = headerBlock->addArgument(i32Type, loc);
+ BlockArgument partialVector = headerBlock->addArgument(vectorType, loc);
+ BlockArgument continueIndVar = continueBlock->addArgument(i32Type, loc);
+ BlockArgument continueVector = continueBlock->addArgument(vectorType, loc);
+
+ // Insert code into loop entry block
+ rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+
+ // Header block needs two arguments: induction variable, updated vector
+ rewriter.create<spirv::BranchOp>(loc, headerBlock,
+ ArrayRef<Value>({zero, emptyVector}));
+
+ // Insert code into loop header block
+ rewriter.setInsertionPointToEnd(headerBlock);
+ auto maskBit = extractMaskBit(rewriter, loc, adaptor.getMask(), indVar);
+
+ auto scalarType = memrefType.getElementType();
+ auto passThruValule = extractVectorElement(rewriter, loc, scalarType,
+ adaptor.getPassThru(), indVar);
+
+ // Load base[indVar]
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto indices = llvm::to_vector<4>(adaptor.getIndices());
+ auto updatedAccessChain =
+ addOffsetToIndices(rewriter, loc, indices, indVar, typeConverter,
+ memrefType, adaptor.getBase());
+ auto loadScalar =
+ rewriter.create<spirv::LoadOp>(loc, scalarType, updatedAccessChain);
+
+ // Select the loaded value or pass-through according to the mask bit.
+ auto valueToInsert = rewriter.create<spirv::SelectOp>(
+ loc, scalarType, maskBit, loadScalar, passThruValule);
+
+ // Insert the selected value to output vector.
+ auto updatedVector = rewriter.create<spirv::VectorInsertDynamicOp>(
+ loc, vectorType, partialVector, valueToInsert, indVar);
+ rewriter.create<spirv::StoreOp>(loc, alloc, updatedVector);
+ rewriter.create<spirv::BranchOp>(loc, continueBlock,
+ ArrayRef<Value>({indVar, updatedVector}));
+
+ // Insert code into continue block
+ rewriter.setInsertionPointToEnd(continueBlock);
+
+ // Update induction variable.
+ auto updatedIndVar =
+ rewriter.create<spirv::IAddOp>(loc, continueIndVar, one);
+
+ // Check if the induction variable < length(mask)
+ auto cmpOp =
+ rewriter.create<spirv::SLessThanOp>(loc, updatedIndVar, maskLength);
+
+ auto *mergeBlock = loopOp.getMergeBlock();
+ rewriter.create<spirv::BranchConditionalOp>(
+ loc, cmpOp, headerBlock,
+ ArrayRef<Value>({updatedIndVar, continueVector}), mergeBlock,
+ std::nullopt);
+
+ // Insert code after loop
+ rewriter.setInsertionPointAfter(loopOp);
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(maskedLoadOp, alloc);
+
+ return success();
+ }
+};
+
+/// Convert vector.maskedstore to spirv dialect.
+///
+/// Before:
+///
+/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+/// spirv.mlir.loop {
+/// spirv.Branch ^bb1(0)
+/// ^bb1(%i: i32):
+/// %m = spirv.VectorExtractDynamic %mask[%i]
+/// spirv.mlir.selection {
+/// spirv.BranchConditional %m, ^if_bb1, ^if_bb2
+/// ^if_bb1:
+/// %v = spirv.VectorExtractDynamic %value[%i]
+/// spirv.Store %base[%i], %v
+/// spirv.Branch ^if_bb2
+/// ^if_bb2:
+/// spirv.mlir.merge
+/// }
+/// spirv.Branch ^bb2(%i)
+/// ^bb2(%i: i32):
+/// %update_i = spirv.IAdd %i, 1
+/// %cond = spirv.SLessThan %update_i, %veclen
+/// spirv.BranchConditional %cond, ^bb1, ^bb3
+/// ^bb3:
+/// spirv.mlir.merge
+/// }
+/// return
+///
+struct VectorMaskedStoreOpConverter final
+ : public OpConversionPattern<vector::MaskedStoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = maskedStoreOp.getMemRefType();
+ if (!isa<spirv::StorageClassAttr>(memrefType.getMemorySpace()))
+ return failure();
+
+ VectorType maskVType = maskedStoreOp.getMaskVectorType();
+ if (maskVType.getRank() != 1)
+ return failure();
+ if (maskVType.getShape().size() != 1)
+ return failure();
+
+ // Create constants.
+ auto loc = maskedStoreOp.getLoc();
+ Value zero = createConstantInteger(rewriter, loc, 0);
+ Value one = createConstantInteger(rewriter, loc, 1);
+ Value maskLength =
+ createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+ // Construct a loop to go through the mask value
+ auto loopOp = createSpirvLoop(rewriter, loc);
+ auto *headerBlock = loopOp.getHeaderBlock();
+ auto *continueBlock = loopOp.getContinueBlock();
+
+ auto i32Type = rewriter.getI32Type();
+ BlockArgument indVar = headerBlock->addArgument(i32Type, loc);
+ BlockArgument continueIndVar = continueBlock->addArgument(i32Type, loc);
+
+ // Insert code into loop entry block
+ rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+ rewriter.create<spirv::BranchOp>(loc, headerBlock, ArrayRef<Value>({zero}));
+
+ // Insert code into loop header block
+ rewriter.setInsertionPointToEnd(headerBlock);
+ auto maskBit = extractMaskBit(rewriter, loc, adaptor.getMask(), indVar);
+
+ auto selectionOp = createSpirvSelection(rewriter, loc);
+ auto *selectionHeaderBlock = selectionOp.getHeaderBlock();
+ auto *selectionMergeBlock = selectionOp.getMergeBlock();
+ auto *selectionTrueBlock = &(*std::next(selectionOp.getBody().begin(), 1));
+
+ // Insert code into selection header block
+ rewriter.setInsertionPointToEnd(selectionHeaderBlock);
+ rewriter.create<spirv::BranchConditionalOp>(
+ loc, maskBit, selectionTrueBlock, std::nullopt, selectionMergeBlock,
+ std::nullopt);
+
+ // Insert code into selection true block
+ rewriter.setInsertionPointToEnd(selectionTrueBlock);
+ auto scalarType = memrefType.getElementType();
+ auto extractedStoreValue = extractVectorElement(
+ rewriter, loc, scalarType, adaptor.getValueToStore(), indVar);
+
+ // Store base[indVar]
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto indices = llvm::to_vector<4>(adaptor.getIndices());
+ auto updatedAccessChain =
+ addOffsetToIndices(rewriter, loc, indices, indVar, typeConverter,
+ memrefType, adaptor.getBase());
+ rewriter.create<spirv::StoreOp>(loc, updatedAccessChain,
+ extractedStoreValue);
+ rewriter.create<spirv::BranchOp>(loc, selectionMergeBlock, std::nullopt);
+
+ // Insert code into loop header block
+ rewriter.setInsertionPointAfter(selectionOp);
+ rewriter.create<spirv::BranchOp>(loc, continueBlock,
+ ArrayRef<Value>({indVar}));
+
+ // Insert code into loop continue block
+ rewriter.setInsertionPointToEnd(continueBlock);
+
+ // Update induction variable.
+ auto updatedIndVar =
+ rewriter.create<spirv::IAddOp>(loc, continueIndVar, one);
+
+ // Check if the induction variable < length(mask)
+ auto cmpOp =
+ rewriter.create<spirv::SLessThanOp>(loc, updatedIndVar, maskLength);
+
+ auto *mergeBlock = loopOp.getMergeBlock();
+ rewriter.create<spirv::BranchConditionalOp>(
+ loc, cmpOp, headerBlock, ArrayRef<Value>({updatedIndVar}), mergeBlock,
+ std::nullopt);
+
+ // Insert code after loop
+ rewriter.setInsertionPointAfter(loopOp);
+ rewriter.replaceOp(maskedStoreOp, loopOp);
+
+ return success();
+ }
+};
+
struct VectorReductionToIntDotProd final
: OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
@@ -821,7 +1143,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+ VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c9984091d5acc..bc9e92981644b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -805,4 +805,115 @@ func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageB
return
}
+// CHECK-LABEL: @vector_maskedload
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+// CHECK: %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+// CHECK: %[[CST_F0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[S4:.*]] = spirv.CompositeConstruct %[[CST_F0]], %[[CST_F0]], %[[CST_F0]], %[[CST_F0]] : (f32, f32, f32, f32) -> vector<4xf32>
+// CHECK: %[[S5:.*]] = spirv.Variable : !spirv.ptr<vector<4xf32>, Function>
+// CHECK: %[[C0_1:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C1_1:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[C4_1:.*]] = spirv.Constant 4 : i32
+// CHECK: %[[CV0:.*]] = spirv.Constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: spirv.mlir.loop {
+// CHECK: spirv.Branch ^bb1(%[[C0_1]], %[[CV0]] : i32, vector<4xf32>)
+// CHECK: ^bb1(%[[S7:.*]]: i32, %[[S8:.*]]: vector<4xf32>): // 2 preds: ^bb0, ^bb2
+// CHECK: %[[S9:.*]] = spirv.VectorExtractDynamic %[[S3]][%[[S7]]] : vector<4xi1>, i32
+// CHECK: %[[S10:.*]] = spirv.VectorExtractDynamic %[[S4]][%[[S7]]] : vector<4xf32>, i32
+// CHECK: %[[S11:.*]] = spirv.IAdd %[[S2]], %[[S7]] : i32
+// CHECK: %[[C0_2:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C0_3:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C5:.*]] = spirv.Constant 5 : i32
+// CHECK: %[[S12:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+// CHECK: %[[S13:.*]] = spirv.IAdd %[[C0_3]], %[[S12]] : i32
+// CHECK: %[[C1_2:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[S14:.*]] = spirv.IMul %[[C1_2]], %[[S11]] : i32
+// CHECK: %[[S15:.*]] = spirv.IAdd %[[S13]], %[[S14]] : i32
+// CHECK: %[[S16:.*]] = spirv.AccessChain %[[S0]][%[[C0_2]], %[[S15]]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S17:.*]] = spirv.Load "StorageBuffer" %[[S16]] : f32
+// CHECK: %[[S18:.*]] = spirv.Select %[[S9]], %[[S17]], %[[S10]] : i1, f32
+// CHECK: %[[S19:.*]] = spirv.VectorInsertDynamic %[[S18]], %[[S8]][%[[S7]]] : vector<4xf32>, i32
+// CHECK: spirv.Store "Function" %[[S5]], %[[S19]] : vector<4xf32>
+// CHECK: spirv.Branch ^bb2(%[[S7]], %[[S19]] : i32, vector<4xf32>)
+// CHECK: ^bb2(%[[S20:.*]]: i32, %[[S21:.*]]: vector<4xf32>): // pred: ^bb1
+// CHECK: %[[S22:.*]] = spirv.IAdd %[[S20]], %[[C1_1]] : i32
+// CHECK: %[[S23:.*]] = spirv.SLessThan %[[S22]], %[[C4_1]] : i32
+// CHECK: spirv.BranchConditional %[[S23]], ^bb1(%[[S22]], %[[S21]] : i32, vector<4xf32>), ^bb3
+// CHECK: ^bb3: // pred: ^bb2
+// CHECK: spirv.mlir.merge
+// CHECK: }
+// CHECK: %[[S6:.*]] = spirv.Load "Function" %[[S5]] : vector<4xf32>
+// CHECK: return %[[S6]] : vector<4xf32>
+// CHECK: }
+func.func @vector_maskedload(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %idx_4 = arith.constant 4 : index
+ %mask = vector.create_mask %idx_1 : vector<4xi1>
+ %s = arith.constant 0.0 : f32
+ %pass_thru = vector.splat %s : vector<4xf32>
+ %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_maskedstore
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.*]]: vector<4xf32>) {
+// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+// CHECK: %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+// CHECK: %[[C0_1:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C1_1:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[C4_1:.*]] = spirv.Constant 4 : i32
+// CHECK: spirv.mlir.loop {
+// CHECK: spirv.Branch ^bb1(%[[C0_1]] : i32)
+// CHECK: ^bb1(%[[S4:.*]]: i32): // 2 preds: ^bb0, ^bb2
+// CHECK: %[[S5:.*]] = spirv.VectorExtractDynamic %[[S3]][%[[S4]]] : vector<4xi1>, i32
+// CHECK: spirv.mlir.selection {
+// CHECK: spirv.BranchConditional %[[S5]], ^bb1, ^bb2
+// CHECK: ^bb1: // pred: ^bb0
+// CHECK: %[[S9:.*]] = spirv.VectorExtractDynamic %[[ARG1]][%[[S4]]] : vector<4xf32>, i32
+// CHECK: %[[S10:.*]] = spirv.IAdd %[[S2]], %[[S4]] : i32
+// CHECK: %[[C0_2:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C1_2:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[C5:.*]] = spirv.Constant 5 : i32
+// CHECK: %[[S11:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+// CHECK: %[[S12:.*]] = spirv.IAdd %[[C1_2]], %[[S11]] : i32
+// CHECK: %[[C1_3:.*]] = spirv.Constant 1 : i32
+// CHECK: %[[S13:.*]] = spirv.IMul %[[C1_3]], %[[S10]] : i32
+// CHECK: %[[S14:.*]] = spirv.I...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/74834
More information about the Mlir-commits
mailing list