Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 53 additions & 51 deletions cryptarchia/cryptarchia.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,36 +73,36 @@ class Config:
def cryptarchia_v0_0_1(initial_total_active_stake) -> "Config":
return Config(
k=2160,
active_slot_coeff=0.05,
active_slot_coeff=1 / 30,
epoch_stake_distribution_stabilization=3,
epoch_period_nonce_buffer=3,
epoch_period_nonce_stabilization=4,
initial_total_active_stake=initial_total_active_stake,
total_active_stake_learning_rate=0.8,
total_active_stake_learning_rate=1.0,
time=TimeConfig(
slot_duration=1,
chain_start_time=0,
),
)

@property
@functools.cached_property
def base_period_length(self) -> int:
return int(floor(self.k / self.active_slot_coeff))

@property
@functools.cached_property
def epoch_relative_nonce_slot(self) -> int:
return (
self.epoch_stake_distribution_stabilization + self.epoch_period_nonce_buffer
) * self.base_period_length

@property
@functools.cached_property
def epoch_length(self) -> int:
return (
self.epoch_relative_nonce_slot
+ self.epoch_period_nonce_stabilization * self.base_period_length
)

@property
@functools.cached_property
def s(self):
"""
The Security Paramater. This paramter controls how many slots one must
Expand Down Expand Up @@ -162,6 +162,7 @@ def encode_sk(self) -> bytes:
def encode_pk(self) -> bytes:
return int.to_bytes(self.pk, length=32, byteorder="big")

@functools.cached_property
def commitment(self) -> Hash:
value_bytes = int.to_bytes(self.value, length=32, byteorder="big")
return Hash(
Expand All @@ -174,8 +175,9 @@ def commitment(self) -> Hash:
self.zone_id,
)

@functools.cached_property
def nullifier(self) -> Hash:
return Hash(b"NOMOS_NOTE_NF", self.commitment(), self.encode_sk())
return Hash(b"NOMOS_NOTE_NF", self.commitment, self.encode_sk())


@dataclass
Expand All @@ -188,7 +190,7 @@ def epoch_nonce_contribution(self) -> Hash:
return Hash(
b"NOMOS_NONCE_CONTRIB",
self.slot.encode(),
self.note.commitment(),
self.note.commitment,
self.note.encode_sk(),
)

Expand All @@ -199,8 +201,8 @@ def verify(
return (
slot == self.slot
and parent == self.parent
and self.note.commitment() in commitments
and self.note.nullifier() not in nullifiers
and self.note.commitment in commitments
and self.note.nullifier not in nullifiers
)


Expand All @@ -217,6 +219,7 @@ class BlockHeader:
# as serialized in the format specified by the 'HEADER' rule in 'messages.abnf'.
#
# The following code is to be considered as a reference implementation, mostly to be used for testing.
@functools.cached_property
def id(self) -> Hash:
return Hash(
b"BLOCK_ID",
Expand All @@ -231,7 +234,7 @@ def id(self) -> Hash:
)

def __hash__(self):
return hash(self.id())
return hash(self.id)


@dataclass
Expand Down Expand Up @@ -271,7 +274,7 @@ def replace(self, **kwarg) -> "LedgerState":
return replace(self, **kwarg)

def apply(self, block: BlockHeader):
assert block.parent == self.block.id()
assert block.parent == self.block.id

self.nonce = Hash(
b"EPOCH_NONCE",
Expand Down Expand Up @@ -317,12 +320,12 @@ class Follower:
def __init__(self, genesis_state: LedgerState, config: Config):
self.config = config
self.forks: list[Hash] = []
self.local_chain = genesis_state.block.id()
self.local_chain = genesis_state.block.id
self.genesis_state = genesis_state
self.ledger_state = {genesis_state.block.id(): genesis_state.copy()}
self.ledger_state = {genesis_state.block.id: genesis_state.copy()}
self.epoch_state = {}
self.state = State.BOOTSTRAPPING
self.lib = genesis_state.block.id() # Last immutable block, initially the genesis block
self.lib = genesis_state.block.id # Last immutable block, initially the genesis block

def to_online(self):
"""
Expand Down Expand Up @@ -359,22 +362,22 @@ def validate_header(self, block: BlockHeader):
raise InvalidLeaderProof

def on_block(self, block: BlockHeader):
if block.id() in self.ledger_state:
if block.id in self.ledger_state:
logger.warning("dropping already processed block")
return

self.validate_header(block)

new_state = self.ledger_state[block.parent].copy()
new_state.apply(block)
self.ledger_state[block.id()] = new_state
self.ledger_state[block.id] = new_state

if block.parent == self.local_chain:
# simply extending the local chain
self.local_chain = block.id()
self.local_chain = block.id
else:
# otherwise, this block creates a fork
self.forks.append(block.id())
self.forks.append(block.id)

