Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 69 additions & 126 deletions daphne/http_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

from twisted.internet.defer import inlineCallbacks, maybeDeferred
from twisted.internet.interfaces import IProtocolNegotiationFactory
from twisted.protocols.policies import ProtocolWrapper
from twisted.web import http
from zope.interface import implementer

from .utils import HEADER_NAME_RE, parse_x_forwarded_for
from .utils import parse_x_forwarded_for

logger = logging.getLogger(__name__)

Expand All @@ -24,132 +23,77 @@ class WebRequest(http.Request):
"""

error_template = (
"""
<html>
<head>
<title>%(title)s</title>
<style>
body { font-family: sans-serif; margin: 0; padding: 0; }
h1 { padding: 0.6em 0 0.2em 20px; color: #896868; margin: 0; }
p { padding: 0 0 0.3em 20px; margin: 0; }
footer { padding: 1em 0 0.3em 20px; color: #999; font-size: 80%%; font-style: italic; }
</style>
</head>
<body>
<h1>%(title)s</h1>
<p>%(body)s</p>
<footer>Daphne</footer>
</body>
</html>
""".replace(
"\n", ""
)
.replace(" ", " ")
.replace(" ", " ")
.replace(" ", " ")
) # Shorten it a bit, bytes wise
b"<!DOCTYPE html>"
b"<html>"
b"<head><title>%(status)d %(status_text)s</title></head>"
b"<body><h1>%(status)d %(status_text)s</h1>%(text)s</body>"
b"</html>"
)

def __init__(self, *args, **kwargs):
http.Request.__init__(self, *args, **kwargs)
# Easy server link
self.server = self.channel.factory.server
self.application_queue = None
self._response_started = False
self.client_addr = None
self.server_addr = None
try:
http.Request.__init__(self, *args, **kwargs)
# Easy server link
self.server = self.channel.factory.server
self.application_queue = None
self._response_started = False
self.server.protocol_connected(self)
except Exception:
logger.error(traceback.format_exc())
raise

### Twisted progress callbacks
self.client_scheme = None
# Build the client address
if self.transport:
peer = self.transport.getPeer()
host = self.transport.getHost()
# Always set scheme if we have a transport
self.client_scheme = (
"https" if hasattr(peer, "is_ssl") and peer.is_ssl else "http"
)
if hasattr(peer, "host") and hasattr(peer, "port"):
self.client_addr = [str(peer.host), peer.port]
self.server_addr = [str(host.host), host.port]
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
self.is_websocket = upgrade_header and upgrade_header.lower() == b"websocket"
# Hook up request parsing
self.socket_opened = time.time()
self.server.protocol_connected(self)

@inlineCallbacks
def process(self):
"""
Called when all headers have been received and we can start processing content.
"""
# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if forwarded
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
{name: value for name, value in self.requestHeaders.getAllRawHeaders()},
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme,
)
# Check for maximum request body size
if self.server.request_max_size:
self.channel.maxData = self.server.request_max_size
# Get query string
self.query_string = self.uri.split(b"?", 1)[1] if b"?" in self.uri else b""
try:
self.request_start = time.time()

# Validate header names.
for name, _ in self.requestHeaders.getAllRawHeaders():
if not HEADER_NAME_RE.fullmatch(name):
self.basic_error(400, b"Bad Request", "Invalid header name")
return

# Get upgrade header
upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if possible
if hasattr(self.client, "host") and hasattr(self.client, "port"):
# client.host and host.host are byte strings in Python 2, but spec
# requires unicode string.
self.client_addr = [str(self.client.host), self.client.port]
self.server_addr = [str(self.host.host), self.host.port]

self.client_scheme = "https" if self.isSecure() else "http"

# See if we need to get the address from a proxy header instead
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
self.requestHeaders,
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme,
)
# Check for unicodeish path (or it'll crash when trying to parse)
try:
self.path.decode("ascii")
except UnicodeDecodeError:
self.path = b"/"
self.basic_error(400, b"Bad Request", "Invalid characters in path")
return
# Calculate query string
self.query_string = b""
if b"?" in self.uri:
self.query_string = self.uri.split(b"?", 1)[1]
try:
self.query_string.decode("ascii")
except UnicodeDecodeError:
self.basic_error(400, b"Bad Request", "Invalid query string")
return
# Is it WebSocket? IS IT?!
# Process WebSocket requests via HTTP upgrade
if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to
protocol = self.server.ws_factory.buildProtocol(
self.transport.getPeer()
)
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
logger.warn("Could not make WebSocket protocol")
self.finish()
# Give it the raw query string
protocol._raw_query_string = self.query_string
# Port across transport
transport, self.transport = self.transport, None
if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol
transport.wrappedProtocol = protocol
else:
transport.protocol = protocol
protocol.makeConnection(transport)
# Re-inject request
data = self.method + b" " + self.uri + b" HTTP/1.1\x0d\x0a"
for h in self.requestHeaders.getAllRawHeaders():
data += h[0] + b": " + b",".join(h[1]) + b"\x0d\x0a"
data += b"\x0d\x0a"
data += self.content.read()
protocol.dataReceived(data)
# Remove our HTTP reply channel association
# Pass request to WebSocketResource for handling
self.server.ws_resource.render_GET(self)
# The WebSocketResource will handle the rest of the connection
logger.debug("Upgraded connection %s to WebSocket", self.client_addr)
self.server.protocol_disconnected(self)
# Resume the producer so we keep getting data, if it's available as a method
self.channel._networkProducer.resumeProducing()

# Boring old HTTP.
# Don't continue with HTTP processing
return
# Handle normal HTTP requests
else:
# Sanitize and decode headers, potentially extracting root path
self.clean_headers = []
Expand Down Expand Up @@ -339,9 +283,9 @@ def duration(self):
"""
Returns the time since the start of the request.
"""
if not hasattr(self, "request_start"):
if not hasattr(self, "socket_opened"):
return 0
return time.time() - self.request_start
return time.time() - self.socket_opened

def basic_error(self, status, status_text, body):
"""
Expand All @@ -357,13 +301,12 @@ def basic_error(self, status, status_text, body):
self.handle_reply(
{
"type": "http.response.body",
"body": (
self.error_template
% {
"title": str(status) + " " + status_text.decode("ascii"),
"body": body,
}
).encode("utf8"),
"body": self.error_template
% {
"status": status,
"status_text": status_text,
"text": body,
},
}
)

Expand Down
16 changes: 10 additions & 6 deletions daphne/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from twisted.internet.endpoints import serverFromString
from twisted.logger import STDLibLogObserver, globalLogBeginner
from twisted.web import http
from twisted.web.websocket import WebSocketResource

from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.request_buffer_size = request_buffer_size
self.request_max_size = None # No limit by default
self.proxy_forwarded_address_header = proxy_forwarded_address_header
self.proxy_forwarded_port_header = proxy_forwarded_port_header
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header
Expand All @@ -99,12 +101,14 @@ def run(self):
self.connections = {}
# Make the factory
self.http_factory = HTTPFactory(self)
self.ws_factory = WebSocketFactory(self, server=self.server_name)
self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout,
allowNullOrigin=True,
openHandshakeTimeout=self.websocket_handshake_timeout,
)

# Create WebSocket factory
self.ws_factory = WebSocketFactory(server_class=self)

# Create WebSocket resource for handling upgrade requests
self.ws_resource = WebSocketResource(self.ws_factory)

# Configure logging
if self.verbosity <= 1:
# Redirect the Twisted log to nowhere
globalLogBeginner.beginLoggingTo(
Expand Down
Loading
Loading