# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright (C) 2021 Freie Universität Berlin
#
# Distributed under terms of the MIT license.
"""DNS over DTLS serving side of the proxy."""
import abc
import asyncio
import logging
import time
import typing
from DTLSSocket import dtls
from .config import Config
from .dns_server import BaseServerFactory, BaseDNSServer
from .dns_upstream import DNSUpstreamServerMixin
logger = logging.getLogger()
[docs]
class BaseDTLSWrapper(abc.ABC):
"""An abstract wrapper for a DTLS implementation
:param transport: The datagram transport the datagrams should be encrypted
and decrypted for.
:type transport: :py:class:`asyncio.DatagramTransport`
"""
def __init__(self, transport: asyncio.DatagramTransport):
self.transport = transport
[docs]
@abc.abstractmethod
def is_connected(self, addr: typing.Any) -> bool:
"""Check if a session with ``addr`` was established.
:param addr: A remote endpoint (implementation-specific)
:returns: ``True``, when a session with ``addr`` is established,
``False`` if not.
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def sessions(self) -> typing.Sequence:
"""Returns all currently established sessions.
:returns: A sequence of (implementation-specific) remote endpoints with
which a session is established.
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def connect(self, addr: typing.Any) -> typing.NoReturn:
"""Establish a session with ``addr``.
:param addr: An (implementation-specific) remote endpoint
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def close(self, addr: typing.Any) -> typing.NoReturn:
"""Closes a session with ``addr``.
:param addr: An (implementation-specific) remote endpoint
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def handle_message(
self, msg: bytes, addr: typing.Any
) -> typing.Tuple[bytes, typing.Any, bool]:
"""Handles a DTLS message that came over the datagram transport.
:param msg: An incoming DTLS message.
:type msg: bytes
:param addr: The remote endpoint as served by the datagram transport.
:returns: A 3-tuple, containing
1. The unencrypted message,
2. The (implementation-specific) remote endpoint the
message was received from, an
3. A boolean, indicating if the last message established a
session with the remote endpoint.
If ``msg`` was a control message, the first and second
elements will be ``None``.
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def write(self, msg: bytes, addr: typing.Any) -> typing.NoReturn:
"""Send a ``msg`` encrypted to ``addr``
:param msg: The message to encrypt via DTLS.
:type msg: bytes
:param addr: An (implementation-specific) remote endpoint to send the
encrypted message to.
"""
raise NotImplementedError
[docs]
class TinyDTLSWrapper(BaseDTLSWrapper):
"""A wrapper for
`tinydtls <https://projects.eclipse.org/projects/iot.tinydtls>`_.
"""
EVENT_CONNECTED = 0x1DE
_CT_HANDSHAKE = 22
_HT_SERVER_HELLO_DONE = 14
LOG_LEVEL = {
# pylint: disable=c-extension-no-member
logging.DEBUG: dtls.DTLS_LOG_DEBUG,
logging.INFO: dtls.DTLS_LOG_INFO,
logging.WARNING: dtls.DTLS_LOG_WARN,
logging.ERROR: dtls.DTLS_LOG_CRIT,
logging.CRITICAL: dtls.DTLS_LOG_EMERG,
}
def __init__(self, transport):
super().__init__(transport)
# pylint: disable=c-extension-no-member
credentials = Config()["dtls_credentials"]
client_identity = credentials["client_identity"].encode()
psk = credentials["psk"].encode()
self._dtls = dtls.DTLS(
read=self._read,
write=self._write,
event=self._event,
pskId=client_identity,
pskStore={client_identity: psk},
)
dtls.setLogLevel(self.LOG_LEVEL[logger.level])
self._active_sessions = {}
self._app_data = None
self._last_event = None
def __del__(self):
self._active_sessions.clear()
del self._dtls
self._dtls = None
def _read(self, addr, data):
self._app_data = (data, addr + (0, 0))
return len(data)
def _write(self, addr, data):
if (
len(data) > 13
and data[0] == self._CT_HANDSHAKE
and data[13] == self._HT_SERVER_HELLO_DONE
):
delay = Config().get("dtls", {}).get("server_hello_done_delay")
if delay:
# Delay Server Hello Done for a bit
time.sleep(delay)
self.transport.sendto(data, addr + (0, 0))
return len(data)
def _event(self, level, code):
# pylint: disable=unused-argument
self._last_event = code
[docs]
def is_connected(self, addr):
return addr in self._active_sessions
[docs]
def sessions(self):
return list(self._active_sessions)
[docs]
def connect(self, addr):
self._dtls.connect(*addr)
[docs]
def close(self, addr):
if self.is_connected(addr):
self._dtls.close(self._active_sessions[addr])
del self._active_sessions[addr]
[docs]
def handle_message(self, msg, addr):
connected = False
if isinstance(addr, tuple):
ret = self._dtls.handleMessageAddr(addr[0], addr[1], msg)
elif isinstance(addr, dtls.Session): # pylint: disable=I1101
ret = self._dtls.handleMessage(addr, msg)
addr = addr.addr, addr.port, addr.flowinfo, addr.scope_id
else:
raise ValueError(f"Unexpected session type {type(addr)}")
if ret < 0:
logger.warning("Unable to handle incoming DTLS message from %s", addr)
return None, None, connected
if self._last_event == self.EVENT_CONNECTED and not self.is_connected(addr):
# pylint: disable=c-extension-no-member
self._active_sessions[addr] = dtls.Session(*addr[:4])
connected = True
self._last_event = None
if self._app_data is None:
logger.debug("Unable to fetch application data from DTLS")
return None, None, connected
data, addr = self._app_data
self._app_data = None
return data, addr, connected
[docs]
def write(self, msg, addr):
if isinstance(addr, tuple):
if not self.is_connected(addr):
logger.warning("%s does not have an active session", addr)
return
addr = self._active_sessions[addr]
self._dtls.write(addr, msg)
[docs]
class DNSOverDTLSServerFactory(BaseServerFactory):
"""Factory to create DNS over DLTS servers"""
# pylint: disable=too-few-public-methods
DODTLS_PORT = 853
dtls_class = TinyDTLSWrapper
[docs]
class DNSOverDTLSServer(BaseDNSServer, DNSUpstreamServerMixin):
"""DNS over DTLS server implementation.
:param factory: The factory that created the DNS over DTLS server.
:type factory: :py:class:`DNSOverDTLSServerFactory`
"""
def __init__(self, factory):
super().__init__(dns_upstream=factory.dns_upstream)
self.factory = factory
self.transport = None
self._dtls = None
def __del__(self):
del self._dtls
self._dtls = None
[docs]
def connection_made(self, transport):
# pylint: disable=line-too-long
"""See `connection_made()`_
.. _`connection_made()`: https://docs.python.org/3/library/asyncio-protocol.html#asyncio.BaseProtocol.connection_made
""" # noqa: E501
self.transport = transport
self._dtls = self.factory.dtls_class(self.transport)
[docs]
def datagram_received(self, data, addr):
# pylint: disable=line-too-long
"""See `datagram_received()`_
.. _`datagram_received()`: https://docs.python.org/3/library/asyncio-protocol.html#asyncio.DatagramProtocol.datagram_received
""" # noqa: E501
data, addr, _ = self._dtls.handle_message(data, addr)
if data is None:
return
self.dns_query_received(data, addr)
[docs]
def error_received(self, exc):
# pylint: disable=line-too-long
"""See `error_received()`_
.. _`error_received()`: https://docs.python.org/3/library/asyncio-protocol.html#asyncio.DatagramProtocol.error_received
""" # noqa: E501
self._dtls = None # pragma: no cover
raise exc # pragma: no cover
[docs]
def send_response_to_requester(self, response, requester):
self._dtls.write(response, requester)
[docs]
async def close(self):
if self.transport is not None:
if self._dtls is not None: # pragma: no cover
for session in self._dtls.sessions():
self._dtls.close(session)
self.transport.close()
self.transport = None
[docs]
def connection_lost(self, exc): # pylint: disable=unused-argument
# pylint: disable=line-too-long
"""See `connection_lost()`_
.. _`connection_lost()`: https://docs.python.org/3/library/asyncio-protocol.html#asyncio.BaseProtocol.connection_lost
""" # noqa: E501
self._dtls = None
def _create_server_protocol(self, *args, **kwargs):
try:
config = Config()
_ = config["dtls_credentials"]["client_identity"]
_ = config["dtls_credentials"]["psk"]
except KeyError as exc:
raise RuntimeError(f"DTLS credential option {exc} not found") from exc
return self.DNSOverDTLSServer(self, *args, **kwargs)
[docs]
async def create_server(self, loop, *args, local_addr=None, **kwargs):
"""Creates an :py:class:`DNSOverDTLSServer` object.
:param loop: the asyncio event loop the server should run in
:type loop: :py:class:`asyncio.AbstractEventLoop`
:param local_addr: A tuple for the created server to bind to. The first
element is the host part, the second element the
port.
:type local_addr: :py:class:`typing.Tuple[str, int]`
:returns: An :py:class:`DNSOverDTLSServer` object representing an
DNS over DTLS server.
:rtype: :py:class:`DNSOverDTLSServer`
"""
if local_addr is None:
local_addr = ("localhost", self.DODTLS_PORT)
if local_addr[1] is None:
local_addr = (local_addr[0], self.DODTLS_PORT)
_, protocol = await loop.create_datagram_endpoint(
self._create_server_protocol,
*args,
local_addr=local_addr,
**kwargs,
)
return protocol