# remove any existing fork that is superceded by this block
if block.parent in self.forks:
Expand All @@ -389,7 +392,6 @@ def on_block(self, block: BlockHeader):
if self.state == State.ONLINE:
self.update_lib()


# Update the lib, and prune forks that do not descend from it.
def update_lib(self):
"""
Expand All @@ -401,7 +403,7 @@ def update_lib(self):
return
# prune forks that do not descend from the last immutable block, this is needed to avoid Genesis rule to roll back
# past the LIB
self.lib = next(islice(iter_chain(self.local_chain, self.ledger_state), self.config.k, None), self.genesis_state).block.id()
self.lib = next(islice(iter_chain(self.local_chain, self.ledger_state), self.config.k, None), self.genesis_state).block.id
self.forks = [
f for f in self.forks if is_ancestor(self.lib, f, self.ledger_state)
]
Expand All @@ -411,7 +413,6 @@ def update_lib(self):
if is_ancestor(self.lib, k, self.ledger_state) or is_ancestor(k, self.lib, self.ledger_state)
}


# Evaluate the fork choice rule and return the chain we should be following
def fork_choice(self) -> Hash:
if self.state == State.BOOTSTRAPPING:
Expand Down Expand Up @@ -450,20 +451,18 @@ def state_at_slot_beginning(self, tip: Hash, slot: Slot) -> LedgerState:
def epoch_start_slot(self, epoch) -> Slot:
return Slot(epoch.epoch * self.config.epoch_length)

def stake_distribution_snapshot(self, epoch, tip: Hash):
def stake_distribution_snapshot_slot(self, epoch):
# stake distribution snapshot happens at the beginning of the previous epoch,
# i.e. for epoch e, the snapshot is taken at the last block of epoch e-2
slot = Slot(epoch.prev().epoch * self.config.epoch_length)
return self.state_at_slot_beginning(tip, slot)
return self.epoch_start_slot(epoch.prev())

def nonce_snapshot(self, epoch, tip):
def nonce_snapshot_slot(self, epoch):
# nonce snapshot happens partway through the previous epoch after the
# stake distribution has stabilized
slot = Slot(
return Slot(
self.config.epoch_relative_nonce_slot
+ self.epoch_start_slot(epoch.prev()).absolute_slot
)
return self.state_at_slot_beginning(tip, slot)

def compute_epoch_state(self, epoch: Epoch, tip: Hash) -> EpochState:
if epoch.epoch == 0:
Expand All @@ -473,18 +472,21 @@ def compute_epoch_state(self, epoch: Epoch, tip: Hash) -> EpochState:
inferred_total_active_stake=self.config.initial_total_active_stake,
)

stake_distribution_snapshot = self.stake_distribution_snapshot(epoch, tip)
nonce_snapshot = self.nonce_snapshot(epoch, tip)

# we memoize epoch states to avoid recursion killing our performance
memo_block_id = nonce_snapshot.block.id()
if state := self.epoch_state.get((epoch, memo_block_id)):
if state := self.epoch_state.get((epoch, tip)):
return state

nonce_slot = self.nonce_snapshot_slot(epoch)
stake_distribution_slot = self.stake_distribution_snapshot_slot(epoch)

stake_distribution_snapshot = self.state_at_slot_beginning(
tip, stake_distribution_slot
)
nonce_snapshot = self.state_at_slot_beginning(tip, nonce_slot)

# To update our inference of total stake, we need the prior estimate which
# was calculated last epoch. Thus we recurse here to retreive the previous
# estimate of total stake.
prev_epoch = self.compute_epoch_state(epoch.prev(), tip)
prev_epoch = self.compute_epoch_state(epoch.prev(), nonce_snapshot.block.id)
inferred_total_active_stake = self._infer_total_active_stake(
prev_epoch, nonce_snapshot, stake_distribution_snapshot
)
Expand All @@ -495,7 +497,7 @@ def compute_epoch_state(self, epoch: Epoch, tip: Hash) -> EpochState:
inferred_total_active_stake=inferred_total_active_stake,
)

self.epoch_state[(epoch, memo_block_id)] = state
self.epoch_state[(epoch, tip)] = state
return state

def _infer_total_active_stake(
Expand All @@ -509,19 +511,19 @@ def _infer_total_active_stake(
# Since we need a stable inference of total stake for the start of this epoch,
# we limit our look back period to the start of last epoch until when the nonce
# snapshot was taken.
block_proposals_last_epoch = (
period_block_density = (
nonce_snapshot.leader_count - stake_distribution_snapshot.leader_count
)
T = self.config.epoch_relative_nonce_slot
mean_blocks_per_slot = block_proposals_last_epoch / T
expected_blocks_per_slot = np.log(1 / (1 - self.config.active_slot_coeff))
blocks_per_slot_err = expected_blocks_per_slot - mean_blocks_per_slot
h = (
self.config.total_active_stake_learning_rate
* prev_epoch.inferred_total_active_stake
/ expected_blocks_per_slot
)
return int(prev_epoch.inferred_total_active_stake - h * blocks_per_slot_err)
# Use epoch_relative_nonce_slot as this is the actual observation window
# (the slot range from stake_distribution_snapshot to nonce_snapshot)
period = self.config.epoch_relative_nonce_slot
f = self.config.active_slot_coeff
beta = self.config.total_active_stake_learning_rate
total_stake_estimate = prev_epoch.inferred_total_active_stake

slot_activation_error = 1 - period_block_density / (period * f)
coefficient = total_stake_estimate * beta
return max(1, int(total_stake_estimate - coefficient * slot_activation_error))

def blocks_by_slot(self, from_slot: Slot) -> Generator[BlockHeader, None, None]:
# Returns blocks in the given range of slots in order of slot
Expand Down Expand Up @@ -564,7 +566,7 @@ def _is_slot_leader(self, epoch: EpochState, slot: Slot):
b"LEAD",
epoch.nonce(),
slot.encode(),
self.note.commitment(),
self.note.commitment,
self.note.encode_sk(),
)
ticket = int.from_bytes(ticket)
Expand Down Expand Up @@ -605,7 +607,7 @@ def is_ancestor(a: Hash, b: Hash, states: Dict[Hash, LedgerState]) -> bool:
Returns True if `a` is an ancestor of `b` in the chain.
"""
for state in iter_chain(b, states):
if state.block.id() == a:
if state.block.id == a:
return True
return False

