[Mlir-commits] [mlir] [mlir][spirv] Add folding for IAddCarry/[S|U]MulExtended (PR #73340)

Finn Plummer llvmlistbot at llvm.org
Wed Nov 29 09:17:24 PST 2023


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/73340

>From dcbfc967eb39c1bab149ec789b6aeaf1e8a05e12 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Fri, 24 Nov 2023 09:34:44 +0100
Subject: [PATCH 1/3] [mlir][spirv] Add folding for IAddCarry/[S|U]MulExtended

Add missing constant propogation folder for IAddCarry and
[S|U]MulExtended. Due to currently missing constant value for
spirv.struct the folding is done using canonicalization patterns.

Implement additional folding when rhs is 0 for all ops and when rhs
is 1 for UMulExt.

This helps for readability of lowered code into SPIRV.

Part of work for #70704
---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    |   6 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 190 ++++++++++++++++++
 .../SPIRV/Transforms/canonicalize.mlir        | 148 ++++++++++++++
 3 files changed, 344 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..951cfe4feb2e63e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
     %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -607,6 +609,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
     %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -742,6 +746,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
     %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..cefcbfd87cbdd1c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -115,6 +115,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
   results.add<CombineChainedAccessChain>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // iaddcarry (x, 0) = <0, x>
+    if (matchPattern(operands[1], m_Zero())) {
+      constituents.push_back(operands[1]);
+      constituents.push_back(operands[0]);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    //  Result Type must be from OpTypeStruct.  The struct must have two
+    //  members...
+    //
+    //  Member 0 of the result gets the low-order bits (full component width) of
+    //  the addition.
+    //
+    //  Member 1 of the result gets the high-order (carry) bit of the result of
+    //  the addition. That is, it gets the value 1 if the addition overflowed
+    //  the component width, and 0 otherwise.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto adds = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+    if (!adds)
+      return failure();
+
+    auto carrys = constFoldBinaryOp<IntegerAttr>(
+        ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+          APInt zero = APInt::getZero(a.getBitWidth());
+          return a.ult(b) ? (zero + 1) : zero;
+        });
+
+    if (!carrys)
+      return failure();
+
+    Value addsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+    constituents.push_back(addsVal);
+
+    Value carrysVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+    constituents.push_back(carrysVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+  using OpRewritePattern<MulOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // [su]mulextended (x, 0) = <0, 0>
+    if (matchPattern(operands[1], m_Zero())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(zero);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    // Result Type must be from OpTypeStruct.  The struct must have two
+    // members...
+    //
+    // Member 0 of the result gets the low-order bits of the multiplication.
+    //
+    // Member 1 of the result gets the high-order bits of the multiplication.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto lowBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+    if (!lowBits)
+      return failure();
+
+    auto highBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) {
+          unsigned bitWidth = a.getBitWidth();
+          APInt c;
+          if (IsSigned) {
+            c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+          } else {
+            c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+          }
+          return c.extractBits(bitWidth, bitWidth); // Extract high result
+        });
+
+    if (!highBits)
+      return failure();
+
+    Value lowBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+    constituents.push_back(lowBitsVal);
+
+    Value highBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+    constituents.push_back(highBitsVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // umulextended (x, 1) = <x, 0>
+    if (matchPattern(operands[1], m_One())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(operands[0]);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..16215e21b369584 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant -3
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 1
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+  // CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
+  // CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IMul
 //===----------------------------------------------------------------------===//
@@ -400,6 +446,108 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smulextended_x_0
+func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_smulextended
+func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant -1
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_smulextended
+func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umulextended_x_0
+func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umulextended_x_1
+func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 1 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_umulextended
+func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant 4
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_umulextended
+func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+
 //===----------------------------------------------------------------------===//
 // spirv.ISub
 //===----------------------------------------------------------------------===//

>From fc568a25c6468c96585e55578cfeac0caac0b3e3 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Wed, 29 Nov 2023 17:39:54 +0100
Subject: [PATCH 2/3] review comments

- improve readability with lhs/rhs instead of operands[0]/[1]
- use stack array instead of llvm::SmallVector
- increase strictness of tests to ensure proper CompositeConstruct and
  return order
---
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 68 +++++++-------
 .../SPIRV/Transforms/canonicalize.mlir        | 89 +++++++++++--------
 2 files changed, 83 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index cefcbfd87cbdd1c..0f63015552e004f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -127,15 +127,13 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
   LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto operands = op.getOperands();
-
-    SmallVector<Value> constituents;
-    Type constituentType = operands[0].getType();
+    Value lhs = op.getOperand1();
+    Value rhs = op.getOperand2();
+    Type constituentType = lhs.getType();
 
     // iaddcarry (x, 0) = <0, x>
-    if (matchPattern(operands[1], m_Zero())) {
-      constituents.push_back(operands[1]);
-      constituents.push_back(operands[0]);
+    if (matchPattern(rhs, m_Zero())) {
+      Value constituents[2] = {rhs, lhs};
       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
                                                                constituents);
       return success();
@@ -152,19 +150,20 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
     //  Member 1 of the result gets the high-order (carry) bit of the result of
     //  the addition. That is, it gets the value 1 if the addition overflowed
     //  the component width, and 0 otherwise.
-    Attribute lhs;
-    Attribute rhs;
-    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
-        !matchPattern(operands[1], m_Constant(&rhs)))
+    Attribute lhsAttr;
+    Attribute rhsAttr;
+    if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
+        !matchPattern(rhs, m_Constant(&rhsAttr)))
       return failure();
 
     auto adds = constFoldBinaryOp<IntegerAttr>(
-        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+        {lhsAttr, rhsAttr},
+        [](const APInt &a, const APInt &b) { return a + b; });
     if (!adds)
       return failure();
 
     auto carrys = constFoldBinaryOp<IntegerAttr>(
-        ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+        ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
           APInt zero = APInt::getZero(a.getBitWidth());
           return a.ult(b) ? (zero + 1) : zero;
         });
@@ -174,12 +173,11 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
 
     Value addsVal =
         rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
-    constituents.push_back(addsVal);
 
     Value carrysVal =
         rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
-    constituents.push_back(carrysVal);
 
+    Value constituents[2] = {addsVal, carrysVal};
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
                                                              constituents);
     return success();
