exhuma exhuma - 22 days ago 4
Python Question

How can I count each UDP packet sent out by subprocesses?

I have a Python application which orchestrates calls to an underlying process. The processes are called using

subprocess.check_output
and they make SNMP calls to remote network devices.

For performance monitoring, I would like to count the number of sent SNMP packets which are transmitted. I am primarily interested in the count of the packets. Packet size of request/response would be interesting too, but less important. The aim is to have an idea on the firewall stress this application causes.

So, for the sake of argument, let's assume the following silly application:

from subprocess import check_output
output = check_output(['snmpget', '-v2c', '-c', 'private', '192.168.1.1', '1.3.6.1.2.1.1.2.0'])
print(output)


This would cause a new UDP packet to be sent out on port 161.

How can I count them in such a case?

Here's another version with stubbed functions (could also be a context manager):

from subprocess import check_call


def start_monitoring():
pass


def stop_monitoring():
return 0


start_monitoring()
check_call(['snmpget', '-v2c', '-c', 'private', '192.168.1.1', '1.3.6.1.2.1.1.2.0'])
check_call(['snmpget', '-v2c', '-c', 'private', '192.168.1.1', '1.3.6.1.2.1.1.2.0'])
check_call(['snmpget', '-v2c', '-c', 'private', '192.168.1.1', '1.3.6.1.2.1.1.2.0'])
num_connections = stop_monitoring()
assert num_connections == 3


In this contrived example, it will obviously be 3 calls, as I manually execute the SNMP calls. But in the practical example, the number of SNMP calls is not equal to calls to the subprocess. Sometimes one or more GETs are executed, sometimes it's simple walks (that is, a lot of sequential UDP requests) sometimes it's bulk walks (an unknown amount of requests).

So I can't simply monitor the amount of times the application is called. I really have to monitor the UDP requests.

Is something like that even possible? If yes, how?

It's likely important to know that this runs on Linux as non-root user. But all subprocesses run as the same user.

Answer

Following this answer, this github repo via yet another answer, I came up with the following implementation of a UDP Proxy/Relay:

#!/usr/bin/env python

from collections import namedtuple
from contextlib import contextmanager
from random import randint
from time import sleep
import logging
import socket
import threading

import snmp


MSG_DONTWAIT = 0x40  # from socket.h
LOCK = threading.Lock()

MSG_TYPE_REQUEST = 1
MSG_TYPE_RESPONSE = 2
Statistics = namedtuple('Statistics', 'msgtype packet_size')


def visible_octets(data: bytes) -> str:
    """
    Returns a geek-friendly (hexdump)  output of a bytes object.

    Developer note:
        This is not super performant. But it's not something that's supposed to
        be run during normal operations (mostly for testing and debugging).  So
        performance should not be an issue, and this is less obfuscated than
        existing solutions.

    Example::

        >>> from os import urandom
        >>> print(visible_octets(urandom(40)))
        99 1f 56 a9 25 50 f7 9b  95 7e ff 80 16 14 88 c5   ..V.%P...~......
        f3 b4 83 d4 89 b2 34 b4  71 4e 5a 69 aa 9f 1d f8   ......4.qNZi....
        1d 33 f9 8e f1 b9 12 e9                            .3......

    """
    from binascii import hexlify, unhexlify

    hexed = hexlify(data).decode('ascii')
    tuples = [''.join((a, b)) for a, b in zip(hexed[::2], hexed[1::2])]
    line = []
    output = []
    ascii_column = []
    for idx, octet in enumerate(tuples):
        line.append(octet)
        # only use printable characters in ascii output
        ascii_column.append(octet if 32 <= int(octet, 16) < 127 else '2e')
        if (idx+1) % 8 == 0:
            line.append('')
        if (idx+1) % 8 == 0 and (idx+1) % 16 == 0:
            raw_ascii = unhexlify(''.join(ascii_column))
            raw_ascii = raw_ascii.replace(b'\\n z', b'.')
            ascii_column = []
            output.append('%-50s %s' % (' '.join(line),
                                        raw_ascii.decode('ascii')))
            line = []
    raw_ascii = unhexlify(''.join(ascii_column))
    raw_ascii = raw_ascii.replace(b'\\n z', b'.')
    output.append('%-50s %s' % (' '.join(line), raw_ascii.decode('ascii')))
    line = []
    return '\n'.join(output)


