Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 73 additions & 34 deletions libp2p/protocols/secure/noise.nim
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ proc encrypt(

inc state.n
if state.n > NonceMax:
raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
raise (ref NoiseNonceMaxError)(msg: "Noise max nonce value reached")

proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte]
{.raises: [NoiseNonceMaxError].} =
Expand All @@ -168,10 +168,11 @@ proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte]
trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n
if tagIn != tagOut:
debug "decryptWithAd failed", data = shortLog(data)
raise newException(NoiseDecryptTagError, "decryptWithAd failed tag authentication.")
raise (ref NoiseDecryptTagError)(msg:
"decryptWithAd failed tag authentication.")
inc state.n
if state.n > NonceMax:
raise newException(NoiseNonceMaxError, "Noise max nonce value reached")
raise (ref NoiseNonceMaxError)(msg: "Noise max nonce value reached")

# Symmetricstate

Expand All @@ -181,8 +182,7 @@ proc init(_: type[SymmetricState]): SymmetricState =
result.cs = CipherState(k: EmptyKey)

proc mixKey(ss: var SymmetricState, ikm: ChaChaPolyKey) =
var
temp_keys: array[2, ChaChaPolyKey]
var temp_keys: array[2, ChaChaPolyKey]
sha256.hkdf(ss.ck, ikm, [], temp_keys)
ss.ck = temp_keys[0]
ss.cs = CipherState(k: temp_keys[1])
Expand All @@ -198,8 +198,7 @@ proc mixHash(ss: var SymmetricState, data: openArray[byte]) =

# We might use this for other handshake patterns/tokens
proc mixKeyAndHash(ss: var SymmetricState, ikm: openArray[byte]) {.used.} =
var
temp_keys: array[3, ChaChaPolyKey]
var temp_keys: array[3, ChaChaPolyKey]
sha256.hkdf(ss.ck, ikm, [], temp_keys)
ss.ck = temp_keys[0]
ss.mixHash(temp_keys[1])
Expand Down Expand Up @@ -234,7 +233,8 @@ proc init(_: type[HandshakeState]): HandshakeState =

template write_e: untyped =
trace "noise write e"
# Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key).
# Sets e (which must be empty) to GENERATE_KEYPAIR().
# Appends e.public_key to the buffer. Calls MixHash(e.public_key).
hs.e = genKeyPair(p.rng[])
msg.add hs.e.publicKey
hs.ss.mixHash(hs.e.publicKey)
Expand Down Expand Up @@ -275,26 +275,28 @@ template read_e: untyped =
trace "noise read e", size = msg.len

if msg.len < Curve25519Key.len:
raise newException(NoiseHandshakeError, "Noise E, expected more data")
raise (ref NoiseHandshakeError)(msg: "Noise E, expected more data")

# Sets re (which must be empty) to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
# Sets re (which must be empty) to the next DHLEN bytes from the message.
# Calls MixHash(re.public_key).
hs.re[0..Curve25519Key.high] = msg.toOpenArray(0, Curve25519Key.high)
msg.consume(Curve25519Key.len)
hs.ss.mixHash(hs.re)