@@ -204,16 +202,14 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
   LogicalResult matchAndRewrite(MulOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto operands = op.getOperands();
-
-    SmallVector<Value> constituents;
-    Type constituentType = operands[0].getType();
+    Value lhs = op.getOperand1();
+    Value rhs = op.getOperand2();
+    Type constituentType = lhs.getType();
 
     // [su]mulextended (x, 0) = <0, 0>
-    if (matchPattern(operands[1], m_Zero())) {
+    if (matchPattern(rhs, m_Zero())) {
       Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
-      constituents.push_back(zero);
-      constituents.push_back(zero);
+      Value constituents[2] = {zero, zero};
       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
                                                                constituents);
       return success();
@@ -227,20 +223,21 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
     // Member 0 of the result gets the low-order bits of the multiplication.
     //
     // Member 1 of the result gets the high-order bits of the multiplication.
-    Attribute lhs;
-    Attribute rhs;
-    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
-        !matchPattern(operands[1], m_Constant(&rhs)))
+    Attribute lhsAttr;
+    Attribute rhsAttr;
+    if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
+        !matchPattern(rhs, m_Constant(&rhsAttr)))
       return failure();
 
     auto lowBits = constFoldBinaryOp<IntegerAttr>(
-        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+        {lhsAttr, rhsAttr},
+        [](const APInt &a, const APInt &b) { return a * b; });
 
     if (!lowBits)
       return failure();
 
     auto highBits = constFoldBinaryOp<IntegerAttr>(
-        {lhs, rhs}, [](const APInt &a, const APInt &b) {
+        {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
           unsigned bitWidth = a.getBitWidth();
           APInt c;
           if (IsSigned) {
@@ -256,12 +253,11 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
 
     Value lowBitsVal =
         rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
-    constituents.push_back(lowBitsVal);
 
     Value highBitsVal =
         rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
-    constituents.push_back(highBitsVal);
 
+    Value constituents[2] = {lowBitsVal, highBitsVal};
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
                                                              constituents);
     return success();
@@ -280,16 +276,14 @@ struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
   LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto operands = op.getOperands();
-
-    SmallVector<Value> constituents;
-    Type constituentType = operands[0].getType();
+    Value lhs = op.getOperand1();
+    Value rhs = op.getOperand2();
+    Type constituentType = lhs.getType();
 
     // umulextended (x, 1) = <x, 0>
-    if (matchPattern(operands[1], m_One())) {
+    if (matchPattern(rhs, m_One())) {
       Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
-      constituents.push_back(operands[0]);
-      constituents.push_back(zero);
+      Value constituents[2] = {lhs, zero};
       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
                                                                constituents);
       return success();
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 16215e21b369584..d1c626c1bc1db25 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -342,10 +342,11 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
 
 // CHECK-LABEL: @iaddcarry_x_0
 func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  // CHECK: %[[RET:.*]] = spirv.CompositeConstruct
   %c0 = spirv.Constant 0 : i32
-
-  // CHECK: spirv.CompositeConstruct
   %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+
+  // CHECK: return %[[RET]]
   return %0 : !spirv.struct<(i32, i32)>
 }
 
@@ -355,15 +356,16 @@ func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.s
   %cn5 = spirv.Constant -5 : i32
   %cn8 = spirv.Constant -8 : i32
 
-  // CHECK-DAG: spirv.Constant 0
-  // CHECK-DAG: spirv.Constant -3
-  // CHECK-DAG: spirv.CompositeConstruct
-  // CHECK-DAG: spirv.Constant 1
-  // CHECK-DAG: spirv.Constant -13
-  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
+  // CHECK-DAG: %[[CN3:.*]] = spirv.Constant -3
+  // CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeConstruct %[[CN3]], %[[C0]]
+  // CHECK-DAG: %[[C1:.*]] = spirv.Constant 1
+  // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
+  // CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeConstruct %[[CN13]], %[[C1]]
   %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
   %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
 
+  // CHECK: return %[[CC_CN3_C0]], %[[CC_CN13_C1]]
   return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
 }
 
@@ -372,12 +374,13 @@ func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector
   %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
   %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
 
-  // CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
-  // CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
-  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[-3, -11, 0]>
+  // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[0, 1, 1]>
+  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
   %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
-  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
+  // CHECK: return %[[CC_CV1_CV2]]
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 }
 
 // -----
@@ -452,10 +455,12 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
 
 // CHECK-LABEL: @smulextended_x_0
 func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  // CHECK: %[[C0:.*]] = spirv.Constant 0
+  // CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[C0]], %[[C0]]
   %c0 = spirv.Constant 0 : i32
-
-  // CHECK: spirv.CompositeConstruct
   %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+
+  // CHECK: return %[[RET]]
   return %0 : !spirv.struct<(i32, i32)>
 }
 
