Address code review comments

This commit is contained in:
ReAnzu 2018-03-08 18:55:41 -06:00
parent aa2963c0c6
commit a900c54fe6
16 changed files with 74 additions and 62 deletions

2
.gitignore vendored
View file

@ -1,4 +1,4 @@
config.yaml config.yaml
*/*_modules/ */*_modules/
.coverage .coverage
.cache .cache

View file

@ -337,8 +337,7 @@ class Api extends events.EventTarget {
req.auth = null; req.auth = null;
req.set('Authorization', 'Token ' req.set('Authorization', 'Token '
+ new Buffer(this.userName + ":" + this.userToken).toString('base64')) + new Buffer(this.userName + ":" + this.userToken).toString('base64'))
} } else if (this.userName && this.userPassword) {
else if (this.userName && this.userPassword) {
req.auth( req.auth(
this.userName, this.userName,
encodeURIComponent(this.userPassword) encodeURIComponent(this.userPassword)

View file

@ -45,7 +45,7 @@ class UserToken extends events.EventTarget {
let userTokenRequest = { let userTokenRequest = {
enabled: true enabled: true
}; };
if (note){ if (note) {
userTokenRequest.note = note; userTokenRequest.note = note;
} }
if (expirationTime) { if (expirationTime) {
@ -86,4 +86,4 @@ class UserToken extends events.EventTarget {
} }
} }
module.exports = UserToken; module.exports = UserToken;

View file

@ -4,4 +4,4 @@ Date.prototype.addDays = function(days) {
let dat = new Date(this.valueOf()); let dat = new Date(this.valueOf());
dat.setDate(dat.getDate() + days); dat.setDate(dat.getDate() + days);
return dat; return dat;
}; };

View file

@ -78,7 +78,9 @@ class UserTokenView extends events.EventTarget {
this._userTokenNoteInputNode.value : this._userTokenNoteInputNode.value :
undefined, undefined,
expirationTime: this._userTokenExpirationTimeInputNode && this._userTokenExpirationTimeInputNode.value.length > 0 ? expirationTime:
(this._userTokenExpirationTimeInputNode
&& this._userTokenExpirationTimeInputNode.value.length > 0) ?
new Date(this._userTokenExpirationTimeInputNode.value).toISOString() : new Date(this._userTokenExpirationTimeInputNode.value).toISOString() :
undefined, undefined,

View file

@ -1,5 +1,4 @@
from typing import Dict from typing import Dict
from szurubooru import model, rest from szurubooru import model, rest
from szurubooru.func import auth, users, user_tokens, serialization, versions from szurubooru.func import auth, users, user_tokens, serialization, versions
@ -37,8 +36,8 @@ def create_user_token(
user_tokens.update_user_token_note(user_token, note) user_tokens.update_user_token_note(user_token, note)
if ctx.has_param('expirationTime'): if ctx.has_param('expirationTime'):
expiration_time = ctx.get_param_as_string('expirationTime') expiration_time = ctx.get_param_as_string('expirationTime')
user_tokens.update_user_token_expiration_time(user_token, user_tokens.update_user_token_expiration_time(
expiration_time) user_token, expiration_time)
ctx.session.add(user_token) ctx.session.add(user_token)
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, user_token) return _serialize(ctx, user_token)
@ -55,8 +54,8 @@ def update_user_token(
versions.bump_version(user_token) versions.bump_version(user_token)
if ctx.has_param('enabled'): if ctx.has_param('enabled'):
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix) auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
user_tokens.update_user_token_enabled(user_token, user_tokens.update_user_token_enabled(
ctx.get_param_as_bool('enabled')) user_token, ctx.get_param_as_bool('enabled'))
if ctx.has_param('note'): if ctx.has_param('note'):
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix) auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
note = ctx.get_param_as_string('note') note = ctx.get_param_as_string('note')
@ -64,8 +63,8 @@ def update_user_token(
if ctx.has_param('expirationTime'): if ctx.has_param('expirationTime'):
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix) auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
expiration_time = ctx.get_param_as_string('expirationTime') expiration_time = ctx.get_param_as_string('expirationTime')
user_tokens.update_user_token_expiration_time(user_token, user_tokens.update_user_token_expiration_time(
expiration_time) user_token, expiration_time)
user_tokens.update_user_token_edit_time(user_token) user_tokens.update_user_token_edit_time(user_token)
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, user_token) return _serialize(ctx, user_token)

View file

@ -28,7 +28,8 @@ def get_password_hash(salt: str, password: str) -> Tuple[str, int]:
).decode('utf8'), 3 ).decode('utf8'), 3
def get_sha256_legacy_password_hash(salt: str, password: str) -> Tuple[str, int]: def get_sha256_legacy_password_hash(
salt: str, password: str) -> Tuple[str, int]:
''' Retrieve old-style sha256 password hash. ''' ''' Retrieve old-style sha256 password hash. '''
digest = hashlib.sha256() digest = hashlib.sha256()
digest.update(config.config['secret'].encode('utf8')) digest.update(config.config['secret'].encode('utf8'))
@ -81,8 +82,10 @@ def is_valid_password(user: model.User, password: str) -> bool:
def is_valid_token(user_token: model.UserToken) -> bool: def is_valid_token(user_token: model.UserToken) -> bool:
''' Token must be enabled and if it has an expiration, '''
it must be greater than now. ''' Token must be enabled and if it has an expiration, it must be
greater than now.
'''
assert user_token assert user_token
if not user_token.enabled: if not user_token.enabled:
return False return False

