Address code review comments

This commit is contained in:
ReAnzu 2018-03-08 18:55:41 -06:00
parent aa2963c0c6
commit a900c54fe6
16 changed files with 74 additions and 62 deletions

2
.gitignore vendored
View file

@ -1,4 +1,4 @@
config.yaml
*/*_modules/
.coverage
.cache
.cache

View file

@ -337,8 +337,7 @@ class Api extends events.EventTarget {
req.auth = null;
req.set('Authorization', 'Token '
+ new Buffer(this.userName + ":" + this.userToken).toString('base64'))
}
else if (this.userName && this.userPassword) {
} else if (this.userName && this.userPassword) {
req.auth(
this.userName,
encodeURIComponent(this.userPassword)

View file

@ -45,7 +45,7 @@ class UserToken extends events.EventTarget {
let userTokenRequest = {
enabled: true
};
if (note){
if (note) {
userTokenRequest.note = note;
}
if (expirationTime) {
@ -86,4 +86,4 @@ class UserToken extends events.EventTarget {
}
}
module.exports = UserToken;
module.exports = UserToken;

View file

@ -4,4 +4,4 @@ Date.prototype.addDays = function(days) {
let dat = new Date(this.valueOf());
dat.setDate(dat.getDate() + days);
return dat;
};
};

View file

@ -78,7 +78,9 @@ class UserTokenView extends events.EventTarget {
this._userTokenNoteInputNode.value :
undefined,
expirationTime: this._userTokenExpirationTimeInputNode && this._userTokenExpirationTimeInputNode.value.length > 0 ?
expirationTime:
(this._userTokenExpirationTimeInputNode
&& this._userTokenExpirationTimeInputNode.value.length > 0) ?
new Date(this._userTokenExpirationTimeInputNode.value).toISOString() :
undefined,

View file

@ -1,5 +1,4 @@
from typing import Dict
from szurubooru import model, rest
from szurubooru.func import auth, users, user_tokens, serialization, versions
@ -37,8 +36,8 @@ def create_user_token(
user_tokens.update_user_token_note(user_token, note)
if ctx.has_param('expirationTime'):
expiration_time = ctx.get_param_as_string('expirationTime')
user_tokens.update_user_token_expiration_time(user_token,
expiration_time)
user_tokens.update_user_token_expiration_time(
user_token, expiration_time)
ctx.session.add(user_token)
ctx.session.commit()
return _serialize(ctx, user_token)
@ -55,8 +54,8 @@ def update_user_token(
versions.bump_version(user_token)
if ctx.has_param('enabled'):
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
user_tokens.update_user_token_enabled(user_token,
ctx.get_param_as_bool('enabled'))
user_tokens.update_user_token_enabled(
user_token, ctx.get_param_as_bool('enabled'))
if ctx.has_param('note'):
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
note = ctx.get_param_as_string('note')
@ -64,8 +63,8 @@ def update_user_token(
if ctx.has_param('expirationTime'):
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
expiration_time = ctx.get_param_as_string('expirationTime')
user_tokens.update_user_token_expiration_time(user_token,
expiration_time)
user_tokens.update_user_token_expiration_time(
user_token, expiration_time)
user_tokens.update_user_token_edit_time(user_token)
ctx.session.commit()
return _serialize(ctx, user_token)

View file

@ -28,7 +28,8 @@ def get_password_hash(salt: str, password: str) -> Tuple[str, int]:
).decode('utf8'), 3
def get_sha256_legacy_password_hash(salt: str, password: str) -> Tuple[str, int]:
def get_sha256_legacy_password_hash(
salt: str, password: str) -> Tuple[str, int]:
''' Retrieve old-style sha256 password hash. '''
digest = hashlib.sha256()
digest.update(config.config['secret'].encode('utf8'))
@ -81,8 +82,10 @@ def is_valid_password(user: model.User, password: str) -> bool:
def is_valid_token(user_token: model.UserToken) -> bool:
''' Token must be enabled and if it has an expiration,
it must be greater than now. '''
'''
Token must be enabled and if it has an expiration, it must be
greater than now.
'''
assert user_token
if not user_token.enabled:
return False

View file

@ -74,10 +74,12 @@ def serialize_user_token(
def get_by_user_and_token(
user: model.User, token: str) -> model.UserToken:
return (db.session.query(model.UserToken)
.filter(model.UserToken.user_id == user.user_id,
model.UserToken.token == token)
.one_or_none())
return (
db.session
.query(model.UserToken)
.filter(model.UserToken.user_id == user.user_id)
.filter(model.UserToken.token == token)
.one_or_none())
def get_user_tokens(user: model.User) -> List[model.UserToken]:
@ -111,23 +113,23 @@ def update_user_token_edit_time(user_token: model.UserToken) -> None:
def update_user_token_expiration_time(
user_token: model.UserToken, expiration_time: str) -> None:
user_token: model.UserToken, expiration_time_str: str) -> None:
assert user_token
if expiration_time is not None:
try:
expiration_time = dateutil_parser.parse(expiration_time)
except ValueError:
raise InvalidExpirationError(
'Expiration is in invalid format {}'.format(expiration_time))
try:
expiration_time = dateutil_parser.parse(expiration_time_str)
if expiration_time.tzinfo is None:
raise InvalidExpirationError(
'Expiration cannot be missing timezone')
else:
expiration_time = expiration_time.astimezone(pytz.UTC)
if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC):
raise InvalidExpirationError(
'Expiration cannot happen in the past')
user_token.expiration_time = expiration_time
if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC):
raise InvalidExpirationError(
'Expiration cannot happen in the past')
user_token.expiration_time = expiration_time
except ValueError:
raise InvalidExpirationError(
'Expiration is in an invalid format {}'.format(
expiration_time_str))
def update_user_token_note(user_token: model.UserToken, note: str) -> None:

View file

@ -170,10 +170,11 @@ def get_user_count() -> int:
def try_get_user_by_name(name: str) -> Optional[model.User]:
return (db.session
.query(model.User)
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
.one_or_none())
return (
db.session
.query(model.User)
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
.one_or_none())
def get_user_by_name(name: str) -> model.User:

View file

@ -5,7 +5,7 @@ from szurubooru.func import auth, users, user_tokens
from szurubooru.rest.errors import HttpBadRequest
def _authenticate(username: str, password: str) -> model.User:
def _authenticate_basic_auth(username: str, password: str) -> model.User:
''' Try to authenticate user. Throw AuthError for invalid users. '''
user = users.get_user_by_name(username)
if not auth.is_valid_password(user, password):
@ -31,7 +31,7 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
if auth_type.lower() == 'basic':
username, password = base64.decodebytes(
credentials.encode('ascii')).decode('utf8').split(':', 1)
return _authenticate(username, password)
return _authenticate_basic_auth(username, password)
elif auth_type.lower() == 'token':
username, token = base64.decodebytes(
credentials.encode('ascii')).decode('utf8').split(':', 1)

View file

@ -27,7 +27,9 @@ def upgrade():
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
sa.Column('version', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
['user_id'], ['user.id'], ondelete='CASCADE'),
['user_id'],
['user.id'],
ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'))
op.create_index(
op.f('ix_user_token_user_id'), 'user_token', ['user_id'], unique=False)

View file

@ -1,7 +1,5 @@
from unittest.mock import patch
import pytest
from szurubooru import api
from szurubooru.func import user_tokens, users
@ -22,7 +20,10 @@ def test_creating_user_token(
user_tokens.serialize_user_token.return_value = 'serialized user token'
user_tokens.create_user_token.return_value = user_token
result = api.user_token_api.create_user_token(
context_factory(user=user_token.user),
{'user_name': user_token.user.name})
context_factory(
user=user_token.user),
{
'user_name': user_token.user.name
})
assert result == 'serialized user token'
user_tokens.create_user_token.assert_called_once_with(user_token.user)

View file

@ -1,7 +1,5 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db
from szurubooru.func import user_tokens, users
@ -22,9 +20,12 @@ def test_deleting_user_token(
users.get_user_by_name.return_value = user_token.user
user_tokens.get_by_user_and_token.return_value = user_token
result = api.user_token_api.delete_user_token(
context_factory(user=user_token.user),
{'user_name': user_token.user.name,
'user_token': user_token.token})
context_factory(
user=user_token.user),
{
'user_name': user_token.user.name,
'user_token': user_token.token
})
assert result == {}
user_tokens.get_by_user_and_token.assert_called_once_with(
user_token.user, user_token.token)

View file

@ -1,7 +1,5 @@
from unittest.mock import patch
import pytest
from szurubooru import api
from szurubooru.func import user_tokens, users
@ -25,7 +23,10 @@ def test_retrieving_user_tokens(
user_tokens.get_user_tokens.return_value = [user_token1, user_token2,
user_token3]
result = api.user_token_api.get_user_tokens(
context_factory(user=user_token1.user),
{'user_name': user_token1.user.name})
context_factory(
user=user_token1.user),
{
'user_name': user_token1.user.name
})
assert result == {'results': ['serialized user token'] * 3}
user_tokens.get_user_tokens.assert_called_once_with(user_token1.user)

View file

@ -1,7 +1,5 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db
from szurubooru.func import user_tokens, users
@ -25,12 +23,16 @@ def test_edit_user_token(user_token_factory, context_factory, fake_datetime):
user_tokens.serialize_user_token.return_value = 'serialized user token'
user_tokens.get_by_user_and_token.return_value = user_token
result = api.user_token_api.update_user_token(
context_factory(params={'version': user_token.version,
'enabled': False,
},
user=user_token.user),
{'user_name': user_token.user.name,
'user_token': user_token.token})
context_factory(
params={
'version': user_token.version,
'enabled': False,
},
user=user_token.user),
{
'user_name': user_token.user.name,
'user_token': user_token.token
})
assert result == 'serialized user token'
user_tokens.get_by_user_and_token.assert_called_once_with(
user_token.user, user_token.token)

View file

@ -1,9 +1,8 @@
from unittest.mock import patch
import pytest
from szurubooru.func import auth, users, user_tokens
from szurubooru.middleware import authenticator
from szurubooru.rest import errors
import pytest
def test_process_request_no_header(context_factory):