@@ -465,15 +470,16 @@ func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spir
   %cn5 = spirv.Constant -5 : i32
   %cn8 = spirv.Constant -8 : i32
 
-  // CHECK-DAG: spirv.Constant -40
-  // CHECK-DAG: spirv.Constant -1
-  // CHECK-DAG: spirv.CompositeConstruct
-  // CHECK-DAG: spirv.Constant 40
-  // CHECK-DAG: spirv.Constant 0
-  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
+  // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1
+  // CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeConstruct %[[CN40]], %[[CN1]]
+  // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
+  // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
+  // CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeConstruct %[[C40]], %[[C0]]
   %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
   %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
 
+  // CHECK: return %[[CC_CN40_CN1]], %[[CC_C40_C0]]
   return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
 }
 
@@ -482,10 +488,12 @@ func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vec
   %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
   %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
 
-  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
-  // CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
-  // CHECK-NEXT: spirv.CompositeConstruct
+  // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, 0, -1]>
+  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
   %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+  // CHECK: return %[[CC_CV1_CV2]]
   return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
 }
@@ -498,19 +506,24 @@ func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vec
 
 // CHECK-LABEL: @umulextended_x_0
 func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  // CHECK: %[[C0:.*]] = spirv.Constant 0
+  // CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[C0]], %[[C0]]
   %c0 = spirv.Constant 0 : i32
