Address code review comments
This commit is contained in:
parent
aa2963c0c6
commit
a900c54fe6
16 changed files with 74 additions and 62 deletions
|
@ -337,8 +337,7 @@ class Api extends events.EventTarget {
|
||||||
req.auth = null;
|
req.auth = null;
|
||||||
req.set('Authorization', 'Token '
|
req.set('Authorization', 'Token '
|
||||||
+ new Buffer(this.userName + ":" + this.userToken).toString('base64'))
|
+ new Buffer(this.userName + ":" + this.userToken).toString('base64'))
|
||||||
}
|
} else if (this.userName && this.userPassword) {
|
||||||
else if (this.userName && this.userPassword) {
|
|
||||||
req.auth(
|
req.auth(
|
||||||
this.userName,
|
this.userName,
|
||||||
encodeURIComponent(this.userPassword)
|
encodeURIComponent(this.userPassword)
|
||||||
|
|
|
@ -45,7 +45,7 @@ class UserToken extends events.EventTarget {
|
||||||
let userTokenRequest = {
|
let userTokenRequest = {
|
||||||
enabled: true
|
enabled: true
|
||||||
};
|
};
|
||||||
if (note){
|
if (note) {
|
||||||
userTokenRequest.note = note;
|
userTokenRequest.note = note;
|
||||||
}
|
}
|
||||||
if (expirationTime) {
|
if (expirationTime) {
|
||||||
|
|
|
@ -78,7 +78,9 @@ class UserTokenView extends events.EventTarget {
|
||||||
this._userTokenNoteInputNode.value :
|
this._userTokenNoteInputNode.value :
|
||||||
undefined,
|
undefined,
|
||||||
|
|
||||||
expirationTime: this._userTokenExpirationTimeInputNode && this._userTokenExpirationTimeInputNode.value.length > 0 ?
|
expirationTime:
|
||||||
|
(this._userTokenExpirationTimeInputNode
|
||||||
|
&& this._userTokenExpirationTimeInputNode.value.length > 0) ?
|
||||||
new Date(this._userTokenExpirationTimeInputNode.value).toISOString() :
|
new Date(this._userTokenExpirationTimeInputNode.value).toISOString() :
|
||||||
undefined,
|
undefined,
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from szurubooru import model, rest
|
from szurubooru import model, rest
|
||||||
from szurubooru.func import auth, users, user_tokens, serialization, versions
|
from szurubooru.func import auth, users, user_tokens, serialization, versions
|
||||||
|
|
||||||
|
@ -37,8 +36,8 @@ def create_user_token(
|
||||||
user_tokens.update_user_token_note(user_token, note)
|
user_tokens.update_user_token_note(user_token, note)
|
||||||
if ctx.has_param('expirationTime'):
|
if ctx.has_param('expirationTime'):
|
||||||
expiration_time = ctx.get_param_as_string('expirationTime')
|
expiration_time = ctx.get_param_as_string('expirationTime')
|
||||||
user_tokens.update_user_token_expiration_time(user_token,
|
user_tokens.update_user_token_expiration_time(
|
||||||
expiration_time)
|
user_token, expiration_time)
|
||||||
ctx.session.add(user_token)
|
ctx.session.add(user_token)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, user_token)
|
return _serialize(ctx, user_token)
|
||||||
|
@ -55,8 +54,8 @@ def update_user_token(
|
||||||
versions.bump_version(user_token)
|
versions.bump_version(user_token)
|
||||||
if ctx.has_param('enabled'):
|
if ctx.has_param('enabled'):
|
||||||
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
|
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
|
||||||
user_tokens.update_user_token_enabled(user_token,
|
user_tokens.update_user_token_enabled(
|
||||||
ctx.get_param_as_bool('enabled'))
|
user_token, ctx.get_param_as_bool('enabled'))
|
||||||
if ctx.has_param('note'):
|
if ctx.has_param('note'):
|
||||||
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
|
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
|
||||||
note = ctx.get_param_as_string('note')
|
note = ctx.get_param_as_string('note')
|
||||||
|
@ -64,8 +63,8 @@ def update_user_token(
|
||||||
if ctx.has_param('expirationTime'):
|
if ctx.has_param('expirationTime'):
|
||||||
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
|
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
|
||||||
expiration_time = ctx.get_param_as_string('expirationTime')
|
expiration_time = ctx.get_param_as_string('expirationTime')
|
||||||
user_tokens.update_user_token_expiration_time(user_token,
|
user_tokens.update_user_token_expiration_time(
|
||||||
expiration_time)
|
user_token, expiration_time)
|
||||||
user_tokens.update_user_token_edit_time(user_token)
|
user_tokens.update_user_token_edit_time(user_token)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, user_token)
|
return _serialize(ctx, user_token)
|
||||||
|
|
|
@ -28,7 +28,8 @@ def get_password_hash(salt: str, password: str) -> Tuple[str, int]:
|
||||||
).decode('utf8'), 3
|
).decode('utf8'), 3
|
||||||
|
|
||||||
|
|
||||||
def get_sha256_legacy_password_hash(salt: str, password: str) -> Tuple[str, int]:
|
def get_sha256_legacy_password_hash(
|
||||||
|
salt: str, password: str) -> Tuple[str, int]:
|
||||||
''' Retrieve old-style sha256 password hash. '''
|
''' Retrieve old-style sha256 password hash. '''
|
||||||
digest = hashlib.sha256()
|
digest = hashlib.sha256()
|
||||||
digest.update(config.config['secret'].encode('utf8'))
|
digest.update(config.config['secret'].encode('utf8'))
|
||||||
|
@ -81,8 +82,10 @@ def is_valid_password(user: model.User, password: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def is_valid_token(user_token: model.UserToken) -> bool:
|
def is_valid_token(user_token: model.UserToken) -> bool:
|
||||||
''' Token must be enabled and if it has an expiration,
|
'''
|
||||||
it must be greater than now. '''
|
Token must be enabled and if it has an expiration, it must be
|
||||||
|
greater than now.
|
||||||
|
'''
|
||||||
assert user_token
|
assert user_token
|
||||||
if not user_token.enabled:
|
if not user_token.enabled:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -74,9 +74,11 @@ def serialize_user_token(
|
||||||
|
|
||||||
def get_by_user_and_token(
|
def get_by_user_and_token(
|
||||||
user: model.User, token: str) -> model.UserToken:
|
user: model.User, token: str) -> model.UserToken:
|
||||||
return (db.session.query(model.UserToken)
|
return (
|
||||||
.filter(model.UserToken.user_id == user.user_id,
|
db.session
|
||||||
model.UserToken.token == token)
|
.query(model.UserToken)
|
||||||
|
.filter(model.UserToken.user_id == user.user_id)
|
||||||
|
.filter(model.UserToken.token == token)
|
||||||
.one_or_none())
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,14 +113,10 @@ def update_user_token_edit_time(user_token: model.UserToken) -> None:
|
||||||
|
|
||||||
|
|
||||||
def update_user_token_expiration_time(
|
def update_user_token_expiration_time(
|
||||||
user_token: model.UserToken, expiration_time: str) -> None:
|
user_token: model.UserToken, expiration_time_str: str) -> None:
|
||||||
assert user_token
|
assert user_token
|
||||||
if expiration_time is not None:
|
|
||||||
try:
|
try:
|
||||||
expiration_time = dateutil_parser.parse(expiration_time)
|
expiration_time = dateutil_parser.parse(expiration_time_str)
|
||||||
except ValueError:
|
|
||||||
raise InvalidExpirationError(
|
|
||||||
'Expiration is in invalid format {}'.format(expiration_time))
|
|
||||||
if expiration_time.tzinfo is None:
|
if expiration_time.tzinfo is None:
|
||||||
raise InvalidExpirationError(
|
raise InvalidExpirationError(
|
||||||
'Expiration cannot be missing timezone')
|
'Expiration cannot be missing timezone')
|
||||||
|
@ -128,6 +126,10 @@ def update_user_token_expiration_time(
|
||||||
raise InvalidExpirationError(
|
raise InvalidExpirationError(
|
||||||
'Expiration cannot happen in the past')
|
'Expiration cannot happen in the past')
|
||||||
user_token.expiration_time = expiration_time
|
user_token.expiration_time = expiration_time
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidExpirationError(
|
||||||
|
'Expiration is in an invalid format {}'.format(
|
||||||
|
expiration_time_str))
|
||||||
|
|
||||||
|
|
||||||
def update_user_token_note(user_token: model.UserToken, note: str) -> None:
|
def update_user_token_note(user_token: model.UserToken, note: str) -> None:
|
||||||
|
|
|
@ -170,7 +170,8 @@ def get_user_count() -> int:
|
||||||
|
|
||||||
|
|
||||||
def try_get_user_by_name(name: str) -> Optional[model.User]:
|
def try_get_user_by_name(name: str) -> Optional[model.User]:
|
||||||
return (db.session
|
return (
|
||||||
|
db.session
|
||||||
.query(model.User)
|
.query(model.User)
|
||||||
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
|
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
|
||||||
.one_or_none())
|
.one_or_none())
|
||||||
|
|
|
@ -5,7 +5,7 @@ from szurubooru.func import auth, users, user_tokens
|
||||||
from szurubooru.rest.errors import HttpBadRequest
|
from szurubooru.rest.errors import HttpBadRequest
|
||||||
|
|
||||||
|
|
||||||
def _authenticate(username: str, password: str) -> model.User:
|
def _authenticate_basic_auth(username: str, password: str) -> model.User:
|
||||||
''' Try to authenticate user. Throw AuthError for invalid users. '''
|
''' Try to authenticate user. Throw AuthError for invalid users. '''
|
||||||
user = users.get_user_by_name(username)
|
user = users.get_user_by_name(username)
|
||||||
if not auth.is_valid_password(user, password):
|
if not auth.is_valid_password(user, password):
|
||||||
|
@ -31,7 +31,7 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
|
||||||
if auth_type.lower() == 'basic':
|
if auth_type.lower() == 'basic':
|
||||||
username, password = base64.decodebytes(
|
username, password = base64.decodebytes(
|
||||||
credentials.encode('ascii')).decode('utf8').split(':', 1)
|
credentials.encode('ascii')).decode('utf8').split(':', 1)
|
||||||
return _authenticate(username, password)
|
return _authenticate_basic_auth(username, password)
|
||||||
elif auth_type.lower() == 'token':
|
elif auth_type.lower() == 'token':
|
||||||
username, token = base64.decodebytes(
|
username, token = base64.decodebytes(
|
||||||
credentials.encode('ascii')).decode('utf8').split(':', 1)
|
credentials.encode('ascii')).decode('utf8').split(':', 1)
|
||||||
|
|
|
@ -27,7 +27,9 @@ def upgrade():
|
||||||
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
|
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
|
||||||
sa.Column('version', sa.Integer(), nullable=False),
|
sa.Column('version', sa.Integer(), nullable=False),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
['user_id'], ['user.id'], ondelete='CASCADE'),
|
['user_id'],
|
||||||
|
['user.id'],
|
||||||
|
ondelete='CASCADE'),
|
||||||
sa.PrimaryKeyConstraint('id'))
|
sa.PrimaryKeyConstraint('id'))
|
||||||
op.create_index(
|
op.create_index(
|
||||||
op.f('ix_user_token_user_id'), 'user_token', ['user_id'], unique=False)
|
op.f('ix_user_token_user_id'), 'user_token', ['user_id'], unique=False)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from szurubooru import api
|
from szurubooru import api
|
||||||
from szurubooru.func import user_tokens, users
|
from szurubooru.func import user_tokens, users
|
||||||
|
|
||||||
|
@ -22,7 +20,10 @@ def test_creating_user_token(
|
||||||
user_tokens.serialize_user_token.return_value = 'serialized user token'
|
user_tokens.serialize_user_token.return_value = 'serialized user token'
|
||||||
user_tokens.create_user_token.return_value = user_token
|
user_tokens.create_user_token.return_value = user_token
|
||||||
result = api.user_token_api.create_user_token(
|
result = api.user_token_api.create_user_token(
|
||||||
context_factory(user=user_token.user),
|
context_factory(
|
||||||
{'user_name': user_token.user.name})
|
user=user_token.user),
|
||||||
|
{
|
||||||
|
'user_name': user_token.user.name
|
||||||
|
})
|
||||||
assert result == 'serialized user token'
|
assert result == 'serialized user token'
|
||||||
user_tokens.create_user_token.assert_called_once_with(user_token.user)
|
user_tokens.create_user_token.assert_called_once_with(user_token.user)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from szurubooru import api, db
|
from szurubooru import api, db
|
||||||
from szurubooru.func import user_tokens, users
|
from szurubooru.func import user_tokens, users
|
||||||
|
|
||||||
|
@ -22,9 +20,12 @@ def test_deleting_user_token(
|
||||||
users.get_user_by_name.return_value = user_token.user
|
users.get_user_by_name.return_value = user_token.user
|
||||||
user_tokens.get_by_user_and_token.return_value = user_token
|
user_tokens.get_by_user_and_token.return_value = user_token
|
||||||
result = api.user_token_api.delete_user_token(
|
result = api.user_token_api.delete_user_token(
|
||||||
context_factory(user=user_token.user),
|
context_factory(
|
||||||
{'user_name': user_token.user.name,
|
user=user_token.user),
|
||||||
'user_token': user_token.token})
|
{
|
||||||
|
'user_name': user_token.user.name,
|
||||||
|
'user_token': user_token.token
|
||||||
|
})
|
||||||
assert result == {}
|
assert result == {}
|
||||||
user_tokens.get_by_user_and_token.assert_called_once_with(
|
user_tokens.get_by_user_and_token.assert_called_once_with(
|
||||||
user_token.user, user_token.token)
|
user_token.user, user_token.token)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from szurubooru import api
|
from szurubooru import api
|
||||||
from szurubooru.func import user_tokens, users
|
from szurubooru.func import user_tokens, users
|
||||||
|
|
||||||
|
@ -25,7 +23,10 @@ def test_retrieving_user_tokens(
|
||||||
user_tokens.get_user_tokens.return_value = [user_token1, user_token2,
|
user_tokens.get_user_tokens.return_value = [user_token1, user_token2,
|
||||||
user_token3]
|
user_token3]
|
||||||
result = api.user_token_api.get_user_tokens(
|
result = api.user_token_api.get_user_tokens(
|
||||||
context_factory(user=user_token1.user),
|
context_factory(
|
||||||
{'user_name': user_token1.user.name})
|
user=user_token1.user),
|
||||||
|
{
|
||||||
|
'user_name': user_token1.user.name
|
||||||
|
})
|
||||||
assert result == {'results': ['serialized user token'] * 3}
|
assert result == {'results': ['serialized user token'] * 3}
|
||||||
user_tokens.get_user_tokens.assert_called_once_with(user_token1.user)
|
user_tokens.get_user_tokens.assert_called_once_with(user_token1.user)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from szurubooru import api, db
|
from szurubooru import api, db
|
||||||
from szurubooru.func import user_tokens, users
|
from szurubooru.func import user_tokens, users
|
||||||
|
|
||||||
|
@ -25,12 +23,16 @@ def test_edit_user_token(user_token_factory, context_factory, fake_datetime):
|
||||||
user_tokens.serialize_user_token.return_value = 'serialized user token'
|
user_tokens.serialize_user_token.return_value = 'serialized user token'
|
||||||
user_tokens.get_by_user_and_token.return_value = user_token
|
user_tokens.get_by_user_and_token.return_value = user_token
|
||||||
result = api.user_token_api.update_user_token(
|
result = api.user_token_api.update_user_token(
|
||||||
context_factory(params={'version': user_token.version,
|
context_factory(
|
||||||
|
params={
|
||||||
|
'version': user_token.version,
|
||||||
'enabled': False,
|
'enabled': False,
|
||||||
},
|
},
|
||||||
user=user_token.user),
|
user=user_token.user),
|
||||||
{'user_name': user_token.user.name,
|
{
|
||||||
'user_token': user_token.token})
|
'user_name': user_token.user.name,
|
||||||
|
'user_token': user_token.token
|
||||||
|
})
|
||||||
assert result == 'serialized user token'
|
assert result == 'serialized user token'
|
||||||
user_tokens.get_by_user_and_token.assert_called_once_with(
|
user_tokens.get_by_user_and_token.assert_called_once_with(
|
||||||
user_token.user, user_token.token)
|
user_token.user, user_token.token)
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
from szurubooru.func import auth, users, user_tokens
|
from szurubooru.func import auth, users, user_tokens
|
||||||
from szurubooru.middleware import authenticator
|
from szurubooru.middleware import authenticator
|
||||||
from szurubooru.rest import errors
|
from szurubooru.rest import errors
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_request_no_header(context_factory):
|
def test_process_request_no_header(context_factory):
|
||||||
|
|
Reference in a new issue