template read_s: untyped =
trace "noise read s", size = msg.len
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes otherwise.
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True,
# or to the next DHLEN bytes otherwise.
# Sets rs (which must be empty) to DecryptAndHash(temp).
let
rsLen =
if hs.ss.cs.hasKey:
if msg.len < Curve25519Key.len + ChaChaPolyTag.len:
raise newException(NoiseHandshakeError, "Noise S, expected more data")
raise (ref NoiseHandshakeError)(msg: "Noise S, expected more data")
Curve25519Key.len + ChaChaPolyTag.len
else:
if msg.len < Curve25519Key.len:
raise newException(NoiseHandshakeError, "Noise S, expected more data")
raise (ref NoiseHandshakeError)(msg: "Noise S, expected more data")
Curve25519Key.len
hs.rs[0..Curve25519Key.high] =
hs.ss.decryptAndHash(msg.toOpenArray(0, rsLen - 1))
Expand All @@ -315,7 +317,11 @@ proc readFrame(
await sconn.readExactly(addr buffer[0], buffer.len)
return buffer

proc writeFrame(sconn: Connection, buf: openArray[byte]): Future[void] =
proc writeFrame(
sconn: Connection,
buf: openArray[byte]
): Future[void] {.async: (raises: [
CancelledError, LPStreamError], raw: true).} =
doAssert buf.len <= uint16.high.int
var
lesize = buf.len.uint16
Expand All @@ -326,13 +332,24 @@ proc writeFrame(sconn: Connection, buf: openArray[byte]): Future[void] =
outbuf &= buf
sconn.write(outbuf)

proc receiveHSMessage(sconn: Connection): Future[seq[byte]] = readFrame(sconn)
proc sendHSMessage(sconn: Connection, buf: openArray[byte]): Future[void] =
proc receiveHSMessage(
sconn: Connection
): Future[seq[byte]] {.async: (raises: [
CancelledError, LPStreamError], raw: true).} =
readFrame(sconn)

proc sendHSMessage(
sconn: Connection,
buf: openArray[byte]
): Future[void] {.async: (raises: [
CancelledError, LPStreamError], raw: true).} =
writeFrame(sconn, buf)

proc handshakeXXOutbound(
p: Noise, conn: Connection,
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
p2pSecret: seq[byte]
): Future[HandshakeResult] {.async: (raises: [
CancelledError, LPStreamError]).} =
const initiator = true
var
hs = HandshakeState.init()
Expand Down Expand Up @@ -374,13 +391,16 @@ proc handshakeXXOutbound(
await conn.sendHSMessage(msg.data)

let (cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
return HandshakeResult(
cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
finally:
burnMem(hs)

proc handshakeXXInbound(
p: Noise, conn: Connection,
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
p2pSecret: seq[byte]
): Future[HandshakeResult] {.async: (raises: [
CancelledError, LPStreamError]).} =
const initiator = false

var
Expand Down Expand Up @@ -424,7 +444,8 @@ proc handshakeXXInbound(
let
remoteP2psecret = hs.ss.decryptAndHash(msg.data)
(cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
return HandshakeResult(
cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
finally:
burnMem(hs)

Expand Down Expand Up @@ -486,7 +507,8 @@ method write*(
try:
encryptFrame(
sconn,
cipherFrames.toOpenArray(woffset, woffset + chunkSize + FramingSize - 1),
cipherFrames.toOpenArray(
woffset, woffset + chunkSize + FramingSize - 1),
message.toOpenArray(offset, offset + chunkSize - 1))
except NoiseNonceMaxError as exc:
debug "Noise nonce exceeded"
Expand All @@ -509,21 +531,28 @@ method write*(
# sequencing issues
sconn.stream.write(cipherFrames)

method handshake*(p: Noise, conn: Connection, initiator: bool, peerId: Opt[PeerId]): Future[SecureConn] {.async.} =
method handshake*(
p: Noise,
conn: Connection,
initiator: bool,
peerId: Opt[PeerId]
): Future[SecureConn] {.async: (raises: [CancelledError, LPStreamError]).} =
trace "Starting Noise handshake", conn, initiator

let timeout = conn.timeout
conn.timeout = HandshakeTimeout

# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
let
signedPayload = p.localPrivateKey.sign(
PayloadString & p.noiseKeys.publicKey.getBytes).tryGet()
let signedPayload = p.localPrivateKey.sign(
PayloadString & p.noiseKeys.publicKey.getBytes)
if signedPayload.isErr():
raise (ref NoiseHandshakeError)(msg:
"Failed to sign public key: " & $signedPayload.error())

var
libp2pProof = initProtoBuffer()
libp2pProof.write(1, p.localPublicKey)
libp2pProof.write(2, signedPayload.getBytes())
libp2pProof.write(2, signedPayload.get().getBytes())
# data field also there but not used!
libp2pProof.finish()

Expand All @@ -542,29 +571,38 @@ method handshake*(p: Noise, conn: Connection, initiator: bool, peerId: Opt[PeerI
remoteSigBytes: seq[byte]

if not remoteProof.getField(1, remotePubKeyBytes).valueOr(false):
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ")")
raise (ref NoiseHandshakeError)(msg:
"Failed to deserialize remote public key bytes. (initiator: " &
$initiator & ")")
if not remoteProof.getField(2, remoteSigBytes).valueOr(false):
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ")")
raise (ref NoiseHandshakeError)(msg:
"Failed to deserialize remote signature bytes. (initiator: " &
$initiator & ")")

if not remotePubKey.init(remotePubKeyBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ")")
raise (ref NoiseHandshakeError)(msg:
"Failed to decode remote public key. (initiator: " & $initiator & ")")
if not remoteSig.init(remoteSigBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ")")
raise (ref NoiseHandshakeError)(msg:
"Failed to decode remote signature. (initiator: " & $initiator & ")")

let verifyPayload = PayloadString & handshakeRes.rs.getBytes
if not remoteSig.verify(verifyPayload, remotePubKey):
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
raise (ref NoiseHandshakeError)(msg:
"Noise handshake signature verify failed.")
else:
trace "Remote signature verified", conn

let pid = PeerId.init(remotePubKey).valueOr:
raise newException(NoiseHandshakeError, "Invalid remote peer id: " & $error)
raise (ref NoiseHandshakeError)(msg:
"Invalid remote peer id: " & $error)

trace "Remote peer id", pid = $pid

peerId.withValue(targetPid):
if not targetPid.validate():
raise newException(NoiseHandshakeError, "Failed to validate expected peerId.")
raise (ref NoiseHandshakeError)(msg:
"Failed to validate expected peerId.")

if pid != targetPid:
var
Expand All @@ -574,7 +612,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool, peerId: Opt[PeerI
initiator, dealt_peer = conn,
dealt_key = $failedKey, received_peer = $pid,
received_key = $remotePubKey
raise newException(NoiseHandshakeError, "Noise handshake, peer id don't match! " & $pid & " != " & $targetPid)
raise (ref NoiseHandshakeError)(msg:
"Noise handshake, peer id don't match! " & $pid & " != " & $targetPid)
conn.peerId = pid

var tmp = NoiseConnection.new(conn, conn.peerId, conn.observedAddr)
Expand Down
Loading