-
-  // CHECK: spirv.CompositeConstruct
   %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+
+  // CHECK: return %[[RET]]
   return %0 : !spirv.struct<(i32, i32)>
 }
 
 // CHECK-LABEL: @umulextended_x_1
+// CHECK-SAME: (%[[ARG:.*]]: i32)
 func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  // CHECK: %[[C0:.*]] = spirv.Constant 0
+  // CHECK: %[[RET:.*]] = spirv.CompositeConstruct %[[ARG]], %[[C0]]
   %c0 = spirv.Constant 1 : i32
-
-  // CHECK: spirv.CompositeConstruct
   %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+
+  // CHECK: return %[[RET]]
   return %0 : !spirv.struct<(i32, i32)>
 }
 
@@ -520,15 +533,16 @@ func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spir
   %cn5 = spirv.Constant -5 : i32
   %cn8 = spirv.Constant -8 : i32
 
-  // CHECK-DAG: spirv.Constant 40
-  // CHECK-DAG: spirv.Constant -13
-  // CHECK-DAG: spirv.CompositeConstruct
-  // CHECK-DAG: spirv.Constant -40
-  // CHECK-DAG: spirv.Constant 4
-  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
+  // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
+  // CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeConstruct %[[C40]], %[[CN13]]
+  // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
+  // CHECK-DAG: %[[C4:.*]] = spirv.Constant 4
+  // CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeConstruct %[[CN40]], %[[C4]]
   %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
   %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
 
+  // CHECK: return %[[CC_CN40_C4]], %[[CC_C40_CN13]]
   return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
 }
 
@@ -537,12 +551,13 @@ func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vec
   %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
   %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
 
-  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
-  // CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
-  // CHECK-NEXT: spirv.CompositeConstruct
+  // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, -13, 0]>
+  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
   %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
-  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
+  // CHECK: return %[[CC_CV1_CV2]]
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 }
 
 // -----

>From a0cb3597b0c5d8d7aab9d07f04f094ba3be76534 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Wed, 29 Nov 2023 18:07:17 +0100
Subject: [PATCH 3/3] switch to using CompositeInsert from Construct

- commit to demonstrate how we could potentially use CompositeInsert
  instead of CompositeConstruct
---
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 22 ++++++++---
 .../SPIRV/Transforms/canonicalize.mlir        | 37 ++++++++++++++-----
 2 files changed, 44 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 0f63015552e004f..48523cd31c13046 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -177,9 +177,14 @@ struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
     Value carrysVal =
         rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
 
-    Value constituents[2] = {addsVal, carrysVal};
-    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
-                                                             constituents);
+    // Create empty struct
+    Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+    // Fill in adds at id 0
+    Value intermediate =
+        rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
+    // Fill in carrys at id 1
+    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
+                                                          intermediate, 1);
     return success();
   }
 };
@@ -257,9 +262,14 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
     Value highBitsVal =
         rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
 
-    Value constituents[2] = {lowBitsVal, highBitsVal};
-    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
-                                                             constituents);
+    // Create empty struct
+    Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
+    // Fill in lowBits at id 0
+    Value intermediate =
+        rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
+    // Fill in highBits at id 1
+    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
+                                                          intermediate, 1);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index d1c626c1bc1db25..7ba20b3a5d5c377 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -358,10 +358,14 @@ func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.s
 
   // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
   // CHECK-DAG: %[[CN3:.*]] = spirv.Constant -3