View file

@ -74,10 +74,12 @@ def serialize_user_token(
def get_by_user_and_token( def get_by_user_and_token(
user: model.User, token: str) -> model.UserToken: user: model.User, token: str) -> model.UserToken:
return (db.session.query(model.UserToken) return (
.filter(model.UserToken.user_id == user.user_id, db.session
model.UserToken.token == token) .query(model.UserToken)
.one_or_none()) .filter(model.UserToken.user_id == user.user_id)
.filter(model.UserToken.token == token)
.one_or_none())
def get_user_tokens(user: model.User) -> List[model.UserToken]: def get_user_tokens(user: model.User) -> List[model.UserToken]:
@ -111,23 +113,23 @@ def update_user_token_edit_time(user_token: model.UserToken) -> None:
def update_user_token_expiration_time( def update_user_token_expiration_time(
user_token: model.UserToken, expiration_time: str) -> None: user_token: model.UserToken, expiration_time_str: str) -> None:
assert user_token assert user_token
if expiration_time is not None: try:
try: expiration_time = dateutil_parser.parse(expiration_time_str)
expiration_time = dateutil_parser.parse(expiration_time)
except ValueError:
raise InvalidExpirationError(
'Expiration is in invalid format {}'.format(expiration_time))
if expiration_time.tzinfo is None: if expiration_time.tzinfo is None:
raise InvalidExpirationError( raise InvalidExpirationError(
'Expiration cannot be missing timezone') 'Expiration cannot be missing timezone')
else: else:
expiration_time = expiration_time.astimezone(pytz.UTC) expiration_time = expiration_time.astimezone(pytz.UTC)
if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC): if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC):
raise InvalidExpirationError( raise InvalidExpirationError(
'Expiration cannot happen in the past') 'Expiration cannot happen in the past')
user_token.expiration_time = expiration_time user_token.expiration_time = expiration_time
except ValueError:
raise InvalidExpirationError(
'Expiration is in an invalid format {}'.format(
expiration_time_str))
def update_user_token_note(user_token: model.UserToken, note: str) -> None: def update_user_token_note(user_token: model.UserToken, note: str) -> None:

View file

@ -170,10 +170,11 @@ def get_user_count() -> int:
def try_get_user_by_name(name: str) -> Optional[model.User]: def try_get_user_by_name(name: str) -> Optional[model.User]:
return (db.session return (
.query(model.User) db.session
.filter(sa.func.lower(model.User.name) == sa.func.lower(name)) .query(model.User)
.one_or_none()) .filter(sa.func.lower(model.User.name) == sa.func.lower(name))
.one_or_none())
def get_user_by_name(name: str) -> model.User: def get_user_by_name(name: str) -> model.User:

View file

@ -5,7 +5,7 @@ from szurubooru.func import auth, users, user_tokens
from szurubooru.rest.errors import HttpBadRequest from szurubooru.rest.errors import HttpBadRequest
def _authenticate(username: str, password: str) -> model.User: def _authenticate_basic_auth(username: str, password: str) -> model.User:
''' Try to authenticate user. Throw AuthError for invalid users. ''' ''' Try to authenticate user. Throw AuthError for invalid users. '''
user = users.get_user_by_name(username) user = users.get_user_by_name(username)
if not auth.is_valid_password(user, password): if not auth.is_valid_password(user, password):
@ -31,7 +31,7 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
if auth_type.lower() == 'basic': if auth_type.lower() == 'basic':
username, password = base64.decodebytes( username, password = base64.decodebytes(
credentials.encode('ascii')).decode('utf8').split(':', 1) credentials.encode('ascii')).decode('utf8').split(':', 1)
return _authenticate(username, password) return _authenticate_basic_auth(username, password)
elif auth_type.lower() == 'token': elif auth_type.lower() == 'token':
username, token = base64.decodebytes( username, token = base64.decodebytes(
credentials.encode('ascii')).decode('utf8').split(':', 1) credentials.encode('ascii')).decode('utf8').split(':', 1)

View file

