passerelle/passerelle/utils/__init__.py

536 lines
21 KiB
Python

# Copyright (C) 2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import base64
import hashlib
import re
import time
import urllib.parse
import warnings
from functools import wraps
from io import BytesIO
from itertools import chain, islice
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.core.cache import cache
from django.core.exceptions import PermissionDenied
from django.db import transaction
from django.http import HttpResponse, HttpResponseBadRequest
from django.utils.encoding import force_bytes, force_str
from django.utils.functional import lazy
from django.utils.html import mark_safe
from django.utils.translation import ngettext_lazy
from django.views.generic.detail import SingleObjectMixin
from requests import Response as RequestResponse
from requests import Session as RequestSession
from requests.adapters import HTTPAdapter
from requests.structures import CaseInsensitiveDict
from urllib3.exceptions import InsecureRequestWarning
from urllib3.util.retry import Retry
from passerelle.base.signature import check_query, check_url
# legacy import, other modules keep importing to_json from passerelle.utils
from .jsonresponse import to_json # noqa F401 pylint: disable=unused-import
from .sftp import SFTP, SFTPField # noqa F401 pylint: disable=unused-import
from .soap import SOAPClient, SOAPTransport # noqa F401 pylint: disable=unused-import
mark_safe_lazy = lazy(mark_safe, str)
class ImportSiteError(Exception):
pass
def response_for_json(request, data):
import json
response = HttpResponse(content_type='application/json')
json_str = json.dumps(data)
for variable in ('jsonpCallback', 'callback'):
if variable in request.GET:
identifier = request.GET[variable]
if not re.match(r'^[$A-Za-z_][0-9A-Za-z_$]*$', identifier):
return HttpResponseBadRequest('invalid JSONP callback name')
json_str = '%s(%s);' % (identifier, json_str)
response['Content-Type'] = 'application/javascript'
break
response.write(json_str)
return response
def get_request_users(request):
from passerelle.base.models import ApiUser
users = []
users.extend(ApiUser.objects.filter(keytype=''))
if 'orig' in request.GET and 'signature' in request.GET:
orig = request.GET['orig']
query = request.META['QUERY_STRING']
signature_users = ApiUser.objects.filter(keytype='SIGN', username=orig)
for signature_user in signature_users:
if check_query(query, signature_user.key):
users.append(signature_user)
elif 'apikey' in request.GET:
users.extend(ApiUser.objects.filter(keytype='API', key=request.GET['apikey']))
elif 'HTTP_AUTHORIZATION' in request.META:
http_authorization = request.headers['Authorization'].split(' ', 1)
scheme = http_authorization[0].lower()
if scheme == 'basic' and len(http_authorization) > 1:
param = http_authorization[1]
try:
decoded = force_str(base64.b64decode(force_bytes(param.strip())))
username, password = decoded.split(':', 1)
except (TypeError, ValueError):
pass
else:
users.extend(ApiUser.objects.filter(keytype='SIGN', username=username, key=password))
def ip_match(ip, match):
if not ip:
return True
if ip == match:
return True
return False
users = [x for x in users if ip_match(x.ipsource, request.META.get('REMOTE_ADDR'))]
return users
def get_trusted_services():
"""
All services in settings.KNOWN_SERVICES are "trusted"
"""
trusted_services = []
for service_type in getattr(settings, 'KNOWN_SERVICES', {}):
for slug, service in settings.KNOWN_SERVICES[service_type].items():
if service.get('secret') and service.get('verif_orig'):
trusted_service = service.copy()
trusted_service['service_type'] = service_type
trusted_service['slug'] = slug
trusted_services.append(trusted_service)
return trusted_services
def is_trusted(request):
"""
True if query-string is signed by a trusted service (see get_trusted_services() above)
"""
if not request.GET.get('orig') or not request.GET.get('signature'):
return False
full_path = request.get_full_path()
for service in get_trusted_services():
if (
service.get('verif_orig') == request.GET['orig']
and service.get('secret')
and check_url(full_path, service['secret'])
):
return True
return False
def is_authorized(request, obj, perm):
from passerelle.base.models import AccessRight
if request.user.is_superuser:
return True
if is_trusted(request):
return True
resource_type = ContentType.objects.get_for_model(obj)
rights = AccessRight.objects.filter(resource_type=resource_type, resource_pk=obj.id, codename=perm)
users = [x.apiuser for x in rights]
return set(users).intersection(get_request_users(request))
def protected_api(perm):
def decorator(view_func):
@wraps(view_func)
def _wrapped_view(instance, request, *args, **kwargs):
if not isinstance(instance, SingleObjectMixin):
raise Exception("protected_api must be applied on a method of a class based view")
obj = instance.get_object()
if not is_authorized(request, obj, perm):
raise PermissionDenied()
return view_func(instance, request, *args, **kwargs)
return _wrapped_view
return decorator
def should_content_type_body_be_logged(ctype):
content_types = settings.LOGGED_CONTENT_TYPES_MESSAGES
if not ctype:
return False
for content_type in content_types:
if re.match(content_type, ctype):
return True
return False
def make_headers_safe(headers):
"""Convert dict of HTTP headers to text safely, as some services returns 8-bits encoding in headers."""
return {
force_str(key, errors='replace'): force_str(value, errors='replace') for key, value in headers.items()
}
def log_http_request(
logger, request, response=None, exception=None, error_log=True, extra=None, duration=None
):
log_function = logger.info
message = ''
extra = extra or {}
kwargs = {}
if request is not None:
message = '%s %s' % (request.method, request.url)
extra['request_url'] = request.url
if logger.level == 10 and request: # DEBUG
extra['request_headers'] = make_headers_safe(request.headers)
if request.body:
max_size = settings.LOGGED_REQUESTS_MAX_SIZE
if hasattr(logger, 'connector'):
max_size = logger.connector.logging_parameters.requests_max_size or max_size
extra['request_payload'] = request.body[:max_size]
if duration is not None:
extra['request_duration'] = duration
if response is not None:
message = message + ' (=> %s)' % response.status_code
extra['response_status'] = response.status_code
if logger.level == 10: # DEBUG
extra['response_headers'] = make_headers_safe(response.headers)
# log body only if content type is allowed
content_type = response.headers.get('Content-Type', '').split(';')[0].strip().lower()
if should_content_type_body_be_logged(content_type):
max_size = settings.LOGGED_RESPONSES_MAX_SIZE
if hasattr(logger, 'connector'):
max_size = logger.connector.logging_parameters.responses_max_size or max_size
content = response.content[:max_size]
extra['response_content'] = content
if response.status_code // 100 == 3:
log_function = logger.warning
elif response.status_code // 100 >= 4:
log_function = logger.error
elif exception:
if message:
message = message + ' (=> %s)' % repr(exception)
else:
message = repr(exception)
extra['response_exception'] = repr(exception)
log_function = logger.error
kwargs['exc_info'] = exception
# allow resources to disable any error log at requests level
if not error_log:
log_function = logger.info
log_function(message, extra=extra, **kwargs)
# Wrapper around requests.Session
# - log input and output data
# - use HTTP Basic auth if resource.basic_auth_username and resource.basic_auth_password exist
# - use client side certificate if resource.client_certificate (FileField) exists
# - verify server certificate CA if resource.trusted_certificate_authorities (FileField) exists
# - disable CA verification if resource.verify_cert (BooleanField) exists and is set
# - use a proxy for HTTP and HTTPS if resource.http_proxy exists
class Request(RequestSession):
ADAPTER_REGISTRY = {} # connection pooling
log_requests_errors = True
def __init__(self, *args, **kwargs):
self.logger = kwargs.pop('logger')
self.resource = kwargs.pop('resource', None)
resource_log_requests_errors = getattr(self.resource, 'log_requests_errors', True)
self.log_requests_errors = kwargs.pop('log_requests_errors', resource_log_requests_errors)
timeout = kwargs.pop('timeout', None)
super().__init__(*args, **kwargs)
if self.resource:
timeout = timeout if timeout is not None else getattr(self.resource, 'requests_timeout', None)
http_adapter_init_kwargs = {}
requests_max_retries = dict(settings.REQUESTS_MAX_RETRIES)
if getattr(self.resource, 'requests_max_retries', None):
requests_max_retries = dict(self.resource.requests_max_retries)
if requests_max_retries:
requests_max_retries.setdefault('read', None)
http_adapter_init_kwargs['max_retries'] = Retry(**requests_max_retries)
adapter = Request.ADAPTER_REGISTRY.setdefault(
type(self.resource), HTTPAdapter(**http_adapter_init_kwargs)
)
self.mount('https://', adapter)
self.mount('http://', adapter)
self.timeout = timeout if timeout is not None else settings.REQUESTS_TIMEOUT
def _substitute(self, search, replace, value):
if isinstance(value, str):
value, nsub = re.subn(search, replace, value)
if nsub:
self.logger.debug('substitution: %d occurences', nsub)
elif isinstance(value, list):
value = [self._substitute(search, replace, v) for v in value]
elif isinstance(value, dict):
value = {
self._substitute(search, replace, k): self._substitute(search, replace, v)
for k, v in value.items()
}
return value
def apply_requests_substitution(self, response, substitution):
if not isinstance(substitution, dict):
self.logger.warning('substitution: invalid substitution, %r', substitution)
return
for key in ['search', 'replace']:
if key not in substitution:
self.logger.warning('substitution: missing field "%s": %s', key, substitution)
return
if not isinstance(substitution[key], str):
self.logger.warning(
'substitution: invalid type for field "%s", must be str: %s', key, substitution
)
return
search = substitution['search']
replace = substitution['replace']
# filter on url
if isinstance(substitution.get('url'), str):
url = urllib.parse.urlparse(substitution['url'])
request_url = urllib.parse.urlparse(response.request.url)
if url.scheme and url.scheme != request_url.scheme:
return
# substitution without a netloc are ignored
if not url.netloc:
return
if request_url.netloc != url.netloc:
return
if url.path and url.path != '/' and not request_url.path.startswith(url.path):
return
# filter on content-type
content_type = response.headers.get('Content-Type', '').split(';')[0].strip().lower()
for content_type_re in settings.REQUESTS_SUBSTITUTIONS_CONTENT_TYPES:
if re.match(content_type_re, content_type):
break
else:
self.logger.debug('substitution: content_type did not match %s', content_type)
return
self.logger.debug('substitution: try %s', substitution)
try:
if re.match(r'application/([^;]\+)?json', content_type):
import json
response._content = json.dumps(self._substitute(search, replace, response.json())).encode()
else:
response._content = self._substitute(search, replace, response.text).encode()
response.encoding = 'utf-8'
return True
except Exception:
self.logger.exception('substitution: "%s" failed', substitution)
def request(self, method, url, **kwargs):
cache_duration = kwargs.pop('cache_duration', None)
invalidate_cache = kwargs.pop('invalidate_cache', False)
# search in legacy urls
legacy_urls_mapping = getattr(settings, 'LEGACY_URLS_MAPPING', None)
if legacy_urls_mapping:
splitted_url = urllib.parse.urlparse(url)
hostname = splitted_url.netloc
if hostname in legacy_urls_mapping:
url = splitted_url._replace(netloc=legacy_urls_mapping[hostname]).geturl()
if self.resource:
if 'auth' not in kwargs:
username = getattr(self.resource, 'basic_auth_username', None)
if username and hasattr(self.resource, 'basic_auth_password'):
kwargs['auth'] = (username, self.resource.basic_auth_password)
if 'cert' not in kwargs:
keystore = getattr(self.resource, 'client_certificate', None)
if keystore:
kwargs['cert'] = keystore.path
if 'verify' not in kwargs:
trusted_certificate_authorities = getattr(
self.resource, 'trusted_certificate_authorities', None
)
if trusted_certificate_authorities:
kwargs['verify'] = trusted_certificate_authorities.path
elif hasattr(self.resource, 'verify_cert'):
kwargs['verify'] = self.resource.verify_cert
if 'proxies' not in kwargs:
proxy = getattr(self.resource, 'http_proxy', None)
if proxy:
kwargs['proxies'] = {'http': proxy, 'https': proxy}
if method == 'GET' and cache_duration:
cache_key = hashlib.md5(force_bytes('%r;%r' % (url, kwargs))).hexdigest()
cache_content = cache.get(cache_key)
if cache_content and not invalidate_cache:
response = RequestResponse()
response.raw = BytesIO(cache_content.get('content'))
response.headers = CaseInsensitiveDict(cache_content.get('headers', {}))
response.status_code = cache_content.get('status_code')
return response
if settings.REQUESTS_PROXIES and 'proxies' not in kwargs:
kwargs['proxies'] = settings.REQUESTS_PROXIES
if 'timeout' not in kwargs:
kwargs['timeout'] = self.timeout
with warnings.catch_warnings():
if kwargs.get('verify') is False:
# disable urllib3 warnings
warnings.simplefilter(action='ignore', category=InsecureRequestWarning)
response = super().request(method, url, **kwargs)
if self.resource:
requests_substitutions = self.resource.get_setting('requests_substitutions')
if isinstance(requests_substitutions, list):
for requests_substitution in requests_substitutions:
if not self.apply_requests_substitution(response, requests_substitution):
self.logger.debug('substitution: %s does not match', requests_substitution)
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
cache.set(
cache_key,
{
'content': response.content,
'headers': response.headers,
'status_code': response.status_code,
},
cache_duration,
)
return response
def send(self, request, **kwargs):
start_time = time.time()
try:
response = super().send(request, **kwargs)
duration = time.time() - start_time
except Exception as exc:
duration = time.time() - start_time
self.log_http_request(request, exception=exc, duration=duration)
raise
self.log_http_request(request, response=response, duration=duration)
return response
def log_http_request(self, request, response=None, exception=None, duration=None):
error_log = self.log_requests_errors
log_http_request(
self.logger,
request=request,
response=response,
exception=exception,
error_log=error_log,
duration=duration,
)
def export_site(slugs=None):
'''Dump passerelle configuration (users, resources and ACLs) to JSON dumpable dictionnary'''
from passerelle.base.models import ApiUser
from passerelle.views import get_all_apps
d = {}
d['apiusers'] = [apiuser.export_json() for apiuser in ApiUser.objects.all()]
d['resources'] = resources = []
for app in get_all_apps():
for resource in app.objects.all():
if slugs and resource.slug not in slugs:
continue
try:
resources.append(resource.export_json())
except NotImplementedError:
break
return d
def import_site(d, if_empty=False, clean=False, overwrite=False, import_users=False):
"""Load passerelle configuration (users, resources and ACLs) from a dictionnary loaded from
JSON
"""
from passerelle.base.models import ApiUser, BaseResource
from passerelle.views import get_all_apps
d = d.copy()
def is_empty():
if import_users:
if ApiUser.objects.count():
return False
for app in get_all_apps():
if app.objects.count():
return False
return True
if if_empty and not is_empty():
return
if clean:
for app in get_all_apps():
app.objects.all().delete()
if import_users:
ApiUser.objects.all().delete()
with transaction.atomic():
if import_users:
for apiuser in d.get('apiusers', []):
ApiUser.import_json(apiuser, overwrite=overwrite)
unknown_connectors = []
def import_resource(res):
try:
BaseResource.import_json(res, overwrite=overwrite, import_users=import_users)
except BaseResource.UnknownBaseResourceError as e:
unknown_connectors.append(str(e))
resources = d.get('resources', [])
# import SectorResource first, as AddressResource may need them
for res in [r for r in resources if r['resource_type'] == 'sector.sectorresource']:
import_resource(res)
for res in [r for r in resources if r['resource_type'] != 'sector.sectorresource']:
import_resource(res)
if unknown_connectors:
raise ImportSiteError(
ngettext_lazy('Unknown connector: %s', 'Unknown connectors: %s', len(unknown_connectors))
% ', '.join(unknown_connectors)
)
def batch(iterable, size):
"""Batch an iterable as an iterable of iterables of at most size element
long.
"""
sourceiter = iter(iterable)
while True:
batchiter = islice(sourceiter, size)
# call next() at least one time to advance, if the caller does not
# consume the returned iterators, sourceiter will never be exhausted.
try:
yield chain([next(batchiter)], batchiter)
except StopIteration:
return