diff --git a/API.md b/API.md index b28e4e37..11e0d6c9 100644 --- a/API.md +++ b/API.md @@ -1522,9 +1522,9 @@ data. ```json5 { - "enabled": , // optional - "note": , // optional - "expiration": , // optional + "enabled": , // optional + "note": , // optional + "expirationTime": , // optional } ``` @@ -1550,10 +1550,10 @@ data. ```json5 { - "version": , - "enabled": , // optional - "note": , // optional - "expiration": , // optional + "version": , + "enabled": , // optional + "note": , // optional + "expirationTime": , // optional } ``` @@ -1842,14 +1842,14 @@ A single user token. ```json5 { - "user": , - "token": , - "note": , - "enabled": , - "expiration": , - "version": , - "creationTime": , - "lastEditTime": , + "user": , + "token": , + "note": , + "enabled": , + "expirationTime": , + "version": , + "creationTime": , + "lastEditTime": , } ``` @@ -1858,7 +1858,7 @@ A single user token. - ``: the token that can be used to authenticate the user. - ``: a note that describes the token. - ``: whether the token is still valid for authentication. -- ``: time when the token expires. +- ``: time when the token expires. It must include the timezone as per RFC3339. - ``: resource version. See [versioning](#versioning). - ``: time the user token was created , formatted as per RFC 3339. - ``: time the user token was edited, formatted as per RFC 3339. diff --git a/server/requirements.txt b/server/requirements.txt index 7cc47868..b11c3b5d 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -11,4 +11,6 @@ scipy>=0.18.1 elasticsearch>=5.0.0 elasticsearch-dsl>=5.0.0 scikit-image>=0.12 -pynacl>=1.2.1 \ No newline at end of file +pynacl>=1.2.1 +pytz>=2018.3 +pyRFC3339>=1.0 \ No newline at end of file diff --git a/server/szurubooru/func/user_tokens.py b/server/szurubooru/func/user_tokens.py index f7969c84..e9d926c5 100644 --- a/server/szurubooru/func/user_tokens.py +++ b/server/szurubooru/func/user_tokens.py @@ -1,6 +1,6 @@ import pytz -from dateutil import parser as dateutil_parser from datetime import datetime +from pyrfc3339 import parser as rfc3339_parser from typing import Any, Optional, List, Dict, Callable from szurubooru import db, model, rest, errors from szurubooru.func import auth, serialization, users, util @@ -118,16 +118,12 @@ def update_user_token_expiration_time( user_token: model.UserToken, expiration_time_str: str) -> None: assert user_token try: - expiration_time = dateutil_parser.parse(expiration_time_str) - if expiration_time.tzinfo is None: + expiration_time = rfc3339_parser.parse(expiration_time_str, utc=True) + expiration_time = expiration_time.astimezone(pytz.UTC) + if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC): raise InvalidExpirationError( - 'Expiration cannot be missing timezone') - else: - expiration_time = expiration_time.astimezone(pytz.UTC) - if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC): - raise InvalidExpirationError( - 'Expiration cannot happen in the past') - user_token.expiration_time = expiration_time + 'Expiration cannot happen in the past') + user_token.expiration_time = expiration_time except ValueError: raise InvalidExpirationError( 'Expiration is in an invalid format {}'.format( diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index ba2d4dc9..5e822866 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -160,6 +160,12 @@ def value_exceeds_column_size(value: Optional[str], column: Any) -> bool: return len(value) > max_length +def get_column_size(column: Any) -> Optional[int]: + if not column: + return None + return column.property.columns[0].type.length + + def chunks(source_list: List[Any], part_size: int) -> Generator: for i in range(0, len(source_list), part_size): yield source_list[i:i + part_size] diff --git a/server/szurubooru/tests/api/test_user_token_creating.py b/server/szurubooru/tests/api/test_user_token_creating.py index 05b2caa6..f550f63f 100644 --- a/server/szurubooru/tests/api/test_user_token_creating.py +++ b/server/szurubooru/tests/api/test_user_token_creating.py @@ -25,4 +25,5 @@ def test_creating_user_token( 'user_name': user_token.user.name }) assert result == 'serialized user token' - user_tokens.create_user_token.assert_called_once_with(user_token.user) + user_tokens.create_user_token.assert_called_once_with( + user_token.user, True) diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index ecf22ff6..f6eeee18 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -135,14 +135,20 @@ def user_factory(): @pytest.fixture def user_token_factory(user_factory): - def factory(user=None, token=None, enabled=None, creation_time=None): + def factory( + user=None, + token=None, + expiration_time=None, + enabled=None, + creation_time=None): if user is None: user = user_factory() db.session.add(user) user_token = model.UserToken() user_token.user = user user_token.token = token or 'dummy' - user_token.enabled = enabled or True + user_token.expiration_time = expiration_time + user_token.enabled = enabled if enabled is not None else True user_token.creation_time = creation_time or datetime(1997, 1, 1) return user_token return factory diff --git a/server/szurubooru/tests/func/test_auth.py b/server/szurubooru/tests/func/test_auth.py index 482a8f54..e6ce9ea8 100644 --- a/server/szurubooru/tests/func/test_auth.py +++ b/server/szurubooru/tests/func/test_auth.py @@ -1,5 +1,8 @@ -from szurubooru.func import auth +from datetime import datetime, timedelta import pytest +import pytz + +from szurubooru.func import auth @pytest.fixture(autouse=True) @@ -48,6 +51,17 @@ def test_is_valid_token(user_token_factory): assert auth.is_valid_token(user_token) +def test_expired_token_is_invalid(user_token_factory): + past_expiration = (datetime.utcnow() - timedelta(minutes=30)) + user_token = user_token_factory(expiration_time=past_expiration) + assert not auth.is_valid_token(user_token) + + +def test_disabled_token_is_invalid(user_token_factory): + user_token = user_token_factory(enabled=False) + assert not auth.is_valid_token(user_token) + + def test_generate_authorization_token(): result = auth.generate_authorization_token() assert result != auth.generate_authorization_token() diff --git a/server/szurubooru/tests/func/test_user_tokens.py b/server/szurubooru/tests/func/test_user_tokens.py index 5fba92db..7e101ebc 100644 --- a/server/szurubooru/tests/func/test_user_tokens.py +++ b/server/szurubooru/tests/func/test_user_tokens.py @@ -1,8 +1,11 @@ -from datetime import datetime +from datetime import datetime, timedelta from unittest.mock import patch - -from szurubooru import db -from szurubooru.func import user_tokens, users, auth +import pytest +import pytz +import random +import string +from szurubooru import db, model +from szurubooru.func import user_tokens, users, auth, util def test_serialize_user_token(user_token_factory): @@ -15,7 +18,9 @@ def test_serialize_user_token(user_token_factory): assert result == { 'creationTime': datetime(1997, 1, 1, 0, 0), 'enabled': True, + 'expirationTime': None, 'lastEditTime': None, + 'note': None, 'token': 'dummy', 'user': { 'avatarUrl': 'https://example.com/avatar.png', @@ -34,8 +39,8 @@ def test_get_by_user_and_token(user_token_factory): db.session.add(user_token) db.session.flush() db.session.commit() - result = user_tokens.get_by_user_and_token(user_token.user, - user_token.token) + result = user_tokens.get_by_user_and_token( + user_token.user, user_token.token) assert result == user_token @@ -57,7 +62,7 @@ def test_create_user_token(user_factory): db.session.commit() with patch('szurubooru.func.auth.generate_authorization_token'): auth.generate_authorization_token.return_value = 'test' - result = user_tokens.create_user_token(user) + result = user_tokens.create_user_token(user, True) assert result.token == 'test' assert result.user == user @@ -73,3 +78,67 @@ def test_update_user_token_edit_time(user_token_factory): assert user_token.last_edit_time is None user_tokens.update_user_token_edit_time(user_token) assert user_token.last_edit_time is not None + + +def test_update_user_token_note(user_token_factory): + user_token = user_token_factory() + assert user_token.note is None + user_tokens.update_user_token_note(user_token, ' Test Note ') + assert user_token.note == 'Test Note' + + +def test_update_user_token_note_input_too_long(user_token_factory): + user_token = user_token_factory() + assert user_token.note is None + note_max_length = util.get_column_size(model.UserToken.note) + 1 + note = ''.join( + random.choice(string.ascii_letters) for _ in range(note_max_length)) + with pytest.raises(user_tokens.InvalidNoteError): + user_tokens.update_user_token_note(user_token, note) + + +def test_update_user_token_expiration_time(user_token_factory): + user_token = user_token_factory() + assert user_token.expiration_time is None + expiration_time_str = ( + (datetime.utcnow() + timedelta(days=1)) + .replace(tzinfo=pytz.utc) + ).isoformat() + user_tokens.update_user_token_expiration_time( + user_token, expiration_time_str) + assert user_token.expiration_time.isoformat() == expiration_time_str + + +def test_update_user_token_expiration_time_in_past(user_token_factory): + user_token = user_token_factory() + assert user_token.expiration_time is None + expiration_time_str = ( + (datetime.utcnow() - timedelta(days=1)) + .replace(tzinfo=pytz.utc) + ).isoformat() + with pytest.raises( + user_tokens.InvalidExpirationError, + match='Expiration cannot happen in the past'): + user_tokens.update_user_token_expiration_time( + user_token, expiration_time_str) + + +@pytest.mark.parametrize('expiration_time_str', [ + datetime.utcnow().isoformat(), + (datetime.utcnow() - timedelta(days=1)).ctime(), + '1970/01/01 00:00:01.0000Z', + '70/01/01 00:00:01.0000Z', + ''.join(random.choice(string.ascii_letters) for _ in range(15)), + ''.join(random.choice(string.digits) for _ in range(8)) +]) +def test_update_user_token_expiration_time_invalid_format( + expiration_time_str, user_token_factory): + user_token = user_token_factory() + assert user_token.expiration_time is None + + with pytest.raises( + user_tokens.InvalidExpirationError, + match='Expiration is in an invalid format %s' + % expiration_time_str): + user_tokens.update_user_token_expiration_time( + user_token, expiration_time_str)