about summary refs log tree commit diff stats
path: root/src/network_set.cpp
blob: 3238dcd0a3f73a081086a5206e51b9819ffcde80 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include "network_set.h"

void NetworkSet::Clear() {
  networks_.clear();
  network_by_item_.clear();
}

int NetworkSet::AddLink(int id1, int id2) {
  if (id2 > id1) {
    // Make sure id1 < id2
    std::swap(id1, id2);
  }

  if (network_by_item_.count(id1)) {
    if (network_by_item_.count(id2)) {
      int network_id1 = network_by_item_[id1];
      int network_id2 = network_by_item_[id2];

      networks_[network_id1].emplace(id1, id2);

      if (network_id1 != network_id2) {
        for (const auto& [other_id1, other_id2] : networks_[network_id2]) {
          network_by_item_[other_id1] = network_id1;
          network_by_item_[other_id2] = network_id1;
        }

        networks_[network_id1].merge(networks_[network_id2]);
        networks_[network_id2].clear();
      }

      return network_id1;
    } else {
      int network_id = network_by_item_[id1];
      network_by_item_[id2] = network_id;
      networks_[network_id].emplace(id1, id2);

      return network_id;
    }
  } else {
    if (network_by_item_.count(id2)) {
      int network_id = network_by_item_[id2];
      network_by_item_[id1] = network_id;
      networks_[network_id].emplace(id1, id2);

      return network_id;
    } else {
      int network_id = networks_.size();
      network_by_item_[id1] = network_id;
      network_by_item_[id2] = network_id;
      networks_.emplace_back();
      networks_[network_id] = {{id1, id2}};

      return network_id;
    }
  }
}

bool NetworkSet::IsItemInNetwork(int id) const {
  return network_by_item_.count(id);
}

int NetworkSet::GetNetworkWithItem(int id) const {
  return network_by_item_.at(id);
}

const std::set<std::pair<int, int>>& NetworkSet::GetNetworkGraph(int id) const {
  return networks_.at(id);
}