-  // CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeConstruct %[[CN3]], %[[C0]]
+  // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
+  // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN3]], %[[UNDEF1]][0 : i32]
+  // CHECK-DAG: %[[CC_CN3_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER1]][1 : i32]
   // CHECK-DAG: %[[C1:.*]] = spirv.Constant 1
   // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
-  // CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeConstruct %[[CN13]], %[[C1]]
+  // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
+  // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[CN13]], %[[UNDEF2]][0 : i32]
+  // CHECK-DAG: %[[CC_CN13_C1:.*]] = spirv.CompositeInsert %[[C1]], %[[INTER2]][1 : i32]
   %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
   %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
 
@@ -376,7 +380,9 @@ func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector
 
   // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[-3, -11, 0]>
   // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[0, 1, 1]>
-  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
+  // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
+  // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
+  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
   %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
   // CHECK: return %[[CC_CV1_CV2]]
@@ -472,10 +478,14 @@ func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spir
 
   // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
   // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1
-  // CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeConstruct %[[CN40]], %[[CN1]]
+  // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
+  // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
+  // CHECK-DAG: %[[CC_CN40_CN1:.*]] = spirv.CompositeInsert %[[CN1]], %[[INTER1]]
   // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
   // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
-  // CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeConstruct %[[C40]], %[[C0]]
+  // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
+  // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
+  // CHECK-DAG: %[[CC_C40_C0:.*]] = spirv.CompositeInsert %[[C0]], %[[INTER2]][1 : i32]
   %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
   %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
 
@@ -490,7 +500,9 @@ func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vec
 
   // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
   // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, 0, -1]>
-  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
+  // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
+  // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]][0 : i32]
+  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]][1 : i32]
   %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
   // CHECK: return %[[CC_CV1_CV2]]
@@ -533,12 +545,17 @@ func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spir
   %cn5 = spirv.Constant -5 : i32
   %cn8 = spirv.Constant -8 : i32
 
+
   // CHECK-DAG: %[[C40:.*]] = spirv.Constant 40
   // CHECK-DAG: %[[CN13:.*]] = spirv.Constant -13
-  // CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeConstruct %[[C40]], %[[CN13]]
   // CHECK-DAG: %[[CN40:.*]] = spirv.Constant -40
   // CHECK-DAG: %[[C4:.*]] = spirv.Constant 4
-  // CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeConstruct %[[CN40]], %[[C4]]
+  // CHECK-DAG: %[[UNDEF1:.*]] = spirv.Undef
+  // CHECK-DAG: %[[INTER1:.*]] = spirv.CompositeInsert %[[CN40]], %[[UNDEF1]][0 : i32]
+  // CHECK-DAG: %[[CC_CN40_C4:.*]] = spirv.CompositeInsert %[[C4]], %[[INTER1]][1 : i32]
+  // CHECK-DAG: %[[UNDEF2:.*]] = spirv.Undef
+  // CHECK-DAG: %[[INTER2:.*]] = spirv.CompositeInsert %[[C40]], %[[UNDEF2]][0 : i32]
+  // CHECK-DAG: %[[CC_C40_CN13:.*]] = spirv.CompositeInsert %[[CN13]], %[[INTER2]][1 : i32]
   %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
   %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
 
@@ -553,7 +570,9 @@ func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vec
 
   // CHECK-DAG: %[[CV1:.*]] = spirv.Constant dense<[2147483643, 40, -1]>
   // CHECK-DAG: %[[CV2:.*]] = spirv.Constant dense<[2, -13, 0]>
-  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeConstruct %[[CV1]], %[[CV2]]
+  // CHECK-DAG: %[[UNDEF:.*]] = spirv.Undef
+  // CHECK-DAG: %[[INTER:.*]] = spirv.CompositeInsert %[[CV1]], %[[UNDEF]]
+  // CHECK-DAG: %[[CC_CV1_CV2:.*]] = spirv.CompositeInsert %[[CV2]], %[[INTER]]
   %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
 
   // CHECK: return %[[CC_CV1_CV2]]



More information about the Mlir-commits mailing list