[flang-commits] [flang] [llvm] [MVT][TableGen] Extend Machine Value Type to `uint16_t` (PR #99657)

Brandon Wu via flang-commits flang-commits at lists.llvm.org
Fri Jul 19 09:35:06 PDT 2024


https://github.com/4vtomat updated https://github.com/llvm/llvm-project/pull/99657

>From 52181f2183324ed33f63747b1051cd8d345db21b Mon Sep 17 00:00:00 2001
From: David Truby <david.truby at arm.com>
Date: Fri, 19 Jul 2024 15:55:36 +0100
Subject: [PATCH 1/2] [flang][OpenMP] Implement lastprivate with collapse
 (#99500)

This patch enables the lastprivate clause to be used in the presence of
the collapse clause.

Note: the way we currently implement lastprivate means that this adds a
large number of compare instructions to the end of every iteration of
the loop. This is a clearly non-optimal thing to do, but lastprivate in
general will need re-implementing to prevent this. This is planned as
part of the delayed privatization work. This current implementation is
just a stop-gap measure as generating sub-optimal but working code is
better than crashing out.
---
 .../lib/Lower/OpenMP/DataSharingProcessor.cpp |  55 ++--
 flang/lib/Lower/OpenMP/DataSharingProcessor.h |   7 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             |   7 +-
 .../Lower/OpenMP/parallel-wsloop-lastpriv.f90 | 239 ++++++++++++++++++
 4 files changed, 276 insertions(+), 32 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90

diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7df3905c29990..c1d4c089df3b2 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -139,7 +139,6 @@ void DataSharingProcessor::collectOmpObjectListSymbol(
 }
 
 void DataSharingProcessor::collectSymbolsForPrivatization() {
-  bool hasCollapse = false;
   for (const omp::Clause &clause : clauses) {
     if (const auto &privateClause =
             std::get_if<omp::clause::Private>(&clause.u)) {
@@ -153,16 +152,11 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
       const ObjectList &objects = std::get<ObjectList>(lastPrivateClause->t);
       collectOmpObjectListSymbol(objects, explicitlyPrivatizedSymbols);
       hasLastPrivateOp = true;
-    } else if (std::get_if<omp::clause::Collapse>(&clause.u)) {
-      hasCollapse = true;
     }
   }
 
   for (auto *sym : explicitlyPrivatizedSymbols)
     allPrivatizedSymbols.insert(sym);
-
-  if (hasCollapse && hasLastPrivateOp)
-    TODO(converter.getCurrentLocation(), "Collapse clause with lastprivate");
 }
 
 bool DataSharingProcessor::needBarrier() {
@@ -225,28 +219,39 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
       mlir::Operation *lastOper = loopOp.getRegion().back().getTerminator();
       firOpBuilder.setInsertionPoint(lastOper);
 
-      mlir::Value iv = loopOp.getIVs()[0];
-      mlir::Value ub = loopOp.getUpperBound()[0];
-      mlir::Value step = loopOp.getStep()[0];
-
-      // v = iv + step
-      // cmp = step < 0 ? v < ub : v > ub
-      mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
-      mlir::Value zero =
-          firOpBuilder.createIntegerConstant(loc, step.getType(), 0);
-      mlir::Value negativeStep = firOpBuilder.create<mlir::arith::CmpIOp>(
-          loc, mlir::arith::CmpIPredicate::slt, step, zero);
-      mlir::Value vLT = firOpBuilder.create<mlir::arith::CmpIOp>(
-          loc, mlir::arith::CmpIPredicate::slt, v, ub);
-      mlir::Value vGT = firOpBuilder.create<mlir::arith::CmpIOp>(
-          loc, mlir::arith::CmpIPredicate::sgt, v, ub);
-      mlir::Value cmpOp = firOpBuilder.create<mlir::arith::SelectOp>(
-          loc, negativeStep, vLT, vGT);
+      mlir::Value cmpOp;
+      llvm::SmallVector<mlir::Value> vs;
+      vs.reserve(loopOp.getIVs().size());
+      for (auto [iv, ub, step] : llvm::zip_equal(
+               loopOp.getIVs(), loopOp.getUpperBound(), loopOp.getStep())) {
+        // v = iv + step
+        // cmp = step < 0 ? v < ub : v > ub
+        mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
+        vs.push_back(v);
+        mlir::Value zero =
+            firOpBuilder.createIntegerConstant(loc, step.getType(), 0);
+        mlir::Value negativeStep = firOpBuilder.create<mlir::arith::CmpIOp>(
+            loc, mlir::arith::CmpIPredicate::slt, step, zero);
+        mlir::Value vLT = firOpBuilder.create<mlir::arith::CmpIOp>(
+            loc, mlir::arith::CmpIPredicate::slt, v, ub);
+        mlir::Value vGT = firOpBuilder.create<mlir::arith::CmpIOp>(
+            loc, mlir::arith::CmpIPredicate::sgt, v, ub);
+        mlir::Value icmpOp = firOpBuilder.create<mlir::arith::SelectOp>(
+            loc, negativeStep, vLT, vGT);
+
+        if (cmpOp) {
+          cmpOp = firOpBuilder.create<mlir::arith::AndIOp>(loc, cmpOp, icmpOp);
+        } else {
+          cmpOp = icmpOp;
+        }
+      }
 
       auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
       firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-      assert(loopIV && "loopIV was not set");
-      firOpBuilder.createStoreWithConvert(loc, v, loopIV);
+      for (auto [v, loopIV] : llvm::zip_equal(vs, loopIVs)) {
+        assert(loopIV && "loopIV was not set");
+        firOpBuilder.createStoreWithConvert(loc, v, loopIV);
+      }
       lastPrivIP = firOpBuilder.saveInsertionPoint();
     } else if (mlir::isa<mlir::omp::SectionsOp>(op)) {
       // Already handled by genOMP()
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index 7c0c831ed43a8..1e9d07ec13857 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -71,7 +71,7 @@ class DataSharingProcessor {
   bool hasLastPrivateOp;
   mlir::OpBuilder::InsertPoint lastPrivIP;
   mlir::OpBuilder::InsertPoint insPt;
-  mlir::Value loopIV;
+  llvm::SmallVector<mlir::Value> loopIVs;
   // Symbols in private, firstprivate, and/or lastprivate clauses.
   llvm::SetVector<const semantics::Symbol *> explicitlyPrivatizedSymbols;
   llvm::SetVector<const semantics::Symbol *> defaultSymbols;
@@ -147,10 +147,7 @@ class DataSharingProcessor {
   void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr);
   void processStep2(mlir::Operation *op, bool isLoop);
 
-  void setLoopIV(mlir::Value iv) {
-    assert(!loopIV && "Loop iteration variable already set");
-    loopIV = iv;
-  }
+  void pushLoopIV(mlir::Value iv) { loopIVs.push_back(iv); }
 
   const llvm::SetVector<const semantics::Symbol *> &
   getAllSymbolsToPrivatize() const {
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 47ce2d8c8db87..625a8c83c369b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -677,8 +677,11 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
         assert(tempDsp.has_value());
         tempDsp->processStep2(privatizationTopLevelOp, isLoop);
       } else {
-        if (isLoop && regionArgs.size() > 0)
-          info.dsp->setLoopIV(info.converter.getSymbolAddress(*regionArgs[0]));
+        if (isLoop && regionArgs.size() > 0) {
+          for (const auto &regionArg : regionArgs) {
+            info.dsp->pushLoopIV(info.converter.getSymbolAddress(*regionArg));
+          }
+        }
         info.dsp->processStep2(privatizationTopLevelOp, isLoop);
       }
     }
diff --git a/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90 b/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90
new file mode 100644
index 0000000000000..9b90ac28f693c
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90
@@ -0,0 +1,239 @@
+! This test checks lowering of OpenMP parallel DO, with the loop bound being
+! a lastprivate variable
+
+! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
+
+! CHECK: func @_QPomp_do_lastprivate(%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "a"})
+subroutine omp_do_lastprivate(a)
+  ! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFomp_do_lastprivateEa"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  integer::a
+  integer::n
+  n = a+1
+  !$omp parallel do lastprivate(a)
+  ! CHECK:  omp.parallel {
+
+  ! CHECK: %[[A_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "a", pinned, uniq_name = "_QFomp_do_lastprivateEa"}
+  ! CHECK: %[[A_PVT_DECL:.*]]:2 = hlfir.declare %[[A_PVT_REF]] {uniq_name = "_QFomp_do_lastprivateEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "i", pinned, {{.*}}}
+  ! CHECK: %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFomp_do_lastprivateEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[LB:.*]] = arith.constant 1 : i32
+  ! CHECK-NEXT: %[[UB:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK-NEXT: %[[STEP:.*]] = arith.constant 1 : i32
+  ! CHECK-NEXT: omp.wsloop {
+  ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]]) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
+  ! CHECK-NEXT: fir.store %[[ARG1]] to %[[I_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK-NEXT: fir.call @_QPfoo(%[[I_PVT_DECL]]#1, %[[A_PVT_DECL]]#1) {{.*}}: (!fir.ref<i32>, !fir.ref<i32>) -> ()
+  ! CHECK:      %[[NEXT_ARG1:.*]] = arith.addi %[[ARG1]], %[[STEP]] : i32
+  ! CHECK:      %[[ZERO:.*]] = arith.constant 0 : i32
+  ! CHECK:      %[[STEP_DIR:.*]] = arith.cmpi slt, %[[STEP]], %[[ZERO]] : i32
+  ! CHECK:      %[[LT_UB:.*]] = arith.cmpi slt, %[[NEXT_ARG1]], %[[UB]] : i32
+  ! CHECK:      %[[GT_UB:.*]] = arith.cmpi sgt, %[[NEXT_ARG1]], %[[UB]] : i32
+  ! CHECK:      %[[SEL:.*]] = arith.select %[[STEP_DIR]], %[[LT_UB]], %[[GT_UB]] : i1
+  ! CHECK:      fir.if %[[SEL]] {
+  ! CHECK:        fir.store %[[NEXT_ARG1]] to %[[I_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK:        %[[A_PVT_LOAD:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK:        hlfir.assign %[[A_PVT_LOAD]] to %[[ARG0_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
+  ! CHECK:      }
+
+  ! CHECK-NEXT: omp.yield
+  ! CHECK-NEXT: }
+  ! CHECK-NEXT: omp.terminator
+  ! CHECK-NEXT: }
+    do i=1, a
+      call foo(i, a)
+    end do
+  !$omp end parallel do
+  !CHECK: fir.call @_QPbar(%[[ARG0_DECL]]#1) {{.*}}: (!fir.ref<i32>) -> ()
+  call bar(a)
+end subroutine omp_do_lastprivate
+
+! CHECK: func @_QPomp_do_lastprivate2(%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<i32> {fir.bindc_name = "n"})
+subroutine omp_do_lastprivate2(a, n)
+  ! CHECK:  %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFomp_do_lastprivate2Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  ! CHECK:  %[[ARG1_DECL:.*]]:2 = hlfir.declare %[[ARG1]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFomp_do_lastprivate2En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  integer::a
+  integer::n
+  n = a+1
+  !$omp parallel do lastprivate(a, n)
+  ! CHECK:  omp.parallel {
+
+  ! CHECK: %[[A_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "a", pinned, {{.*}}}
+  ! CHECK: %[[A_PVT_DECL:.*]]:2 = hlfir.declare %[[A_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate2Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[N_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "n", pinned, uniq_name = "_QFomp_do_lastprivate2En"}
+  ! CHECK: %[[N_PVT_DECL:.*]]:2 = hlfir.declare %[[N_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate2En"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "i", pinned, {{.*}}}
+  ! CHECK: %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[LB:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK: %[[UB:.*]] = fir.load %[[N_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK: %[[STEP:.*]] = arith.constant 1 : i32
+  ! CHECK: omp.wsloop {
+  ! CHECK-NEXT: omp.loop_nest (%[[ARG2:.*]]) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
+  ! CHECK: fir.store %[[ARG2]] to %[[I_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK: fir.call @_QPfoo(%[[I_PVT_DECL]]#1, %[[A_PVT_DECL]]#1) {{.*}}: (!fir.ref<i32>, !fir.ref<i32>) -> ()
+  ! CHECK: %[[NEXT_ARG2:.*]] = arith.addi %[[ARG2]], %[[STEP]] : i32
+  ! CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
+  ! CHECK: %[[STEP_DIR:.*]] = arith.cmpi slt, %[[STEP]], %[[ZERO]] : i32
+  ! CHECK: %[[LT_UB:.*]] = arith.cmpi slt, %[[NEXT_ARG2]], %[[UB]] : i32
+  ! CHECK: %[[GT_UB:.*]] = arith.cmpi sgt, %[[NEXT_ARG2]], %[[UB]] : i32
+  ! CHECK: %[[SEL:.*]] = arith.select %[[STEP_DIR]], %[[LT_UB]], %[[GT_UB]] : i1
+  ! CHECK: fir.if %[[SEL]] {
+  ! CHECK:   fir.store %[[NEXT_ARG2]] to %[[I_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK:   %[[A_PVT_LOAD:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK:   hlfir.assign %[[A_PVT_LOAD]] to %[[ARG0_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
+  ! CHECK:   %[[N_PVT_LOAD:.*]] = fir.load %[[N_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK:   hlfir.assign %[[N_PVT_LOAD]] to %[[ARG1_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
+  ! CHECK: }
+
+  ! CHECK: omp.yield
+  ! CHECK: omp.terminator
+    do i= a, n
+      call foo(i, a)
+    end do
+  !$omp end parallel do
+  !CHECK: fir.call @_QPbar(%[[ARG1_DECL]]#1) {{.*}}: (!fir.ref<i32>) -> ()
+  call bar(n)
+end subroutine omp_do_lastprivate2
+
+! CHECK: func @_QPomp_do_lastprivate_collapse2(%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "a"})
+subroutine omp_do_lastprivate_collapse2(a)
+  ! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFomp_do_lastprivate_collapse2Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  integer::a
+  !$omp parallel do lastprivate(a) collapse(2)
+  ! CHECK:  omp.parallel {
+
+  ! CHECK: %[[A_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "a", pinned, uniq_name = "_QFomp_do_lastprivate_collapse2Ea"}
+  ! CHECK: %[[A_PVT_DECL:.*]]:2 = hlfir.declare %[[A_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse2Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "i", pinned, {{.*}}}
+  ! CHECK: %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  !
+  ! CHECK: %[[J_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "j", pinned, {{.*}}}
+  ! CHECK: %[[J_PVT_DECL:.*]]:2 = hlfir.declare %[[J_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse2Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[LB1:.*]] = arith.constant 1 : i32
+  ! CHECK-NEXT: %[[UB1:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK-NEXT: %[[STEP1:.*]] = arith.constant 1 : i32
+  ! CHECK: %[[LB2:.*]] = arith.constant 1 : i32
+  ! CHECK-NEXT: %[[UB2:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK-NEXT: %[[STEP2:.*]] = arith.constant 1 : i32
+  ! CHECK-NEXT: omp.wsloop {
+  ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]]) : i32 = (%[[LB1]], %[[LB2]]) to (%[[UB1]], %[[UB2]]) inclusive step (%[[STEP1]], %[[STEP2]]) {
+  ! CHECK-NEXT: fir.store %[[ARG1]] to %[[I_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK-NEXT: fir.store %[[ARG2]] to %[[J_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK-NEXT: fir.call @_QPfoo(%[[I_PVT_DECL]]#1, %[[A_PVT_DECL]]#1) {{.*}}: (!fir.ref<i32>, !fir.ref<i32>) -> ()
+  ! CHECK:      %[[NEXT_ARG1:.*]] = arith.addi %[[ARG1]], %[[STEP1]] : i32
+  ! CHECK:      %[[ZERO1:.*]] = arith.constant 0 : i32
+  ! CHECK:      %[[STEP1_END:.*]] = arith.cmpi slt, %[[STEP1]], %[[ZERO1]] : i32
+  ! CHECK:      %[[LT_UB1:.*]] = arith.cmpi slt, %[[NEXT_ARG1]], %[[UB1]] : i32
+  ! CHECK:      %[[GT_UB1:.*]] = arith.cmpi sgt, %[[NEXT_ARG1]], %[[UB1]] : i32
+  ! CHECK:      %[[SEL1:.*]] = arith.select %[[STEP1_END]], %[[LT_UB1]], %[[GT_UB1]] : i1
+  ! CHECK:      %[[NEXT_ARG2:.*]] = arith.addi %[[ARG2]], %[[STEP2]] : i32
+  ! CHECK:      %[[ZERO2:.*]] = arith.constant 0 : i32
+  ! CHECK:      %[[STEP2_END:.*]] = arith.cmpi slt, %[[STEP2]], %[[ZERO2]] : i32
+  ! CHECK:      %[[LT_UB2:.*]] = arith.cmpi slt, %[[NEXT_ARG2]], %[[UB2]] : i32
+  ! CHECK:      %[[GT_UB2:.*]] = arith.cmpi sgt, %[[NEXT_ARG2]], %[[UB2]] : i32
+  ! CHECK:      %[[SEL2:.*]] = arith.select %[[STEP2_END]], %[[LT_UB2]], %[[GT_UB2]] : i1
+  ! CHECK:      %[[AND:.*]] = arith.andi %[[SEL1]], %[[SEL2]] : i1
+  ! CHECK:      fir.if %[[AND]] {
+  ! CHECK:        fir.store %[[NEXT_ARG1]] to %[[I_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK:        fir.store %[[NEXT_ARG2]] to %[[J_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK:        %[[A_PVT_LOAD:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK:        hlfir.assign %[[A_PVT_LOAD]] to %[[ARG0_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
+  ! CHECK:      }
+
+  ! CHECK-NEXT: omp.yield
+  ! CHECK-NEXT: }
+  ! CHECK-NEXT: omp.terminator
+  ! CHECK-NEXT: }
+    do i=1, a
+      do j=1, a
+        call foo(i, a)
+      end do
+    end do
+  !$omp end parallel do
+  !CHECK: fir.call @_QPbar(%[[ARG0_DECL]]#1) {{.*}}: (!fir.ref<i32>) -> ()
+  call bar(a)
+end subroutine omp_do_lastprivate_collapse2
+
+! CHECK: func @_QPomp_do_lastprivate_collapse3(%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "a"})
+subroutine omp_do_lastprivate_collapse3(a)
+  ! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFomp_do_lastprivate_collapse3Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  integer::a
+  !$omp parallel do lastprivate(a) collapse(3)
+  ! CHECK:  omp.parallel {
+
+  ! CHECK: %[[A_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "a", pinned, uniq_name = "_QFomp_do_lastprivate_collapse3Ea"}
+  ! CHECK: %[[A_PVT_DECL:.*]]:2 = hlfir.declare %[[A_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[I_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "i", pinned, {{.*}}}
+  ! CHECK: %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[J_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "j", pinned, {{.*}}}
+  ! CHECK: %[[J_PVT_DECL:.*]]:2 = hlfir.declare %[[J_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[K_PVT_REF:.*]] = fir.alloca i32 {bindc_name = "k", pinned, {{.*}}}
+  ! CHECK: %[[K_PVT_DECL:.*]]:2 = hlfir.declare %[[K_PVT_REF]] {uniq_name = "_QFomp_do_lastprivate_collapse3Ek"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+  ! CHECK: %[[LB1:.*]] = arith.constant 1 : i32
+  ! CHECK-NEXT: %[[UB1:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK-NEXT: %[[STEP1:.*]] = arith.constant 1 : i32
+  ! CHECK: %[[LB2:.*]] = arith.constant 1 : i32
+  ! CHECK-NEXT: %[[UB2:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK-NEXT: %[[STEP2:.*]] = arith.constant 1 : i32
+  ! CHECK: %[[LB3:.*]] = arith.constant 1 : i32
+  ! CHECK-NEXT: %[[UB3:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK-NEXT: %[[STEP3:.*]] = arith.constant 1 : i32
+  ! CHECK-NEXT: omp.wsloop {
+  ! CHECK-NEXT: omp.loop_nest (%[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) : i32 = (%[[LB1]], %[[LB2]], %[[LB3]]) to (%[[UB1]], %[[UB2]], %[[UB3]]) inclusive step (%[[STEP1]], %[[STEP2]], %[[STEP3]]) {
+  ! CHECK-NEXT: fir.store %[[ARG1]] to %[[I_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK-NEXT: fir.store %[[ARG2]] to %[[J_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK-NEXT: fir.store %[[ARG3]] to %[[K_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK-NEXT: fir.call @_QPfoo(%[[I_PVT_DECL]]#1, %[[A_PVT_DECL]]#1) {{.*}}: (!fir.ref<i32>, !fir.ref<i32>) -> ()
+  ! CHECK:      %[[NEXT_ARG1:.*]] = arith.addi %[[ARG1]], %[[STEP1]] : i32
+  ! CHECK:      %[[ZERO1:.*]] = arith.constant 0 : i32
+  ! CHECK:      %[[STEP1_END:.*]] = arith.cmpi slt, %[[STEP1]], %[[ZERO1]] : i32
+  ! CHECK:      %[[LT_UB1:.*]] = arith.cmpi slt, %[[NEXT_ARG1]], %[[UB1]] : i32
+  ! CHECK:      %[[GT_UB1:.*]] = arith.cmpi sgt, %[[NEXT_ARG1]], %[[UB1]] : i32
+  ! CHECK:      %[[SEL1:.*]] = arith.select %[[STEP1_END]], %[[LT_UB1]], %[[GT_UB1]] : i1
+  ! CHECK:      %[[NEXT_ARG2:.*]] = arith.addi %[[ARG2]], %[[STEP2]] : i32
+  ! CHECK:      %[[ZERO2:.*]] = arith.constant 0 : i32
+  ! CHECK:      %[[STEP2_END:.*]] = arith.cmpi slt, %[[STEP2]], %[[ZERO2]] : i32
+  ! CHECK:      %[[LT_UB2:.*]] = arith.cmpi slt, %[[NEXT_ARG2]], %[[UB2]] : i32
+  ! CHECK:      %[[GT_UB2:.*]] = arith.cmpi sgt, %[[NEXT_ARG2]], %[[UB2]] : i32
+  ! CHECK:      %[[SEL2:.*]] = arith.select %[[STEP2_END]], %[[LT_UB2]], %[[GT_UB2]] : i1
+  ! CHECK:      %[[AND1:.*]] = arith.andi %[[SEL1]], %[[SEL2]] : i1
+  ! CHECK:      %[[NEXT_ARG3:.*]] = arith.addi %[[ARG3]], %[[STEP3]] : i32
+  ! CHECK:      %[[ZERO3:.*]] = arith.constant 0 : i32
+  ! CHECK:      %[[STEP3_END:.*]] = arith.cmpi slt, %[[STEP3]], %[[ZERO3]] : i32
+  ! CHECK:      %[[LT_UB3:.*]] = arith.cmpi slt, %[[NEXT_ARG3]], %[[UB3]] : i32
+  ! CHECK:      %[[GT_UB3:.*]] = arith.cmpi sgt, %[[NEXT_ARG3]], %[[UB3]] : i32
+  ! CHECK:      %[[SEL3:.*]] = arith.select %[[STEP3_END]], %[[LT_UB3]], %[[GT_UB3]] : i1
+  ! CHECK:      %[[AND2:.*]] = arith.andi %[[AND1]], %[[SEL3]] : i1
+  ! CHECK:      fir.if %[[AND2]] {
+  ! CHECK:        fir.store %[[NEXT_ARG1]] to %[[I_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK:        fir.store %[[NEXT_ARG2]] to %[[J_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK:        fir.store %[[NEXT_ARG3]] to %[[K_PVT_DECL]]#1 : !fir.ref<i32>
+  ! CHECK:        %[[A_PVT_LOAD:.*]] = fir.load %[[A_PVT_DECL]]#0 : !fir.ref<i32>
+  ! CHECK:        hlfir.assign %[[A_PVT_LOAD]] to %[[ARG0_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
+  ! CHECK:      }
+
+  ! CHECK-NEXT: omp.yield
+  ! CHECK-NEXT: }
+  ! CHECK-NEXT: omp.terminator
+  ! CHECK-NEXT: }
+    do i=1, a
+      do j=1, a
+        do k=1, a
+          call foo(i, a)
+        end do
+      end do
+    end do
+  !$omp end parallel do
+  !CHECK: fir.call @_QPbar(%[[ARG0_DECL]]#1) {{.*}}: (!fir.ref<i32>) -> ()
+  call bar(a)
+end subroutine omp_do_lastprivate_collapse3

>From b470e337cdc1c1dc63b4450dda61ea2ff94386c9 Mon Sep 17 00:00:00 2001
From: Brandon Wu <brandon.wu at sifive.com>
Date: Fri, 19 Jul 2024 05:23:55 -0700
Subject: [PATCH 2/2] [MVT][TableGen] Extend Machine Value Type to `uint16_t`

Currently 208 out of 256 MVTs are used, it will be run out soon, so
ultimately we need to extend the original `MVT::SimpleValueType` from
`uint8_t` to `uint16_t` to accomodate more types.
The `MatcherTable` uses `unsigned char` for encoding the matcher code,
so the extended MVTs are no longer fit into the table, thus we need to
use VBR to encode them as we do on others that are wider than 8 bits.

The statistics below shows the difference of "Total Array size" of the
matcher table that appears in every files:
Table                       Before     After     Change(%)
WebAssemblyGenDAGISel.inc   23576      23775     0.844
NVPTXGenDAGISel.inc         173498     173498    0
RISCVGenDAGISel.inc         2179121    2369929   8.756
AVRGenDAGISel.inc           2754       2754      0
PPCGenDAGISel.inc           163315     163617    0.185
MipsGenDAGISel.inc          47280      47447     0.353
SystemZGenDAGISel.inc       56243      56461     0.388
AArch64GenDAGISel.inc       467893     487830    4.261
MSP430GenDAGISel.inc        8069       8069      0
LoongArchGenDAGISel.inc     78928      79131     0.257
XCoreGenDAGISel.inc         3432       3432      0
BPFGenDAGISel.inc           3733       3733      0
VEGenDAGISel.inc            65174      66456     1.967
LanaiGenDAGISel.inc         2067       2067      0
X86GenDAGISel.inc           628787     636987    1.304
ARMGenDAGISel.inc           170968     171036    0.040
HexagonGenDAGISel.inc       155764     155764    0
SparcGenDAGISel.inc         5762       5798      0.625
AMDGPUGenDAGISel.inc        504356     504463    0.021
R600GenDAGISel.inc          29785      29785     0

The statistics below shows the runtime peak memory usage by compiling a
simple C program:
`/bin/time -v clang -target $TARGET -O3 -c test.c`
```
  int test(int a) {
    return a * 3;
  }
```
Target        Before(kbytes)    After(kbytes)    Change(%)
wasm64        110172            110088           -0.076
nvptx64       109784            109980            0.179
riscv64       114020            113656           -0.319
avr           110352            110068           -0.257
ppc64         112612            112476           -0.120
mips64        113588            113668            0.070
systemz       110860            110760           -0.090
aarch64       113704            113432           -0.239
msp430        110284            110200           -0.076
loongarch64   111052            110756           -0.267
xcore         108340            108020           -0.295
bpf           110620            110708            0.080
ve            110960            110920           -0.036
lanai         110180            109960           -0.200
x86_64        113640            113304           -0.296
arm64         113540            113172           -0.324
hexagon       114620            114684            0.056
sparc         110412            110136           -0.250
amdgcn        118164            117144           -0.863
r600          111200            110508           -0.622
---
 .../llvm/CodeGenTypes/MachineValueType.h      |  2 +-
 .../CodeGen/SelectionDAG/SelectionDAGISel.cpp | 69 ++++++++++++++-----
 .../TableGen/dag-isel-regclass-emit-enum.td   |  4 +-
 .../TableGen/Common/CodeGenDAGPatterns.h      |  8 +--
 llvm/utils/TableGen/DAGISelMatcherEmitter.cpp | 66 +++++++++++-------
 llvm/utils/TableGen/VTEmitter.cpp             |  4 +-
 6 files changed, 99 insertions(+), 54 deletions(-)

diff --git a/llvm/include/llvm/CodeGenTypes/MachineValueType.h b/llvm/include/llvm/CodeGenTypes/MachineValueType.h
index 7e7608c34fb70..556531c3147e5 100644
--- a/llvm/include/llvm/CodeGenTypes/MachineValueType.h
+++ b/llvm/include/llvm/CodeGenTypes/MachineValueType.h
@@ -33,7 +33,7 @@ namespace llvm {
   /// type can be represented by an MVT.
   class MVT {
   public:
-    enum SimpleValueType : uint8_t {
+    enum SimpleValueType : uint16_t {
       // Simple value types that aren't explicitly part of this enumeration
       // are considered extended value types.
       INVALID_SIMPLE_VALUE_TYPE = 0,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index df3d207d85d35..63612ffdc9397 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -2896,8 +2896,10 @@ CheckChild2CondCode(const unsigned char *MatcherTable, unsigned &MatcherIndex,
 LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
 CheckValueType(const unsigned char *MatcherTable, unsigned &MatcherIndex,
                SDValue N, const TargetLowering *TLI, const DataLayout &DL) {
-  MVT::SimpleValueType VT =
-      static_cast<MVT::SimpleValueType>(MatcherTable[MatcherIndex++]);
+  unsigned SimpleVT = MatcherTable[MatcherIndex++];
+  if (SimpleVT & 128)
+    SimpleVT = GetVBR(SimpleVT, MatcherTable, MatcherIndex);
+  MVT::SimpleValueType VT = static_cast<MVT::SimpleValueType>(SimpleVT);
   if (cast<VTSDNode>(N)->getVT() == VT)
     return true;
 
@@ -3027,7 +3029,10 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
       VT = MVT::i64;
       break;
     default:
-      VT = static_cast<MVT::SimpleValueType>(Table[Index++]);
+      unsigned SimpleVT = Table[Index++];
+      if (SimpleVT & 128)
+        SimpleVT = GetVBR(SimpleVT, Table, Index);
+      VT = static_cast<MVT::SimpleValueType>(SimpleVT);
       break;
     }
     Result = !::CheckType(VT, N, SDISel.TLI, SDISel.CurDAG->getDataLayout());
@@ -3035,7 +3040,10 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
   }
   case SelectionDAGISel::OPC_CheckTypeRes: {
     unsigned Res = Table[Index++];
-    Result = !::CheckType(static_cast<MVT::SimpleValueType>(Table[Index++]),
+    unsigned SimpleVT = Table[Index++];
+    if (SimpleVT & 128)
+      SimpleVT = GetVBR(SimpleVT, Table, Index);
+    Result = !::CheckType(static_cast<MVT::SimpleValueType>(SimpleVT),
                           N.getValue(Res), SDISel.TLI,
                           SDISel.CurDAG->getDataLayout());
     return Index;
@@ -3075,7 +3083,10 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
       VT = MVT::i64;
       ChildNo = Opcode - SelectionDAGISel::OPC_CheckChild0TypeI64;
     } else {
-      VT = static_cast<MVT::SimpleValueType>(Table[Index++]);
+      unsigned SimpleVT = Table[Index++];
+      if (SimpleVT & 128)
+        SimpleVT = GetVBR(SimpleVT, Table, Index);
+      VT = static_cast<MVT::SimpleValueType>(SimpleVT);
       ChildNo = Opcode - SelectionDAGISel::OPC_CheckChild0Type;
     }
     Result = !::CheckChildType(VT, N, SDISel.TLI,
@@ -3579,7 +3590,10 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
         VT = MVT::i64;
         break;
       default:
-        VT = static_cast<MVT::SimpleValueType>(MatcherTable[MatcherIndex++]);
+        unsigned SimpleVT = MatcherTable[MatcherIndex++];
+        if (SimpleVT & 128)
+          SimpleVT = GetVBR(SimpleVT, MatcherTable, MatcherIndex);
+        VT = static_cast<MVT::SimpleValueType>(SimpleVT);
         break;
       }
       if (!::CheckType(VT, N, TLI, CurDAG->getDataLayout()))
@@ -3588,9 +3602,11 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
 
     case OPC_CheckTypeRes: {
       unsigned Res = MatcherTable[MatcherIndex++];
-      if (!::CheckType(
-              static_cast<MVT::SimpleValueType>(MatcherTable[MatcherIndex++]),
-              N.getValue(Res), TLI, CurDAG->getDataLayout()))
+      unsigned SimpleVT = MatcherTable[MatcherIndex++];
+      if (SimpleVT & 128)
+        SimpleVT = GetVBR(SimpleVT, MatcherTable, MatcherIndex);
+      if (!::CheckType(static_cast<MVT::SimpleValueType>(SimpleVT),
+                       N.getValue(Res), TLI, CurDAG->getDataLayout()))
         break;
       continue;
     }
@@ -3629,7 +3645,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
     case OPC_SwitchType: {
       MVT CurNodeVT = N.getSimpleValueType();
       unsigned SwitchStart = MatcherIndex-1; (void)SwitchStart;
-      unsigned CaseSize;
+      unsigned CaseSize, CaseSimpleVT;
       while (true) {
         // Get the size of this case.
         CaseSize = MatcherTable[MatcherIndex++];
@@ -3637,8 +3653,10 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
           CaseSize = GetVBR(CaseSize, MatcherTable, MatcherIndex);
         if (CaseSize == 0) break;
 
-        MVT CaseVT =
-            static_cast<MVT::SimpleValueType>(MatcherTable[MatcherIndex++]);
+        CaseSimpleVT = MatcherTable[MatcherIndex++];
+        if (CaseSimpleVT & 128)
+          CaseSimpleVT = GetVBR(CaseSimpleVT, MatcherTable, MatcherIndex);
+        MVT CaseVT = static_cast<MVT::SimpleValueType>(CaseSimpleVT);
         if (CaseVT == MVT::iPTR)
           CaseVT = TLI->getPointerTy(CurDAG->getDataLayout());
 
@@ -3694,7 +3712,10 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
         VT = MVT::i64;
         ChildNo = Opcode - SelectionDAGISel::OPC_CheckChild0TypeI64;
       } else {
-        VT = static_cast<MVT::SimpleValueType>(MatcherTable[MatcherIndex++]);
+        unsigned SimpleVT = MatcherTable[MatcherIndex++];
+        if (SimpleVT & 128)
+          SimpleVT = GetVBR(SimpleVT, MatcherTable, MatcherIndex);
+        VT = static_cast<MVT::SimpleValueType>(SimpleVT);
         ChildNo = Opcode - SelectionDAGISel::OPC_CheckChild0Type;
       }
       if (!::CheckChildType(VT, N, TLI, CurDAG->getDataLayout(), ChildNo))
@@ -3788,7 +3809,10 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
         VT = MVT::i64;
         break;
       default:
-        VT = static_cast<MVT::SimpleValueType>(MatcherTable[MatcherIndex++]);
+        unsigned SimpleVT = MatcherTable[MatcherIndex++];
+        if (SimpleVT & 128)
+          SimpleVT = GetVBR(SimpleVT, MatcherTable, MatcherIndex);
+        VT = static_cast<MVT::SimpleValueType>(SimpleVT);
         break;
       }
       int64_t Val = MatcherTable[MatcherIndex++];
@@ -3812,7 +3836,10 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
         VT = MVT::i64;
         break;
       default:
-        VT = static_cast<MVT::SimpleValueType>(MatcherTable[MatcherIndex++]);
+        unsigned SimpleVT = MatcherTable[MatcherIndex++];
+        if (SimpleVT & 128)
+          SimpleVT = GetVBR(SimpleVT, MatcherTable, MatcherIndex);
+        VT = static_cast<MVT::SimpleValueType>(SimpleVT);
         break;
       }
       unsigned RegNo = MatcherTable[MatcherIndex++];
@@ -3824,8 +3851,10 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
       // For targets w/ more than 256 register names, the register enum
       // values are stored in two bytes in the matcher table (just like
       // opcodes).
-      MVT::SimpleValueType VT =
-          static_cast<MVT::SimpleValueType>(MatcherTable[MatcherIndex++]);
+      unsigned SimpleVT = MatcherTable[MatcherIndex++];
+      if (SimpleVT & 128)
+        SimpleVT = GetVBR(SimpleVT, MatcherTable, MatcherIndex);
+      MVT::SimpleValueType VT = static_cast<MVT::SimpleValueType>(SimpleVT);
       unsigned RegNo = MatcherTable[MatcherIndex++];
       RegNo |= MatcherTable[MatcherIndex++] << 8;
       RecordedNodes.push_back(std::pair<SDValue, SDNode*>(
@@ -4063,8 +4092,10 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
         NumVTs = MatcherTable[MatcherIndex++];
       SmallVector<EVT, 4> VTs;
       for (unsigned i = 0; i != NumVTs; ++i) {
-        MVT::SimpleValueType VT =
-            static_cast<MVT::SimpleValueType>(MatcherTable[MatcherIndex++]);
+        unsigned SimpleVT = MatcherTable[MatcherIndex++];
+        if (SimpleVT & 128)
+          SimpleVT = GetVBR(SimpleVT, MatcherTable, MatcherIndex);
+        MVT::SimpleValueType VT = static_cast<MVT::SimpleValueType>(SimpleVT);
         if (VT == MVT::iPTR)
           VT = TLI->getPointerTy(CurDAG->getDataLayout()).SimpleTy;
         VTs.push_back(VT);
diff --git a/llvm/test/TableGen/dag-isel-regclass-emit-enum.td b/llvm/test/TableGen/dag-isel-regclass-emit-enum.td
index ba60df2bbcf6e..0fd7fa02f064c 100644
--- a/llvm/test/TableGen/dag-isel-regclass-emit-enum.td
+++ b/llvm/test/TableGen/dag-isel-regclass-emit-enum.td
@@ -27,13 +27,13 @@ def GPRAbove127 : RegisterClass<"TestTarget", [i32], 32,
 // CHECK-NEXT: OPC_CheckChild1Integer, 0,
 // CHECK-NEXT: OPC_EmitInteger32, 0|128,2/*256*/,
 // CHECK-NEXT: OPC_MorphNodeTo1None, TARGET_VAL(TargetOpcode::COPY_TO_REGCLASS),
-// CHECK-NEXT:     MVT::i32, 2/*#Ops*/, 1, 0,
+// CHECK-NEXT:     7, 2/*#Ops*/, 1, 0,
 def : Pat<(i32 (add i32:$src, (i32 0))),
           (COPY_TO_REGCLASS GPRAbove127, GPR0:$src)>;
 
 // CHECK:      OPC_CheckChild1Integer, 2,
 // CHECK-NEXT: OPC_EmitStringInteger32, TestNamespace::GPR127RegClassID,
 // CHECK-NEXT: OPC_MorphNodeTo1None, TARGET_VAL(TargetOpcode::COPY_TO_REGCLASS),
-// CHECK-NEXT:     MVT::i32, 2/*#Ops*/, 1, 0,
+// CHECK-NEXT:     7, 2/*#Ops*/, 1, 0,
 def : Pat<(i32 (add i32:$src, (i32 1))),
           (COPY_TO_REGCLASS GPR127, GPR0:$src)>;
diff --git a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h
index 1f4d45d81fd33..bac213a356d84 100644
--- a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h
+++ b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h
@@ -48,15 +48,15 @@ class CodeGenDAGPatterns;
 using TreePatternNodePtr = IntrusiveRefCntPtr<TreePatternNode>;
 
 /// This represents a set of MVTs. Since the underlying type for the MVT
-/// is uint8_t, there are at most 256 values. To reduce the number of memory
+/// is uint16_t, there are at most 65536 values. To reduce the number of memory
 /// allocations and deallocations, represent the set as a sequence of bits.
 /// To reduce the allocations even further, make MachineValueTypeSet own
 /// the storage and use std::array as the bit container.
 struct MachineValueTypeSet {
   static_assert(std::is_same<std::underlying_type_t<MVT::SimpleValueType>,
-                             uint8_t>::value,
-                "Change uint8_t here to the SimpleValueType's type");
-  static unsigned constexpr Capacity = std::numeric_limits<uint8_t>::max() + 1;
+                             uint16_t>::value,
+                "Change uint16_t here to the SimpleValueType's type");
+  static unsigned constexpr Capacity = std::numeric_limits<uint16_t>::max() + 1;
   using WordType = uint64_t;
   static unsigned constexpr WordWidth = CHAR_BIT * sizeof(WordType);
   static unsigned constexpr NumWords = Capacity / WordWidth;
diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
index 229245ff40504..178ceb1443335 100644
--- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
@@ -339,7 +339,8 @@ unsigned MatcherTableEmitter::SizeMatcher(Matcher *N, raw_ostream &OS) {
         Size += 2; // Count the child's opcode.
       } else {
         Child = cast<SwitchTypeMatcher>(N)->getCaseMatcher(i);
-        ++Size; // Count the child's type.
+        Size += GetVBRSize(cast<SwitchTypeMatcher>(N)->getCaseType(
+            i)); // Count the child's type.
       }
       const unsigned ChildSize = SizeMatcherList(Child, OS);
       assert(ChildSize != 0 && "Matcher cannot have child of size 0");
@@ -599,7 +600,8 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
         IdxSize = 2; // size of opcode in table is 2 bytes.
       } else {
         Child = cast<SwitchTypeMatcher>(N)->getCaseMatcher(i);
-        IdxSize = 1; // size of type in table is 1 byte.
+        IdxSize = GetVBRSize(cast<SwitchTypeMatcher>(N)->getCaseType(
+            i)); // size of type in table is sizeof(MVT) byte.
       }
 
       if (i != 0) {
@@ -616,7 +618,8 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
       if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N))
         OS << "TARGET_VAL(" << SOM->getCaseOpcode(i).getEnumName() << "),";
       else
-        OS << getEnumName(cast<SwitchTypeMatcher>(N)->getCaseType(i)) << ',';
+        EmitVBRValue(cast<SwitchTypeMatcher>(N)->getCaseType(i),
+                     OS); // size of type in table is sizeof(MVT) byte.
       if (!OmitComments)
         OS << "// ->" << CurrentIdx + ChildSize;
       OS << '\n';
@@ -639,7 +642,7 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
     return CurrentIdx - StartIdx + 1;
   }
 
-  case Matcher::CheckType:
+  case Matcher::CheckType: {
     if (cast<CheckTypeMatcher>(N)->getResNo() == 0) {
       MVT::SimpleValueType VT = cast<CheckTypeMatcher>(N)->getType();
       switch (VT) {
@@ -648,13 +651,17 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
         OS << "OPC_CheckTypeI" << MVT(VT).getSizeInBits() << ",\n";
         return 1;
       default:
-        OS << "OPC_CheckType, " << getEnumName(VT) << ",\n";
-        return 2;
+        OS << "OPC_CheckType, ";
+        unsigned NumBytes = EmitVBRValue(VT, OS);
+        OS << "\n";
+        return NumBytes + 1;
       }
     }
-    OS << "OPC_CheckTypeRes, " << cast<CheckTypeMatcher>(N)->getResNo() << ", "
-       << getEnumName(cast<CheckTypeMatcher>(N)->getType()) << ",\n";
-    return 3;
+    OS << "OPC_CheckTypeRes, " << cast<CheckTypeMatcher>(N)->getResNo() << ", ";
+    unsigned NumBytes = EmitVBRValue(cast<CheckTypeMatcher>(N)->getType(), OS);
+    OS << "\n";
+    return NumBytes + 2;
+  }
 
   case Matcher::CheckChildType: {
     MVT::SimpleValueType VT = cast<CheckChildTypeMatcher>(N)->getType();
@@ -666,8 +673,10 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
       return 1;
     default:
       OS << "OPC_CheckChild" << cast<CheckChildTypeMatcher>(N)->getChildNo()
-         << "Type, " << getEnumName(VT) << ",\n";
-      return 2;
+         << "Type, ";
+      unsigned NumBytes = EmitVBRValue(VT, OS);
+      OS << "\n";
+      return NumBytes + 1;
     }
   }
 
@@ -696,10 +705,13 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
        << cast<CheckChild2CondCodeMatcher>(N)->getCondCodeName() << ",\n";
     return 2;
 
-  case Matcher::CheckValueType:
-    OS << "OPC_CheckValueType, "
-       << getEnumName(cast<CheckValueTypeMatcher>(N)->getVT()) << ",\n";
-    return 2;
+  case Matcher::CheckValueType: {
+    OS << "OPC_CheckValueType, ";
+    unsigned NumBytes =
+        EmitVBRValue(cast<CheckValueTypeMatcher>(N)->getVT(), OS);
+    OS << "\n";
+    return NumBytes + 1;
+  }
 
   case Matcher::CheckComplexPat: {
     const CheckComplexPatMatcher *CCPM = cast<CheckComplexPatMatcher>(N);
@@ -766,8 +778,8 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
       OS << "OPC_EmitInteger" << MVT(VT).getSizeInBits() << ", ";
       break;
     default:
-      OpBytes = 2;
-      OS << "OPC_EmitInteger, " << getEnumName(VT) << ", ";
+      OS << "OPC_EmitInteger, ";
+      OpBytes = EmitVBRValue(VT, OS) + 1;
       break;
     }
     unsigned Bytes = OpBytes + EmitSignedVBRValue(Val, OS);
@@ -785,8 +797,8 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
       OS << "OPC_EmitStringInteger" << MVT(VT).getSizeInBits() << ", ";
       break;
     default:
-      OpBytes = 2;
-      OS << "OPC_EmitStringInteger, " << getEnumName(VT) << ", ";
+      OS << "OPC_EmitStringInteger, ";
+      OpBytes = EmitVBRValue(VT, OS) + 1;
       break;
     }
     OS << Val << ",\n";
@@ -797,14 +809,15 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
     const EmitRegisterMatcher *Matcher = cast<EmitRegisterMatcher>(N);
     const CodeGenRegister *Reg = Matcher->getReg();
     MVT::SimpleValueType VT = Matcher->getVT();
+    unsigned OpBytes;
     // If the enum value of the register is larger than one byte can handle,
     // use EmitRegister2.
     if (Reg && Reg->EnumValue > 255) {
-      OS << "OPC_EmitRegister2, " << getEnumName(VT) << ", ";
+      OS << "OPC_EmitRegister2, ";
+      OpBytes = EmitVBRValue(VT, OS);
       OS << "TARGET_VAL(" << getQualifiedName(Reg->TheDef) << "),\n";
-      return 4;
+      return OpBytes + 3;
     }
-    unsigned OpBytes;
     switch (VT) {
     case MVT::i32:
     case MVT::i64:
@@ -812,8 +825,8 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
       OS << "OPC_EmitRegisterI" << MVT(VT).getSizeInBits() << ", ";
       break;
     default:
-      OpBytes = 2;
-      OS << "OPC_EmitRegister, " << getEnumName(VT) << ", ";
+      OS << "OPC_EmitRegister, ";
+      OpBytes = EmitVBRValue(VT, OS) + 1;
       break;
     }
     if (Reg) {
@@ -958,8 +971,9 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
         OS << "/*#VTs*/";
       OS << ", ";
     }
+    unsigned NumTypeBytes = 0;
     for (unsigned i = 0, e = EN->getNumVTs(); i != e; ++i)
-      OS << getEnumName(EN->getVT(i)) << ", ";
+      NumTypeBytes += EmitVBRValue(EN->getVT(i), OS);
 
     OS << EN->getNumOperands();
     if (!OmitComments)
@@ -992,7 +1006,7 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
     } else
       OS << '\n';
 
-    return 4 + !CompressVTs + !CompressNodeInfo + EN->getNumVTs() +
+    return 4 + !CompressVTs + !CompressNodeInfo + NumTypeBytes +
            NumOperandBytes + NumCoveredBytes;
   }
   case Matcher::CompleteMatch: {
diff --git a/llvm/utils/TableGen/VTEmitter.cpp b/llvm/utils/TableGen/VTEmitter.cpp
index 851a525cf48c0..77d3514b4c22c 100644
--- a/llvm/utils/TableGen/VTEmitter.cpp
+++ b/llvm/utils/TableGen/VTEmitter.cpp
@@ -79,12 +79,12 @@ static void VTtoGetLLVMTyString(raw_ostream &OS, const Record *VT) {
 void VTEmitter::run(raw_ostream &OS) {
   emitSourceFileHeader("ValueTypes Source Fragment", OS, Records);
 
-  std::array<const Record *, 256> VTsByNumber = {};
+  std::array<const Record *, 65536> VTsByNumber = {};
   auto ValueTypes = Records.getAllDerivedDefinitions("ValueType");
   for (auto *VT : ValueTypes) {
     auto Number = VT->getValueAsInt("Value");
     assert(0 <= Number && Number < (int)VTsByNumber.size() &&
-           "ValueType should be uint8_t");
+           "ValueType should be uint16_t");
     assert(!VTsByNumber[Number] && "Duplicate ValueType");
     VTsByNumber[Number] = VT;
   }



More information about the flang-commits mailing list