Address code review comments
This commit is contained in:
parent
aa2963c0c6
commit
a900c54fe6
16 changed files with 74 additions and 62 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,4 +1,4 @@
|
|||
config.yaml
|
||||
*/*_modules/
|
||||
.coverage
|
||||
.cache
|
||||
.cache
|
||||
|
|
|
@ -337,8 +337,7 @@ class Api extends events.EventTarget {
|
|||
req.auth = null;
|
||||
req.set('Authorization', 'Token '
|
||||
+ new Buffer(this.userName + ":" + this.userToken).toString('base64'))
|
||||
}
|
||||
else if (this.userName && this.userPassword) {
|
||||
} else if (this.userName && this.userPassword) {
|
||||
req.auth(
|
||||
this.userName,
|
||||
encodeURIComponent(this.userPassword)
|
||||
|
|
|
@ -45,7 +45,7 @@ class UserToken extends events.EventTarget {
|
|||
let userTokenRequest = {
|
||||
enabled: true
|
||||
};
|
||||
if (note){
|
||||
if (note) {
|
||||
userTokenRequest.note = note;
|
||||
}
|
||||
if (expirationTime) {
|
||||
|
@ -86,4 +86,4 @@ class UserToken extends events.EventTarget {
|
|||
}
|
||||
}
|
||||
|
||||
module.exports = UserToken;
|
||||
module.exports = UserToken;
|
||||
|
|
|
@ -4,4 +4,4 @@ Date.prototype.addDays = function(days) {
|
|||
let dat = new Date(this.valueOf());
|
||||
dat.setDate(dat.getDate() + days);
|
||||
return dat;
|
||||
};
|
||||
};
|
||||
|
|
|
@ -78,7 +78,9 @@ class UserTokenView extends events.EventTarget {
|
|||
this._userTokenNoteInputNode.value :
|
||||
undefined,
|
||||
|
||||
expirationTime: this._userTokenExpirationTimeInputNode && this._userTokenExpirationTimeInputNode.value.length > 0 ?
|
||||
expirationTime:
|
||||
(this._userTokenExpirationTimeInputNode
|
||||
&& this._userTokenExpirationTimeInputNode.value.length > 0) ?
|
||||
new Date(this._userTokenExpirationTimeInputNode.value).toISOString() :
|
||||
undefined,
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Dict
|
||||
|
||||
from szurubooru import model, rest
|
||||
from szurubooru.func import auth, users, user_tokens, serialization, versions
|
||||
|
||||
|
@ -37,8 +36,8 @@ def create_user_token(
|
|||
user_tokens.update_user_token_note(user_token, note)
|
||||
if ctx.has_param('expirationTime'):
|
||||
expiration_time = ctx.get_param_as_string('expirationTime')
|
||||
user_tokens.update_user_token_expiration_time(user_token,
|
||||
expiration_time)
|
||||
user_tokens.update_user_token_expiration_time(
|
||||
user_token, expiration_time)
|
||||
ctx.session.add(user_token)
|
||||
ctx.session.commit()
|
||||
return _serialize(ctx, user_token)
|
||||
|
@ -55,8 +54,8 @@ def update_user_token(
|
|||
versions.bump_version(user_token)
|
||||
if ctx.has_param('enabled'):
|
||||
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
|
||||
user_tokens.update_user_token_enabled(user_token,
|
||||
ctx.get_param_as_bool('enabled'))
|
||||
user_tokens.update_user_token_enabled(
|
||||
user_token, ctx.get_param_as_bool('enabled'))
|
||||
if ctx.has_param('note'):
|
||||
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
|
||||
note = ctx.get_param_as_string('note')
|
||||
|
@ -64,8 +63,8 @@ def update_user_token(
|
|||
if ctx.has_param('expirationTime'):
|
||||
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
|
||||
expiration_time = ctx.get_param_as_string('expirationTime')
|
||||
user_tokens.update_user_token_expiration_time(user_token,
|
||||
expiration_time)
|
||||
user_tokens.update_user_token_expiration_time(
|
||||
user_token, expiration_time)
|
||||
user_tokens.update_user_token_edit_time(user_token)
|
||||
ctx.session.commit()
|
||||
return _serialize(ctx, user_token)
|
||||
|
|
|
@ -28,7 +28,8 @@ def get_password_hash(salt: str, password: str) -> Tuple[str, int]:
|
|||
).decode('utf8'), 3
|
||||
|
||||
|
||||
def get_sha256_legacy_password_hash(salt: str, password: str) -> Tuple[str, int]:
|
||||
def get_sha256_legacy_password_hash(
|
||||
salt: str, password: str) -> Tuple[str, int]:
|
||||
''' Retrieve old-style sha256 password hash. '''
|
||||
digest = hashlib.sha256()
|
||||
digest.update(config.config['secret'].encode('utf8'))
|
||||
|
@ -81,8 +82,10 @@ def is_valid_password(user: model.User, password: str) -> bool:
|
|||
|
||||
|
||||
def is_valid_token(user_token: model.UserToken) -> bool:
|
||||
''' Token must be enabled and if it has an expiration,
|
||||
it must be greater than now. '''
|
||||
'''
|
||||
Token must be enabled and if it has an expiration, it must be
|
||||
greater than now.
|
||||
'''
|
||||
assert user_token
|
||||
if not user_token.enabled:
|
||||
return False
|
||||
|
|
|
@ -74,10 +74,12 @@ def serialize_user_token(
|
|||
|
||||
def get_by_user_and_token(
|
||||
user: model.User, token: str) -> model.UserToken:
|
||||
return (db.session.query(model.UserToken)
|
||||
.filter(model.UserToken.user_id == user.user_id,
|
||||
model.UserToken.token == token)
|
||||
.one_or_none())
|
||||
return (
|
||||
db.session
|
||||
.query(model.UserToken)
|
||||
.filter(model.UserToken.user_id == user.user_id)
|
||||
.filter(model.UserToken.token == token)
|
||||
.one_or_none())
|
||||
|
||||
|
||||
def get_user_tokens(user: model.User) -> List[model.UserToken]:
|
||||
|
@ -111,23 +113,23 @@ def update_user_token_edit_time(user_token: model.UserToken) -> None:
|
|||
|
||||
|
||||
def update_user_token_expiration_time(
|
||||
user_token: model.UserToken, expiration_time: str) -> None:
|
||||
user_token: model.UserToken, expiration_time_str: str) -> None:
|
||||
assert user_token
|
||||
if expiration_time is not None:
|
||||
try:
|
||||
expiration_time = dateutil_parser.parse(expiration_time)
|
||||
except ValueError:
|
||||
raise InvalidExpirationError(
|
||||
'Expiration is in invalid format {}'.format(expiration_time))
|
||||
try:
|
||||
expiration_time = dateutil_parser.parse(expiration_time_str)
|
||||
if expiration_time.tzinfo is None:
|
||||
raise InvalidExpirationError(
|
||||
'Expiration cannot be missing timezone')
|
||||
else:
|
||||
expiration_time = expiration_time.astimezone(pytz.UTC)
|
||||
if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC):
|
||||
raise InvalidExpirationError(
|
||||
'Expiration cannot happen in the past')
|
||||
user_token.expiration_time = expiration_time
|
||||
if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC):
|
||||
raise InvalidExpirationError(
|
||||
'Expiration cannot happen in the past')
|
||||
user_token.expiration_time = expiration_time
|
||||
except ValueError:
|
||||
raise InvalidExpirationError(
|
||||
'Expiration is in an invalid format {}'.format(
|
||||
expiration_time_str))
|
||||
|
||||
|
||||
def update_user_token_note(user_token: model.UserToken, note: str) -> None:
|
||||
|
|
|
@ -170,10 +170,11 @@ def get_user_count() -> int:
|
|||
|
||||
|
||||
def try_get_user_by_name(name: str) -> Optional[model.User]:
|
||||
return (db.session
|
||||
.query(model.User)
|
||||
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
|
||||
.one_or_none())
|
||||
return (
|
||||
db.session
|
||||
.query(model.User)
|
||||
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
|
||||
.one_or_none())
|
||||
|
||||
|
||||
def get_user_by_name(name: str) -> model.User:
|
||||
|
|
|
@ -5,7 +5,7 @@ from szurubooru.func import auth, users, user_tokens
|
|||
from szurubooru.rest.errors import HttpBadRequest
|
||||
|
||||
|
||||
def _authenticate(username: str, password: str) -> model.User:
|
||||
def _authenticate_basic_auth(username: str, password: str) -> model.User:
|
||||
''' Try to authenticate user. Throw AuthError for invalid users. '''
|
||||
user = users.get_user_by_name(username)
|
||||
if not auth.is_valid_password(user, password):
|
||||
|
@ -31,7 +31,7 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
|
|||
if auth_type.lower() == 'basic':
|
||||
username, password = base64.decodebytes(
|
||||
credentials.encode('ascii')).decode('utf8').split(':', 1)
|
||||
return _authenticate(username, password)
|
||||
return _authenticate_basic_auth(username, password)
|
||||
elif auth_type.lower() == 'token':
|
||||
username, token = base64.decodebytes(
|
||||
credentials.encode('ascii')).decode('utf8').split(':', 1)
|
||||
|
|
|
@ -27,7 +27,9 @@ def upgrade():
|
|||
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
|
||||
sa.Column('version', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
['user_id'], ['user.id'], ondelete='CASCADE'),
|
||||
['user_id'],
|
||||
['user.id'],
|
||||
ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'))
|
||||
op.create_index(
|
||||
op.f('ix_user_token_user_id'), 'user_token', ['user_id'], unique=False)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from szurubooru import api
|
||||
from szurubooru.func import user_tokens, users
|
||||
|
||||
|
@ -22,7 +20,10 @@ def test_creating_user_token(
|
|||
user_tokens.serialize_user_token.return_value = 'serialized user token'
|
||||
user_tokens.create_user_token.return_value = user_token
|
||||
result = api.user_token_api.create_user_token(
|
||||
context_factory(user=user_token.user),
|
||||
{'user_name': user_token.user.name})
|
||||
context_factory(
|
||||
user=user_token.user),
|
||||
{
|
||||
'user_name': user_token.user.name
|
||||
})
|
||||
assert result == 'serialized user token'
|
||||
user_tokens.create_user_token.assert_called_once_with(user_token.user)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from szurubooru import api, db
|
||||
from szurubooru.func import user_tokens, users
|
||||
|
||||
|
@ -22,9 +20,12 @@ def test_deleting_user_token(
|
|||
users.get_user_by_name.return_value = user_token.user
|
||||
user_tokens.get_by_user_and_token.return_value = user_token
|
||||
result = api.user_token_api.delete_user_token(
|
||||
context_factory(user=user_token.user),
|
||||
{'user_name': user_token.user.name,
|
||||
'user_token': user_token.token})
|
||||
context_factory(
|
||||
user=user_token.user),
|
||||
{
|
||||
'user_name': user_token.user.name,
|
||||
'user_token': user_token.token
|
||||
})
|
||||
assert result == {}
|
||||
user_tokens.get_by_user_and_token.assert_called_once_with(
|
||||
user_token.user, user_token.token)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from szurubooru import api
|
||||
from szurubooru.func import user_tokens, users
|
||||
|
||||
|
@ -25,7 +23,10 @@ def test_retrieving_user_tokens(
|
|||
user_tokens.get_user_tokens.return_value = [user_token1, user_token2,
|
||||
user_token3]
|
||||
result = api.user_token_api.get_user_tokens(
|
||||
context_factory(user=user_token1.user),
|
||||
{'user_name': user_token1.user.name})
|
||||
context_factory(
|
||||
user=user_token1.user),
|
||||
{
|
||||
'user_name': user_token1.user.name
|
||||
})
|
||||
assert result == {'results': ['serialized user token'] * 3}
|
||||
user_tokens.get_user_tokens.assert_called_once_with(user_token1.user)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from szurubooru import api, db
|
||||
from szurubooru.func import user_tokens, users
|
||||
|
||||
|
@ -25,12 +23,16 @@ def test_edit_user_token(user_token_factory, context_factory, fake_datetime):
|
|||
user_tokens.serialize_user_token.return_value = 'serialized user token'
|
||||
user_tokens.get_by_user_and_token.return_value = user_token
|
||||
result = api.user_token_api.update_user_token(
|
||||
context_factory(params={'version': user_token.version,
|
||||
'enabled': False,
|
||||
},
|
||||
user=user_token.user),
|
||||
{'user_name': user_token.user.name,
|
||||
'user_token': user_token.token})
|
||||
context_factory(
|
||||
params={
|
||||
'version': user_token.version,
|
||||
'enabled': False,
|
||||
},
|
||||
user=user_token.user),
|
||||
{
|
||||
'user_name': user_token.user.name,
|
||||
'user_token': user_token.token
|
||||
})
|
||||
assert result == 'serialized user token'
|
||||
user_tokens.get_by_user_and_token.assert_called_once_with(
|
||||
user_token.user, user_token.token)
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from szurubooru.func import auth, users, user_tokens
|
||||
from szurubooru.middleware import authenticator
|
||||
from szurubooru.rest import errors
|
||||
import pytest
|
||||
|
||||
|
||||
def test_process_request_no_header(context_factory):
|
||||
|
|
Reference in a new issue