@ -27,7 +27,9 @@ def upgrade():
sa.Column('last_edit_time', sa.DateTime(), nullable=True), sa.Column('last_edit_time', sa.DateTime(), nullable=True),
sa.Column('version', sa.Integer(), nullable=False), sa.Column('version', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
['user_id'], ['user.id'], ondelete='CASCADE'), ['user_id'],
['user.id'],
ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')) sa.PrimaryKeyConstraint('id'))
op.create_index( op.create_index(
op.f('ix_user_token_user_id'), 'user_token', ['user_id'], unique=False) op.f('ix_user_token_user_id'), 'user_token', ['user_id'], unique=False)

View file

@ -1,7 +1,5 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api from szurubooru import api
from szurubooru.func import user_tokens, users from szurubooru.func import user_tokens, users
@ -22,7 +20,10 @@ def test_creating_user_token(
user_tokens.serialize_user_token.return_value = 'serialized user token' user_tokens.serialize_user_token.return_value = 'serialized user token'
user_tokens.create_user_token.return_value = user_token user_tokens.create_user_token.return_value = user_token
result = api.user_token_api.create_user_token( result = api.user_token_api.create_user_token(
context_factory(user=user_token.user), context_factory(
{'user_name': user_token.user.name}) user=user_token.user),
{
'user_name': user_token.user.name
})
assert result == 'serialized user token' assert result == 'serialized user token'
user_tokens.create_user_token.assert_called_once_with(user_token.user) user_tokens.create_user_token.assert_called_once_with(user_token.user)

View file

@ -1,7 +1,5 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db from szurubooru import api, db
from szurubooru.func import user_tokens, users from szurubooru.func import user_tokens, users
@ -22,9 +20,12 @@ def test_deleting_user_token(
users.get_user_by_name.return_value = user_token.user users.get_user_by_name.return_value = user_token.user
user_tokens.get_by_user_and_token.return_value = user_token user_tokens.get_by_user_and_token.return_value = user_token
result = api.user_token_api.delete_user_token( result = api.user_token_api.delete_user_token(
context_factory(user=user_token.user), context_factory(
{'user_name': user_token.user.name, user=user_token.user),
'user_token': user_token.token}) {
'user_name': user_token.user.name,
'user_token': user_token.token
})
assert result == {} assert result == {}
user_tokens.get_by_user_and_token.assert_called_once_with( user_tokens.get_by_user_and_token.assert_called_once_with(
user_token.user, user_token.token) user_token.user, user_token.token)

View file

@ -1,7 +1,5 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api from szurubooru import api
from szurubooru.func import user_tokens, users from szurubooru.func import user_tokens, users
@ -25,7 +23,10 @@ def test_retrieving_user_tokens(
user_tokens.get_user_tokens.return_value = [user_token1, user_token2, user_tokens.get_user_tokens.return_value = [user_token1, user_token2,
user_token3] user_token3]
result = api.user_token_api.get_user_tokens( result = api.user_token_api.get_user_tokens(
context_factory(user=user_token1.user), context_factory(
{'user_name': user_token1.user.name}) user=user_token1.user),
{
'user_name': user_token1.user.name
})
assert result == {'results': ['serialized user token'] * 3} assert result == {'results': ['serialized user token'] * 3}
user_tokens.get_user_tokens.assert_called_once_with(user_token1.user) user_tokens.get_user_tokens.assert_called_once_with(user_token1.user)

View file

@ -1,7 +1,5 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db from szurubooru import api, db
from szurubooru.func import user_tokens, users from szurubooru.func import user_tokens, users
@ -25,12 +23,16 @@ def test_edit_user_token(user_token_factory, context_factory, fake_datetime):
user_tokens.serialize_user_token.return_value = 'serialized user token' user_tokens.serialize_user_token.return_value = 'serialized user token'
user_tokens.get_by_user_and_token.return_value = user_token user_tokens.get_by_user_and_token.return_value = user_token
result = api.user_token_api.update_user_token( result = api.user_token_api.update_user_token(
context_factory(params={'version': user_token.version, context_factory(
'enabled': False, params={
}, 'version': user_token.version,
user=user_token.user), 'enabled': False,
{'user_name': user_token.user.name, },
'user_token': user_token.token}) user=user_token.user),
{
'user_name': user_token.user.name,
'user_token': user_token.token
})
assert result == 'serialized user token' assert result == 'serialized user token'
user_tokens.get_by_user_and_token.assert_called_once_with( user_tokens.get_by_user_and_token.assert_called_once_with(
user_token.user, user_token.token) user_token.user, user_token.token)

View file

@ -1,9 +1,8 @@
from unittest.mock import patch from unittest.mock import patch
import pytest
from szurubooru.func import auth, users, user_tokens from szurubooru.func import auth, users, user_tokens
from szurubooru.middleware import authenticator from szurubooru.middleware import authenticator
from szurubooru.rest import errors from szurubooru.rest import errors
import pytest
def test_process_request_no_header(context_factory): def test_process_request_no_header(context_factory):