summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xsimpleddns.py128
1 files changed, 105 insertions, 23 deletions
diff --git a/simpleddns.py b/simpleddns.py
index 45e8308..dc9742b 100755
--- a/simpleddns.py
+++ b/simpleddns.py
@@ -20,6 +20,9 @@ def parse_args() -> argparse.Namespace:
epilog='''ENVIRONMENT VARIABLES
SIMPLEDDNS_CONFIG_DIR - directory where configburation files are stored''')
parser.add_argument('--setup', help='Set up configuration', action='store_true')
+ parser.add_argument('--dry-run',
+ help='Print what API calls would be made, without actually making any non-GET calls.',
+ action='store_true')
return parser.parse_args()
def fatal_error(message: str) -> NoReturn:
@@ -77,6 +80,7 @@ Domain {type_name} {domain_name}
@dataclasses.dataclass
class Settings:
'''Global (i.e. not per-domain) settings'''
+ dry_run: bool = False
interval: int = 15
timeout: int = 20
@@ -103,6 +107,9 @@ class Domain(ABC):
self.root_domain = domain[second_last_dot+1:]
self.last_ips = []
self._init()
+ def _info(self, message: str) -> None:
+ '''Print informational message prefixed with domain name'''
+ print(f'{self.full_domain}: {message}')
@abstractmethod
def _init(self) -> None:
'''Extra provider-specific initialization'''
@@ -136,18 +143,22 @@ class Domain(ABC):
ips.append(ip)
ips.sort()
return ips
- def check_for_update(self) -> None:
- '''Update DNS records if IP has changed'''
+ def check_for_update(self) -> bool:
+ '''Update DNS records if IP has changed. Returns False on error.'''
ips = self.get_ips()
if ips == self.last_ips:
- return
+ return True
+ self._info('Dealing with new IP address(es)... ')
+ sys.stdout.flush()
if self.update(ips):
self.last_ips = ips
+ return True
+ return False
class LinodeDomain(Domain):
'''Domain registered with Linode Domains'''
access_token: str
- _error: bool
+ _had_error: bool
# Domain ID for Linode API
_id: Optional[int]
def _init(self) -> None:
@@ -156,44 +167,69 @@ class LinodeDomain(Domain):
return f'<LinodeDomain domain={self.full_domain} ' \
'getip={repr(self.getip)} ' \
'access_token={repr(self.access_token)}>'
- def _headers(self, method: str = 'GET') -> dict[str, str]:
+ def _headers(self, has_body: bool = False) -> dict[str, str]:
'''Get HTTP headers for making requests'''
headers = {'Accept': 'application/json', 'Authorization': f'Bearer {self.access_token}'}
- if method in ['POST', 'PUT']:
+ if has_body:
headers['Content-Type'] = 'application/json'
return headers
+ def _error(self, message: str) -> None:
+ self._had_error = True
+ warn(message)
+ def _make_request(self, method: str, url: str, body: Any = None) -> Any:
+ headers = self._headers(has_body = body is not None)
+ options: dict[Any, Any] = {
+ 'headers': headers,
+ 'timeout': self.settings.timeout,
+ 'method': method,
+ 'url': url
+ }
+ if body is not None:
+ options['json'] = body
+ try:
+ response = requests.request(**options)
+ if not response.ok:
+ self._error(f'''Got error response from endpoint {repr(url)} (code {response.status_code}):
+{response.text}''')
+ response_json = response.json()
+ if response_json is None:
+ self._error(f'Got null response from endpoint {repr(url)}')
+ return response_json
+ except requests.JSONDecodeError as e:
+ self._error(f'Invalid JSON at endpoint {repr(url)}: {e}')
+ return None
+ except requests.RequestException as e:
+ self._error(f'Error making request to {repr(url)}: {e}')
+ return None
def _paginated_get(self, url: str) -> list[Any]:
'''Get results for a Linode paginated GET API endpoint.'''
- headers = self._headers()
results = []
page_size = 500
- for page in range(1,100):
+ for page in range(1, 100):
page_url = f'{url}{"&" if "?" in url else "?"}page={page}&page_size={page_size}'
- try:
- response = requests.get(page_url, headers=headers, timeout=self.settings.timeout)
- except requests.RequestException as e:
- warn(f'Error making request to {url}: {e}')
sleep(0.33) # should prevent us from hitting Linode's rate limit
+ response_json = self._make_request('GET', page_url)
+ if response_json is None:
+ break
try:
- response_json = response.json()
response_data = response_json['data']
if not isinstance(response_data, list):
raise ValueError('"data" member is not a list')
- except (requests.JSONDecodeError, KeyError, ValueError) as e:
- warn(f'Invalid JSON at endpoint {repr(page_url)}: {e}')
+ except (KeyError, ValueError) as e:
+ self._error(f'Bad JSON format at endpoint {repr(page_url)}: {e}')
break
results.extend(response_data)
if len(response_data) < page_size:
# Reached last page, presumably
break
if page == 99:
- warn(f'Giving up after 99 pages of responses to API endpoint {repr(url)}')
+ self._error(f'Giving up after 99 pages of responses to API endpoint {repr(url)}')
return results
def update(self, ips: list[str]) -> bool:
- self._error = False
+ self._had_error = False
if self._id is None:
domains = self._paginated_get('https://api.linode.com/v4/domains')
- if self._error:
+ if self._had_error:
return False
for domain in domains:
domain_name = domain['domain']
@@ -204,9 +240,50 @@ class LinodeDomain(Domain):
warn(f'Domain {self.root_domain} not found in Linode. Are you sure it is set up there?')
return False
records = self._paginated_get(f'https://api.linode.com/v4/domains/{self._id}/records')
- print(records)
- print(self.subdomain)
- return not self._error
+ remaining_ips = set(ips)
+ unused_records: dict[str, list[int]] = {'A': [], 'AAAA': []}
+ for record in records:
+ if record['type'] not in ['A', 'AAAA']:
+ continue
+ if record['name'] != self.subdomain:
+ continue
+ if record['target'] in remaining_ips:
+ # We're all good
+ remaining_ips.remove(record['target'])
+ continue
+ unused_records[record['type']].append(record['id'])
+ could_update = set()
+ # Update existing records if possible (to save on API calls)
+ for ip in remaining_ips:
+ kind = 'AAAA' if ':' in ip else 'A'
+ if unused_records[kind]:
+ record_id = unused_records[kind].pop()
+ if self.settings.dry_run:
+ self._info(f'Update DNS record {record_id} to point to {ip}')
+ else:
+ print(f'TODO: update DNS record {record_id} to point to {ip}')
+ could_update.add(ip)
+ remaining_ips -= could_update
+ for record_id in (x for rs in unused_records.values() for x in rs):
+ if self.settings.dry_run:
+ self._info(f'Delete DNS record {record_id}')
+ else:
+ print(f'TODO: delete DNS record {record_id}')
+ for ip in remaining_ips:
+ kind = 'AAAA' if ':' in ip else 'A'
+ if self.settings.dry_run:
+ self._info(f'Add {kind} DNS record pointing to {ip}')
+ else:
+ payload = {
+ 'type': kind,
+ 'name': self.subdomain,
+ 'target': ip,
+ 'ttl': 300, # TODO: make this configurable
+ }
+ self._make_request('POST', f'https://api.linode.com/v4/domains/{self._id}/records', payload)
+ if not self._had_error:
+ self._info(f'Successfully created {kind} record pointing to {ip}')
+ return not self._had_error
def validate_specifics(self) -> str:
if not getattr(self, 'access_token', ''):
return 'Access token not set'
@@ -283,12 +360,17 @@ def main() -> None:
If there are API tokens in there, revoke them immediately!!''')
settings, domains = parse_config(config_path)
+ settings.dry_run = args.dry_run
if not domains:
fatal_error('No domains defined. Try running with --setup?')
while True:
- print('Checking...')
+ print('Checking if IP changed... ', end='')
+ sys.stdout.flush()
+ all_ok = True
for domain in domains:
- domain.check_for_update()
+ all_ok = domain.check_for_update() and all_ok
+ if all_ok:
+ print('All OK')
sleep(settings.interval)
if __name__ == '__main__':