@contextmanager
def UdpProxy(remote_host, remote_port, queue=None):
    thread = UdpProxyThread(remote_host, remote_port, stats_queue=queue)
    thread.prime()
    thread.start()
    yield thread.local_port
    thread.stop()
    thread.join()


class UdpProxyThread(threading.Thread):

    def __init__(self, remote_host, remote_port, stats_queue=None):
        super().__init__()
        self.local_port = randint(60000, 65535)
        self.remote_host = remote_host
        self.remote_port = remote_port
        self.daemon = True
        self.log = logging.getLogger('%s.%s' % (
            __name__, self.__class__.__name__))
        self.running = True
        self._socket = None
        self.stats_queue = stats_queue

    def fail(self, reason):
        self.log.debug('UDP Proxy Failure: %s', reason)
        self.running = False

    def prime(self):
        """
        We need to set up a socket on a FREE port for this thread. Retry until
        we find a free port.

        This is used as a separate method to ensure proper locking and to ensure
        that each thread has it's own port

        The port can be retrieved by accessing the *local_port* member of the
        thread.
        """
        with LOCK:
            while True:
                try:
                    self._socket = socket.socket(socket.AF_INET,
                                                 socket.SOCK_DGRAM)
                    self._socket.bind(('', self.local_port))
                    break
                except OSError as exc:
                    self.log.warning('Port %d already in use. Shuffling...',
                                     self.local_port)
                    if exc.errno == 98:  # Address already in use
                        self.local_port = randint(60000, 65535)
                        self._socket.close()
                    else:
                        raise

    @property
    def name(self):
        return 'UDP Proxy Thread {} -> {}:{}'.format(self.local_port,
                                                     self.remote_host,
                                                     self.remote_port)

    def start(self):
        if not self._socket:
            raise ValueError('Socket was not set. Call prime() first!')
        super().start()

    def run(self):
        try:
            known_client = None
            known_server = (self.remote_host, self.remote_port)
            self.log.info('UDP Proxy set up: %s -> %s:%s',
                          self.local_port, self.remote_host, self.remote_port)
            while self.running:
                try:
                    data, addr = self._socket.recvfrom(32768, MSG_DONTWAIT)
                    self.log.debug('Packet received via %s\n%s', addr,
                                   visible_octets(data))
                except BlockingIOError:
                    sleep(0.1)  # Give self.stop() a chance to trigger
                else:
                    if known_client is None:
                        known_client = addr
                    if addr == known_client:
                        self.log.debug('Proxying request packet to %s\n%s',
                                       known_server, visible_octets(data))
                        self._socket.sendto(data, known_server)
                        if self.stats_queue:
                            self.stats_queue.put(Statistics(
                                MSG_TYPE_REQUEST, len(data)))
                    else:
                        self.log.debug('Proxying response packet to %s\n%s',
                                       known_client, visible_octets(data))
                        self._socket.sendto(data, known_client)
                        if self.stats_queue:
                            self.stats_queue.put(Statistics(
                                MSG_TYPE_RESPONSE, len(data)))
            self.log.info('%s stopped!', self.name)
        finally:
            self._socket.close()

    def stop(self):
        self.log.debug('Stopping %s...', self.name)
        self.running = False


if __name__ == '__main__':
    logging.basicConfig(level=0)
    from queue import Queue
    stat_queue = Queue()

    with UdpProxy('192.168.1.1', 161, stat_queue) as proxied_port:
        print(snmp.get('1.3.6.1.2.1.1.2.0', '127.0.0.1:%s' % proxied_port,
                       'testing'))
    with UdpProxy('192.168.1.1', 161, stat_queue) as proxied_port:
        print(snmp.get('1.3.6.1.2.1.1.2.0', '127.0.0.1:%s' % proxied_port,
                       'testing'))

    while not stat_queue.empty():
        stat_item = stat_queue.get()
        print(stat_item)
        stat_queue.task_done()

As seen in the __main__ section, it can simply be used as follows:

    from queue import Queue
    stat_queue = Queue()

    with UdpProxy('192.168.1.1', 161, stat_queue) as proxied_port:
        print(snmp.get('1.3.6.1.2.1.1.2.0', '127.0.0.1:%s' % proxied_port,
                       'testing'))

    while not stat_queue.empty():
        stat_item = stat_queue.get()
        print(stat_item)
        stat_queue.task_done()

One thing to note: the snmp module in this case simply executes a subprocess.check_output() to spawn a snmpget subprocess.