Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions da/assignations/refill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
from heapq import heappush, heappop, heapify


DeclarationId: TypeAlias = bytes
Assignations: TypeAlias = List[Set[DeclarationId]]
ProviderId: TypeAlias = bytes
Assignations: TypeAlias = List[Set[ProviderId]]
BlakeRng: TypeAlias = Any


@dataclass(order=True)
class Participant:
# Participant's wrapper class
# Used for keeping ordering in the heap by the participation first and the declaration id second
# Used for keeping ordering in the heap by the participation first and the provider id second
participation: int # prioritize participation count first
declaration_id: DeclarationId # sort by id on default
provider_id: ProviderId # sort by id on default


@dataclass
class Subnetwork:
# Subnetwork wrapper that keeps the subnetwork id [0..2048) and the set of participants in that subnetwork
participants: Set[DeclarationId]
participants: Set[ProviderId]
subnetwork_id: int

def __lt__(self, other):
Expand All @@ -47,7 +47,7 @@ def all_nodes_assigned(participants: Sequence[Participant], average_participatio
def heappop_next_for_subnetwork(subnetwork: Subnetwork, participants: List[Participant]) -> Participant:
poped = []
participant = heappop(participants)
while participant.declaration_id in subnetwork.participants:
while participant.provider_id in subnetwork.participants:
poped.append(participant)
participant = heappop(participants)
for poped in poped:
Expand Down Expand Up @@ -79,11 +79,11 @@ def fill_subnetworks(
# take the fewest participants subnetwork
subnetwork = heappop(subnetworks)

# take the declaration with the lowest participation that is not included in the subnetwork
# take the provider with the lowest participation that is not included in the subnetwork
participant = heappop_next_for_subnetwork(subnetwork, available_nodes)

# fill into subnetwork
subnetwork.participants.add(participant.declaration_id)
subnetwork.participants.add(participant.provider_id)
participant.participation += 1
# push to heaps
heappush(available_nodes, participant)
Expand Down Expand Up @@ -112,11 +112,11 @@ def balance_subnetworks_grow(
):
for participant in filter(lambda x: x.participation > average_participation, sorted(participants)):
for subnework in sample(
sorted(filter(lambda subnetwork: participant.declaration_id in subnetwork.participants, subnetworks)),
sorted(filter(lambda subnetwork: participant.provider_id in subnetwork.participants, subnetworks)),
random,
k=participant.participation - average_participation
):
subnework.participants.remove(participant.declaration_id)
subnework.participants.remove(participant.provider_id)
participant.participation -= 1


Expand All @@ -129,7 +129,7 @@ def rand(seed: bytes):


def calculate_subnetwork_assignations(
new_nodes_list: Sequence[DeclarationId],
new_nodes_list: Sequence[ProviderId],
previous_subnets: Assignations,
replication_factor: int,
random_seed: bytes,
Expand All @@ -146,7 +146,7 @@ def calculate_subnetwork_assignations(
# 1) For each (sorted) participant, remove the participant from random subnetworks (coming from sorted list)
# until the participation of is equal to the average participation.
# 4. Create a heap with the set of active nodes ordered by, primary the number of subnetworks each participant is at
# and secondary by the DeclarationId of the participant (ascending order).
# and secondary by the ProviderId of the participant (ascending order).
# 5. Create a heap with the subnetworks ordered by the number of participants in each subnetwork
# 6. Until all subnetworks are filled up to a replication factor and all nodes are assigned:
# 1) pop the subnetwork with the fewest participants
Expand All @@ -169,11 +169,11 @@ def calculate_subnetwork_assignations(
active_assignations = [subnet - unavailable_nodes for subnet in previous_subnets]

# count participation per assigned node
assigned_count: Counter[DeclarationId] = Counter(chain.from_iterable(active_assignations))
assigned_count: Counter[ProviderId] = Counter(chain.from_iterable(active_assignations))

# available nodes heap
available_nodes = [
Participant(participation=assigned_count.get(_id, 0), declaration_id=_id) for _id in new_nodes
Participant(participation=assigned_count.get(_id, 0), provider_id=_id) for _id in new_nodes
]

# subnetworks heap
Expand Down
10 changes: 5 additions & 5 deletions da/assignations/test_refill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from itertools import chain
from typing import List, Counter
from unittest import TestCase
from da.assignations.refill import calculate_subnetwork_assignations, Assignations, DeclarationId
from da.assignations.refill import calculate_subnetwork_assignations, Assignations, ProviderId


class TestRefill(TestCase):
Expand Down Expand Up @@ -71,21 +71,21 @@ def test_random_increase_decrease_network(self):


@classmethod
def mutate_nodes(cls, nodes: List[DeclarationId], count: int):
def mutate_nodes(cls, nodes: List[ProviderId], count: int):
assert count < len(nodes)
for i in random.choices(list(range(len(nodes))), k=count):
nodes[i] = random.randbytes(32)

@classmethod
def expand_nodes(cls, nodes: List[DeclarationId], count: int) -> List[DeclarationId]:
def expand_nodes(cls, nodes: List[ProviderId], count: int) -> List[ProviderId]:
return [*nodes, *(random.randbytes(32) for _ in range(count))]

@classmethod
def shrink_nodes(cls, nodes: List[DeclarationId], count: int) -> List[DeclarationId]:
def shrink_nodes(cls, nodes: List[ProviderId], count: int) -> List[ProviderId]:
return list(random.sample(nodes, k=count))


def assert_assignations(self, assignations: Assignations, nodes: List[DeclarationId], replication_factor: int):
def assert_assignations(self, assignations: Assignations, nodes: List[ProviderId], replication_factor: int):
self.assertEqual(
len(set(chain.from_iterable(assignations))),
len(nodes),
Expand Down