diff --git a/passerelle/apps/vivaticket/models.py b/passerelle/apps/vivaticket/models.py
index 9cb66855..c4220f1e 100644
--- a/passerelle/apps/vivaticket/models.py
+++ b/passerelle/apps/vivaticket/models.py
@@ -188,7 +188,7 @@ class VivaTicket(BaseResource):
return self.requests.post(url, json=payload, headers=headers)
return response
- def get_setting(self, endpoint, **kwargs):
+ def get_list_of_settings(self, endpoint, **kwargs):
response = self.get(endpoint, **kwargs)
json = response.json()
data = []
@@ -198,25 +198,25 @@ class VivaTicket(BaseResource):
@endpoint(perm='can_access', methods=['get'], description=_('Get event categories'))
def events(self, request):
- return self.get_setting('Settings/GetEventCategory')
+ return self.get_list_of_settings('Settings/GetEventCategory')
@endpoint(perm='can_access', methods=['get'], description=_('Get rooms'))
def rooms(self, request, event=None):
query = {}
if event is not None:
query['eventCategory'] = event
- return self.get_setting('Settings/GetRooms', **query)
+ return self.get_list_of_settings('Settings/GetRooms', **query)
@endpoint(perm='can_access', methods=['get'], description=_('Get themes'))
def themes(self, request, room=None):
query = {}
if room is not None:
query['room'] = room
- return self.get_setting('Settings/GetThemes', **query)
+ return self.get_list_of_settings('Settings/GetThemes', **query)
@endpoint(name='school-levels', perm='can_access', methods=['get'], description=_('Get school levels'))
def school_levels(self, request):
- return self.get_setting('Settings/GetSchoolLevel')
+ return self.get_list_of_settings('Settings/GetSchoolLevel')
def get_or_create_contact(self, data, name_id=None):
contact_payload = {
diff --git a/passerelle/base/models.py b/passerelle/base/models.py
index 1f07c71c..f4b9e940 100644
--- a/passerelle/base/models.py
+++ b/passerelle/base/models.py
@@ -667,6 +667,16 @@ class BaseResource(models.Model):
resource_type=ContentType.objects.get_for_model(self), resource_pk=self.pk, apiuser__key=''
).exists()
+ def get_setting(self, name):
+ connectors_settings = settings.CONNECTORS_SETTINGS
+ if not isinstance(connectors_settings, dict):
+ return None
+ connector_identifier = f'{self.get_connector_slug()}/{self.slug}'
+ connector_settings = connectors_settings.get(connector_identifier)
+ if not isinstance(connector_settings, dict):
+ return None
+ return connector_settings.get(name)
+
class AccessRight(models.Model):
codename = models.CharField(max_length=100, verbose_name='codename')
diff --git a/passerelle/settings.py b/passerelle/settings.py
index 8210cf29..8b3621d0 100644
--- a/passerelle/settings.py
+++ b/passerelle/settings.py
@@ -245,13 +245,36 @@ REQUESTS_TIMEOUT = 25
# }
REQUESTS_MAX_RETRIES = {}
+# Connectors settings - extra settings for connectors
+#
+# CONNECTORS_SETTINGS = {
+# "cmis/test": {
+# "requests_substitutions": [
+# {
+# 'url': 'https://service.example.com/api/',
+# 'search': 'http://service.example.internal/software/api/',
+# 'replace': 'https://service.example.com/api/'
+# }
+# ]
+# ]
+# ]
+#
+# * requests_substitutions:
+# Apply substitutions to HTTP responses obtained through self.requests
+# search is a python regular expression for re.sub(), and replace the replacement string.
+# The 'url' key is optional, if absent the replacement is done on all URLs.
+CONNECTORS_SETTINGS = {}
+
+# List of authorized content-types, as regular expressions, for substitutions
+REQUESTS_SUBSTITUTIONS_CONTENT_TYPES = [r'text/.*', r'application/(.*\+)?json', r'application/(.*\+)?xml']
+
# Passerelle can receive big requests (for example base64 encoded files)
DATA_UPLOAD_MAX_MEMORY_SIZE = 100 * 1024 * 1024
SITE_BASE_URL = 'http://localhost'
# List of passerelle.utils.Request response Content-Type to log
-LOGGED_CONTENT_TYPES_MESSAGES = (r'text/', r'application/(json|xml)')
+LOGGED_CONTENT_TYPES_MESSAGES = [r'text/.*', r'application/(.*\+)?json', r'application/(.*\+)?xml']
# Max size of the response to log
LOGGED_RESPONSES_MAX_SIZE = 5000
diff --git a/passerelle/utils/__init__.py b/passerelle/utils/__init__.py
index 52dea393..a9942ad3 100644
--- a/passerelle/utils/__init__.py
+++ b/passerelle/utils/__init__.py
@@ -176,7 +176,7 @@ def protected_api(perm):
return decorator
-def content_type_match(ctype):
+def should_content_type_body_be_logged(ctype):
content_types = settings.LOGGED_CONTENT_TYPES_MESSAGES
if not ctype:
return False
@@ -219,7 +219,8 @@ def log_http_request(
if logger.level == 10: # DEBUG
extra['response_headers'] = make_headers_safe(response.headers)
# log body only if content type is allowed
- if content_type_match(response.headers.get('Content-Type')):
+ content_type = response.headers.get('Content-Type', '').split(';')[0].strip().lower()
+ if should_content_type_body_be_logged(content_type):
max_size = settings.LOGGED_RESPONSES_MAX_SIZE
if hasattr(logger, 'connector'):
max_size = logger.connector.logging_parameters.responses_max_size or max_size
@@ -280,6 +281,72 @@ class Request(RequestSession):
self.mount('http://', adapter)
self.timeout = timeout if timeout is not None else settings.REQUESTS_TIMEOUT
+ def _substitute(self, search, replace, value):
+ if isinstance(value, str):
+ value, nsub = re.subn(search, replace, value)
+ if nsub:
+ self.logger.debug('substitution: %d occurences', nsub)
+ elif isinstance(value, list):
+ value = [self._substitute(search, replace, v) for v in value]
+ elif isinstance(value, dict):
+ value = {
+ self._substitute(search, replace, k): self._substitute(search, replace, v)
+ for k, v in value.items()
+ }
+ return value
+
+ def apply_requests_substitution(self, response, substitution):
+ if not isinstance(substitution, dict):
+ self.logger.warning('substitution: invalid substitution, %r', substitution)
+ return
+ for key in ['search', 'replace']:
+ if key not in substitution:
+ self.logger.warning('substitution: missing field "%s": %s', key, substitution)
+ return
+ if not isinstance(substitution[key], str):
+ self.logger.warning(
+ 'substitution: invalid type for field "%s", must be str: %s', key, substitution
+ )
+ return
+ search = substitution['search']
+ replace = substitution['replace']
+
+ # filter on url
+ if isinstance(substitution.get('url'), str):
+ url = urllib.parse.urlparse(substitution['url'])
+ request_url = urllib.parse.urlparse(response.request.url)
+ if url.scheme and url.scheme != request_url.scheme:
+ return
+ # substitution without a netloc are ignored
+ if not url.netloc:
+ return
+ if request_url.netloc != url.netloc:
+ return
+ if url.path and url.path != '/' and not request_url.path.startswith(url.path):
+ return
+
+ # filter on content-type
+ content_type = response.headers.get('Content-Type', '').split(';')[0].strip().lower()
+ for content_type_re in settings.REQUESTS_SUBSTITUTIONS_CONTENT_TYPES:
+ if re.match(content_type_re, content_type):
+ break
+ else:
+ self.logger.debug('substitution: content_type did not match %s', content_type)
+ return
+
+ self.logger.debug('substitution: try %s', substitution)
+ try:
+ if re.match(r'application/([^;]\+)?json', content_type):
+ import json
+
+ response._content = json.dumps(self._substitute(search, replace, response.json())).encode()
+ else:
+ response._content = self._substitute(search, replace, response.text).encode()
+ response.encoding = 'utf-8'
+ return True
+ except Exception:
+ self.logger.exception('substitution: "%s" failed', substitution)
+
def request(self, method, url, **kwargs):
cache_duration = kwargs.pop('cache_duration', None)
invalidate_cache = kwargs.pop('invalidate_cache', False)
@@ -336,6 +403,13 @@ class Request(RequestSession):
warnings.simplefilter(action='ignore', category=InsecureRequestWarning)
response = super().request(method, url, **kwargs)
+ if self.resource:
+ requests_substitutions = self.resource.get_setting('requests_substitutions')
+ if isinstance(requests_substitutions, list):
+ for requests_substitution in requests_substitutions:
+ if not self.apply_requests_substitution(response, requests_substitution):
+ self.logger.debug('substitution: %s does not match', requests_substitution)
+
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
cache.set(
cache_key,
diff --git a/tests/test_requests.py b/tests/test_requests.py
index 26a39b3a..5a3d5cd8 100644
--- a/tests/test_requests.py
+++ b/tests/test_requests.py
@@ -33,6 +33,10 @@ class MockResource:
verify_cert = True
http_proxy = ''
+ @classmethod
+ def get_setting(cls, name):
+ return None
+
@pytest.fixture(params=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'])
def log_level(request):
@@ -531,3 +535,132 @@ def test_requests_to_legacy_urls(log_level):
resp = requests.get('https://old.org/foobar')
assert resp.json() == {"foo": "bar"}
assert resp.request.url == 'https://new.org/foobar'
+
+
+@responses.activate
+def test_requests_substitution(settings):
+ from passerelle.base.models import BaseResource
+
+ resource = mock.Mock()
+ resource.requests_max_retries = {}
+ resource.slug = 'test'
+ resource.get_connector_slug.return_value = 'cmis'
+ resource.get_setting = lambda name: BaseResource.get_setting(resource, name)
+
+ requests = Request(logger=logging.getLogger(), resource=resource)
+ settings.CONNECTORS_SETTINGS = {
+ "cmis/test": {
+ 'requests_substitutions': [
+ {
+ 'url': 'https://example.com/',
+ 'search': 'http://example.internal',
+ 'replace': 'https://example.com',
+ }
+ ]
+ }
+ }
+ responses.add(
+ responses.GET,
+ "https://example.com/html",
+ content_type='text/html',
+ body=b'\n\n',
+ status=200,
+ )
+ assert (
+ requests.get('https://example.com/html?bar=foo', params={'foo': 'bar'}).text
+ == '\n\n'
+ )
+
+ responses.add(
+ responses.GET,
+ "https://example.com/xml",
+ content_type='application/xml',
+ body=b'',
+ status=200,
+ )
+ assert requests.get('https://example.com/xml').text == ''
+
+ # check substitution is applied inside JSON, even if some characters are escaped
+ responses.add(
+ responses.GET,
+ "https://example.com/json",
+ content_type='application/json',
+ body=b'{"url": "http:\\/\\/example.internal/path/"}',
+ status=200,
+ )
+ assert requests.get('https://example.com/json').json() == {'url': 'https://example.com/path/'}
+
+ responses.add(
+ responses.GET,
+ "https://example.com/binary",
+ content_type='application/octet-stream',
+ body=b'\00',
+ status=200,
+ )
+ assert (
+ requests.get('https://example.com/binary').content
+ == b'\00'
+ )
+
+ responses.add(
+ responses.GET,
+ "https://example.com/binary2",
+ content_type='',
+ body=b'\00',
+ status=200,
+ )
+ assert (
+ requests.get('https://example.com/binary2').content
+ == b'\00'
+ )
+
+ responses.add(
+ responses.GET,
+ "https://example2.com/html",
+ content_type='text/html',
+ body=b'\n\n',
+ status=200,
+ )
+ # wrong hostname
+ assert (
+ requests.get('https://example2.com/html?query=1').text
+ == '\n\n'
+ )
+
+ # check that url field is optional
+ settings.CONNECTORS_SETTINGS = {
+ "cmis/test": {
+ 'requests_substitutions': [
+ {
+ 'search': 'http://example.internal',
+ 'replace': 'https://example.com',
+ }
+ ]
+ }
+ }
+ responses.add(
+ responses.GET,
+ "https://whatever.com/html",
+ content_type='text/html',
+ body=b'\n\n',
+ status=200,
+ )
+ assert (
+ requests.get('https://whatever.com/html?bar=foo', params={'foo': 'bar'}).text
+ == '\n\n'
+ )
+
+ # check setting is applied per connector slug
+ resource.get_connector_slug.return_value = 'pas-cmis'
+ requests = Request(logger=logging.getLogger(), resource=resource)
+ responses.add(
+ responses.GET,
+ "https://example.com/html",
+ content_type='text/html',
+ body=b'\n\n',
+ status=200,
+ )
+ assert (
+ requests.get('https://example.com/html?bar=foo', params={'foo': 'bar'}).text
+ == '\n\n'
+ )