# # Copyright (c) 2019 Yutaro Hayakawa # Licensed under the Apache License, Version 2.0 (the "License") # import os import re import yaml import click import socket import argparse import textwrap import tabulate import importlib import ipaddress import subprocess import dataclasses from bcc import BPF from ctypes import * # Protocol name <=> Protocol number mapping PROTO_TO_ID = {} ID_TO_PROTO = {} def init_protocol_mapping(): for line in open("/etc/protocols"): spl = line.split() if len(spl) == 0 or spl[0] == "#": continue PROTO_TO_ID[spl[2]] = spl[1] ID_TO_PROTO[spl[1]] = spl[2] class V4Addrs(Structure): _fields_ = [("saddr", c_uint32), ("daddr", c_uint32)] class V6Addrs(Structure): _fields_ = [("saddr", c_uint8 * 16), ("daddr", c_uint8 * 16)] class IPAddrs(Union): _fields_ = [("v4", V4Addrs), ("v6", V6Addrs)] class EventData(Structure): _anonymous = "addrs" _fields_ = [ ("tstamp", c_uint64), ("faddr", c_uint64), ("l4_protocol", c_uint8), ("l3_protocol", c_uint16), ("addrs", IPAddrs), ("sport", c_uint16), ("dport", c_uint16), ("data", c_uint8 * 64), ] @dataclasses.dataclass(eq=True, frozen=True) class Flow: l4_protocol: str saddr: str daddr: str sport: int = 0 dport: int = 0 def __str__(self): src = self.saddr + (":" + str(self.sport) if self.sport != 0 else "") dst = self.daddr + (":" + str(self.dport) if self.dport != 0 else "") return f"{self.l4_protocol}\t{src}\t->\t{dst}" @dataclasses.dataclass(eq=True, frozen=True) class EventLog: time_stamp: str event_name: str custom_data: str class DefaultModule: def get_name(self): return "Default" def generate_header(self): return "" def generate_body(self): return """ static inline bool custom_match(void *ctx, struct sk_buff *skb, uint8_t *data) { return true; } """ def parse_data(self, data): return "" class IPFTracer: def __init__( self, iv, saddr, daddr, proto, sport, dport, module, regex, length, manifest, verbose, ): self._verbose = verbose self._opts = self._build_opts(iv, saddr, daddr, proto, sport, dport) self._functions = self._read_functions(manifest) self._module = self._load_module(module) self._regex = regex self._length = length self._egress_functions = [] self._flows = {} def _read_functions(self, manifest): funcs = self._read_manifest(manifest) available_funcs = self._read_available_filter_functions() filtered_funcs = [] for f in funcs: if f["name"] in available_funcs: filtered_funcs.append(f) else: if self._verbose: print(f'Function {f["name"]} is not traceable. Skip.') return filtered_funcs def _read_manifest(self, manifest): with open(manifest) as f: return yaml.load(f, Loader=yaml.FullLoader)["functions"] def _read_available_filter_functions(self): path = "/sys/kernel/debug/tracing/available_filter_functions" available = set() with open(path, "r") as f: for l in f.readlines(): available.add(l.split()[0]) return available def _load_module(self, module_path): if module_path is None: module = DefaultModule() else: spec = importlib.util.spec_from_file_location("module", module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) print(f'Loading module "{module.get_name()}"') return module def _inet_addr4(self, addr): a = ipaddress.IPv4Address(addr).packed return str(int.from_bytes(a, byteorder="little")) def _build_addr4_opt(self, addr, direction): if addr == "any": return ["-D", direction + "ADDRV4_ANY"] else: return ["-D", direction + "ADDRV4=" + self._inet_addr4(addr)] def _inet_addr6(self, addr): p = ipaddress.IPv6Address(addr).packed a = ",".join(list(map(lambda b: str(b), p))) return a def _build_addr6_opt(self, addr, direction): if addr == "any": return ["-D", direction + "ADDRV6_ANY"] else: return ["-D", direction + "ADDRV6=" + self._inet_addr6(addr)] def _build_ip_opt(self, iv, saddr, daddr): if iv == "4": ret = ["-D", "L3_PROTO=0x0008"] ret += self._build_addr4_opt(saddr, "S") ret += self._build_addr4_opt(daddr, "D") elif iv == "6": ret = ["-D", "L3_PROTO=0xdd86"] ret += self._build_addr6_opt(saddr, "S") ret += self._build_addr6_opt(daddr, "D") else: raise ValueError("Unknown IP version {}".format(iv)) return ret def _build_proto_opt(self, proto): if proto == "any": return ["-D", "PROTO_ANY"] else: return ["-D", "PROTO=" + PROTO_TO_ID[proto]] def _build_port_opt(self, port, direction): if port == "any": return ["-D", direction + "PORT_ANY"] else: port = str(socket.htons(int(port))) return ["-D", direction + "PORT=" + port] def _build_opts(self, iv, saddr, daddr, proto, sport, dport): opts = [] opts += self._build_ip_opt(iv, saddr, daddr) opts += self._build_proto_opt(proto) opts += self._build_port_opt(sport, "S") opts += self._build_port_opt(dport, "D") return opts def _build_bpf_prog(self): bpf_hdr = os.path.join(os.path.dirname(__file__), "ipftrace.bpf.h") bpf_src = os.path.join(os.path.dirname(__file__), "ipftrace.bpf.c") prog = "" prog += open(bpf_hdr).read() prog += self._module.generate_header() prog += open(bpf_src).read() prog += self._module.generate_body() return prog def _attach_probes(self): probes = self._build_bpf_prog() b = BPF(text=probes, cflags=self._opts) for f in self._functions: name = f["name"] skb_pos = f["skb_pos"] if skb_pos > 4: print( f"Invalid skb_pos for function {name}. It should be lower than 4." ) exit(1) if self._regex != None and not re.match(self._regex, name): continue try: b.attach_kprobe(event=name, fn_name=f"ipftrace_main{skb_pos}") except: print(f"Couldn't attach kprobe to function {name}") if f.get("egress", False): self._egress_functions.append(name) return b def list_functions(self): for f in self._functions: name = f["name"] if self._regex != None and not re.match(self._regex, name): continue print(f"{name}") def _parse_l3_proto(self, event): if event.l3_protocol == 0x0008: # IPv4 saddr = ipaddress.IPv4Address(socket.ntohl(event.addrs.v4.saddr)) daddr = ipaddress.IPv4Address(socket.ntohl(event.addrs.v4.daddr)) elif event.l3_protocol == 0xDD86: # IPv6 saddr = ipaddress.IPv6Address(bytes(event.addrs.v6.saddr)) daddr = ipaddress.IPv6Address(bytes(event.addrs.v6.daddr)) else: print(f"Unsupported l3 protocol {event.l3_protocol}") exit(1) return (str(saddr), str(daddr)) def _parse_l4_proto(self, event): p = str(event.l4_protocol) l4_proto = ID_TO_PROTO.get(p, f"Unknown({p})") sport = socket.ntohs(event.sport) dport = socket.ntohs(event.dport) return (l4_proto, sport, dport) def _parse_custom_data(self, event): try: custom_data = self._module.parse_data(event.data) except Exception as e: print(e) custom_data = "" return custom_data def _print_function_trace(self, flow, event_logs): header = ["Time Stamp", "Function", "Custom Data"] table = [[e.time_stamp, e.event_name, e.custom_data] for e in event_logs] print(flow) print(tabulate.tabulate(table, header, tablefmt="plain")) print("") def _handle_lost(self, lost): self._flows.clear() def _handle_event(self, cpu, data, size): event = cast(data, POINTER(EventData)).contents fname = BPF.ksym(event.faddr).decode("utf-8") tstamp = str(event.tstamp) saddr, daddr = self._parse_l3_proto(event) l4_proto, sport, dport = self._parse_l4_proto(event) custom_data = self._parse_custom_data(event) flow = Flow( l4_protocol=l4_proto, saddr=saddr, daddr=daddr, sport=sport, dport=dport, ) event_logs = self._flows.get(flow, []) event_logs.append(EventLog(tstamp, fname, custom_data)) self._flows[flow] = event_logs # # Print the function trace if # # 1. The skb reaches to the egress function # 2. The function trace reaches to the length limit # if fname in self._egress_functions or len(event_logs) == self._length: self._print_function_trace(flow, event_logs) del self._flows[flow] def run_tracing(self): b = self._attach_probes() event = b["events"] event.open_perf_buffer( self._handle_event, lost_cb=self._handle_lost, page_cnt=64 ) print("Trace ready!") while 1: b.perf_buffer_poll() @click.command() @click.option( "-iv", "--ipversion", default="4", type=click.Choice(["4", "6"]), help="Specify IP version", ) @click.option("-s", "--saddr", default="any", help="Specify IP source address") @click.option("-d", "--daddr", default="any", help="Specify IP destination address") @click.option("-p", "--proto", default="any", help="Specify protocol") @click.option("-sp", "--sport", default="any", help="Specify source port number") @click.option("-dp", "--dport", default="any", help="Specify destination port number") @click.option("-m", "--module", default=None, help="Specify custom match module path") @click.option("-r", "--regex", default=None, help="Filter the function names by regex") @click.option("-l", "--length", default=80, help="Specify the length of function trace") @click.option("-ls", "--list-func", is_flag=True, help="List available functions") @click.option("-v", "--verbose", is_flag=True, help="Verbose output") @click.argument("manifest") def main( ipversion, saddr, daddr, proto, sport, dport, module, regex, length, list_func, verbose, manifest, ): """ Track the journey of the packets in Linux network stack """ init_protocol_mapping() ift = IPFTracer( iv=ipversion, saddr=saddr, daddr=daddr, proto=proto, sport=sport, dport=dport, module=module, regex=regex, length=length, manifest=manifest, verbose=verbose, ) if list_func: ift.list_functions() exit(0) try: ift.run_tracing() except KeyboardInterrupt: print("Tracing finished. Detaching probes...") if __name__ == "__main__": main()