Linux/Python: Block countries with iptables

The amount of attacks against certain services running on my server raised from year to year. This isn't something unusual, it's quite normal for internet facing common application like SSH, Mailservers or WordPress.

There are some countries which have some facts in common. I have no normal visitors from those countries, but get a lot of attacks or malicious traffic. Also abuse emails very often bounce and don't get answered in general. One example for my server would be China with port 22 and potentially 443 as well.

So I looked for a tool which blocks on a per country basis. But I didn't find a tool with the following features I want:

  • Automatic database update (IP ownership changes)
  • Any country
  • Whitelist
  • Port definition per protocol and country
  • Cronjob

So I wrote a Python script which does that job. It uses the database provided by lite.ip2location.com (thanks).

Maybe it is useful for other people...

Download Link: https://d0m.me/wp-content/uploads/2020/09/iptables_country_filter.py

The example configuration should be changed and does the following

China
 - TCP: Filter port 22, 80 and 443
 - UDP: Filter any port

Faroe
 - TCP: Filter no port
 - UDP: Filter port 123 

The two character ISO country code used in the configuration must match the database. A full list of available country codes which are known to the database is available here: https://www.ip2location.com/free/iso3166-2

# Installation
apt-get install python3-requests python3-pandas python3-iptables
wget https://d0m.me/wp-content/uploads/2020/09/iptables_country_filter.py -O /usr/local/sbin/iptables_country_filter
chmod 750 /usr/local/sbin/iptables_country_filter
editor /usr/local/sbin/iptables_country_filter
# Usage
/usr/local/sbin/iptables_country_filter (block|flush|help)

Once the configuration is done and the script was successfully tested, it should be configured as cronjob to run once a month

# Cronjob
cat >> /etc/cron.d/iptables_country_filter << EOF
MAILTO=root
0 10 28 * * root /usr/local/sbin/iptables_country_filter block
EOF

iptables_country_filter

#!/usr/bin/env python3
import pandas as pd
import ipaddress
import iptc
import sys
import zipfile
import requests
import io
import os
import time

# installation: apt-get install python3-requests python3-pandas python3-iptables
# configuration:
#############################################################################################
# iso code of country to block, chain name and ports to filter, empty list to filter all
# set None to skip protocol
block_countries = { 
  'CN': {
    'chain_name': 'BLOCK_CHINA',
    'filter_tcp_ports': [22,80,443],
    'filter_udp_ports': []
  },
  'FO': {
    'chain_name': 'BLOCK_FAROE',
    'filter_tcp_ports': None,
    'filter_udp_ports': [123]
  }
}
# exclude list. needs to be decimal representation of ip, see: https://www.ipaddressguide.com/ip
# if matching, whole block gets skipped. empty list to filter all
exclude_ips = []
# download url and folder. check https://lite.ip2location.com for usage conditions
url = 'https://download.ip2location.com/lite/IP2LOCATION-LITE-DB1.CSV.ZIP'
db_folder = '/var/tmp/ip2location'
#############################################################################################


def update_db():
  dbfile = db_folder + '/IP2LOCATION-LITE-DB1.CSV'
  max_db_age = 14
  if not os.path.isfile(dbfile) or (time.time() - os.path.getmtime(dbfile)) / 3600 > 24*max_db_age:
    r = requests.get(url, stream=True)
    if r.status_code == requests.codes.ok:
      with zipfile.ZipFile(io.BytesIO(r.content), 'r') as zip_ref:
        zip_ref.extractall(db_folder)
      df = pd.read_csv(dbfile, sep=',', header=None, names=['start_num', 'end_num', 'country_iso', 'country_name'])
      return df
    else:
      raise Exception('Database download failed: ' + str(r.status_code))
  else:
    df = pd.read_csv(dbfile, sep=',', header=None, names=['start_num', 'end_num', 'country_iso', 'country_name'])
    return df

def iptables_cleanup(country):
  chain_name = block_countries[country]['chain_name']
  table = iptc.Table(iptc.Table.FILTER)
  chain = iptc.Chain(table, 'INPUT')
  removed = True
  while removed == True:
    removed = False
    for rule in chain.rules:
      if rule.target.name == chain_name:
        chain.delete_rule(rule)
        removed = True
        break
  for chain in table.chains:
    if chain.name == chain_name:
      chain.flush()
      iptc.Table(iptc.Table.FILTER).delete_chain(chain_name)

def filter_ports(port_array, protocol, chain, chain_name):
  rule = iptc.Rule()
  rule.target = iptc.Target(rule, chain_name)
  rule.protocol = protocol
  match = iptc.Match(rule, 'multiport')
  match.dports = ','.join(str(p) for p in port_array)
  rule.add_match(match)
  chain.insert_rule(rule)

def iptables_initialize(country):
  filter_tcp_ports = block_countries[country]['filter_tcp_ports']
  filter_udp_ports = block_countries[country]['filter_udp_ports']
  chain_name = block_countries[country]['chain_name']
  iptc.Table(iptc.Table.FILTER).create_chain(chain_name)
  chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'INPUT')
  if filter_tcp_ports is not None and filter_udp_ports is not None and len(filter_tcp_ports) == 0 and len(filter_udp_ports) == 0:
    rule = iptc.Rule()
    rule.target = iptc.Target(rule, chain_name)
    chain.insert_rule(rule)
    return
  if filter_tcp_ports is not None and len(filter_tcp_ports) > 0:
    filter_ports(filter_tcp_ports, 'tcp', chain, chain_name)
  elif filter_tcp_ports is not None:
    rule = iptc.Rule()
    rule.target = iptc.Target(rule, chain_name)
    rule.protocol = 'tcp'
    chain.insert_rule(rule)
  if filter_udp_ports is not None and  len(filter_udp_ports) > 0:
    filter_ports(filter_udp_ports, 'udp', chain, chain_name)
  elif filter_udp_ports is not None:
    rule = iptc.Rule()
    rule.target = iptc.Target(rule, chain_name)
    rule.protocol = 'udp'
    chain.insert_rule(rule)

def iptables_add_rule(start_ip, end_ip, country):
  chain_name = block_countries[country]['chain_name']
  chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), chain_name)
  rule = iptc.Rule()
  match = iptc.Match(rule, 'iprange')
  match.src_range = start_ip + '-' + end_ip
  rule.add_match(match)
  rule.target = iptc.Target(rule, 'DROP')
  chain.insert_rule(rule)

def flush():
  for country in block_countries:
    iptables_cleanup(country)

def block():
  df = update_db()
  for country in block_countries:
    iptables_cleanup(country)
    if block_countries[country]['filter_udp_ports'] is None and block_countries[country]['filter_tcp_ports'] is None:
      continue
    iptables_initialize(country)
    for i in df.index:
      if df.at[i, 'country_iso'] == country:
        start_num = int(df.at[i, 'start_num'])
        end_num = int(df.at[i, 'end_num'])
        ip_list = range(start_num, end_num + 1)
        if any(int(item) in ip_list for item in exclude_ips):
          pass
        else:
          start_ip = ipaddress.ip_address(start_num)
          end_ip = ipaddress.ip_address(end_num)
          iptables_add_rule(str(start_ip), str(end_ip), country)

def main():
  if len(sys.argv) > 1 and sys.argv[1] == 'flush':
    flush()
  elif len(sys.argv) > 1 and sys.argv[1] == 'block':
    block()
  else:
    print('Usage: ' + os.path.basename(__file__) + ' (flush|block|help)')

if __name__ == "__main__":
  main()

Schreibe einen Kommentar