Expand All @@ -623,7 +625,7 @@ def common_prefix_depth(
try:
a_block = next(a_blocks)
a_suffix.append(a_block)
a_block_id = a_block.id()
a_block_id = a_block.id
if a_block_id in seen:
# we had seen this block from the fork chain
return (
Expand All @@ -640,7 +642,7 @@ def common_prefix_depth(
try:
b_block = next(b_blocks)
b_suffix.append(b_block)
b_block_id = b_block.id()
b_block_id = b_block.id
if b_block_id in seen:
# we had seen the fork in the local chain
return (
Expand Down
21 changes: 9 additions & 12 deletions cryptarchia/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def sync(local: Follower, peers: list[Follower], checkpoint: LedgerState | None
num_blocks += 1
# Reject blocks that have been rejected in the past
# or whose parent has been rejected.
if {block.id(), block.parent} & rejected_blocks:
rejected_blocks.add(block.id())
if {block.id, block.parent} & rejected_blocks:
rejected_blocks.add(block.id)
continue

try:
Expand All @@ -49,7 +49,7 @@ def sync(local: Follower, peers: list[Follower], checkpoint: LedgerState | None
except ParentNotFound:
orphans.add(block)
except Exception:
rejected_blocks.add(block.id())
rejected_blocks.add(block.id)

# Finish the sync process if no block has been fetched,
# which means that no peer has a tip ahead of the local tip.
Expand All @@ -63,14 +63,11 @@ def sync(local: Follower, peers: list[Follower], checkpoint: LedgerState | None
# Skip the orphan block if it has been processed during the previous backfillings
# (i.e. if it has been already added to the local block tree).
# Or, skip if it has been rejected during the previous backfillings.
if (
orphan.id() not in local.ledger_state
and orphan.id() not in rejected_blocks
):
if orphan.id not in local.ledger_state and orphan.id not in rejected_blocks:
try:
backfill_fork(local, orphan, block_fetcher)
except InvalidBlockFromBackfillFork as e:
rejected_blocks.update(block.id() for block in e.invalid_suffix)
rejected_blocks.update(block.id for block in e.invalid_suffix)


def backfill_fork(
Expand All @@ -83,7 +80,7 @@ def backfill_fork(

suffix = find_missing_part(
local,
block_fetcher.fetch_chain_backward(fork_tip.id(), local),
block_fetcher.fetch_chain_backward(fork_tip.id, local),
)

# Add blocks in the fork suffix with applying fork choice rule.
Expand All @@ -105,7 +102,7 @@ def find_missing_part(

suffix: list[BlockHeader] = []
for block in fork:
if block.id() in local.ledger_state:
if block.id in local.ledger_state:
break
suffix.append(block)
suffix.reverse()
Expand Down Expand Up @@ -167,15 +164,15 @@ def fetch_chain_backward(
# First, try to iterate the chain from the local block tree.
for block in iter_chain_blocks(id, local.ledger_state):
yield block
if block.id() == local.genesis_state.block.id():
if block.id == local.genesis_state.block.id:
return
id = block.parent

# Try to continue by fetching the remaining blocks from the peers
for peer in self.peers:
for block in iter_chain_blocks(id, peer.ledger_state):
yield block
if block.id() == local.genesis_state.block.id():
if block.id == local.genesis_state.block.id:
return
id = block.parent

Expand Down
Loading